1//===-- NVPTXISelLowering.cpp - NVPTX DAG Lowering Implementation ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the interfaces that NVPTX uses to lower LLVM code into a
10// selection DAG.
11//
12//===----------------------------------------------------------------------===//
13
14#include "NVPTXISelLowering.h"
15#include "MCTargetDesc/NVPTXBaseInfo.h"
16#include "NVPTX.h"
17#include "NVPTXISelDAGToDAG.h"
18#include "NVPTXSelectionDAGInfo.h"
19#include "NVPTXSubtarget.h"
20#include "NVPTXTargetMachine.h"
21#include "NVPTXTargetObjectFile.h"
22#include "NVPTXUtilities.h"
23#include "NVVMProperties.h"
24#include "llvm/ADT/APFloat.h"
25#include "llvm/ADT/APInt.h"
26#include "llvm/ADT/STLExtras.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/ADT/StringRef.h"
29#include "llvm/CodeGen/Analysis.h"
30#include "llvm/CodeGen/ISDOpcodes.h"
31#include "llvm/CodeGen/MachineFunction.h"
32#include "llvm/CodeGen/MachineJumpTableInfo.h"
33#include "llvm/CodeGen/MachineMemOperand.h"
34#include "llvm/CodeGen/SDPatternMatch.h"
35#include "llvm/CodeGen/SelectionDAG.h"
36#include "llvm/CodeGen/SelectionDAGNodes.h"
37#include "llvm/CodeGen/TargetCallingConv.h"
38#include "llvm/CodeGen/TargetLowering.h"
39#include "llvm/CodeGen/ValueTypes.h"
40#include "llvm/CodeGenTypes/MachineValueType.h"
41#include "llvm/IR/Argument.h"
42#include "llvm/IR/Attributes.h"
43#include "llvm/IR/Constants.h"
44#include "llvm/IR/DataLayout.h"
45#include "llvm/IR/DerivedTypes.h"
46#include "llvm/IR/DiagnosticInfo.h"
47#include "llvm/IR/FPEnv.h"
48#include "llvm/IR/Function.h"
49#include "llvm/IR/GlobalValue.h"
50#include "llvm/IR/IRBuilder.h"
51#include "llvm/IR/Instruction.h"
52#include "llvm/IR/Instructions.h"
53#include "llvm/IR/IntrinsicsNVPTX.h"
54#include "llvm/IR/Module.h"
55#include "llvm/IR/Type.h"
56#include "llvm/IR/Value.h"
57#include "llvm/Support/Alignment.h"
58#include "llvm/Support/AtomicOrdering.h"
59#include "llvm/Support/Casting.h"
60#include "llvm/Support/CodeGen.h"
61#include "llvm/Support/CommandLine.h"
62#include "llvm/Support/ErrorHandling.h"
63#include "llvm/Support/KnownBits.h"
64#include "llvm/Support/NVPTXAddrSpace.h"
65#include "llvm/Support/raw_ostream.h"
66#include "llvm/Target/TargetMachine.h"
67#include "llvm/Target/TargetOptions.h"
68#include <algorithm>
69#include <cassert>
70#include <cmath>
71#include <cstdint>
72#include <iterator>
73#include <optional>
74#include <string>
75#include <tuple>
76#include <utility>
77#include <vector>
78
79#define DEBUG_TYPE "nvptx-lower"
80
81using namespace llvm;
82
83static cl::opt<bool> sched4reg(
84 "nvptx-sched4reg",
85 cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(Val: false));
86
87static cl::opt<unsigned> FMAContractLevelOpt(
88 "nvptx-fma-level", cl::Hidden,
89 cl::desc("NVPTX Specific: FMA contraction (0: don't do it"
90 " 1: do it 2: do it aggressively"),
91 cl::init(Val: 2));
92
93static cl::opt<NVPTX::DivPrecisionLevel> UsePrecDivF32(
94 "nvptx-prec-divf32", cl::Hidden,
95 cl::desc(
96 "NVPTX Specific: Override the precision of the lowering for f32 fdiv"),
97 cl::values(
98 clEnumValN(NVPTX::DivPrecisionLevel::Approx, "0", "Use div.approx"),
99 clEnumValN(NVPTX::DivPrecisionLevel::Full, "1", "Use div.full"),
100 clEnumValN(NVPTX::DivPrecisionLevel::IEEE754, "2",
101 "Use IEEE Compliant F32 div.rnd if available (default)"),
102 clEnumValN(NVPTX::DivPrecisionLevel::IEEE754_NoFTZ, "3",
103 "Use IEEE Compliant F32 div.rnd if available, no FTZ")),
104 cl::init(Val: NVPTX::DivPrecisionLevel::IEEE754));
105
106static cl::opt<bool> UsePrecSqrtF32(
107 "nvptx-prec-sqrtf32", cl::Hidden,
108 cl::desc("NVPTX Specific: 0 use sqrt.approx, 1 use sqrt.rn."),
109 cl::init(Val: true));
110
111/// Whereas CUDA's implementation (see libdevice) uses ex2.approx for exp2(), it
112/// does NOT use lg2.approx for log2, so this is disabled by default.
113static cl::opt<bool> UseApproxLog2F32(
114 "nvptx-approx-log2f32",
115 cl::desc("NVPTX Specific: whether to use lg2.approx for log2"),
116 cl::init(Val: false));
117
118NVPTX::DivPrecisionLevel
119NVPTXTargetLowering::getDivF32Level(const MachineFunction &MF,
120 const SDNode &N) const {
121 // If nvptx-prec-div32=N is used on the command-line, always honor it
122 if (UsePrecDivF32.getNumOccurrences() > 0)
123 return UsePrecDivF32;
124
125 const SDNodeFlags Flags = N.getFlags();
126 if (Flags.hasApproximateFuncs())
127 return NVPTX::DivPrecisionLevel::Approx;
128
129 return NVPTX::DivPrecisionLevel::IEEE754;
130}
131
132bool NVPTXTargetLowering::usePrecSqrtF32(const SDNode *N) const {
133 // If nvptx-prec-sqrtf32 is used on the command-line, always honor it
134 if (UsePrecSqrtF32.getNumOccurrences() > 0)
135 return UsePrecSqrtF32;
136
137 if (N) {
138 const SDNodeFlags Flags = N->getFlags();
139 if (Flags.hasApproximateFuncs())
140 return false;
141 }
142
143 return true;
144}
145
146bool NVPTXTargetLowering::useF32FTZ(const MachineFunction &MF) const {
147 return MF.getDenormalMode(FPType: APFloat::IEEEsingle()).Output ==
148 DenormalMode::PreserveSign;
149}
150
151static bool IsPTXVectorType(MVT VT) {
152 switch (VT.SimpleTy) {
153 default:
154 return false;
155 case MVT::v2i1:
156 case MVT::v4i1:
157 case MVT::v2i8:
158 case MVT::v4i8:
159 case MVT::v8i8: // <2 x i8x4>
160 case MVT::v16i8: // <4 x i8x4>
161 case MVT::v2i16:
162 case MVT::v4i16:
163 case MVT::v8i16: // <4 x i16x2>
164 case MVT::v2i32:
165 case MVT::v4i32:
166 case MVT::v2i64:
167 case MVT::v2f16:
168 case MVT::v4f16:
169 case MVT::v8f16: // <4 x f16x2>
170 case MVT::v2bf16:
171 case MVT::v4bf16:
172 case MVT::v8bf16: // <4 x bf16x2>
173 case MVT::v2f32:
174 case MVT::v4f32:
175 case MVT::v2f64:
176 case MVT::v4i64:
177 case MVT::v4f64:
178 case MVT::v8i32:
179 case MVT::v8f32:
180 case MVT::v16f16: // <8 x f16x2>
181 case MVT::v16bf16: // <8 x bf16x2>
182 case MVT::v16i16: // <8 x i16x2>
183 case MVT::v32i8: // <8 x i8x4>
184 return true;
185 }
186}
187
188// When legalizing vector loads/stores, this function is called, which does two
189// things:
190// 1. Determines Whether the vector is something we want to custom lower,
191// std::nullopt is returned if we do not want to custom lower it.
192// 2. If we do want to handle it, returns two parameters:
193// - unsigned int NumElts - The number of elements in the final vector
194// - EVT EltVT - The type of the elements in the final vector
195static std::optional<std::pair<unsigned int, MVT>>
196getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI,
197 unsigned AddressSpace) {
198 const bool CanLowerTo256Bit = STI.has256BitVectorLoadStore(AS: AddressSpace);
199
200 if (CanLowerTo256Bit && VectorEVT.isScalarInteger() &&
201 VectorEVT.getSizeInBits() == 256)
202 return {{4, MVT::i64}};
203
204 if (!VectorEVT.isSimple())
205 return std::nullopt;
206 const MVT VectorVT = VectorEVT.getSimpleVT();
207
208 if (!VectorVT.isVector()) {
209 if (VectorVT == MVT::i128 || VectorVT == MVT::f128)
210 return {{2, MVT::i64}};
211 return std::nullopt;
212 }
213
214 const MVT EltVT = VectorVT.getVectorElementType();
215 const unsigned NumElts = VectorVT.getVectorNumElements();
216
217 // The size of the PTX virtual register that holds a packed type.
218 unsigned PackRegSize;
219
220 // We only handle "native" vector sizes for now, e.g. <4 x double> is not
221 // legal. We can (and should) split that into 2 stores of <2 x double> here
222 // but I'm leaving that as a TODO for now.
223 switch (VectorVT.SimpleTy) {
224 default:
225 return std::nullopt;
226
227 case MVT::v4i64:
228 case MVT::v4f64:
229 // This is a "native" vector type iff the address space is global and the
230 // target supports 256-bit loads/stores
231 if (!CanLowerTo256Bit)
232 return std::nullopt;
233 [[fallthrough]];
234 case MVT::v2i8:
235 case MVT::v2i64:
236 case MVT::v2f64:
237 // This is a "native" vector type
238 return std::pair(NumElts, EltVT);
239
240 case MVT::v16f16: // <8 x f16x2>
241 case MVT::v16bf16: // <8 x bf16x2>
242 case MVT::v16i16: // <8 x i16x2>
243 case MVT::v32i8: // <8 x i8x4>
244 // This can be upsized into a "native" vector type iff the address space is
245 // global and the target supports 256-bit loads/stores.
246 if (!CanLowerTo256Bit)
247 return std::nullopt;
248 [[fallthrough]];
249 case MVT::v2i16: // <1 x i16x2>
250 case MVT::v2f16: // <1 x f16x2>
251 case MVT::v2bf16: // <1 x bf16x2>
252 case MVT::v4i8: // <1 x i8x4>
253 case MVT::v4i16: // <2 x i16x2>
254 case MVT::v4f16: // <2 x f16x2>
255 case MVT::v4bf16: // <2 x bf16x2>
256 case MVT::v8i8: // <2 x i8x4>
257 case MVT::v8f16: // <4 x f16x2>
258 case MVT::v8bf16: // <4 x bf16x2>
259 case MVT::v8i16: // <4 x i16x2>
260 case MVT::v16i8: // <4 x i8x4>
261 PackRegSize = 32;
262 break;
263
264 case MVT::v8f32: // <4 x f32x2>
265 case MVT::v8i32: // <4 x i32x2>
266 // This is a "native" vector type iff the address space is global and the
267 // target supports 256-bit loads/stores
268 if (!CanLowerTo256Bit)
269 return std::nullopt;
270 [[fallthrough]];
271 case MVT::v2f32: // <1 x f32x2>
272 case MVT::v4f32: // <2 x f32x2>
273 case MVT::v2i32: // <1 x i32x2>
274 case MVT::v4i32: // <2 x i32x2>
275 if (!STI.hasF32x2Instructions())
276 return std::pair(NumElts, EltVT);
277 PackRegSize = 64;
278 break;
279 }
280
281 // If we reach here, then we can pack 2 or more elements into a single 32-bit
282 // or 64-bit PTX register and treat the vector as a new vector containing
283 // packed elements.
284
285 // Number of elements to pack in one word.
286 const unsigned NPerReg = PackRegSize / EltVT.getSizeInBits();
287
288 return std::pair(NumElts / NPerReg, MVT::getVectorVT(VT: EltVT, NumElements: NPerReg));
289}
290
291/// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
292/// legal-ish MVTs that compose it. Unlike ComputeValueVTs, this will legalize
293/// the types as required by the calling convention (with special handling for
294/// i8s).
295/// NOTE: This is a band-aid for code that expects ComputeValueVTs to return the
296/// same number of types as the Ins/Outs arrays in LowerFormalArguments,
297/// LowerCall, and LowerReturn.
298static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
299 LLVMContext &Ctx, CallingConv::ID CallConv,
300 Type *Ty, SmallVectorImpl<EVT> &ValueVTs,
301 SmallVectorImpl<uint64_t> &Offsets,
302 uint64_t StartingOffset = 0) {
303 SmallVector<EVT, 16> TempVTs;
304 SmallVector<uint64_t, 16> TempOffsets;
305 ComputeValueVTs(TLI, DL, Ty, ValueVTs&: TempVTs, /*MemVTs=*/nullptr, FixedOffsets: &TempOffsets,
306 StartingOffset);
307
308 for (const auto [VT, Off] : zip(t&: TempVTs, u&: TempOffsets)) {
309 MVT RegisterVT = TLI.getRegisterTypeForCallingConv(Context&: Ctx, CC: CallConv, VT);
310 unsigned NumRegs = TLI.getNumRegistersForCallingConv(Context&: Ctx, CC: CallConv, VT);
311
312 // Since we actually can load/store b8, we need to ensure that we'll use
313 // the original sized type for any i8s or i8 vectors.
314 if (VT.getScalarType() == MVT::i8) {
315 if (RegisterVT == MVT::i16)
316 RegisterVT = MVT::i8;
317 else if (RegisterVT == MVT::v2i16)
318 RegisterVT = MVT::v2i8;
319 else
320 assert(RegisterVT == MVT::v4i8 &&
321 "Expected v4i8, v2i16, or i16 for i8 RegisterVT");
322 }
323
324 // TODO: This is horribly incorrect for cases where the vector elements are
325 // not a multiple of bytes (ex i1) and legal or i8. However, this problem
326 // has existed for as long as NVPTX has and no one has complained, so we'll
327 // leave it for now.
328 for (unsigned I : seq(Size: NumRegs)) {
329 ValueVTs.push_back(Elt: RegisterVT);
330 Offsets.push_back(Elt: Off + I * RegisterVT.getStoreSize());
331 }
332 }
333}
334
335// We return an EVT that can hold N VTs
336// If the VT is a vector, the resulting EVT is a flat vector with the same
337// element type as VT's element type.
338static EVT getVectorizedVT(EVT VT, unsigned N, LLVMContext &C) {
339 if (N == 1)
340 return VT;
341
342 return VT.isVector() ? EVT::getVectorVT(Context&: C, VT: VT.getScalarType(),
343 NumElements: VT.getVectorNumElements() * N)
344 : EVT::getVectorVT(Context&: C, VT, NumElements: N);
345}
346
347static SDValue getExtractVectorizedValue(SDValue V, unsigned I, EVT VT,
348 const SDLoc &dl, SelectionDAG &DAG) {
349 if (V.getValueType() == VT) {
350 assert(I == 0 && "Index must be 0 for scalar value");
351 return V;
352 }
353
354 if (!VT.isVector())
355 return DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: dl, VT, N1: V,
356 N2: DAG.getVectorIdxConstant(Val: I, DL: dl));
357
358 return DAG.getNode(
359 Opcode: ISD::EXTRACT_SUBVECTOR, DL: dl, VT, N1: V,
360 N2: DAG.getVectorIdxConstant(Val: I * VT.getVectorNumElements(), DL: dl));
361}
362
363template <typename T>
364static inline SDValue getBuildVectorizedValue(unsigned N, const SDLoc &dl,
365 SelectionDAG &DAG, T GetElement) {
366 if (N == 1)
367 return GetElement(0);
368
369 SmallVector<SDValue, 8> Values;
370 for (const unsigned I : llvm::seq(Size: N)) {
371 SDValue Val = GetElement(I);
372 if (Val.getValueType().isVector())
373 DAG.ExtractVectorElements(Op: Val, Args&: Values);
374 else
375 Values.push_back(Elt: Val);
376 }
377
378 EVT VT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: Values[0].getValueType(),
379 NumElements: Values.size());
380 return DAG.getBuildVector(VT, DL: dl, Ops: Values);
381}
382
383/// PromoteScalarIntegerPTX
384/// Used to make sure the arguments/returns are suitable for passing
385/// and promote them to a larger size if they're not.
386///
387/// The promoted type is placed in \p PromoteVT if the function returns true.
388static EVT promoteScalarIntegerPTX(const EVT VT) {
389 if (VT.isScalarInteger()) {
390 switch (PowerOf2Ceil(A: VT.getFixedSizeInBits())) {
391 default:
392 llvm_unreachable(
393 "Promotion is not suitable for scalars of size larger than 64-bits");
394 case 1:
395 return MVT::i1;
396 case 2:
397 case 4:
398 case 8:
399 return MVT::i8;
400 case 16:
401 return MVT::i16;
402 case 32:
403 return MVT::i32;
404 case 64:
405 return MVT::i64;
406 }
407 }
408 return VT;
409}
410
411// Check whether we can merge loads/stores of some of the pieces of a
412// flattened function parameter or return value into a single vector
413// load/store.
414//
415// The flattened parameter is represented as a list of EVTs and
416// offsets, and the whole structure is aligned to ParamAlignment. This
417// function determines whether we can load/store pieces of the
418// parameter starting at index Idx using a single vectorized op of
419// size AccessSize. If so, it returns the number of param pieces
420// covered by the vector op. Otherwise, it returns 1.
421template <typename T>
422static unsigned canMergeParamLoadStoresStartingAt(
423 unsigned Idx, uint32_t AccessSize, const SmallVectorImpl<EVT> &ValueVTs,
424 const SmallVectorImpl<T> &Offsets, Align ParamAlignment) {
425
426 // Can't vectorize if param alignment is not sufficient.
427 if (ParamAlignment < AccessSize)
428 return 1;
429 // Can't vectorize if offset is not aligned.
430 if (Offsets[Idx] & (AccessSize - 1))
431 return 1;
432
433 EVT EltVT = ValueVTs[Idx];
434 unsigned EltSize = EltVT.getStoreSize();
435
436 // Element is too large to vectorize.
437 if (EltSize >= AccessSize)
438 return 1;
439
440 unsigned NumElts = AccessSize / EltSize;
441 // Can't vectorize if AccessBytes if not a multiple of EltSize.
442 if (AccessSize != EltSize * NumElts)
443 return 1;
444
445 // We don't have enough elements to vectorize.
446 if (Idx + NumElts > ValueVTs.size())
447 return 1;
448
449 // PTX ISA can only deal with 2- and 4-element vector ops.
450 if (NumElts != 4 && NumElts != 2)
451 return 1;
452
453 for (unsigned j = Idx + 1; j < Idx + NumElts; ++j) {
454 // Types do not match.
455 if (ValueVTs[j] != EltVT)
456 return 1;
457
458 // Elements are not contiguous.
459 if (Offsets[j] - Offsets[j - 1] != EltSize)
460 return 1;
461 }
462 // OK. We can vectorize ValueVTs[i..i+NumElts)
463 return NumElts;
464}
465
466// Computes whether and how we can vectorize the loads/stores of a
467// flattened function parameter or return value.
468//
469// The flattened parameter is represented as the list of ValueVTs and
470// Offsets, and is aligned to ParamAlignment bytes. We return a vector
471// of the same size as ValueVTs indicating how each piece should be
472// loaded/stored (i.e. as a scalar, or as part of a vector
473// load/store).
474template <typename T>
475static SmallVector<unsigned, 16>
476VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
477 const SmallVectorImpl<T> &Offsets, Align ParamAlignment,
478 bool IsVAArg = false) {
479 // Set vector size to match ValueVTs and mark all elements as
480 // scalars by default.
481
482 if (IsVAArg)
483 return SmallVector<unsigned>(ValueVTs.size(), 1);
484
485 SmallVector<unsigned, 16> VectorInfo;
486
487 const auto GetNumElts = [&](unsigned I) -> unsigned {
488 for (const unsigned AccessSize : {16, 8, 4, 2}) {
489 const unsigned NumElts = canMergeParamLoadStoresStartingAt(
490 I, AccessSize, ValueVTs, Offsets, ParamAlignment);
491 assert((NumElts == 1 || NumElts == 2 || NumElts == 4) &&
492 "Unexpected vectorization size");
493 if (NumElts != 1)
494 return NumElts;
495 }
496 return 1;
497 };
498
499 // Check what we can vectorize using 128/64/32-bit accesses.
500 for (unsigned I = 0, E = ValueVTs.size(); I != E;) {
501 const unsigned NumElts = GetNumElts(I);
502 VectorInfo.push_back(Elt: NumElts);
503 I += NumElts;
504 }
505 assert(std::accumulate(VectorInfo.begin(), VectorInfo.end(), 0u) ==
506 ValueVTs.size());
507 return VectorInfo;
508}
509
510// NVPTXTargetLowering Constructor.
511NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
512 const NVPTXSubtarget &STI)
513 : TargetLowering(TM, STI), nvTM(&TM), STI(STI), GlobalUniqueCallSite(0) {
514 // always lower memset, memcpy, and memmove intrinsics to load/store
515 // instructions, rather
516 // then generating calls to memset, mempcy or memmove.
517 MaxStoresPerMemset = MaxStoresPerMemsetOptSize = (unsigned)0xFFFFFFFF;
518 MaxStoresPerMemcpy = MaxStoresPerMemcpyOptSize = (unsigned) 0xFFFFFFFF;
519 MaxStoresPerMemmove = MaxStoresPerMemmoveOptSize = (unsigned) 0xFFFFFFFF;
520
521 setBooleanContents(ZeroOrNegativeOneBooleanContent);
522 setBooleanVectorContents(ZeroOrNegativeOneBooleanContent);
523
524 // Jump is Expensive. Don't create extra control flow for 'and', 'or'
525 // condition branches.
526 setJumpIsExpensive(true);
527
528 // Wide divides are _very_ slow. Try to reduce the width of the divide if
529 // possible.
530 addBypassSlowDiv(SlowBitWidth: 64, FastBitWidth: 32);
531
532 // By default, use the Source scheduling
533 if (sched4reg)
534 setSchedulingPreference(Sched::RegPressure);
535 else
536 setSchedulingPreference(Sched::Source);
537
538 auto setFP16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
539 LegalizeAction NoF16Action) {
540 bool IsOpSupported = STI.allowFP16Math();
541 switch (Op) {
542 // Several FP16 instructions are available on sm_80 only.
543 case ISD::FMINNUM:
544 case ISD::FMAXNUM:
545 case ISD::FMAXNUM_IEEE:
546 case ISD::FMINNUM_IEEE:
547 case ISD::FMAXIMUM:
548 case ISD::FMINIMUM:
549 case ISD::FMAXIMUMNUM:
550 case ISD::FMINIMUMNUM:
551 IsOpSupported &= STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
552 break;
553 case ISD::FEXP2:
554 IsOpSupported &= STI.getSmVersion() >= 75 && STI.getPTXVersion() >= 70;
555 break;
556 }
557 setOperationAction(Op, VT, Action: IsOpSupported ? Action : NoF16Action);
558 };
559
560 auto setBF16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
561 LegalizeAction NoBF16Action) {
562 bool IsOpSupported = STI.hasNativeBF16Support(Opcode: Op);
563 setOperationAction(
564 Op, VT, Action: IsOpSupported ? Action : NoBF16Action);
565 };
566
567 auto setI16x2OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
568 LegalizeAction NoI16x2Action) {
569 bool IsOpSupported = false;
570 // instructions are available on sm_90 only
571 switch (Op) {
572 case ISD::ADD:
573 case ISD::SMAX:
574 case ISD::SMIN:
575 case ISD::UMIN:
576 case ISD::UMAX:
577 IsOpSupported = STI.getSmVersion() >= 90 && STI.getPTXVersion() >= 80;
578 break;
579 }
580 setOperationAction(Op, VT, Action: IsOpSupported ? Action : NoI16x2Action);
581 };
582
583 addRegisterClass(VT: MVT::i1, RC: &NVPTX::B1RegClass);
584 addRegisterClass(VT: MVT::i16, RC: &NVPTX::B16RegClass);
585 addRegisterClass(VT: MVT::v2i16, RC: &NVPTX::B32RegClass);
586 addRegisterClass(VT: MVT::v4i8, RC: &NVPTX::B32RegClass);
587 addRegisterClass(VT: MVT::i32, RC: &NVPTX::B32RegClass);
588 addRegisterClass(VT: MVT::i64, RC: &NVPTX::B64RegClass);
589 addRegisterClass(VT: MVT::f32, RC: &NVPTX::B32RegClass);
590 addRegisterClass(VT: MVT::f64, RC: &NVPTX::B64RegClass);
591 addRegisterClass(VT: MVT::f16, RC: &NVPTX::B16RegClass);
592 addRegisterClass(VT: MVT::v2f16, RC: &NVPTX::B32RegClass);
593 addRegisterClass(VT: MVT::bf16, RC: &NVPTX::B16RegClass);
594 addRegisterClass(VT: MVT::v2bf16, RC: &NVPTX::B32RegClass);
595
596 if (STI.hasF32x2Instructions()) {
597 addRegisterClass(VT: MVT::v2f32, RC: &NVPTX::B64RegClass);
598 addRegisterClass(VT: MVT::v2i32, RC: &NVPTX::B64RegClass);
599 }
600
601 // Conversion to/from FP16/FP16x2 is always legal.
602 setOperationAction(Op: ISD::BUILD_VECTOR, VT: MVT::v2f16, Action: Custom);
603 setOperationAction(Op: ISD::EXTRACT_VECTOR_ELT, VT: MVT::v2f16, Action: Custom);
604 setOperationAction(Op: ISD::INSERT_VECTOR_ELT, VT: MVT::v2f16, Action: Expand);
605 setOperationAction(Op: ISD::VECTOR_SHUFFLE, VT: MVT::v2f16, Action: Expand);
606
607 setOperationAction(Op: ISD::READCYCLECOUNTER, VT: MVT::i64, Action: Legal);
608 if (STI.getSmVersion() >= 30 && STI.getPTXVersion() > 31)
609 setOperationAction(Op: ISD::READSTEADYCOUNTER, VT: MVT::i64, Action: Legal);
610
611 setFP16OperationAction(ISD::SETCC, MVT::f16, Legal, Promote);
612 setFP16OperationAction(ISD::SETCC, MVT::v2f16, Legal, Expand);
613
614 // Conversion to/from BFP16/BFP16x2 is always legal.
615 setOperationAction(Op: ISD::BUILD_VECTOR, VT: MVT::v2bf16, Action: Custom);
616 setOperationAction(Op: ISD::EXTRACT_VECTOR_ELT, VT: MVT::v2bf16, Action: Custom);
617 setOperationAction(Op: ISD::INSERT_VECTOR_ELT, VT: MVT::v2bf16, Action: Expand);
618 setOperationAction(Op: ISD::VECTOR_SHUFFLE, VT: MVT::v2bf16, Action: Expand);
619
620 setBF16OperationAction(ISD::SETCC, MVT::v2bf16, Legal, Expand);
621 setBF16OperationAction(ISD::SETCC, MVT::bf16, Legal, Promote);
622 if (getOperationAction(Op: ISD::SETCC, VT: MVT::bf16) == Promote)
623 AddPromotedToType(Opc: ISD::SETCC, OrigVT: MVT::bf16, DestVT: MVT::f32);
624
625 // Conversion to/from i16/i16x2 is always legal.
626 setOperationAction(Op: ISD::BUILD_VECTOR, VT: MVT::v2i16, Action: Custom);
627 setOperationAction(Op: ISD::EXTRACT_VECTOR_ELT, VT: MVT::v2i16, Action: Custom);
628 setOperationAction(Op: ISD::INSERT_VECTOR_ELT, VT: MVT::v2i16, Action: Expand);
629 setOperationAction(Op: ISD::VECTOR_SHUFFLE, VT: MVT::v2i16, Action: Expand);
630
631 setOperationAction(Op: ISD::BUILD_VECTOR, VT: MVT::v4i8, Action: Custom);
632 setOperationAction(Op: ISD::EXTRACT_VECTOR_ELT, VT: MVT::v4i8, Action: Custom);
633 setOperationAction(Op: ISD::INSERT_VECTOR_ELT, VT: MVT::v4i8, Action: Custom);
634 setOperationAction(Op: ISD::VECTOR_SHUFFLE, VT: MVT::v4i8, Action: Custom);
635
636 // No support for these operations with v2f32/v2i32
637 setOperationAction(Ops: ISD::INSERT_VECTOR_ELT, VTs: {MVT::v2f32, MVT::v2i32}, Action: Expand);
638 setOperationAction(Ops: ISD::VECTOR_SHUFFLE, VTs: {MVT::v2f32, MVT::v2i32}, Action: Expand);
639
640 setOperationAction(Op: ISD::TRUNCATE, VT: MVT::v2i16, Action: Expand);
641 setOperationAction(Ops: {ISD::ANY_EXTEND, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND},
642 VT: MVT::v2i32, Action: Expand);
643
644 // Need custom lowering in case the index is dynamic.
645 if (STI.hasF32x2Instructions())
646 setOperationAction(Ops: ISD::EXTRACT_VECTOR_ELT, VTs: {MVT::v2f32, MVT::v2i32},
647 Action: Custom);
648
649 // Custom conversions to/from v2i8.
650 setOperationAction(Op: ISD::BITCAST, VT: MVT::v2i8, Action: Custom);
651
652 // Only logical ops can be done on v4i8/v2i32 directly, others must be done
653 // elementwise.
654 setOperationAction(
655 Ops: {ISD::ABS, ISD::ADD, ISD::ADDC, ISD::ADDE,
656 ISD::BITREVERSE, ISD::CTLZ, ISD::CTPOP, ISD::CTTZ,
657 ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::FSHL, ISD::FSHR,
658 ISD::MUL, ISD::MULHS, ISD::MULHU, ISD::PARITY,
659 ISD::ROTL, ISD::ROTR, ISD::SADDO, ISD::SADDO_CARRY,
660 ISD::SADDSAT, ISD::SDIV, ISD::SDIVREM, ISD::SELECT_CC,
661 ISD::SETCC, ISD::SHL, ISD::SINT_TO_FP, ISD::SMAX,
662 ISD::SMIN, ISD::SMULO, ISD::SMUL_LOHI, ISD::SRA,
663 ISD::SREM, ISD::SRL, ISD::SSHLSAT, ISD::SSUBO,
664 ISD::SSUBO_CARRY, ISD::SSUBSAT, ISD::SUB, ISD::SUBC,
665 ISD::SUBE, ISD::UADDO, ISD::UADDO_CARRY, ISD::UADDSAT,
666 ISD::UDIV, ISD::UDIVREM, ISD::UINT_TO_FP, ISD::UMAX,
667 ISD::UMIN, ISD::UMULO, ISD::UMUL_LOHI, ISD::UREM,
668 ISD::USHLSAT, ISD::USUBO, ISD::USUBO_CARRY, ISD::VSELECT,
669 ISD::USUBSAT},
670 VTs: {MVT::v4i8, MVT::v2i32}, Action: Expand);
671
672 // Operations not directly supported by NVPTX.
673 for (MVT VT : {MVT::bf16, MVT::f16, MVT::v2bf16, MVT::v2f16, MVT::f32,
674 MVT::v2f32, MVT::f64, MVT::i1, MVT::i8, MVT::i16, MVT::v2i16,
675 MVT::v4i8, MVT::i32, MVT::v2i32, MVT::i64}) {
676 setOperationAction(Op: ISD::SELECT_CC, VT, Action: Expand);
677 setOperationAction(Op: ISD::BR_CC, VT, Action: Expand);
678 }
679
680 // We don't want ops like FMINIMUM or UMAX to be lowered to SETCC+VSELECT.
681 setOperationAction(Ops: ISD::VSELECT, VTs: {MVT::v2f32, MVT::v2i32}, Action: Expand);
682
683 // Some SIGN_EXTEND_INREG can be done using cvt instruction.
684 // For others we will expand to a SHL/SRA pair.
685 setOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: MVT::i64, Action: Legal);
686 setOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: MVT::i32, Action: Legal);
687 setOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: MVT::i16, Action: Legal);
688 setOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: MVT::i8 , Action: Legal);
689 setOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: MVT::i1, Action: Expand);
690 setOperationAction(Ops: ISD::SIGN_EXTEND_INREG, VTs: {MVT::v2i16, MVT::v2i32}, Action: Expand);
691
692 setOperationAction(Op: ISD::SHL_PARTS, VT: MVT::i32 , Action: Custom);
693 setOperationAction(Op: ISD::SRA_PARTS, VT: MVT::i32 , Action: Custom);
694 setOperationAction(Op: ISD::SRL_PARTS, VT: MVT::i32 , Action: Custom);
695 setOperationAction(Op: ISD::SHL_PARTS, VT: MVT::i64 , Action: Custom);
696 setOperationAction(Op: ISD::SRA_PARTS, VT: MVT::i64 , Action: Custom);
697 setOperationAction(Op: ISD::SRL_PARTS, VT: MVT::i64 , Action: Custom);
698
699 setOperationAction(Op: ISD::BITREVERSE, VT: MVT::i32, Action: Legal);
700 setOperationAction(Op: ISD::BITREVERSE, VT: MVT::i64, Action: Legal);
701
702 setOperationAction(Ops: {ISD::ROTL, ISD::ROTR},
703 VTs: {MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64},
704 Action: Expand);
705
706 if (STI.hasHWROT32()) {
707 setOperationAction(Ops: {ISD::FSHL, ISD::FSHR}, VT: MVT::i32, Action: Legal);
708 setOperationAction(Ops: {ISD::ROTL, ISD::ROTR, ISD::FSHL, ISD::FSHR}, VT: MVT::i64,
709 Action: Custom);
710 }
711
712 setOperationAction(Op: ISD::BR_JT, VT: MVT::Other, Action: STI.hasBrx() ? Legal : Expand);
713 setOperationAction(Op: ISD::BRIND, VT: MVT::Other, Action: Expand);
714
715 // We want to legalize constant related memmove and memcopy
716 // intrinsics.
717 setOperationAction(Op: ISD::INTRINSIC_W_CHAIN, VT: MVT::Other, Action: Custom);
718
719 // FP extload/truncstore is not legal in PTX. We need to expand all these.
720 for (auto FloatVTs :
721 {MVT::fp_valuetypes(), MVT::fp_fixedlen_vector_valuetypes()}) {
722 for (MVT ValVT : FloatVTs) {
723 for (MVT MemVT : FloatVTs) {
724 setLoadExtAction(ExtType: ISD::EXTLOAD, ValVT, MemVT, Action: Expand);
725 setTruncStoreAction(ValVT, MemVT, Action: Expand);
726 }
727 }
728 }
729
730 // To improve CodeGen we'll legalize any-extend loads to zext loads. This is
731 // how they'll be lowered in ISel anyway, and by doing this a little earlier
732 // we allow for more DAG combine opportunities.
733 for (auto IntVTs :
734 {MVT::integer_valuetypes(), MVT::integer_fixedlen_vector_valuetypes()})
735 for (MVT ValVT : IntVTs)
736 for (MVT MemVT : IntVTs)
737 if (isTypeLegal(VT: ValVT))
738 setLoadExtAction(ExtType: ISD::EXTLOAD, ValVT, MemVT, Action: Custom);
739
740 // PTX does not support load / store predicate registers
741 setOperationAction(Ops: {ISD::LOAD, ISD::STORE}, VT: MVT::i1, Action: Custom);
742 for (MVT VT : MVT::integer_valuetypes()) {
743 setLoadExtAction(ExtTypes: {ISD::SEXTLOAD, ISD::ZEXTLOAD, ISD::EXTLOAD}, ValVT: VT, MemVT: MVT::i1,
744 Action: Promote);
745 setTruncStoreAction(ValVT: VT, MemVT: MVT::i1, Action: Expand);
746 }
747
748 // Disable generations of extload/truncstore for v2i32/v2i16/v2i8. The generic
749 // expansion for these nodes when they are unaligned is incorrect if the
750 // type is a vector.
751 //
752 // TODO: Fix the generic expansion for these nodes found in
753 // TargetLowering::expandUnalignedLoad/Store.
754 setLoadExtAction(ExtTypes: {ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, ValVT: MVT::v2i16,
755 MemVT: MVT::v2i8, Action: Expand);
756 setLoadExtAction(ExtTypes: {ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, ValVT: MVT::v2i32,
757 MemVTs: {MVT::v2i8, MVT::v2i16}, Action: Expand);
758 setTruncStoreAction(ValVT: MVT::v2i16, MemVT: MVT::v2i8, Action: Expand);
759 setTruncStoreAction(ValVT: MVT::v2i32, MemVT: MVT::v2i16, Action: Expand);
760 setTruncStoreAction(ValVT: MVT::v2i32, MemVT: MVT::v2i8, Action: Expand);
761
762 // Register custom handling for illegal type loads/stores. We'll try to custom
763 // lower almost all illegal types and logic in the lowering will discard cases
764 // we can't handle.
765 setOperationAction(Ops: {ISD::LOAD, ISD::STORE}, VTs: {MVT::i128, MVT::i256, MVT::f128},
766 Action: Custom);
767 for (MVT VT : MVT::fixedlen_vector_valuetypes())
768 if (!isTypeLegal(VT) && VT.getStoreSizeInBits() <= 256)
769 setOperationAction(Ops: {ISD::STORE, ISD::LOAD, ISD::MSTORE, ISD::MLOAD}, VT,
770 Action: Custom);
771
772 // Custom legalization for LDU intrinsics.
773 // TODO: The logic to lower these is not very robust and we should rewrite it.
774 // Perhaps LDU should not be represented as an intrinsic at all.
775 setOperationAction(Op: ISD::INTRINSIC_W_CHAIN, VT: MVT::i8, Action: Custom);
776 for (MVT VT : MVT::fixedlen_vector_valuetypes())
777 if (IsPTXVectorType(VT))
778 setOperationAction(Op: ISD::INTRINSIC_W_CHAIN, VT, Action: Custom);
779
780 setCondCodeAction(CCs: {ISD::SETNE, ISD::SETEQ, ISD::SETUGE, ISD::SETULE,
781 ISD::SETUGT, ISD::SETULT, ISD::SETGT, ISD::SETLT,
782 ISD::SETGE, ISD::SETLE},
783 VT: MVT::i1, Action: Expand);
784
785 // This is legal in NVPTX
786 setOperationAction(Op: ISD::ConstantFP, VT: MVT::f64, Action: Legal);
787 setOperationAction(Op: ISD::ConstantFP, VT: MVT::f32, Action: Legal);
788 setOperationAction(Op: ISD::ConstantFP, VT: MVT::f16, Action: Legal);
789 setOperationAction(Op: ISD::ConstantFP, VT: MVT::bf16, Action: Legal);
790
791 setOperationAction(Ops: ISD::DYNAMIC_STACKALLOC, VTs: {MVT::i32, MVT::i64}, Action: Custom);
792 setOperationAction(Ops: {ISD::STACKRESTORE, ISD::STACKSAVE}, VT: MVT::Other, Action: Custom);
793
794 // TRAP can be lowered to PTX trap
795 setOperationAction(Op: ISD::TRAP, VT: MVT::Other, Action: Legal);
796 // DEBUGTRAP can be lowered to PTX brkpt
797 setOperationAction(Op: ISD::DEBUGTRAP, VT: MVT::Other, Action: Legal);
798
799 // Support varargs.
800 setOperationAction(Op: ISD::VASTART, VT: MVT::Other, Action: Custom);
801 setOperationAction(Op: ISD::VAARG, VT: MVT::Other, Action: Custom);
802 setOperationAction(Op: ISD::VACOPY, VT: MVT::Other, Action: Expand);
803 setOperationAction(Op: ISD::VAEND, VT: MVT::Other, Action: Expand);
804
805 setOperationAction(Ops: {ISD::ABS, ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX},
806 VTs: {MVT::i16, MVT::i32, MVT::i64}, Action: Legal);
807
808 setOperationAction(Ops: {ISD::CTPOP, ISD::CTLZ, ISD::CTLZ_ZERO_UNDEF}, VT: MVT::i16,
809 Action: Promote);
810 setOperationAction(Ops: {ISD::CTPOP, ISD::CTLZ}, VT: MVT::i32, Action: Legal);
811 setOperationAction(Ops: {ISD::CTPOP, ISD::CTLZ}, VT: MVT::i64, Action: Custom);
812
813 setI16x2OperationAction(ISD::ABS, MVT::v2i16, Legal, Custom);
814 setI16x2OperationAction(ISD::SMIN, MVT::v2i16, Legal, Custom);
815 setI16x2OperationAction(ISD::SMAX, MVT::v2i16, Legal, Custom);
816 setI16x2OperationAction(ISD::UMIN, MVT::v2i16, Legal, Custom);
817 setI16x2OperationAction(ISD::UMAX, MVT::v2i16, Legal, Custom);
818 setI16x2OperationAction(ISD::CTPOP, MVT::v2i16, Legal, Expand);
819 setI16x2OperationAction(ISD::CTLZ, MVT::v2i16, Legal, Expand);
820
821 setI16x2OperationAction(ISD::ADD, MVT::v2i16, Legal, Custom);
822 setI16x2OperationAction(ISD::SUB, MVT::v2i16, Legal, Custom);
823 setI16x2OperationAction(ISD::MUL, MVT::v2i16, Legal, Custom);
824 setI16x2OperationAction(ISD::SHL, MVT::v2i16, Legal, Custom);
825 setI16x2OperationAction(ISD::SREM, MVT::v2i16, Legal, Custom);
826 setI16x2OperationAction(ISD::UREM, MVT::v2i16, Legal, Custom);
827
828 // Other arithmetic and logic ops are unsupported.
829 setOperationAction(Ops: {ISD::SDIV, ISD::UDIV, ISD::SRA, ISD::SRL, ISD::MULHS,
830 ISD::MULHU, ISD::FP_TO_SINT, ISD::FP_TO_UINT,
831 ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::SETCC},
832 VTs: {MVT::v2i16, MVT::v2i32}, Action: Expand);
833
834 // v2i32 is not supported for any arithmetic operations
835 setOperationAction(Ops: {ISD::ABS, ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX,
836 ISD::CTPOP, ISD::CTLZ, ISD::ADD, ISD::SUB, ISD::MUL,
837 ISD::SHL, ISD::SRA, ISD::SRL, ISD::OR, ISD::AND, ISD::XOR,
838 ISD::SREM, ISD::UREM},
839 VT: MVT::v2i32, Action: Expand);
840
841 setOperationAction(Op: ISD::ADDC, VT: MVT::i32, Action: Legal);
842 setOperationAction(Op: ISD::ADDE, VT: MVT::i32, Action: Legal);
843 setOperationAction(Op: ISD::SUBC, VT: MVT::i32, Action: Legal);
844 setOperationAction(Op: ISD::SUBE, VT: MVT::i32, Action: Legal);
845 if (STI.getPTXVersion() >= 43) {
846 setOperationAction(Op: ISD::ADDC, VT: MVT::i64, Action: Legal);
847 setOperationAction(Op: ISD::ADDE, VT: MVT::i64, Action: Legal);
848 setOperationAction(Op: ISD::SUBC, VT: MVT::i64, Action: Legal);
849 setOperationAction(Op: ISD::SUBE, VT: MVT::i64, Action: Legal);
850 }
851
852 setOperationAction(Op: ISD::CTTZ, VT: MVT::i16, Action: Expand);
853 setOperationAction(Ops: ISD::CTTZ, VTs: {MVT::v2i16, MVT::v2i32}, Action: Expand);
854 setOperationAction(Op: ISD::CTTZ, VT: MVT::i32, Action: Expand);
855 setOperationAction(Op: ISD::CTTZ, VT: MVT::i64, Action: Expand);
856
857 // PTX does not directly support SELP of i1, so promote to i32 first
858 setOperationAction(Op: ISD::SELECT, VT: MVT::i1, Action: Custom);
859
860 // PTX cannot multiply two i64s in a single instruction.
861 setOperationAction(Op: ISD::SMUL_LOHI, VT: MVT::i64, Action: Expand);
862 setOperationAction(Op: ISD::UMUL_LOHI, VT: MVT::i64, Action: Expand);
863
864 // We have some custom DAG combine patterns for these nodes
865 setTargetDAGCombine({ISD::ADD,
866 ISD::AND,
867 ISD::EXTRACT_VECTOR_ELT,
868 ISD::FADD,
869 ISD::FMAXNUM,
870 ISD::FMINNUM,
871 ISD::FMAXIMUM,
872 ISD::FMINIMUM,
873 ISD::FMAXIMUMNUM,
874 ISD::FMINIMUMNUM,
875 ISD::MUL,
876 ISD::SELECT,
877 ISD::SHL,
878 ISD::SREM,
879 ISD::UREM,
880 ISD::VSELECT,
881 ISD::BUILD_VECTOR,
882 ISD::ADDRSPACECAST,
883 ISD::LOAD,
884 ISD::STORE,
885 ISD::ZERO_EXTEND,
886 ISD::SIGN_EXTEND,
887 ISD::INTRINSIC_WO_CHAIN});
888
889 // If the vector operands require register coalescing, scalarize instead
890 if (STI.hasF32x2Instructions())
891 setTargetDAGCombine({ISD::FMA, ISD::FMUL, ISD::FSUB});
892
893 // setcc for f16x2 and bf16x2 needs special handling to prevent
894 // legalizer's attempt to scalarize it due to v2i1 not being legal.
895 if (STI.allowFP16Math() || STI.hasBF16Math())
896 setTargetDAGCombine(ISD::SETCC);
897
898 // Vector reduction operations. These may be turned into shuffle or tree
899 // reductions depending on what instructions are available for each type.
900 for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
901 MVT EltVT = VT.getVectorElementType();
902 if (EltVT == MVT::f32 || EltVT == MVT::f64) {
903 setOperationAction(Ops: {ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
904 ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
905 VT, Action: Custom);
906 }
907 }
908
909 // Promote fp16 arithmetic if fp16 hardware isn't available or the
910 // user passed --nvptx-no-fp16-math. The flag is useful because,
911 // although sm_53+ GPUs have some sort of FP16 support in
912 // hardware, only sm_53 and sm_60 have full implementation. Others
913 // only have token amount of hardware and are likely to run faster
914 // by using fp32 units instead.
915 for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) {
916 setFP16OperationAction(Op, MVT::f16, Legal, Promote);
917 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
918 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
919 // bf16 must be promoted to f32.
920 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
921 if (getOperationAction(Op, VT: MVT::bf16) == Promote)
922 AddPromotedToType(Opc: Op, OrigVT: MVT::bf16, DestVT: MVT::f32);
923 setOperationAction(Op, VT: MVT::v2f32,
924 Action: STI.hasF32x2Instructions() ? Legal : Expand);
925 }
926
927 // On SM80, we select add/mul/sub as fma to avoid promotion to float
928 for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB}) {
929 for (const auto &VT : {MVT::bf16, MVT::v2bf16}) {
930 if (!STI.hasNativeBF16Support(Opcode: Op) && STI.hasNativeBF16Support(Opcode: ISD::FMA)) {
931 setOperationAction(Op, VT, Action: Custom);
932 }
933 }
934 }
935
936 // f16/f16x2 neg was introduced in PTX 60, SM_53.
937 const bool IsFP16FP16x2NegAvailable = STI.getSmVersion() >= 53 &&
938 STI.getPTXVersion() >= 60 &&
939 STI.allowFP16Math();
940 for (const auto &VT : {MVT::f16, MVT::v2f16})
941 setOperationAction(Op: ISD::FNEG, VT,
942 Action: IsFP16FP16x2NegAvailable ? Legal : Expand);
943
944 setBF16OperationAction(ISD::FNEG, MVT::bf16, Legal, Expand);
945 setBF16OperationAction(ISD::FNEG, MVT::v2bf16, Legal, Expand);
946 setOperationAction(Op: ISD::FNEG, VT: MVT::v2f32, Action: Expand);
947 // (would be) Library functions.
948
949 // These map to conversion instructions for scalar FP types.
950 for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
951 ISD::FROUNDEVEN, ISD::FTRUNC}) {
952 setOperationAction(Op, VT: MVT::f16, Action: Legal);
953 setOperationAction(Op, VT: MVT::f32, Action: Legal);
954 setOperationAction(Op, VT: MVT::f64, Action: Legal);
955 setOperationAction(Op, VT: MVT::v2f16, Action: Expand);
956 setOperationAction(Op, VT: MVT::v2bf16, Action: Expand);
957 setOperationAction(Op, VT: MVT::v2f32, Action: Expand);
958 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
959 if (getOperationAction(Op, VT: MVT::bf16) == Promote)
960 AddPromotedToType(Opc: Op, OrigVT: MVT::bf16, DestVT: MVT::f32);
961 }
962
963 if (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71) {
964 setOperationAction(Op: ISD::BF16_TO_FP, VT: MVT::f32, Action: Expand);
965 }
966 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
967 for (MVT VT : {MVT::bf16, MVT::f32, MVT::f64}) {
968 setOperationAction(Op: ISD::FP_EXTEND, VT, Action: Custom);
969 setOperationAction(Op: ISD::FP_ROUND, VT, Action: Custom);
970 }
971 }
972
973 // Expand v2f32 = fp_extend
974 setOperationAction(Op: ISD::FP_EXTEND, VT: MVT::v2f32, Action: Expand);
975 // Expand v2[b]f16 = fp_round v2f32
976 setOperationAction(Ops: ISD::FP_ROUND, VTs: {MVT::v2bf16, MVT::v2f16}, Action: Expand);
977
978 // sm_80 only has conversions between f32 and bf16. Custom lower all other
979 // bf16 conversions.
980 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
981 for (MVT VT : {MVT::i1, MVT::i16, MVT::i32, MVT::i64}) {
982 setOperationAction(
983 Ops: {ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT},
984 VT, Action: Custom);
985 }
986 setOperationAction(
987 Ops: {ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT},
988 VT: MVT::bf16, Action: Custom);
989 }
990
991 setOperationAction(Op: ISD::FROUND, VT: MVT::f16, Action: Promote);
992 setOperationAction(Op: ISD::FROUND, VT: MVT::v2f16, Action: Expand);
993 setOperationAction(Op: ISD::FROUND, VT: MVT::v2bf16, Action: Expand);
994 setOperationAction(Op: ISD::FROUND, VT: MVT::f32, Action: Custom);
995 setOperationAction(Op: ISD::FROUND, VT: MVT::f64, Action: Custom);
996 setOperationAction(Op: ISD::FROUND, VT: MVT::bf16, Action: Promote);
997 AddPromotedToType(Opc: ISD::FROUND, OrigVT: MVT::bf16, DestVT: MVT::f32);
998
999 // 'Expand' implements FCOPYSIGN without calling an external library.
1000 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::f16, Action: Expand);
1001 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::v2f16, Action: Expand);
1002 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::bf16, Action: Expand);
1003 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::v2bf16, Action: Expand);
1004 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::f32, Action: Custom);
1005 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::f64, Action: Custom);
1006
1007 // These map to corresponding instructions for f32/f64. f16 must be
1008 // promoted to f32. v2f16 is expanded to f16, which is then promoted
1009 // to f32.
1010 for (const auto &Op :
1011 {ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS, ISD::FTANH}) {
1012 setOperationAction(Op, VT: MVT::f16, Action: Promote);
1013 setOperationAction(Op, VT: MVT::f32, Action: Legal);
1014 // only div/rem/sqrt are legal for f64
1015 if (Op == ISD::FDIV || Op == ISD::FREM || Op == ISD::FSQRT) {
1016 setOperationAction(Op, VT: MVT::f64, Action: Legal);
1017 }
1018 setOperationAction(Ops: Op, VTs: {MVT::v2f16, MVT::v2bf16, MVT::v2f32}, Action: Expand);
1019 setOperationAction(Op, VT: MVT::bf16, Action: Promote);
1020 AddPromotedToType(Opc: Op, OrigVT: MVT::bf16, DestVT: MVT::f32);
1021 }
1022 setOperationAction(Ops: ISD::FREM, VTs: {MVT::f32, MVT::f64}, Action: Custom);
1023
1024 setOperationAction(Ops: ISD::FABS, VTs: {MVT::f32, MVT::f64}, Action: Legal);
1025 setOperationAction(Op: ISD::FABS, VT: MVT::v2f32, Action: Expand);
1026 if (STI.getPTXVersion() >= 65) {
1027 setFP16OperationAction(ISD::FABS, MVT::f16, Legal, Promote);
1028 setFP16OperationAction(ISD::FABS, MVT::v2f16, Legal, Expand);
1029 } else {
1030 setOperationAction(Op: ISD::FABS, VT: MVT::f16, Action: Promote);
1031 setOperationAction(Op: ISD::FABS, VT: MVT::v2f16, Action: Expand);
1032 }
1033 setBF16OperationAction(ISD::FABS, MVT::v2bf16, Legal, Expand);
1034 setBF16OperationAction(ISD::FABS, MVT::bf16, Legal, Promote);
1035 if (getOperationAction(Op: ISD::FABS, VT: MVT::bf16) == Promote)
1036 AddPromotedToType(Opc: ISD::FABS, OrigVT: MVT::bf16, DestVT: MVT::f32);
1037
1038 for (const auto &Op :
1039 {ISD::FMINNUM, ISD::FMAXNUM, ISD::FMINIMUMNUM, ISD::FMAXIMUMNUM}) {
1040 setOperationAction(Op, VT: MVT::f32, Action: Legal);
1041 setOperationAction(Op, VT: MVT::f64, Action: Legal);
1042 setFP16OperationAction(Op, MVT::f16, Legal, Promote);
1043 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
1044 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
1045 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
1046 if (getOperationAction(Op, VT: MVT::bf16) == Promote)
1047 AddPromotedToType(Opc: Op, OrigVT: MVT::bf16, DestVT: MVT::f32);
1048 setOperationAction(Op, VT: MVT::v2f32, Action: Expand);
1049 }
1050 bool SupportsF32MinMaxNaN =
1051 STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
1052 for (const auto &Op : {ISD::FMINIMUM, ISD::FMAXIMUM}) {
1053 setOperationAction(Op, VT: MVT::f32, Action: SupportsF32MinMaxNaN ? Legal : Expand);
1054 setFP16OperationAction(Op, MVT::f16, Legal, Expand);
1055 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
1056 setBF16OperationAction(Op, MVT::bf16, Legal, Expand);
1057 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
1058 setOperationAction(Op, VT: MVT::v2f32, Action: Expand);
1059 }
1060
1061 // Custom lowering for inline asm with 128-bit operands
1062 setOperationAction(Op: ISD::CopyToReg, VT: MVT::i128, Action: Custom);
1063 setOperationAction(Op: ISD::CopyFromReg, VT: MVT::i128, Action: Custom);
1064
1065 // FEXP2 support:
1066 // - f32
1067 // - f16/f16x2 (sm_70+, PTX 7.0+)
1068 // - bf16/bf16x2 (sm_90+, PTX 7.8+)
1069 // When f16/bf16 types aren't supported, they are promoted/expanded to f32.
1070 setOperationAction(Op: ISD::FEXP2, VT: MVT::f32, Action: Legal);
1071 setOperationAction(Op: ISD::FEXP2, VT: MVT::v2f32, Action: Expand);
1072 setFP16OperationAction(ISD::FEXP2, MVT::f16, Legal, Promote);
1073 setFP16OperationAction(ISD::FEXP2, MVT::v2f16, Legal, Expand);
1074 setBF16OperationAction(ISD::FEXP2, MVT::bf16, Legal, Promote);
1075 setBF16OperationAction(ISD::FEXP2, MVT::v2bf16, Legal, Expand);
1076
1077 // FLOG2 supports f32 only
1078 // f16/bf16 types aren't supported, but they are promoted/expanded to f32.
1079 if (UseApproxLog2F32) {
1080 setOperationAction(Op: ISD::FLOG2, VT: MVT::f32, Action: Legal);
1081 setOperationPromotedToType(Opc: ISD::FLOG2, OrigVT: MVT::f16, DestVT: MVT::f32);
1082 setOperationPromotedToType(Opc: ISD::FLOG2, OrigVT: MVT::bf16, DestVT: MVT::f32);
1083 setOperationAction(Ops: ISD::FLOG2, VTs: {MVT::v2f16, MVT::v2bf16, MVT::v2f32},
1084 Action: Expand);
1085 }
1086
1087 setOperationAction(Ops: ISD::ADDRSPACECAST, VTs: {MVT::i32, MVT::i64}, Action: Custom);
1088
1089 setOperationAction(Ops: ISD::ATOMIC_LOAD_SUB, VTs: {MVT::i32, MVT::i64}, Action: Expand);
1090
1091 // atom.b128 is legal in PTX but since we don't represent i128 as a legal
1092 // type, we need to custom lower it.
1093 setOperationAction(Ops: {ISD::ATOMIC_CMP_SWAP, ISD::ATOMIC_SWAP}, VT: MVT::i128,
1094 Action: Custom);
1095
1096 // Now deduce the information based on the above mentioned
1097 // actions
1098 computeRegisterProperties(TRI: STI.getRegisterInfo());
1099
1100 // PTX support for 16-bit CAS is emulated. Only use 32+
1101 setMinCmpXchgSizeInBits(STI.getMinCmpXchgSizeInBits());
1102 setMaxAtomicSizeInBitsSupported(STI.hasAtomSwap128() ? 128 : 64);
1103 setMaxDivRemBitWidthSupported(64);
1104
1105 // Custom lowering for tcgen05.ld vector operands
1106 setOperationAction(Ops: ISD::INTRINSIC_W_CHAIN,
1107 VTs: {MVT::v2i32, MVT::v4i32, MVT::v8i32, MVT::v16i32,
1108 MVT::v32i32, MVT::v64i32, MVT::v128i32, MVT::v2f32,
1109 MVT::v4f32, MVT::v8f32, MVT::v16f32, MVT::v32f32,
1110 MVT::v64f32, MVT::v128f32},
1111 Action: Custom);
1112
1113 // Custom lowering for tcgen05.st vector operands
1114 setOperationAction(Ops: ISD::INTRINSIC_VOID,
1115 VTs: {MVT::v2i32, MVT::v4i32, MVT::v8i32, MVT::v16i32,
1116 MVT::v32i32, MVT::v64i32, MVT::v128i32, MVT::Other},
1117 Action: Custom);
1118
1119 // Enable custom lowering for the following:
1120 // * MVT::i128 - clusterlaunchcontrol
1121 // * MVT::i32 - prmt
1122 // * MVT::v4f32 - cvt_rs fp{4/6/8}x4 intrinsics
1123 // * MVT::Other - internal.addrspace.wrap
1124 setOperationAction(Ops: ISD::INTRINSIC_WO_CHAIN,
1125 VTs: {MVT::i32, MVT::i128, MVT::v4f32, MVT::Other}, Action: Custom);
1126
1127 // Custom lowering for bswap
1128 setOperationAction(Ops: ISD::BSWAP, VTs: {MVT::i16, MVT::i32, MVT::i64, MVT::v2i16},
1129 Action: Custom);
1130}
1131
1132TargetLoweringBase::LegalizeTypeAction
1133NVPTXTargetLowering::getPreferredVectorAction(MVT VT) const {
1134 if (!VT.isScalableVector() && VT.getVectorNumElements() != 1 &&
1135 VT.getScalarType() == MVT::i1)
1136 return TypeSplitVector;
1137 return TargetLoweringBase::getPreferredVectorAction(VT);
1138}
1139
1140SDValue NVPTXTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG,
1141 int Enabled, int &ExtraSteps,
1142 bool &UseOneConst,
1143 bool Reciprocal) const {
1144 if (!(Enabled == ReciprocalEstimate::Enabled ||
1145 (Enabled == ReciprocalEstimate::Unspecified && !usePrecSqrtF32())))
1146 return SDValue();
1147
1148 if (ExtraSteps == ReciprocalEstimate::Unspecified)
1149 ExtraSteps = 0;
1150
1151 SDLoc DL(Operand);
1152 EVT VT = Operand.getValueType();
1153 bool Ftz = useF32FTZ(MF: DAG.getMachineFunction());
1154
1155 auto MakeIntrinsicCall = [&](Intrinsic::ID IID) {
1156 return DAG.getNode(Opcode: ISD::INTRINSIC_WO_CHAIN, DL, VT,
1157 N1: DAG.getConstant(Val: IID, DL, VT: MVT::i32), N2: Operand);
1158 };
1159
1160 // The sqrt and rsqrt refinement processes assume we always start out with an
1161 // approximation of the rsqrt. Therefore, if we're going to do any refinement
1162 // (i.e. ExtraSteps > 0), we must return an rsqrt. But if we're *not* doing
1163 // any refinement, we must return a regular sqrt.
1164 if (Reciprocal || ExtraSteps > 0) {
1165 if (VT == MVT::f32)
1166 return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_rsqrt_approx_ftz_f
1167 : Intrinsic::nvvm_rsqrt_approx_f);
1168 else if (VT == MVT::f64)
1169 return MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d);
1170 else
1171 return SDValue();
1172 } else {
1173 if (VT == MVT::f32)
1174 return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_sqrt_approx_ftz_f
1175 : Intrinsic::nvvm_sqrt_approx_f);
1176 else {
1177 // There's no sqrt.approx.f64 instruction, so we emit
1178 // reciprocal(rsqrt(x)). This is faster than
1179 // select(x == 0, 0, x * rsqrt(x)). (In fact, it's faster than plain
1180 // x * rsqrt(x).)
1181 return DAG.getNode(
1182 Opcode: ISD::INTRINSIC_WO_CHAIN, DL, VT,
1183 N1: DAG.getConstant(Val: Intrinsic::nvvm_rcp_approx_ftz_d, DL, VT: MVT::i32),
1184 N2: MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d));
1185 }
1186 }
1187}
1188
1189static Align getArgumentAlignment(const CallBase *CB, Type *Ty, unsigned Idx,
1190 const DataLayout &DL);
1191
1192std::string NVPTXTargetLowering::getPrototype(
1193 const DataLayout &DL, Type *RetTy, const ArgListTy &Args,
1194 const SmallVectorImpl<ISD::OutputArg> &Outs,
1195 std::optional<unsigned> FirstVAArg, const CallBase &CB,
1196 unsigned UniqueCallSite) const {
1197 auto PtrVT = getPointerTy(DL);
1198
1199 std::string Prototype;
1200 raw_string_ostream O(Prototype);
1201 O << "prototype_" << UniqueCallSite << " : .callprototype ";
1202
1203 if (RetTy->isVoidTy()) {
1204 O << "()";
1205 } else {
1206 O << "(";
1207 if (shouldPassAsArray(Ty: RetTy)) {
1208 const Align RetAlign = getArgumentAlignment(CB: &CB, Ty: RetTy, Idx: 0, DL);
1209 O << ".param .align " << RetAlign.value() << " .b8 _["
1210 << DL.getTypeAllocSize(Ty: RetTy) << "]";
1211 } else if (RetTy->isFloatingPointTy() || RetTy->isIntegerTy()) {
1212 unsigned size = 0;
1213 if (auto *ITy = dyn_cast<IntegerType>(Val: RetTy)) {
1214 size = ITy->getBitWidth();
1215 } else {
1216 assert(RetTy->isFloatingPointTy() &&
1217 "Floating point type expected here");
1218 size = RetTy->getPrimitiveSizeInBits();
1219 }
1220 // PTX ABI requires all scalar return values to be at least 32
1221 // bits in size. fp16 normally uses .b16 as its storage type in
1222 // PTX, so its size must be adjusted here, too.
1223 size = promoteScalarArgumentSize(size);
1224
1225 O << ".param .b" << size << " _";
1226 } else if (isa<PointerType>(Val: RetTy)) {
1227 O << ".param .b" << PtrVT.getSizeInBits() << " _";
1228 } else {
1229 llvm_unreachable("Unknown return type");
1230 }
1231 O << ") ";
1232 }
1233 O << "_ (";
1234
1235 bool first = true;
1236
1237 const unsigned NumArgs = FirstVAArg.value_or(u: Args.size());
1238 auto AllOuts = ArrayRef(Outs);
1239 for (const unsigned I : llvm::seq(Size: NumArgs)) {
1240 const auto ArgOuts =
1241 AllOuts.take_while(Pred: [I](auto O) { return O.OrigArgIndex == I; });
1242 AllOuts = AllOuts.drop_front(N: ArgOuts.size());
1243
1244 Type *Ty = Args[I].Ty;
1245 if (!first) {
1246 O << ", ";
1247 }
1248 first = false;
1249
1250 if (ArgOuts[0].Flags.isByVal()) {
1251 // Indirect calls need strict ABI alignment so we disable optimizations by
1252 // not providing a function to optimize.
1253 Type *ETy = Args[I].IndirectType;
1254 Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
1255 Align ParamByValAlign =
1256 getFunctionByValParamAlign(/*F=*/nullptr, ArgTy: ETy, InitialAlign, DL);
1257
1258 O << ".param .align " << ParamByValAlign.value() << " .b8 _["
1259 << ArgOuts[0].Flags.getByValSize() << "]";
1260 } else {
1261 if (shouldPassAsArray(Ty)) {
1262 Align ParamAlign =
1263 getArgumentAlignment(CB: &CB, Ty, Idx: I + AttributeList::FirstArgIndex, DL);
1264 O << ".param .align " << ParamAlign.value() << " .b8 _["
1265 << DL.getTypeAllocSize(Ty) << "]";
1266 continue;
1267 }
1268 // i8 types in IR will be i16 types in SDAG
1269 assert((getValueType(DL, Ty) == ArgOuts[0].VT ||
1270 (getValueType(DL, Ty) == MVT::i8 && ArgOuts[0].VT == MVT::i16)) &&
1271 "type mismatch between callee prototype and arguments");
1272 // scalar type
1273 unsigned sz = 0;
1274 if (auto *ITy = dyn_cast<IntegerType>(Val: Ty)) {
1275 sz = promoteScalarArgumentSize(size: ITy->getBitWidth());
1276 } else if (isa<PointerType>(Val: Ty)) {
1277 sz = PtrVT.getSizeInBits();
1278 } else {
1279 sz = Ty->getPrimitiveSizeInBits();
1280 }
1281 O << ".param .b" << sz << " _";
1282 }
1283 }
1284
1285 if (FirstVAArg)
1286 O << (first ? "" : ",") << " .param .align "
1287 << STI.getMaxRequiredAlignment() << " .b8 _[]";
1288 O << ")";
1289 if (shouldEmitPTXNoReturn(V: &CB, TM: *nvTM))
1290 O << " .noreturn";
1291 O << ";";
1292
1293 return Prototype;
1294}
1295
1296static Align getArgumentAlignment(const CallBase *CB, Type *Ty, unsigned Idx,
1297 const DataLayout &DL) {
1298 if (!CB) {
1299 // CallSite is zero, fallback to ABI type alignment
1300 return DL.getABITypeAlign(Ty);
1301 }
1302
1303 const Function *DirectCallee = CB->getCalledFunction();
1304
1305 if (!DirectCallee) {
1306 // We don't have a direct function symbol, but that may be because of
1307 // constant cast instructions in the call.
1308
1309 // With bitcast'd call targets, the instruction will be the call
1310 if (const auto *CI = dyn_cast<CallInst>(Val: CB)) {
1311 // Check if we have call alignment metadata
1312 if (MaybeAlign StackAlign = getAlign(*CI, Idx))
1313 return StackAlign.value();
1314 }
1315 DirectCallee = getMaybeBitcastedCallee(CB);
1316 }
1317
1318 // Check for function alignment information if we found that the
1319 // ultimate target is a Function
1320 if (DirectCallee)
1321 return getFunctionArgumentAlignment(F: DirectCallee, Ty, Idx, DL);
1322
1323 // Call is indirect, fall back to the ABI type alignment
1324 return DL.getABITypeAlign(Ty);
1325}
1326
1327static bool shouldConvertToIndirectCall(const CallBase *CB,
1328 const GlobalAddressSDNode *Func) {
1329 if (!Func)
1330 return false;
1331 if (auto *CalleeFunc = dyn_cast<Function>(Val: Func->getGlobal()))
1332 return CB->getFunctionType() != CalleeFunc->getFunctionType();
1333 return false;
1334}
1335
1336static MachinePointerInfo refinePtrAS(SDValue &Ptr, SelectionDAG &DAG,
1337 const DataLayout &DL,
1338 const TargetLowering &TL) {
1339 if (Ptr->getOpcode() == ISD::FrameIndex) {
1340 auto Ty = TL.getPointerTy(DL, AS: ADDRESS_SPACE_LOCAL);
1341 Ptr = DAG.getAddrSpaceCast(dl: SDLoc(), VT: Ty, Ptr, SrcAS: ADDRESS_SPACE_GENERIC,
1342 DestAS: ADDRESS_SPACE_LOCAL);
1343
1344 return MachinePointerInfo(ADDRESS_SPACE_LOCAL);
1345 }
1346
1347 // Peel of an addrspacecast to generic and load directly from the specific
1348 // address space.
1349 if (Ptr->getOpcode() == ISD::ADDRSPACECAST) {
1350 const auto *ASC = cast<AddrSpaceCastSDNode>(Val&: Ptr);
1351 if (ASC->getDestAddressSpace() == ADDRESS_SPACE_GENERIC) {
1352 Ptr = ASC->getOperand(Num: 0);
1353 return MachinePointerInfo(ASC->getSrcAddressSpace());
1354 }
1355 }
1356
1357 return MachinePointerInfo();
1358}
1359
1360static ISD::NodeType getExtOpcode(const ISD::ArgFlagsTy &Flags) {
1361 if (Flags.isSExt())
1362 return ISD::SIGN_EXTEND;
1363 if (Flags.isZExt())
1364 return ISD::ZERO_EXTEND;
1365 return ISD::ANY_EXTEND;
1366}
1367
1368static SDValue correctParamType(SDValue V, EVT ExpectedVT,
1369 ISD::ArgFlagsTy Flags, SelectionDAG &DAG,
1370 SDLoc dl) {
1371 const EVT ActualVT = V.getValueType();
1372 assert((ActualVT == ExpectedVT ||
1373 (ExpectedVT.isInteger() && ActualVT.isInteger())) &&
1374 "Non-integer argument type size mismatch");
1375 if (ExpectedVT.bitsGT(VT: ActualVT))
1376 return DAG.getNode(Opcode: getExtOpcode(Flags), DL: dl, VT: ExpectedVT, Operand: V);
1377 if (ExpectedVT.bitsLT(VT: ActualVT))
1378 return DAG.getNode(Opcode: ISD::TRUNCATE, DL: dl, VT: ExpectedVT, Operand: V);
1379
1380 return V;
1381}
1382
1383SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
1384 SmallVectorImpl<SDValue> &InVals) const {
1385
1386 if (CLI.IsVarArg && (STI.getPTXVersion() < 60 || STI.getSmVersion() < 30))
1387 report_fatal_error(
1388 reason: "Support for variadic functions (unsized array parameter) introduced "
1389 "in PTX ISA version 6.0 and requires target sm_30.");
1390
1391 SelectionDAG &DAG = CLI.DAG;
1392 SDLoc dl = CLI.DL;
1393 const SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
1394 SDValue Callee = CLI.Callee;
1395 ArgListTy &Args = CLI.getArgs();
1396 Type *RetTy = CLI.RetTy;
1397 const CallBase *CB = CLI.CB;
1398 const DataLayout &DL = DAG.getDataLayout();
1399 LLVMContext &Ctx = *DAG.getContext();
1400
1401 const auto GetI32 = [&](const unsigned I) {
1402 return DAG.getConstant(Val: I, DL: dl, VT: MVT::i32);
1403 };
1404
1405 const unsigned UniqueCallSite = GlobalUniqueCallSite++;
1406 const SDValue CallChain = CLI.Chain;
1407 const SDValue StartChain =
1408 DAG.getCALLSEQ_START(Chain: CallChain, InSize: UniqueCallSite, OutSize: 0, DL: dl);
1409 SDValue DeclareGlue = StartChain.getValue(R: 1);
1410
1411 SmallVector<SDValue, 16> CallPrereqs{StartChain};
1412
1413 const auto MakeDeclareScalarParam = [&](SDValue Symbol, unsigned Size) {
1414 // PTX ABI requires integral types to be at least 32 bits in size. FP16 is
1415 // loaded/stored using i16, so it's handled here as well.
1416 const unsigned SizeBits = promoteScalarArgumentSize(size: Size * 8);
1417 SDValue Declare =
1418 DAG.getNode(Opcode: NVPTXISD::DeclareScalarParam, DL: dl, ResultTys: {MVT::Other, MVT::Glue},
1419 Ops: {StartChain, Symbol, GetI32(SizeBits), DeclareGlue});
1420 CallPrereqs.push_back(Elt: Declare);
1421 DeclareGlue = Declare.getValue(R: 1);
1422 return Declare;
1423 };
1424
1425 const auto MakeDeclareArrayParam = [&](SDValue Symbol, Align Align,
1426 unsigned Size) {
1427 SDValue Declare = DAG.getNode(
1428 Opcode: NVPTXISD::DeclareArrayParam, DL: dl, ResultTys: {MVT::Other, MVT::Glue},
1429 Ops: {StartChain, Symbol, GetI32(Align.value()), GetI32(Size), DeclareGlue});
1430 CallPrereqs.push_back(Elt: Declare);
1431 DeclareGlue = Declare.getValue(R: 1);
1432 return Declare;
1433 };
1434
1435 // Variadic arguments.
1436 //
1437 // Normally, for each argument, we declare a param scalar or a param
1438 // byte array in the .param space, and store the argument value to that
1439 // param scalar or array starting at offset 0.
1440 //
1441 // In the case of the first variadic argument, we declare a vararg byte array
1442 // with size 0. The exact size of this array isn't known at this point, so
1443 // it'll be patched later. All the variadic arguments will be stored to this
1444 // array at a certain offset (which gets tracked by 'VAOffset'). The offset is
1445 // initially set to 0, so it can be used for non-variadic arguments (which use
1446 // 0 offset) to simplify the code.
1447 //
1448 // After all vararg is processed, 'VAOffset' holds the size of the
1449 // vararg byte array.
1450 assert((CLI.IsVarArg || CLI.Args.size() == CLI.NumFixedArgs) &&
1451 "Non-VarArg function with extra arguments");
1452
1453 const unsigned FirstVAArg = CLI.NumFixedArgs; // position of first variadic
1454 unsigned VAOffset = 0; // current offset in the param array
1455
1456 const SDValue VADeclareParam =
1457 CLI.Args.size() > FirstVAArg
1458 ? MakeDeclareArrayParam(getCallParamSymbol(DAG, I: FirstVAArg, T: MVT::i32),
1459 Align(STI.getMaxRequiredAlignment()), 0)
1460 : SDValue();
1461
1462 // Args.size() and Outs.size() need not match.
1463 // Outs.size() will be larger
1464 // * if there is an aggregate argument with multiple fields (each field
1465 // showing up separately in Outs)
1466 // * if there is a vector argument with more than typical vector-length
1467 // elements (generally if more than 4) where each vector element is
1468 // individually present in Outs.
1469 // So a different index should be used for indexing into Outs/OutVals.
1470 // See similar issue in LowerFormalArguments.
1471 auto AllOuts = ArrayRef(CLI.Outs);
1472 auto AllOutVals = ArrayRef(CLI.OutVals);
1473 assert(AllOuts.size() == AllOutVals.size() &&
1474 "Outs and OutVals must be the same size");
1475 // Declare the .params or .reg need to pass values
1476 // to the function
1477 for (const auto E : llvm::enumerate(First&: Args)) {
1478 const auto ArgI = E.index();
1479 const auto Arg = E.value();
1480 const auto ArgOuts =
1481 AllOuts.take_while(Pred: [&](auto O) { return O.OrigArgIndex == ArgI; });
1482 const auto ArgOutVals = AllOutVals.take_front(N: ArgOuts.size());
1483 AllOuts = AllOuts.drop_front(N: ArgOuts.size());
1484 AllOutVals = AllOutVals.drop_front(N: ArgOuts.size());
1485
1486 const bool IsVAArg = (ArgI >= FirstVAArg);
1487 const bool IsByVal = Arg.IsByVal;
1488
1489 const SDValue ParamSymbol =
1490 getCallParamSymbol(DAG, I: IsVAArg ? FirstVAArg : ArgI, T: MVT::i32);
1491
1492 assert((!IsByVal || Arg.IndirectType) &&
1493 "byval arg must have indirect type");
1494 Type *ETy = (IsByVal ? Arg.IndirectType : Arg.Ty);
1495
1496 const Align ArgAlign = [&]() {
1497 if (IsByVal) {
1498 // The ByValAlign in the Outs[OIdx].Flags is always set at this point,
1499 // so we don't need to worry whether it's naturally aligned or not.
1500 // See TargetLowering::LowerCallTo().
1501 const Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
1502 return getFunctionByValParamAlign(F: CB->getCalledFunction(), ArgTy: ETy,
1503 InitialAlign, DL);
1504 }
1505 return getArgumentAlignment(CB, Ty: Arg.Ty, Idx: ArgI + 1, DL);
1506 }();
1507
1508 const unsigned TySize = DL.getTypeAllocSize(Ty: ETy);
1509 assert((!IsByVal || TySize == ArgOuts[0].Flags.getByValSize()) &&
1510 "type size mismatch");
1511
1512 const SDValue ArgDeclare = [&]() {
1513 if (IsVAArg)
1514 return VADeclareParam;
1515
1516 if (IsByVal || shouldPassAsArray(Ty: Arg.Ty))
1517 return MakeDeclareArrayParam(ParamSymbol, ArgAlign, TySize);
1518
1519 assert(ArgOuts.size() == 1 && "We must pass only one value as non-array");
1520 assert((ArgOuts[0].VT.isInteger() || ArgOuts[0].VT.isFloatingPoint()) &&
1521 "Only int and float types are supported as non-array arguments");
1522
1523 return MakeDeclareScalarParam(ParamSymbol, TySize);
1524 }();
1525
1526 if (IsByVal) {
1527 assert(ArgOutVals.size() == 1 && "We must pass only one value as byval");
1528 SDValue SrcPtr = ArgOutVals[0];
1529 const auto PointerInfo = refinePtrAS(Ptr&: SrcPtr, DAG, DL, TL: *this);
1530 const Align BaseSrcAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
1531
1532 if (IsVAArg)
1533 VAOffset = alignTo(Size: VAOffset, A: ArgAlign);
1534
1535 SmallVector<EVT, 4> ValueVTs, MemVTs;
1536 SmallVector<TypeSize, 4> Offsets;
1537 ComputeValueVTs(TLI: *this, DL, Ty: ETy, ValueVTs, MemVTs: &MemVTs, Offsets: &Offsets);
1538
1539 unsigned J = 0;
1540 const auto VI = VectorizePTXValueVTs(ValueVTs: MemVTs, Offsets, ParamAlignment: ArgAlign, IsVAArg);
1541 for (const unsigned NumElts : VI) {
1542 EVT LoadVT = getVectorizedVT(VT: MemVTs[J], N: NumElts, C&: Ctx);
1543 Align SrcAlign = commonAlignment(A: BaseSrcAlign, Offset: Offsets[J]);
1544 SDValue SrcAddr = DAG.getObjectPtrOffset(SL: dl, Ptr: SrcPtr, Offset: Offsets[J]);
1545 SDValue SrcLoad =
1546 DAG.getLoad(VT: LoadVT, dl, Chain: CallChain, Ptr: SrcAddr, PtrInfo: PointerInfo, Alignment: SrcAlign);
1547
1548 TypeSize ParamOffset = Offsets[J].getWithIncrement(RHS: VAOffset);
1549 Align ParamAlign = commonAlignment(A: ArgAlign, Offset: ParamOffset);
1550 SDValue ParamAddr =
1551 DAG.getObjectPtrOffset(SL: dl, Ptr: ParamSymbol, Offset: ParamOffset);
1552 SDValue StoreParam = DAG.getStore(
1553 Chain: ArgDeclare, dl, Val: SrcLoad, Ptr: ParamAddr,
1554 PtrInfo: MachinePointerInfo(NVPTX::AddressSpace::DeviceParam), Alignment: ParamAlign);
1555 CallPrereqs.push_back(Elt: StoreParam);
1556
1557 J += NumElts;
1558 }
1559 if (IsVAArg)
1560 VAOffset += TySize;
1561 } else {
1562 SmallVector<EVT, 16> VTs;
1563 SmallVector<uint64_t, 16> Offsets;
1564 ComputePTXValueVTs(TLI: *this, DL, Ctx, CallConv: CLI.CallConv, Ty: Arg.Ty, ValueVTs&: VTs, Offsets,
1565 StartingOffset: VAOffset);
1566 assert(VTs.size() == Offsets.size() && "Size mismatch");
1567 assert(VTs.size() == ArgOuts.size() && "Size mismatch");
1568
1569 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
1570 // than 32-bits are sign extended or zero extended, depending on
1571 // whether they are signed or unsigned types. This case applies
1572 // only to scalar parameters and not to aggregate values.
1573 const bool ExtendIntegerParam =
1574 Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty: Arg.Ty) < 32;
1575
1576 const auto GetStoredValue = [&](const unsigned I) {
1577 SDValue StVal = ArgOutVals[I];
1578 assert(promoteScalarIntegerPTX(StVal.getValueType()) ==
1579 StVal.getValueType() &&
1580 "OutVal type should always be legal");
1581
1582 const EVT VTI = promoteScalarIntegerPTX(VT: VTs[I]);
1583 const EVT StoreVT =
1584 ExtendIntegerParam ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
1585
1586 return correctParamType(V: StVal, ExpectedVT: StoreVT, Flags: ArgOuts[I].Flags, DAG, dl);
1587 };
1588
1589 unsigned J = 0;
1590 const auto VI = VectorizePTXValueVTs(ValueVTs: VTs, Offsets, ParamAlignment: ArgAlign, IsVAArg);
1591 for (const unsigned NumElts : VI) {
1592 const EVT EltVT = promoteScalarIntegerPTX(VT: VTs[J]);
1593
1594 unsigned Offset;
1595 if (IsVAArg) {
1596 // TODO: We may need to support vector types that can be passed
1597 // as scalars in variadic arguments.
1598 assert(NumElts == 1 &&
1599 "Vectorization should be disabled for vaargs.");
1600
1601 // Align each part of the variadic argument to their type.
1602 VAOffset = alignTo(Size: VAOffset, A: DAG.getEVTAlign(MemoryVT: EltVT));
1603 Offset = VAOffset;
1604
1605 const EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
1606 VAOffset += DL.getTypeAllocSize(Ty: TheStoreType.getTypeForEVT(Context&: Ctx));
1607 } else {
1608 assert(VAOffset == 0 && "VAOffset must be 0 for non-VA args");
1609 Offset = Offsets[J];
1610 }
1611
1612 SDValue Ptr =
1613 DAG.getObjectPtrOffset(SL: dl, Ptr: ParamSymbol, Offset: TypeSize::getFixed(ExactSize: Offset));
1614
1615 const MaybeAlign CurrentAlign = ExtendIntegerParam
1616 ? MaybeAlign(std::nullopt)
1617 : commonAlignment(A: ArgAlign, Offset);
1618
1619 SDValue Val =
1620 getBuildVectorizedValue(N: NumElts, dl, DAG, GetElement: [&](unsigned K) {
1621 return GetStoredValue(J + K);
1622 });
1623
1624 SDValue StoreParam = DAG.getStore(
1625 Chain: ArgDeclare, dl, Val, Ptr,
1626 PtrInfo: MachinePointerInfo(NVPTX::AddressSpace::DeviceParam), Alignment: CurrentAlign);
1627 CallPrereqs.push_back(Elt: StoreParam);
1628
1629 J += NumElts;
1630 }
1631 }
1632 }
1633
1634 // Handle Result
1635 if (!Ins.empty()) {
1636 const SDValue RetSymbol = DAG.getExternalSymbol(Sym: "retval0", VT: MVT::i32);
1637 const unsigned ResultSize = DL.getTypeAllocSize(Ty: RetTy);
1638 if (shouldPassAsArray(Ty: RetTy)) {
1639 const Align RetAlign = getArgumentAlignment(CB, Ty: RetTy, Idx: 0, DL);
1640 MakeDeclareArrayParam(RetSymbol, RetAlign, ResultSize);
1641 } else {
1642 MakeDeclareScalarParam(RetSymbol, ResultSize);
1643 }
1644 }
1645
1646 // Set the size of the vararg param byte array if the callee is a variadic
1647 // function and the variadic part is not empty.
1648 if (VADeclareParam) {
1649 SDValue DeclareParamOps[] = {VADeclareParam.getOperand(i: 0),
1650 VADeclareParam.getOperand(i: 1),
1651 VADeclareParam.getOperand(i: 2), GetI32(VAOffset),
1652 VADeclareParam.getOperand(i: 4)};
1653 DAG.MorphNodeTo(N: VADeclareParam.getNode(), Opc: VADeclareParam.getOpcode(),
1654 VTs: VADeclareParam->getVTList(), Ops: DeclareParamOps);
1655 }
1656
1657 const auto *Func = dyn_cast<GlobalAddressSDNode>(Val: Callee.getNode());
1658 // If the type of the callsite does not match that of the function, convert
1659 // the callsite to an indirect call.
1660 const bool ConvertToIndirectCall = shouldConvertToIndirectCall(CB, Func);
1661
1662 // Both indirect calls and libcalls have nullptr Func. In order to distinguish
1663 // between them we must rely on the call site value which is valid for
1664 // indirect calls but is always null for libcalls.
1665 const bool IsIndirectCall = (!Func && CB) || ConvertToIndirectCall;
1666
1667 if (isa<ExternalSymbolSDNode>(Val: Callee)) {
1668 Function* CalleeFunc = nullptr;
1669
1670 // Try to find the callee in the current module.
1671 Callee = DAG.getSymbolFunctionGlobalAddress(Op: Callee, TargetFunction: &CalleeFunc);
1672 assert(CalleeFunc != nullptr && "Libcall callee must be set.");
1673
1674 // Set the "libcall callee" attribute to indicate that the function
1675 // must always have a declaration.
1676 CalleeFunc->addFnAttr(Kind: "nvptx-libcall-callee", Val: "true");
1677 }
1678
1679 if (IsIndirectCall) {
1680 // This is indirect function call case : PTX requires a prototype of the
1681 // form
1682 // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
1683 // to be emitted, and the label has to used as the last arg of call
1684 // instruction.
1685 // The prototype is embedded in a string and put as the operand for a
1686 // CallPrototype SDNode which will print out to the value of the string.
1687 const bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs);
1688 std::string Proto =
1689 getPrototype(DL, RetTy, Args, Outs: CLI.Outs,
1690 FirstVAArg: HasVAArgs ? std::optional(FirstVAArg) : std::nullopt, CB: *CB,
1691 UniqueCallSite);
1692 const char *ProtoStr = nvTM->getStrPool().save(S: Proto).data();
1693 const SDValue PrototypeDeclare = DAG.getNode(
1694 Opcode: NVPTXISD::CallPrototype, DL: dl, VT: MVT::Other,
1695 Ops: {StartChain, DAG.getTargetExternalSymbol(Sym: ProtoStr, VT: MVT::i32)});
1696 CallPrereqs.push_back(Elt: PrototypeDeclare);
1697 }
1698
1699 const unsigned Proto = IsIndirectCall ? UniqueCallSite : 0;
1700 const unsigned NumArgs =
1701 std::min<unsigned>(a: CLI.NumFixedArgs + 1, b: Args.size());
1702 /// CALL(Chain, IsConvergent, IsIndirectCall/IsUniform, NumReturns,
1703 /// NumParams, Callee, Proto)
1704 const SDValue CallToken = DAG.getTokenFactor(DL: dl, Vals&: CallPrereqs);
1705 const SDValue Call = DAG.getNode(
1706 Opcode: NVPTXISD::CALL, DL: dl, VT: MVT::Other,
1707 Ops: {CallToken, GetI32(CLI.IsConvergent), GetI32(IsIndirectCall),
1708 GetI32(Ins.empty() ? 0 : 1), GetI32(NumArgs), Callee, GetI32(Proto)});
1709
1710 SmallVector<SDValue, 16> LoadChains{Call};
1711 SmallVector<SDValue, 16> ProxyRegOps;
1712 if (!Ins.empty()) {
1713 SmallVector<EVT, 16> VTs;
1714 SmallVector<uint64_t, 16> Offsets;
1715 ComputePTXValueVTs(TLI: *this, DL, Ctx, CallConv: CLI.CallConv, Ty: RetTy, ValueVTs&: VTs, Offsets);
1716 assert(VTs.size() == Ins.size() && "Bad value decomposition");
1717
1718 const Align RetAlign = getArgumentAlignment(CB, Ty: RetTy, Idx: 0, DL);
1719 const SDValue RetSymbol = DAG.getExternalSymbol(Sym: "retval0", VT: MVT::i32);
1720
1721 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
1722 // 32-bits are sign extended or zero extended, depending on whether
1723 // they are signed or unsigned types.
1724 const bool ExtendIntegerRetVal =
1725 RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty: RetTy) < 32;
1726
1727 unsigned I = 0;
1728 const auto VI = VectorizePTXValueVTs(ValueVTs: VTs, Offsets, ParamAlignment: RetAlign);
1729 for (const unsigned NumElts : VI) {
1730 const MaybeAlign CurrentAlign =
1731 ExtendIntegerRetVal ? MaybeAlign(std::nullopt)
1732 : commonAlignment(A: RetAlign, Offset: Offsets[I]);
1733
1734 const EVT VTI = promoteScalarIntegerPTX(VT: VTs[I]);
1735 const EVT LoadVT =
1736 ExtendIntegerRetVal ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
1737 const EVT VecVT = getVectorizedVT(VT: LoadVT, N: NumElts, C&: Ctx);
1738 SDValue Ptr =
1739 DAG.getObjectPtrOffset(SL: dl, Ptr: RetSymbol, Offset: TypeSize::getFixed(ExactSize: Offsets[I]));
1740
1741 SDValue R = DAG.getLoad(
1742 VT: VecVT, dl, Chain: Call, Ptr,
1743 PtrInfo: MachinePointerInfo(NVPTX::AddressSpace::DeviceParam), Alignment: CurrentAlign);
1744
1745 LoadChains.push_back(Elt: R.getValue(R: 1));
1746 for (const unsigned J : llvm::seq(Size: NumElts))
1747 ProxyRegOps.push_back(Elt: getExtractVectorizedValue(V: R, I: J, VT: LoadVT, dl, DAG));
1748 I += NumElts;
1749 }
1750 }
1751
1752 const SDValue EndToken = DAG.getTokenFactor(DL: dl, Vals&: LoadChains);
1753 const SDValue CallEnd = DAG.getCALLSEQ_END(Chain: EndToken, Size1: UniqueCallSite,
1754 Size2: UniqueCallSite + 1, Glue: SDValue(), DL: dl);
1755
1756 // Append ProxyReg instructions to the chain to make sure that `callseq_end`
1757 // will not get lost. Otherwise, during libcalls expansion, the nodes can become
1758 // dangling.
1759 for (const auto [I, Reg] : llvm::enumerate(First&: ProxyRegOps)) {
1760 SDValue Proxy =
1761 DAG.getNode(Opcode: NVPTXISD::ProxyReg, DL: dl, VT: Reg.getValueType(), Ops: {CallEnd, Reg});
1762 SDValue Ret = correctParamType(V: Proxy, ExpectedVT: Ins[I].VT, Flags: Ins[I].Flags, DAG, dl);
1763 InVals.push_back(Elt: Ret);
1764 }
1765
1766 // set IsTailCall to false for now, until we figure out how to express
1767 // tail call optimization in PTX
1768 CLI.IsTailCall = false;
1769 return CallEnd;
1770}
1771
1772SDValue NVPTXTargetLowering::LowerDYNAMIC_STACKALLOC(SDValue Op,
1773 SelectionDAG &DAG) const {
1774
1775 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
1776 const Function &Fn = DAG.getMachineFunction().getFunction();
1777
1778 DAG.getContext()->diagnose(DI: DiagnosticInfoUnsupported(
1779 Fn,
1780 "Support for dynamic alloca introduced in PTX ISA version 7.3 and "
1781 "requires target sm_52.",
1782 SDLoc(Op).getDebugLoc()));
1783 auto Ops = {DAG.getConstant(Val: 0, DL: SDLoc(), VT: Op.getValueType()),
1784 Op.getOperand(i: 0)};
1785 return DAG.getMergeValues(Ops, dl: SDLoc());
1786 }
1787
1788 SDLoc DL(Op.getNode());
1789 SDValue Chain = Op.getOperand(i: 0);
1790 SDValue Size = Op.getOperand(i: 1);
1791 uint64_t Align = Op.getConstantOperandVal(i: 2);
1792
1793 // The alignment on a ISD::DYNAMIC_STACKALLOC node may be 0 to indicate that
1794 // the default stack alignment should be used.
1795 if (Align == 0)
1796 Align = DAG.getSubtarget().getFrameLowering()->getStackAlign().value();
1797
1798 // The size for ptx alloca instruction is 64-bit for m64 and 32-bit for m32.
1799 const MVT LocalVT = getPointerTy(DL: DAG.getDataLayout(), AS: ADDRESS_SPACE_LOCAL);
1800
1801 SDValue Alloc =
1802 DAG.getNode(Opcode: NVPTXISD::DYNAMIC_STACKALLOC, DL, ResultTys: {LocalVT, MVT::Other},
1803 Ops: {Chain, DAG.getZExtOrTrunc(Op: Size, DL, VT: LocalVT),
1804 DAG.getTargetConstant(Val: Align, DL, VT: MVT::i32)});
1805
1806 SDValue ASC = DAG.getAddrSpaceCast(
1807 dl: DL, VT: Op.getValueType(), Ptr: Alloc, SrcAS: ADDRESS_SPACE_LOCAL, DestAS: ADDRESS_SPACE_GENERIC);
1808
1809 return DAG.getMergeValues(Ops: {ASC, SDValue(Alloc.getNode(), 1)}, dl: DL);
1810}
1811
1812SDValue NVPTXTargetLowering::LowerSTACKRESTORE(SDValue Op,
1813 SelectionDAG &DAG) const {
1814 SDLoc DL(Op.getNode());
1815 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
1816 const Function &Fn = DAG.getMachineFunction().getFunction();
1817
1818 DAG.getContext()->diagnose(DI: DiagnosticInfoUnsupported(
1819 Fn,
1820 "Support for stackrestore requires PTX ISA version >= 7.3 and target "
1821 ">= sm_52.",
1822 DL.getDebugLoc()));
1823 return Op.getOperand(i: 0);
1824 }
1825
1826 const MVT LocalVT = getPointerTy(DL: DAG.getDataLayout(), AS: ADDRESS_SPACE_LOCAL);
1827 SDValue Chain = Op.getOperand(i: 0);
1828 SDValue Ptr = Op.getOperand(i: 1);
1829 SDValue ASC = DAG.getAddrSpaceCast(dl: DL, VT: LocalVT, Ptr, SrcAS: ADDRESS_SPACE_GENERIC,
1830 DestAS: ADDRESS_SPACE_LOCAL);
1831 return DAG.getNode(Opcode: NVPTXISD::STACKRESTORE, DL, VT: MVT::Other, Ops: {Chain, ASC});
1832}
1833
1834SDValue NVPTXTargetLowering::LowerSTACKSAVE(SDValue Op,
1835 SelectionDAG &DAG) const {
1836 SDLoc DL(Op.getNode());
1837 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
1838 const Function &Fn = DAG.getMachineFunction().getFunction();
1839
1840 DAG.getContext()->diagnose(DI: DiagnosticInfoUnsupported(
1841 Fn,
1842 "Support for stacksave requires PTX ISA version >= 7.3 and target >= "
1843 "sm_52.",
1844 DL.getDebugLoc()));
1845 auto Ops = {DAG.getConstant(Val: 0, DL, VT: Op.getValueType()), Op.getOperand(i: 0)};
1846 return DAG.getMergeValues(Ops, dl: DL);
1847 }
1848
1849 const MVT LocalVT = getPointerTy(DL: DAG.getDataLayout(), AS: ADDRESS_SPACE_LOCAL);
1850 SDValue Chain = Op.getOperand(i: 0);
1851 SDValue SS =
1852 DAG.getNode(Opcode: NVPTXISD::STACKSAVE, DL, ResultTys: {LocalVT, MVT::Other}, Ops: Chain);
1853 SDValue ASC = DAG.getAddrSpaceCast(
1854 dl: DL, VT: Op.getValueType(), Ptr: SS, SrcAS: ADDRESS_SPACE_LOCAL, DestAS: ADDRESS_SPACE_GENERIC);
1855 return DAG.getMergeValues(Ops: {ASC, SDValue(SS.getNode(), 1)}, dl: DL);
1856}
1857
1858// By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
1859// (see LegalizeDAG.cpp). This is slow and uses local memory.
1860// We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
1861SDValue
1862NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
1863 SDNode *Node = Op.getNode();
1864 SDLoc dl(Node);
1865 SmallVector<SDValue, 8> Ops;
1866 unsigned NumOperands = Node->getNumOperands();
1867 for (unsigned i = 0; i < NumOperands; ++i) {
1868 SDValue SubOp = Node->getOperand(Num: i);
1869 EVT VVT = SubOp.getNode()->getValueType(ResNo: 0);
1870 EVT EltVT = VVT.getVectorElementType();
1871 unsigned NumSubElem = VVT.getVectorNumElements();
1872 for (unsigned j = 0; j < NumSubElem; ++j) {
1873 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: dl, VT: EltVT, N1: SubOp,
1874 N2: DAG.getIntPtrConstant(Val: j, DL: dl)));
1875 }
1876 }
1877 return DAG.getBuildVector(VT: Node->getValueType(ResNo: 0), DL: dl, Ops);
1878}
1879
1880static SDValue getPRMT(SDValue A, SDValue B, SDValue Selector, SDLoc DL,
1881 SelectionDAG &DAG,
1882 unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
1883 assert(A.getValueType() == MVT::i32 && B.getValueType() == MVT::i32 &&
1884 Selector.getValueType() == MVT::i32 && "PRMT must have i32 operands");
1885 return DAG.getNode(Opcode: NVPTXISD::PRMT, DL, VT: MVT::i32,
1886 Ops: {A, B, Selector, DAG.getConstant(Val: Mode, DL, VT: MVT::i32)});
1887}
1888
1889static SDValue getPRMT(SDValue A, SDValue B, uint64_t Selector, SDLoc DL,
1890 SelectionDAG &DAG,
1891 unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
1892 return getPRMT(A, B, Selector: DAG.getConstant(Val: Selector, DL, VT: MVT::i32), DL, DAG, Mode);
1893}
1894
1895/// Reduces the elements using the scalar operations provided. The operations
1896/// are sorted descending in number of inputs they take. The flags on the
1897/// original reduction operation will be propagated to each scalar operation.
1898/// Nearby elements are grouped in tree reduction, unlike the shuffle reduction
1899/// used in ExpandReductions and SelectionDAG.
1900static SDValue buildTreeReduction(
1901 const SmallVector<SDValue> &Elements, EVT EltTy,
1902 ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
1903 const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
1904 // Build the reduction tree at each level, starting with all the elements.
1905 SmallVector<SDValue> Level = Elements;
1906
1907 unsigned OpIdx = 0;
1908 while (Level.size() > 1) {
1909 // Try to reduce this level using the current operator.
1910 const auto [Op, NumInputs] = Ops[OpIdx];
1911
1912 // Build the next level by partially reducing all elements.
1913 SmallVector<SDValue> ReducedLevel;
1914 unsigned I = 0, E = Level.size();
1915 for (; I + NumInputs <= E; I += NumInputs) {
1916 // Reduce elements in groups of [NumInputs], as much as possible.
1917 ReducedLevel.push_back(Elt: DAG.getNode(
1918 Opcode: Op, DL, VT: EltTy, Ops: ArrayRef<SDValue>(Level).slice(N: I, M: NumInputs), Flags));
1919 }
1920
1921 if (I < E) {
1922 // Handle leftover elements.
1923
1924 if (ReducedLevel.empty()) {
1925 // We didn't reduce anything at this level. We need to pick a smaller
1926 // operator.
1927 ++OpIdx;
1928 assert(OpIdx < Ops.size() && "no smaller operators for reduction");
1929 continue;
1930 }
1931
1932 // We reduced some things but there's still more left, meaning the
1933 // operator's number of inputs doesn't evenly divide this level size. Move
1934 // these elements to the next level.
1935 for (; I < E; ++I)
1936 ReducedLevel.push_back(Elt: Level[I]);
1937 }
1938
1939 // Process the next level.
1940 Level = ReducedLevel;
1941 }
1942
1943 return *Level.begin();
1944}
1945
1946// Get scalar reduction opcode
1947static ISD::NodeType getScalarOpcodeForReduction(unsigned ReductionOpcode) {
1948 switch (ReductionOpcode) {
1949 case ISD::VECREDUCE_FMAX:
1950 return ISD::FMAXNUM;
1951 case ISD::VECREDUCE_FMIN:
1952 return ISD::FMINNUM;
1953 case ISD::VECREDUCE_FMAXIMUM:
1954 return ISD::FMAXIMUM;
1955 case ISD::VECREDUCE_FMINIMUM:
1956 return ISD::FMINIMUM;
1957 default:
1958 llvm_unreachable("unhandled reduction opcode");
1959 }
1960}
1961
1962/// Get 3-input scalar reduction opcode
1963static std::optional<unsigned>
1964getScalar3OpcodeForReduction(unsigned ReductionOpcode) {
1965 switch (ReductionOpcode) {
1966 case ISD::VECREDUCE_FMAX:
1967 return NVPTXISD::FMAXNUM3;
1968 case ISD::VECREDUCE_FMIN:
1969 return NVPTXISD::FMINNUM3;
1970 case ISD::VECREDUCE_FMAXIMUM:
1971 return NVPTXISD::FMAXIMUM3;
1972 case ISD::VECREDUCE_FMINIMUM:
1973 return NVPTXISD::FMINIMUM3;
1974 default:
1975 return std::nullopt;
1976 }
1977}
1978
1979/// Lower reductions to either a sequence of operations or a tree if
1980/// reassociations are allowed. This method will use larger operations like
1981/// max3/min3 when the target supports them.
1982SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
1983 SelectionDAG &DAG) const {
1984 SDLoc DL(Op);
1985 const SDNodeFlags Flags = Op->getFlags();
1986 SDValue Vector = Op.getOperand(i: 0);
1987
1988 const unsigned Opcode = Op->getOpcode();
1989 const EVT EltTy = Vector.getValueType().getVectorElementType();
1990
1991 // Whether we can use 3-input min/max when expanding the reduction.
1992 const bool CanUseMinMax3 =
1993 EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
1994 STI.getPTXVersion() >= 88 &&
1995 (Opcode == ISD::VECREDUCE_FMAX || Opcode == ISD::VECREDUCE_FMIN ||
1996 Opcode == ISD::VECREDUCE_FMAXIMUM || Opcode == ISD::VECREDUCE_FMINIMUM);
1997
1998 // A list of SDNode opcodes with equivalent semantics, sorted descending by
1999 // number of inputs they take.
2000 SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> ScalarOps;
2001
2002 if (auto Opcode3Elem = getScalar3OpcodeForReduction(ReductionOpcode: Opcode);
2003 CanUseMinMax3 && Opcode3Elem)
2004 ScalarOps.push_back(Elt: {*Opcode3Elem, 3});
2005 ScalarOps.push_back(Elt: {getScalarOpcodeForReduction(ReductionOpcode: Opcode), 2});
2006
2007 SmallVector<SDValue> Elements;
2008 DAG.ExtractVectorElements(Op: Vector, Args&: Elements);
2009
2010 return buildTreeReduction(Elements, EltTy, Ops: ScalarOps, DL, Flags, DAG);
2011}
2012
2013SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
2014 // Handle bitcasting from v2i8 without hitting the default promotion
2015 // strategy which goes through stack memory.
2016 EVT FromVT = Op->getOperand(Num: 0)->getValueType(ResNo: 0);
2017 if (FromVT != MVT::v2i8) {
2018 return Op;
2019 }
2020
2021 // Pack vector elements into i16 and bitcast to final type
2022 SDLoc DL(Op);
2023 SDValue Vec0 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i8,
2024 N1: Op->getOperand(Num: 0), N2: DAG.getIntPtrConstant(Val: 0, DL));
2025 SDValue Vec1 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i8,
2026 N1: Op->getOperand(Num: 0), N2: DAG.getIntPtrConstant(Val: 1, DL));
2027 SDValue Extend0 = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: MVT::i16, Operand: Vec0);
2028 SDValue Extend1 = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: MVT::i16, Operand: Vec1);
2029 SDValue Const8 = DAG.getConstant(Val: 8, DL, VT: MVT::i16);
2030 SDValue AsInt = DAG.getNode(
2031 Opcode: ISD::OR, DL, VT: MVT::i16,
2032 Ops: {Extend0, DAG.getNode(Opcode: ISD::SHL, DL, VT: MVT::i16, Ops: {Extend1, Const8})});
2033 EVT ToVT = Op->getValueType(ResNo: 0);
2034 return DAG.getBitcast(VT: ToVT, V: AsInt);
2035}
2036
2037// We can init constant f16x2/v2i16/v4i8 with a single .b32 move. Normally it
2038// would get lowered as two constant loads and vector-packing move.
2039// Instead we want just a constant move:
2040// mov.b32 %r2, 0x40003C00
2041SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
2042 SelectionDAG &DAG) const {
2043 EVT VT = Op->getValueType(ResNo: 0);
2044 if (!(NVPTX::isPackedVectorTy(VT) && VT.is32BitVector()))
2045 return Op;
2046 SDLoc DL(Op);
2047
2048 if (!llvm::all_of(Range: Op->ops(), P: [](SDValue Operand) {
2049 return Operand->isUndef() || isa<ConstantSDNode>(Val: Operand) ||
2050 isa<ConstantFPSDNode>(Val: Operand);
2051 })) {
2052 if (VT != MVT::v4i8)
2053 return Op;
2054 // Lower non-const v4i8 vector as byte-wise constructed i32, which allows us
2055 // to optimize calculation of constant parts.
2056 auto GetPRMT = [&](const SDValue Left, const SDValue Right, bool Cast,
2057 uint64_t SelectionValue) -> SDValue {
2058 SDValue L = Left;
2059 SDValue R = Right;
2060 if (Cast) {
2061 L = DAG.getAnyExtOrTrunc(Op: L, DL, VT: MVT::i32);
2062 R = DAG.getAnyExtOrTrunc(Op: R, DL, VT: MVT::i32);
2063 }
2064 return getPRMT(A: L, B: R, Selector: SelectionValue, DL, DAG);
2065 };
2066 auto PRMT__10 = GetPRMT(Op->getOperand(Num: 0), Op->getOperand(Num: 1), true, 0x3340);
2067 auto PRMT__32 = GetPRMT(Op->getOperand(Num: 2), Op->getOperand(Num: 3), true, 0x3340);
2068 auto PRMT3210 = GetPRMT(PRMT__10, PRMT__32, false, 0x5410);
2069 return DAG.getBitcast(VT, V: PRMT3210);
2070 }
2071
2072 // Get value or the Nth operand as an APInt(32). Undef values treated as 0.
2073 auto GetOperand = [](SDValue Op, int N) -> APInt {
2074 const SDValue &Operand = Op->getOperand(Num: N);
2075 EVT VT = Op->getValueType(ResNo: 0);
2076 if (Operand->isUndef())
2077 return APInt(32, 0);
2078 APInt Value;
2079 if (VT == MVT::v2f16 || VT == MVT::v2bf16)
2080 Value = cast<ConstantFPSDNode>(Val: Operand)->getValueAPF().bitcastToAPInt();
2081 else if (VT == MVT::v2i16 || VT == MVT::v4i8)
2082 Value = Operand->getAsAPIntVal();
2083 else
2084 llvm_unreachable("Unsupported type");
2085 // i8 values are carried around as i16, so we need to zero out upper bits,
2086 // so they do not get in the way of combining individual byte values
2087 if (VT == MVT::v4i8)
2088 Value = Value.trunc(width: 8);
2089 return Value.zext(width: 32);
2090 };
2091
2092 // Construct a 32-bit constant by shifting into place smaller values
2093 // (elements of the vector type VT).
2094 // For example, if VT has 2 elements, then N == 2:
2095 // ShiftAmount = 32 / N = 16
2096 // Value |= Op0 (b16) << 0
2097 // Value |= Op1 (b16) << 16
2098 // If N == 4:
2099 // ShiftAmount = 32 / N = 8
2100 // Value |= Op0 (b8) << 0
2101 // Value |= Op1 (b8) << 8
2102 // Value |= Op2 (b8) << 16
2103 // Value |= Op3 (b8) << 24
2104 // ...etc
2105 APInt Value(32, 0);
2106 const unsigned NumElements = VT.getVectorNumElements();
2107 assert(32 % NumElements == 0 && "must evenly divide bit length");
2108 const unsigned ShiftAmount = 32 / NumElements;
2109 for (unsigned ElementNo : seq(Size: NumElements))
2110 Value |= GetOperand(Op, ElementNo).shl(shiftAmt: ElementNo * ShiftAmount);
2111 SDValue Const = DAG.getConstant(Val: Value, DL, VT: MVT::i32);
2112 return DAG.getNode(Opcode: ISD::BITCAST, DL, VT: Op->getValueType(ResNo: 0), Operand: Const);
2113}
2114
2115SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
2116 SelectionDAG &DAG) const {
2117 SDValue Index = Op->getOperand(Num: 1);
2118 SDValue Vector = Op->getOperand(Num: 0);
2119 SDLoc DL(Op);
2120 EVT VectorVT = Vector.getValueType();
2121
2122 if (VectorVT == MVT::v4i8) {
2123 SDValue Selector = DAG.getNode(Opcode: ISD::OR, DL, VT: MVT::i32,
2124 N1: DAG.getZExtOrTrunc(Op: Index, DL, VT: MVT::i32),
2125 N2: DAG.getConstant(Val: 0x7770, DL, VT: MVT::i32));
2126 SDValue PRMT = getPRMT(A: DAG.getBitcast(VT: MVT::i32, V: Vector),
2127 B: DAG.getConstant(Val: 0, DL, VT: MVT::i32), Selector, DL, DAG);
2128 SDValue Ext = DAG.getAnyExtOrTrunc(Op: PRMT, DL, VT: Op->getValueType(ResNo: 0));
2129 SDNodeFlags Flags;
2130 Flags.setNoSignedWrap(Ext.getScalarValueSizeInBits() > 8);
2131 Flags.setNoUnsignedWrap(Ext.getScalarValueSizeInBits() >= 8);
2132 Ext->setFlags(Flags);
2133 return Ext;
2134 }
2135
2136 // Constant index will be matched by tablegen.
2137 if (isa<ConstantSDNode>(Val: Index.getNode()))
2138 return Op;
2139
2140 // Extract individual elements and select one of them.
2141 assert(NVPTX::isPackedVectorTy(VectorVT) &&
2142 VectorVT.getVectorNumElements() == 2 && "Unexpected vector type.");
2143 EVT EltVT = VectorVT.getVectorElementType();
2144
2145 SDLoc dl(Op.getNode());
2146 SDValue E0 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: dl, VT: EltVT, N1: Vector,
2147 N2: DAG.getIntPtrConstant(Val: 0, DL: dl));
2148 SDValue E1 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: dl, VT: EltVT, N1: Vector,
2149 N2: DAG.getIntPtrConstant(Val: 1, DL: dl));
2150 return DAG.getSelectCC(DL: dl, LHS: Index, RHS: DAG.getIntPtrConstant(Val: 0, DL: dl), True: E0, False: E1,
2151 Cond: ISD::CondCode::SETEQ);
2152}
2153
2154SDValue NVPTXTargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op,
2155 SelectionDAG &DAG) const {
2156 SDValue Vector = Op->getOperand(Num: 0);
2157 EVT VectorVT = Vector.getValueType();
2158
2159 if (VectorVT != MVT::v4i8)
2160 return Op;
2161 SDLoc DL(Op);
2162 SDValue Value = Op->getOperand(Num: 1);
2163 if (Value->isUndef())
2164 return Vector;
2165
2166 SDValue Index = Op->getOperand(Num: 2);
2167
2168 SDValue BFI =
2169 DAG.getNode(Opcode: NVPTXISD::BFI, DL, VT: MVT::i32,
2170 Ops: {DAG.getZExtOrTrunc(Op: Value, DL, VT: MVT::i32), Vector,
2171 DAG.getNode(Opcode: ISD::MUL, DL, VT: MVT::i32,
2172 N1: DAG.getZExtOrTrunc(Op: Index, DL, VT: MVT::i32),
2173 N2: DAG.getConstant(Val: 8, DL, VT: MVT::i32)),
2174 DAG.getConstant(Val: 8, DL, VT: MVT::i32)});
2175 return DAG.getNode(Opcode: ISD::BITCAST, DL, VT: Op->getValueType(ResNo: 0), Operand: BFI);
2176}
2177
2178SDValue NVPTXTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
2179 SelectionDAG &DAG) const {
2180 SDValue V1 = Op.getOperand(i: 0);
2181 EVT VectorVT = V1.getValueType();
2182 if (VectorVT != MVT::v4i8 || Op.getValueType() != MVT::v4i8)
2183 return Op;
2184
2185 // Lower shuffle to PRMT instruction.
2186 const ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Val: Op.getNode());
2187 SDValue V2 = Op.getOperand(i: 1);
2188 uint32_t Selector = 0;
2189 for (auto I : llvm::enumerate(First: SVN->getMask())) {
2190 if (I.value() != -1) // -1 is a placeholder for undef.
2191 Selector |= (I.value() << (I.index() * 4));
2192 }
2193
2194 SDLoc DL(Op);
2195 SDValue PRMT = getPRMT(A: DAG.getBitcast(VT: MVT::i32, V: V1),
2196 B: DAG.getBitcast(VT: MVT::i32, V: V2), Selector, DL, DAG);
2197 return DAG.getBitcast(VT: Op.getValueType(), V: PRMT);
2198}
2199/// LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which
2200/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
2201/// amount, or
2202/// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift
2203/// amount.
2204SDValue NVPTXTargetLowering::LowerShiftRightParts(SDValue Op,
2205 SelectionDAG &DAG) const {
2206 assert(Op.getNumOperands() == 3 && "Not a double-shift!");
2207 assert(Op.getOpcode() == ISD::SRA_PARTS || Op.getOpcode() == ISD::SRL_PARTS);
2208
2209 EVT VT = Op.getValueType();
2210 unsigned VTBits = VT.getSizeInBits();
2211 SDLoc dl(Op);
2212 SDValue ShOpLo = Op.getOperand(i: 0);
2213 SDValue ShOpHi = Op.getOperand(i: 1);
2214 SDValue ShAmt = Op.getOperand(i: 2);
2215 unsigned Opc = (Op.getOpcode() == ISD::SRA_PARTS) ? ISD::SRA : ISD::SRL;
2216
2217 if (VTBits == 32 && STI.getSmVersion() >= 35) {
2218 // For 32bit and sm35, we can use the funnel shift 'shf' instruction.
2219 // {dHi, dLo} = {aHi, aLo} >> Amt
2220 // dHi = aHi >> Amt
2221 // dLo = shf.r.clamp aLo, aHi, Amt
2222
2223 SDValue Hi = DAG.getNode(Opcode: Opc, DL: dl, VT, N1: ShOpHi, N2: ShAmt);
2224 SDValue Lo =
2225 DAG.getNode(Opcode: NVPTXISD::FSHR_CLAMP, DL: dl, VT, N1: ShOpHi, N2: ShOpLo, N3: ShAmt);
2226
2227 SDValue Ops[2] = { Lo, Hi };
2228 return DAG.getMergeValues(Ops, dl);
2229 }
2230 else {
2231 // {dHi, dLo} = {aHi, aLo} >> Amt
2232 // - if (Amt>=size) then
2233 // dLo = aHi >> (Amt-size)
2234 // dHi = aHi >> Amt (this is either all 0 or all 1)
2235 // else
2236 // dLo = (aLo >>logic Amt) | (aHi << (size-Amt))
2237 // dHi = aHi >> Amt
2238
2239 SDValue RevShAmt = DAG.getNode(Opcode: ISD::SUB, DL: dl, VT: MVT::i32,
2240 N1: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32),
2241 N2: ShAmt);
2242 SDValue Tmp1 = DAG.getNode(Opcode: ISD::SRL, DL: dl, VT, N1: ShOpLo, N2: ShAmt);
2243 SDValue ExtraShAmt = DAG.getNode(Opcode: ISD::SUB, DL: dl, VT: MVT::i32, N1: ShAmt,
2244 N2: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32));
2245 SDValue Tmp2 = DAG.getNode(Opcode: ISD::SHL, DL: dl, VT, N1: ShOpHi, N2: RevShAmt);
2246 SDValue FalseVal = DAG.getNode(Opcode: ISD::OR, DL: dl, VT, N1: Tmp1, N2: Tmp2);
2247 SDValue TrueVal = DAG.getNode(Opcode: Opc, DL: dl, VT, N1: ShOpHi, N2: ExtraShAmt);
2248
2249 SDValue Cmp = DAG.getSetCC(DL: dl, VT: MVT::i1, LHS: ShAmt,
2250 RHS: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32),
2251 Cond: ISD::SETGE);
2252 SDValue Hi = DAG.getNode(Opcode: Opc, DL: dl, VT, N1: ShOpHi, N2: ShAmt);
2253 SDValue Lo = DAG.getNode(Opcode: ISD::SELECT, DL: dl, VT, N1: Cmp, N2: TrueVal, N3: FalseVal);
2254
2255 SDValue Ops[2] = { Lo, Hi };
2256 return DAG.getMergeValues(Ops, dl);
2257 }
2258}
2259
2260/// LowerShiftLeftParts - Lower SHL_PARTS, which
2261/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
2262/// amount, or
2263/// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift
2264/// amount.
2265SDValue NVPTXTargetLowering::LowerShiftLeftParts(SDValue Op,
2266 SelectionDAG &DAG) const {
2267 assert(Op.getNumOperands() == 3 && "Not a double-shift!");
2268 assert(Op.getOpcode() == ISD::SHL_PARTS);
2269
2270 EVT VT = Op.getValueType();
2271 unsigned VTBits = VT.getSizeInBits();
2272 SDLoc dl(Op);
2273 SDValue ShOpLo = Op.getOperand(i: 0);
2274 SDValue ShOpHi = Op.getOperand(i: 1);
2275 SDValue ShAmt = Op.getOperand(i: 2);
2276
2277 if (VTBits == 32 && STI.getSmVersion() >= 35) {
2278 // For 32bit and sm35, we can use the funnel shift 'shf' instruction.
2279 // {dHi, dLo} = {aHi, aLo} << Amt
2280 // dHi = shf.l.clamp aLo, aHi, Amt
2281 // dLo = aLo << Amt
2282
2283 SDValue Hi =
2284 DAG.getNode(Opcode: NVPTXISD::FSHL_CLAMP, DL: dl, VT, N1: ShOpHi, N2: ShOpLo, N3: ShAmt);
2285 SDValue Lo = DAG.getNode(Opcode: ISD::SHL, DL: dl, VT, N1: ShOpLo, N2: ShAmt);
2286
2287 SDValue Ops[2] = { Lo, Hi };
2288 return DAG.getMergeValues(Ops, dl);
2289 }
2290 else {
2291 // {dHi, dLo} = {aHi, aLo} << Amt
2292 // - if (Amt>=size) then
2293 // dLo = aLo << Amt (all 0)
2294 // dLo = aLo << (Amt-size)
2295 // else
2296 // dLo = aLo << Amt
2297 // dHi = (aHi << Amt) | (aLo >> (size-Amt))
2298
2299 SDValue RevShAmt = DAG.getNode(Opcode: ISD::SUB, DL: dl, VT: MVT::i32,
2300 N1: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32),
2301 N2: ShAmt);
2302 SDValue Tmp1 = DAG.getNode(Opcode: ISD::SHL, DL: dl, VT, N1: ShOpHi, N2: ShAmt);
2303 SDValue ExtraShAmt = DAG.getNode(Opcode: ISD::SUB, DL: dl, VT: MVT::i32, N1: ShAmt,
2304 N2: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32));
2305 SDValue Tmp2 = DAG.getNode(Opcode: ISD::SRL, DL: dl, VT, N1: ShOpLo, N2: RevShAmt);
2306 SDValue FalseVal = DAG.getNode(Opcode: ISD::OR, DL: dl, VT, N1: Tmp1, N2: Tmp2);
2307 SDValue TrueVal = DAG.getNode(Opcode: ISD::SHL, DL: dl, VT, N1: ShOpLo, N2: ExtraShAmt);
2308
2309 SDValue Cmp = DAG.getSetCC(DL: dl, VT: MVT::i1, LHS: ShAmt,
2310 RHS: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32),
2311 Cond: ISD::SETGE);
2312 SDValue Lo = DAG.getNode(Opcode: ISD::SHL, DL: dl, VT, N1: ShOpLo, N2: ShAmt);
2313 SDValue Hi = DAG.getNode(Opcode: ISD::SELECT, DL: dl, VT, N1: Cmp, N2: TrueVal, N3: FalseVal);
2314
2315 SDValue Ops[2] = { Lo, Hi };
2316 return DAG.getMergeValues(Ops, dl);
2317 }
2318}
2319
2320/// If the types match, convert the generic copysign to the NVPTXISD version,
2321/// otherwise bail ensuring that mismatched cases are properly expaned.
2322SDValue NVPTXTargetLowering::LowerFCOPYSIGN(SDValue Op,
2323 SelectionDAG &DAG) const {
2324 EVT VT = Op.getValueType();
2325 SDLoc DL(Op);
2326
2327 SDValue In1 = Op.getOperand(i: 0);
2328 SDValue In2 = Op.getOperand(i: 1);
2329 EVT SrcVT = In2.getValueType();
2330
2331 if (!SrcVT.bitsEq(VT))
2332 return SDValue();
2333
2334 return DAG.getNode(Opcode: NVPTXISD::FCOPYSIGN, DL, VT, N1: In1, N2: In2);
2335}
2336
2337SDValue NVPTXTargetLowering::LowerFROUND(SDValue Op, SelectionDAG &DAG) const {
2338 EVT VT = Op.getValueType();
2339
2340 if (VT == MVT::f32)
2341 return LowerFROUND32(Op, DAG);
2342
2343 if (VT == MVT::f64)
2344 return LowerFROUND64(Op, DAG);
2345
2346 llvm_unreachable("unhandled type");
2347}
2348
2349// This is the the rounding method used in CUDA libdevice in C like code:
2350// float roundf(float A)
2351// {
2352// float RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f));
2353// RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
2354// return abs(A) < 0.5 ? (float)(int)A : RoundedA;
2355// }
2356SDValue NVPTXTargetLowering::LowerFROUND32(SDValue Op,
2357 SelectionDAG &DAG) const {
2358 SDLoc SL(Op);
2359 SDValue A = Op.getOperand(i: 0);
2360 EVT VT = Op.getValueType();
2361
2362 SDValue AbsA = DAG.getNode(Opcode: ISD::FABS, DL: SL, VT, Operand: A);
2363
2364 // RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f))
2365 SDValue Bitcast = DAG.getNode(Opcode: ISD::BITCAST, DL: SL, VT: MVT::i32, Operand: A);
2366 const unsigned SignBitMask = 0x80000000;
2367 SDValue Sign = DAG.getNode(Opcode: ISD::AND, DL: SL, VT: MVT::i32, N1: Bitcast,
2368 N2: DAG.getConstant(Val: SignBitMask, DL: SL, VT: MVT::i32));
2369 const unsigned PointFiveInBits = 0x3F000000;
2370 SDValue PointFiveWithSignRaw =
2371 DAG.getNode(Opcode: ISD::OR, DL: SL, VT: MVT::i32, N1: Sign,
2372 N2: DAG.getConstant(Val: PointFiveInBits, DL: SL, VT: MVT::i32));
2373 SDValue PointFiveWithSign =
2374 DAG.getNode(Opcode: ISD::BITCAST, DL: SL, VT, Operand: PointFiveWithSignRaw);
2375 SDValue AdjustedA = DAG.getNode(Opcode: ISD::FADD, DL: SL, VT, N1: A, N2: PointFiveWithSign);
2376 SDValue RoundedA = DAG.getNode(Opcode: ISD::FTRUNC, DL: SL, VT, Operand: AdjustedA);
2377
2378 // RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
2379 EVT SetCCVT = getSetCCResultType(DL: DAG.getDataLayout(), Ctx&: *DAG.getContext(), VT);
2380 SDValue IsLarge =
2381 DAG.getSetCC(DL: SL, VT: SetCCVT, LHS: AbsA, RHS: DAG.getConstantFP(Val: pow(x: 2.0, y: 23.0), DL: SL, VT),
2382 Cond: ISD::SETOGT);
2383 RoundedA = DAG.getNode(Opcode: ISD::SELECT, DL: SL, VT, N1: IsLarge, N2: A, N3: RoundedA);
2384
2385 // return abs(A) < 0.5 ? (float)(int)A : RoundedA;
2386 SDValue IsSmall =DAG.getSetCC(DL: SL, VT: SetCCVT, LHS: AbsA,
2387 RHS: DAG.getConstantFP(Val: 0.5, DL: SL, VT), Cond: ISD::SETOLT);
2388 SDValue RoundedAForSmallA = DAG.getNode(Opcode: ISD::FTRUNC, DL: SL, VT, Operand: A);
2389 return DAG.getNode(Opcode: ISD::SELECT, DL: SL, VT, N1: IsSmall, N2: RoundedAForSmallA, N3: RoundedA);
2390}
2391
2392// The implementation of round(double) is similar to that of round(float) in
2393// that they both separate the value range into three regions and use a method
2394// specific to the region to round the values. However, round(double) first
2395// calculates the round of the absolute value and then adds the sign back while
2396// round(float) directly rounds the value with sign.
2397SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
2398 SelectionDAG &DAG) const {
2399 SDLoc SL(Op);
2400 SDValue A = Op.getOperand(i: 0);
2401 EVT VT = Op.getValueType();
2402
2403 SDValue AbsA = DAG.getNode(Opcode: ISD::FABS, DL: SL, VT, Operand: A);
2404
2405 // double RoundedA = (double) (int) (abs(A) + 0.5f);
2406 SDValue AdjustedA = DAG.getNode(Opcode: ISD::FADD, DL: SL, VT, N1: AbsA,
2407 N2: DAG.getConstantFP(Val: 0.5, DL: SL, VT));
2408 SDValue RoundedA = DAG.getNode(Opcode: ISD::FTRUNC, DL: SL, VT, Operand: AdjustedA);
2409
2410 // RoundedA = abs(A) < 0.5 ? (double)0 : RoundedA;
2411 EVT SetCCVT = getSetCCResultType(DL: DAG.getDataLayout(), Ctx&: *DAG.getContext(), VT);
2412 SDValue IsSmall =DAG.getSetCC(DL: SL, VT: SetCCVT, LHS: AbsA,
2413 RHS: DAG.getConstantFP(Val: 0.5, DL: SL, VT), Cond: ISD::SETOLT);
2414 RoundedA = DAG.getNode(Opcode: ISD::SELECT, DL: SL, VT, N1: IsSmall,
2415 N2: DAG.getConstantFP(Val: 0, DL: SL, VT),
2416 N3: RoundedA);
2417
2418 // Add sign to rounded_A
2419 RoundedA = DAG.getNode(Opcode: ISD::FCOPYSIGN, DL: SL, VT, N1: RoundedA, N2: A);
2420 DAG.getNode(Opcode: ISD::FTRUNC, DL: SL, VT, Operand: A);
2421
2422 // RoundedA = abs(A) > 0x1.0p52 ? A : RoundedA;
2423 SDValue IsLarge =
2424 DAG.getSetCC(DL: SL, VT: SetCCVT, LHS: AbsA, RHS: DAG.getConstantFP(Val: pow(x: 2.0, y: 52.0), DL: SL, VT),
2425 Cond: ISD::SETOGT);
2426 return DAG.getNode(Opcode: ISD::SELECT, DL: SL, VT, N1: IsLarge, N2: A, N3: RoundedA);
2427}
2428
2429static SDValue PromoteBinOpToF32(SDNode *N, SelectionDAG &DAG) {
2430 EVT VT = N->getValueType(ResNo: 0);
2431 EVT NVT = MVT::f32;
2432 if (VT.isVector()) {
2433 NVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: NVT, EC: VT.getVectorElementCount());
2434 }
2435 SDLoc DL(N);
2436 SDValue Tmp0 = DAG.getFPExtendOrRound(Op: N->getOperand(Num: 0), DL, VT: NVT);
2437 SDValue Tmp1 = DAG.getFPExtendOrRound(Op: N->getOperand(Num: 1), DL, VT: NVT);
2438 SDValue Res = DAG.getNode(Opcode: N->getOpcode(), DL, VT: NVT, N1: Tmp0, N2: Tmp1, Flags: N->getFlags());
2439 return DAG.getFPExtendOrRound(Op: Res, DL, VT);
2440}
2441
2442SDValue NVPTXTargetLowering::PromoteBinOpIfF32FTZ(SDValue Op,
2443 SelectionDAG &DAG) const {
2444 if (useF32FTZ(MF: DAG.getMachineFunction())) {
2445 return PromoteBinOpToF32(N: Op.getNode(), DAG);
2446 }
2447 return Op;
2448}
2449
2450SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
2451 SelectionDAG &DAG) const {
2452 assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
2453
2454 if (Op.getValueType() == MVT::bf16) {
2455 SDLoc Loc(Op);
2456 return DAG.getNode(
2457 Opcode: ISD::FP_ROUND, DL: Loc, VT: MVT::bf16,
2458 N1: DAG.getNode(Opcode: Op.getOpcode(), DL: Loc, VT: MVT::f32, Operand: Op.getOperand(i: 0)),
2459 N2: DAG.getIntPtrConstant(Val: 0, DL: Loc, /*isTarget=*/true));
2460 }
2461
2462 // Everything else is considered legal.
2463 return Op;
2464}
2465
2466SDValue NVPTXTargetLowering::LowerFP_TO_INT(SDValue Op,
2467 SelectionDAG &DAG) const {
2468 assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
2469
2470 if (Op.getOperand(i: 0).getValueType() == MVT::bf16) {
2471 SDLoc Loc(Op);
2472 return DAG.getNode(
2473 Opcode: Op.getOpcode(), DL: Loc, VT: Op.getValueType(),
2474 Operand: DAG.getNode(Opcode: ISD::FP_EXTEND, DL: Loc, VT: MVT::f32, Operand: Op.getOperand(i: 0)));
2475 }
2476
2477 // Everything else is considered legal.
2478 return Op;
2479}
2480
2481SDValue NVPTXTargetLowering::LowerFP_ROUND(SDValue Op,
2482 SelectionDAG &DAG) const {
2483 EVT NarrowVT = Op.getValueType();
2484 SDValue Wide = Op.getOperand(i: 0);
2485 EVT WideVT = Wide.getValueType();
2486 if (NarrowVT.getScalarType() == MVT::bf16) {
2487 const TargetLowering *TLI = STI.getTargetLowering();
2488 if (STI.getSmVersion() < 80 || STI.getPTXVersion() < 70) {
2489 return TLI->expandFP_ROUND(Node: Op.getNode(), DAG);
2490 }
2491 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
2492 // This combination was the first to support f32 -> bf16.
2493 if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70) {
2494 if (WideVT.getScalarType() == MVT::f32) {
2495 return Op;
2496 }
2497 if (WideVT.getScalarType() == MVT::f64) {
2498 SDLoc Loc(Op);
2499 // Round-inexact-to-odd f64 to f32, then do the final rounding using
2500 // the hardware f32 -> bf16 instruction.
2501 SDValue rod = TLI->expandRoundInexactToOdd(
2502 ResultVT: WideVT.changeElementType(Context&: *DAG.getContext(), EltVT: MVT::f32), Op: Wide, DL: Loc,
2503 DAG);
2504 return DAG.getFPExtendOrRound(Op: rod, DL: Loc, VT: NarrowVT);
2505 }
2506 }
2507 return TLI->expandFP_ROUND(Node: Op.getNode(), DAG);
2508 }
2509 }
2510
2511 // Everything else is considered legal.
2512 return Op;
2513}
2514
2515SDValue NVPTXTargetLowering::LowerFP_EXTEND(SDValue Op,
2516 SelectionDAG &DAG) const {
2517 SDValue Narrow = Op.getOperand(i: 0);
2518 EVT NarrowVT = Narrow.getValueType();
2519 EVT WideVT = Op.getValueType();
2520 if (NarrowVT.getScalarType() == MVT::bf16) {
2521 if (WideVT.getScalarType() == MVT::f32 &&
2522 (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71)) {
2523 SDLoc Loc(Op);
2524 return DAG.getNode(Opcode: ISD::BF16_TO_FP, DL: Loc, VT: WideVT, Operand: Narrow);
2525 }
2526 if (WideVT.getScalarType() == MVT::f64 &&
2527 (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78)) {
2528 EVT F32 = NarrowVT.changeElementType(Context&: *DAG.getContext(), EltVT: MVT::f32);
2529 SDLoc Loc(Op);
2530 if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 71) {
2531 Op = DAG.getNode(Opcode: ISD::FP_EXTEND, DL: Loc, VT: F32, Operand: Narrow);
2532 } else {
2533 Op = DAG.getNode(Opcode: ISD::BF16_TO_FP, DL: Loc, VT: F32, Operand: Narrow);
2534 }
2535 return DAG.getNode(Opcode: ISD::FP_EXTEND, DL: Loc, VT: WideVT, Operand: Op);
2536 }
2537 }
2538
2539 // Everything else is considered legal.
2540 return Op;
2541}
2542
2543static SDValue LowerVectorArith(SDValue Op, SelectionDAG &DAG) {
2544 SDLoc DL(Op);
2545 if (Op.getValueType() != MVT::v2i16)
2546 return Op;
2547 EVT EltVT = Op.getValueType().getVectorElementType();
2548 SmallVector<SDValue> VecElements;
2549 for (int I = 0, E = Op.getValueType().getVectorNumElements(); I < E; I++) {
2550 SmallVector<SDValue> ScalarArgs;
2551 llvm::transform(Range: Op->ops(), d_first: std::back_inserter(x&: ScalarArgs),
2552 F: [&](const SDUse &O) {
2553 return DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT,
2554 N1: O.get(), N2: DAG.getIntPtrConstant(Val: I, DL));
2555 });
2556 VecElements.push_back(Elt: DAG.getNode(Opcode: Op.getOpcode(), DL, VT: EltVT, Ops: ScalarArgs));
2557 }
2558 SDValue V =
2559 DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: Op.getValueType(), Ops: VecElements);
2560 return V;
2561}
2562
2563static SDValue lowerTcgen05St(SDValue Op, SelectionDAG &DAG,
2564 bool hasOffset = false) {
2565 // skip lowering if the vector operand is already legalized
2566 if (!Op->getOperand(Num: hasOffset ? 4 : 3).getValueType().isVector())
2567 return Op;
2568
2569 SDNode *N = Op.getNode();
2570 SDLoc DL(N);
2571 SmallVector<SDValue, 32> Ops;
2572
2573 // split the vector argument
2574 for (size_t I = 0; I < N->getNumOperands(); I++) {
2575 SDValue Val = N->getOperand(Num: I);
2576 EVT ValVT = Val.getValueType();
2577 if (ValVT.isVector()) {
2578 EVT EltVT = ValVT.getVectorElementType();
2579 for (unsigned J = 0, NElts = ValVT.getVectorNumElements(); J < NElts; J++)
2580 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: Val,
2581 N2: DAG.getIntPtrConstant(Val: J, DL)));
2582 } else
2583 Ops.push_back(Elt: Val);
2584 }
2585
2586 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
2587 SDValue Tcgen05StNode =
2588 DAG.getMemIntrinsicNode(Opcode: ISD::INTRINSIC_VOID, dl: DL, VTList: N->getVTList(), Ops,
2589 MemVT: MemSD->getMemoryVT(), MMO: MemSD->getMemOperand());
2590
2591 return Tcgen05StNode;
2592}
2593
2594static SDValue lowerBSWAP(SDValue Op, SelectionDAG &DAG) {
2595 SDLoc DL(Op);
2596 SDValue Src = Op.getOperand(i: 0);
2597 EVT VT = Op.getValueType();
2598
2599 switch (VT.getSimpleVT().SimpleTy) {
2600 case MVT::i16: {
2601 SDValue Extended = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT: MVT::i32, Operand: Src);
2602 SDValue Swapped =
2603 getPRMT(A: Extended, B: DAG.getConstant(Val: 0, DL, VT: MVT::i32), Selector: 0x7701, DL, DAG);
2604 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MVT::i16, Operand: Swapped);
2605 }
2606 case MVT::i32: {
2607 return getPRMT(A: Src, B: DAG.getConstant(Val: 0, DL, VT: MVT::i32), Selector: 0x0123, DL, DAG);
2608 }
2609 case MVT::v2i16: {
2610 SDValue Converted = DAG.getBitcast(VT: MVT::i32, V: Src);
2611 SDValue Swapped =
2612 getPRMT(A: Converted, B: DAG.getConstant(Val: 0, DL, VT: MVT::i32), Selector: 0x2301, DL, DAG);
2613 return DAG.getNode(Opcode: ISD::BITCAST, DL, VT: MVT::v2i16, Operand: Swapped);
2614 }
2615 case MVT::i64: {
2616 SDValue UnpackSrc =
2617 DAG.getNode(Opcode: NVPTXISD::UNPACK_VECTOR, DL, ResultTys: {MVT::i32, MVT::i32}, Ops: Src);
2618 SDValue SwappedLow =
2619 getPRMT(A: UnpackSrc.getValue(R: 0), B: DAG.getConstant(Val: 0, DL, VT: MVT::i32), Selector: 0x0123,
2620 DL, DAG);
2621 SDValue SwappedHigh =
2622 getPRMT(A: UnpackSrc.getValue(R: 1), B: DAG.getConstant(Val: 0, DL, VT: MVT::i32), Selector: 0x0123,
2623 DL, DAG);
2624 return DAG.getNode(Opcode: NVPTXISD::BUILD_VECTOR, DL, VT: MVT::i64,
2625 Ops: {SwappedHigh, SwappedLow});
2626 }
2627 default:
2628 llvm_unreachable("unsupported type for bswap");
2629 }
2630}
2631
2632static unsigned getTcgen05MMADisableOutputLane(unsigned IID) {
2633 switch (IID) {
2634 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1:
2635 return NVPTXISD::TCGEN05_MMA_SHARED_DISABLE_OUTPUT_LANE_CG1;
2636 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2:
2637 return NVPTXISD::TCGEN05_MMA_SHARED_DISABLE_OUTPUT_LANE_CG2;
2638 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1:
2639 return NVPTXISD::TCGEN05_MMA_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG1;
2640 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2:
2641 return NVPTXISD::TCGEN05_MMA_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG2;
2642 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1:
2643 return NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG1;
2644 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2:
2645 return NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG2;
2646 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1:
2647 return NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1;
2648 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2:
2649 return NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2;
2650 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift:
2651 return NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG1_ASHIFT;
2652 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift:
2653 return NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG2_ASHIFT;
2654 case Intrinsic::
2655 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift:
2656 return NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1_ASHIFT;
2657 case Intrinsic::
2658 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift:
2659 return NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2_ASHIFT;
2660 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1:
2661 return NVPTXISD::TCGEN05_MMA_SP_SHARED_DISABLE_OUTPUT_LANE_CG1;
2662 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2:
2663 return NVPTXISD::TCGEN05_MMA_SP_SHARED_DISABLE_OUTPUT_LANE_CG2;
2664 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1:
2665 return NVPTXISD::TCGEN05_MMA_SP_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG1;
2666 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2:
2667 return NVPTXISD::TCGEN05_MMA_SP_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG2;
2668 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1:
2669 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG1;
2670 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2:
2671 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG2;
2672 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift:
2673 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG1_ASHIFT;
2674 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift:
2675 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG2_ASHIFT;
2676 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1:
2677 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1;
2678 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2:
2679 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2;
2680 case Intrinsic::
2681 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift:
2682 return NVPTXISD::
2683 TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1_ASHIFT;
2684 case Intrinsic::
2685 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift:
2686 return NVPTXISD::
2687 TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2_ASHIFT;
2688 };
2689 llvm_unreachable("unhandled tcgen05.mma.disable_output_lane intrinsic");
2690}
2691
2692static SDValue LowerTcgen05MMADisableOutputLane(SDValue Op, SelectionDAG &DAG) {
2693 SDNode *N = Op.getNode();
2694 SDLoc DL(N);
2695 unsigned IID = cast<ConstantSDNode>(Val: N->getOperand(Num: 1))->getZExtValue();
2696
2697 SmallVector<SDValue, 16> Ops;
2698 // split the vector argument
2699 for (size_t I = 0; I < N->getNumOperands(); I++) {
2700 if (I == 1)
2701 continue; // skip IID
2702 SDValue Val = N->getOperand(Num: I);
2703 EVT ValVT = Val.getValueType();
2704 if (ValVT.isVector()) {
2705 EVT EltVT = ValVT.getVectorElementType();
2706 for (unsigned J = 0, NElts = ValVT.getVectorNumElements(); J < NElts; J++)
2707 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: Val,
2708 N2: DAG.getIntPtrConstant(Val: J, DL)));
2709 } else
2710 Ops.push_back(Elt: Val);
2711 }
2712
2713 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
2714 SDValue Tcgen05MMANode = DAG.getMemIntrinsicNode(
2715 Opcode: getTcgen05MMADisableOutputLane(IID), dl: DL, VTList: N->getVTList(), Ops,
2716 MemVT: MemSD->getMemoryVT(), MMO: MemSD->getMemOperand());
2717
2718 return Tcgen05MMANode;
2719}
2720
2721// Lower vector return type of tcgen05.ld intrinsics
2722static std::optional<std::pair<SDValue, SDValue>>
2723lowerTcgen05Ld(SDNode *N, SelectionDAG &DAG, bool HasOffset = false) {
2724 SDLoc DL(N);
2725 EVT ResVT = N->getValueType(ResNo: 0);
2726 if (!ResVT.isVector())
2727 return {}; // already legalized.
2728
2729 const unsigned NumElts = ResVT.getVectorNumElements();
2730
2731 // Create the return type of the instructions
2732 SmallVector<EVT, 5> ListVTs;
2733 for (unsigned i = 0; i < NumElts; ++i)
2734 ListVTs.push_back(Elt: MVT::i32);
2735
2736 ListVTs.push_back(Elt: N->getValueType(ResNo: 1)); // Chain
2737
2738 SDVTList ResVTs = DAG.getVTList(VTs: ListVTs);
2739
2740 SmallVector<SDValue, 8> Ops{N->getOperand(Num: 0), N->getOperand(Num: 1),
2741 N->getOperand(Num: 2)};
2742
2743 if (HasOffset) {
2744 Ops.push_back(Elt: N->getOperand(Num: 3)); // offset
2745 Ops.push_back(Elt: N->getOperand(Num: 4)); // Pack flag
2746 } else
2747 Ops.push_back(Elt: N->getOperand(Num: 3)); // Pack flag
2748
2749 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
2750 SDValue NewNode =
2751 DAG.getMemIntrinsicNode(Opcode: ISD::INTRINSIC_W_CHAIN, dl: DL, VTList: ResVTs, Ops,
2752 MemVT: MemSD->getMemoryVT(), MMO: MemSD->getMemOperand());
2753
2754 // split the vector result
2755 SmallVector<SDValue, 4> ScalarRes;
2756 for (unsigned i = 0; i < NumElts; ++i) {
2757 SDValue Res = NewNode.getValue(R: i);
2758 ScalarRes.push_back(Elt: Res);
2759 }
2760
2761 SDValue Chain = NewNode.getValue(R: NumElts);
2762 SDValue BuildVector = DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: ResVT, Ops: ScalarRes);
2763 return {{BuildVector, Chain}};
2764}
2765
2766static SDValue reportInvalidTensormapReplaceUsage(SDValue Op, SelectionDAG &DAG,
2767 unsigned Val) {
2768 SDNode *N = Op.getNode();
2769 SDLoc DL(N);
2770
2771 const Function &Fn = DAG.getMachineFunction().getFunction();
2772
2773 unsigned AS = 0;
2774 if (auto *MemN = dyn_cast<MemIntrinsicSDNode>(Val: N))
2775 AS = MemN->getAddressSpace();
2776 Type *PtrTy = PointerType::get(C&: *DAG.getContext(), AddressSpace: AS);
2777 Module *M = DAG.getMachineFunction().getFunction().getParent();
2778
2779 DAG.getContext()->diagnose(DI: DiagnosticInfoUnsupported(
2780 Fn,
2781 "Intrinsic " +
2782 Intrinsic::getName(Id: N->getConstantOperandVal(Num: 1), OverloadTys: {PtrTy}, M) +
2783 " with value " + Twine(Val) +
2784 " is not supported on the given target.",
2785 DL.getDebugLoc()));
2786 return Op.getOperand(i: 0);
2787}
2788
2789static SDValue lowerTensormapReplaceElemtype(SDValue Op, SelectionDAG &DAG) {
2790 SDNode *N = Op.getNode();
2791 SDLoc DL(N);
2792
2793 // immediate argument representing elemtype
2794 unsigned Val = N->getConstantOperandVal(Num: 3);
2795
2796 if (!DAG.getSubtarget<NVPTXSubtarget>().hasTensormapReplaceElemtypeSupport(
2797 value: Val))
2798 return reportInvalidTensormapReplaceUsage(Op, DAG, Val);
2799
2800 return Op;
2801}
2802
2803static SDValue lowerTensormapReplaceSwizzleMode(SDValue Op, SelectionDAG &DAG) {
2804 SDNode *N = Op.getNode();
2805 SDLoc DL(N);
2806
2807 // immediate argument representing swizzle mode
2808 unsigned Val = N->getConstantOperandVal(Num: 3);
2809
2810 if (!DAG.getSubtarget<NVPTXSubtarget>().hasTensormapReplaceSwizzleModeSupport(
2811 value: Val))
2812 return reportInvalidTensormapReplaceUsage(Op, DAG, Val);
2813
2814 return Op;
2815}
2816
2817static SDValue lowerIntrinsicVoid(SDValue Op, SelectionDAG &DAG) {
2818 SDNode *N = Op.getNode();
2819 SDValue Intrin = N->getOperand(Num: 1);
2820
2821 // Get the intrinsic ID
2822 unsigned IntrinNo = cast<ConstantSDNode>(Val: Intrin.getNode())->getZExtValue();
2823 switch (IntrinNo) {
2824 default:
2825 break;
2826 case Intrinsic::nvvm_tcgen05_st_16x64b_x2:
2827 case Intrinsic::nvvm_tcgen05_st_16x64b_x4:
2828 case Intrinsic::nvvm_tcgen05_st_16x64b_x8:
2829 case Intrinsic::nvvm_tcgen05_st_16x64b_x16:
2830 case Intrinsic::nvvm_tcgen05_st_16x64b_x32:
2831 case Intrinsic::nvvm_tcgen05_st_16x64b_x128:
2832 case Intrinsic::nvvm_tcgen05_st_16x128b_x1:
2833 case Intrinsic::nvvm_tcgen05_st_16x128b_x2:
2834 case Intrinsic::nvvm_tcgen05_st_16x128b_x4:
2835 case Intrinsic::nvvm_tcgen05_st_16x128b_x8:
2836 case Intrinsic::nvvm_tcgen05_st_16x128b_x16:
2837 case Intrinsic::nvvm_tcgen05_st_16x128b_x32:
2838 case Intrinsic::nvvm_tcgen05_st_16x128b_x64:
2839 case Intrinsic::nvvm_tcgen05_st_16x256b_x1:
2840 case Intrinsic::nvvm_tcgen05_st_16x256b_x2:
2841 case Intrinsic::nvvm_tcgen05_st_16x256b_x4:
2842 case Intrinsic::nvvm_tcgen05_st_16x256b_x8:
2843 case Intrinsic::nvvm_tcgen05_st_16x256b_x16:
2844 case Intrinsic::nvvm_tcgen05_st_16x256b_x32:
2845 case Intrinsic::nvvm_tcgen05_st_32x32b_x2:
2846 case Intrinsic::nvvm_tcgen05_st_32x32b_x4:
2847 case Intrinsic::nvvm_tcgen05_st_32x32b_x8:
2848 case Intrinsic::nvvm_tcgen05_st_32x32b_x16:
2849 case Intrinsic::nvvm_tcgen05_st_32x32b_x32:
2850 case Intrinsic::nvvm_tcgen05_st_16x64b_x64:
2851 case Intrinsic::nvvm_tcgen05_st_32x32b_x64:
2852 case Intrinsic::nvvm_tcgen05_st_32x32b_x128:
2853 return lowerTcgen05St(Op, DAG);
2854 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x2:
2855 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x4:
2856 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x8:
2857 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x16:
2858 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x32:
2859 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x64:
2860 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x128:
2861 return lowerTcgen05St(Op, DAG, /* hasOffset */ true);
2862 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1:
2863 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2:
2864 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1:
2865 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2:
2866 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1:
2867 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2:
2868 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1:
2869 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2:
2870 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1:
2871 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2:
2872 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1:
2873 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2:
2874 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1:
2875 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2:
2876 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1:
2877 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2:
2878 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift:
2879 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift:
2880 case Intrinsic::
2881 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift:
2882 case Intrinsic::
2883 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift:
2884 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift:
2885 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift:
2886 case Intrinsic::
2887 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift:
2888 case Intrinsic::
2889 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift:
2890 return LowerTcgen05MMADisableOutputLane(Op, DAG);
2891 case Intrinsic::nvvm_tensormap_replace_elemtype:
2892 return lowerTensormapReplaceElemtype(Op, DAG);
2893 case Intrinsic::nvvm_tensormap_replace_swizzle_mode:
2894 return lowerTensormapReplaceSwizzleMode(Op, DAG);
2895 }
2896 return Op;
2897}
2898
2899static SDValue LowerClusterLaunchControlQueryCancel(SDValue Op,
2900 SelectionDAG &DAG) {
2901
2902 SDNode *N = Op.getNode();
2903 if (N->getOperand(Num: 1).getValueType() != MVT::i128) {
2904 // return, if the operand is already lowered
2905 return SDValue();
2906 }
2907
2908 unsigned IID =
2909 cast<ConstantSDNode>(Val: N->getOperand(Num: 0).getNode())->getZExtValue();
2910 auto Opcode = [&]() {
2911 switch (IID) {
2912 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled:
2913 return NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_IS_CANCELED;
2914 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x:
2915 return NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_X;
2916 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y:
2917 return NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Y;
2918 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z:
2919 return NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Z;
2920 default:
2921 llvm_unreachable("unsupported/unhandled intrinsic");
2922 }
2923 }();
2924
2925 SDLoc DL(N);
2926 SDValue TryCancelResponse = N->getOperand(Num: 1);
2927 SDValue Cast = DAG.getNode(Opcode: ISD::BITCAST, DL, VT: MVT::v2i64, Operand: TryCancelResponse);
2928 SDValue TryCancelResponse0 =
2929 DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i64, N1: Cast,
2930 N2: DAG.getIntPtrConstant(Val: 0, DL));
2931 SDValue TryCancelResponse1 =
2932 DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i64, N1: Cast,
2933 N2: DAG.getIntPtrConstant(Val: 1, DL));
2934
2935 return DAG.getNode(Opcode, DL, VTList: N->getVTList(),
2936 Ops: {TryCancelResponse0, TryCancelResponse1});
2937}
2938
2939static SDValue lowerCvtRSIntrinsics(SDValue Op, SelectionDAG &DAG) {
2940 SDNode *N = Op.getNode();
2941 SDLoc DL(N);
2942 SDValue F32Vec = N->getOperand(Num: 1);
2943 SDValue RBits = N->getOperand(Num: 2);
2944
2945 unsigned IntrinsicID = N->getConstantOperandVal(Num: 0);
2946
2947 // Extract the 4 float elements from the vector
2948 SmallVector<SDValue, 6> Ops;
2949 for (unsigned i = 0; i < 4; ++i)
2950 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::f32, N1: F32Vec,
2951 N2: DAG.getIntPtrConstant(Val: i, DL)));
2952
2953 using NVPTX::PTXCvtMode::CvtMode;
2954
2955 auto [OpCode, RetTy, CvtModeFlag] =
2956 [&]() -> std::tuple<unsigned, MVT::SimpleValueType, uint32_t> {
2957 switch (IntrinsicID) {
2958 case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite:
2959 return {NVPTXISD::CVT_E4M3X4_F32X4_RS_SF, MVT::v4i8,
2960 CvtMode::RS | CvtMode::RELU_FLAG};
2961 case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite:
2962 return {NVPTXISD::CVT_E4M3X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
2963 case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite:
2964 return {NVPTXISD::CVT_E5M2X4_F32X4_RS_SF, MVT::v4i8,
2965 CvtMode::RS | CvtMode::RELU_FLAG};
2966 case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite:
2967 return {NVPTXISD::CVT_E5M2X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
2968 case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite:
2969 return {NVPTXISD::CVT_E2M3X4_F32X4_RS_SF, MVT::v4i8,
2970 CvtMode::RS | CvtMode::RELU_FLAG};
2971 case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite:
2972 return {NVPTXISD::CVT_E2M3X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
2973 case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite:
2974 return {NVPTXISD::CVT_E3M2X4_F32X4_RS_SF, MVT::v4i8,
2975 CvtMode::RS | CvtMode::RELU_FLAG};
2976 case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite:
2977 return {NVPTXISD::CVT_E3M2X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
2978 case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite:
2979 return {NVPTXISD::CVT_E2M1X4_F32X4_RS_SF, MVT::i16,
2980 CvtMode::RS | CvtMode::RELU_FLAG};
2981 case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite:
2982 return {NVPTXISD::CVT_E2M1X4_F32X4_RS_SF, MVT::i16, CvtMode::RS};
2983 default:
2984 llvm_unreachable("unsupported/unhandled intrinsic");
2985 }
2986 }();
2987
2988 Ops.push_back(Elt: RBits);
2989 Ops.push_back(Elt: DAG.getConstant(Val: CvtModeFlag, DL, VT: MVT::i32));
2990
2991 return DAG.getNode(Opcode: OpCode, DL, VT: RetTy, Ops);
2992}
2993
2994static SDValue lowerPrmtIntrinsic(SDValue Op, SelectionDAG &DAG) {
2995 const unsigned Mode = [&]() {
2996 switch (Op->getConstantOperandVal(Num: 0)) {
2997 case Intrinsic::nvvm_prmt:
2998 return NVPTX::PTXPrmtMode::NONE;
2999 case Intrinsic::nvvm_prmt_b4e:
3000 return NVPTX::PTXPrmtMode::B4E;
3001 case Intrinsic::nvvm_prmt_ecl:
3002 return NVPTX::PTXPrmtMode::ECL;
3003 case Intrinsic::nvvm_prmt_ecr:
3004 return NVPTX::PTXPrmtMode::ECR;
3005 case Intrinsic::nvvm_prmt_f4e:
3006 return NVPTX::PTXPrmtMode::F4E;
3007 case Intrinsic::nvvm_prmt_rc16:
3008 return NVPTX::PTXPrmtMode::RC16;
3009 case Intrinsic::nvvm_prmt_rc8:
3010 return NVPTX::PTXPrmtMode::RC8;
3011 default:
3012 llvm_unreachable("unsupported/unhandled intrinsic");
3013 }
3014 }();
3015 SDLoc DL(Op);
3016 SDValue A = Op->getOperand(Num: 1);
3017 SDValue B = Op.getNumOperands() == 4 ? Op.getOperand(i: 2)
3018 : DAG.getConstant(Val: 0, DL, VT: MVT::i32);
3019 SDValue Selector = (Op->op_end() - 1)->get();
3020 return getPRMT(A, B, Selector, DL, DAG, Mode);
3021}
3022
3023#define TCGEN05_LD_RED_INTR(SHAPE, NUM, TYPE) \
3024 Intrinsic::nvvm_tcgen05_ld_red_##SHAPE##_x##NUM##_##TYPE
3025
3026#define TCGEN05_LD_RED_INST(SHAPE, NUM, TYPE) \
3027 NVPTXISD::TCGEN05_LD_RED_##SHAPE##_X##NUM##_##TYPE
3028
3029static unsigned getTcgen05LdRedID(Intrinsic::ID IID) {
3030 switch (IID) {
3031 case TCGEN05_LD_RED_INTR(32x32b, 2, f32):
3032 return TCGEN05_LD_RED_INST(32x32b, 2, F32);
3033 case TCGEN05_LD_RED_INTR(32x32b, 4, f32):
3034 return TCGEN05_LD_RED_INST(32x32b, 4, F32);
3035 case TCGEN05_LD_RED_INTR(32x32b, 8, f32):
3036 return TCGEN05_LD_RED_INST(32x32b, 8, F32);
3037 case TCGEN05_LD_RED_INTR(32x32b, 16, f32):
3038 return TCGEN05_LD_RED_INST(32x32b, 16, F32);
3039 case TCGEN05_LD_RED_INTR(32x32b, 32, f32):
3040 return TCGEN05_LD_RED_INST(32x32b, 32, F32);
3041 case TCGEN05_LD_RED_INTR(32x32b, 64, f32):
3042 return TCGEN05_LD_RED_INST(32x32b, 64, F32);
3043 case TCGEN05_LD_RED_INTR(32x32b, 128, f32):
3044 return TCGEN05_LD_RED_INST(32x32b, 128, F32);
3045 case TCGEN05_LD_RED_INTR(16x32bx2, 2, f32):
3046 return TCGEN05_LD_RED_INST(16x32bx2, 2, F32);
3047 case TCGEN05_LD_RED_INTR(16x32bx2, 4, f32):
3048 return TCGEN05_LD_RED_INST(16x32bx2, 4, F32);
3049 case TCGEN05_LD_RED_INTR(16x32bx2, 8, f32):
3050 return TCGEN05_LD_RED_INST(16x32bx2, 8, F32);
3051 case TCGEN05_LD_RED_INTR(16x32bx2, 16, f32):
3052 return TCGEN05_LD_RED_INST(16x32bx2, 16, F32);
3053 case TCGEN05_LD_RED_INTR(16x32bx2, 32, f32):
3054 return TCGEN05_LD_RED_INST(16x32bx2, 32, F32);
3055 case TCGEN05_LD_RED_INTR(16x32bx2, 64, f32):
3056 return TCGEN05_LD_RED_INST(16x32bx2, 64, F32);
3057 case TCGEN05_LD_RED_INTR(16x32bx2, 128, f32):
3058 return TCGEN05_LD_RED_INST(16x32bx2, 128, F32);
3059 case TCGEN05_LD_RED_INTR(32x32b, 2, i32):
3060 return TCGEN05_LD_RED_INST(32x32b, 2, I32);
3061 case TCGEN05_LD_RED_INTR(32x32b, 4, i32):
3062 return TCGEN05_LD_RED_INST(32x32b, 4, I32);
3063 case TCGEN05_LD_RED_INTR(32x32b, 8, i32):
3064 return TCGEN05_LD_RED_INST(32x32b, 8, I32);
3065 case TCGEN05_LD_RED_INTR(32x32b, 16, i32):
3066 return TCGEN05_LD_RED_INST(32x32b, 16, I32);
3067 case TCGEN05_LD_RED_INTR(32x32b, 32, i32):
3068 return TCGEN05_LD_RED_INST(32x32b, 32, I32);
3069 case TCGEN05_LD_RED_INTR(32x32b, 64, i32):
3070 return TCGEN05_LD_RED_INST(32x32b, 64, I32);
3071 case TCGEN05_LD_RED_INTR(32x32b, 128, i32):
3072 return TCGEN05_LD_RED_INST(32x32b, 128, I32);
3073 case TCGEN05_LD_RED_INTR(16x32bx2, 2, i32):
3074 return TCGEN05_LD_RED_INST(16x32bx2, 2, I32);
3075 case TCGEN05_LD_RED_INTR(16x32bx2, 4, i32):
3076 return TCGEN05_LD_RED_INST(16x32bx2, 4, I32);
3077 case TCGEN05_LD_RED_INTR(16x32bx2, 8, i32):
3078 return TCGEN05_LD_RED_INST(16x32bx2, 8, I32);
3079 case TCGEN05_LD_RED_INTR(16x32bx2, 16, i32):
3080 return TCGEN05_LD_RED_INST(16x32bx2, 16, I32);
3081 case TCGEN05_LD_RED_INTR(16x32bx2, 32, i32):
3082 return TCGEN05_LD_RED_INST(16x32bx2, 32, I32);
3083 case TCGEN05_LD_RED_INTR(16x32bx2, 64, i32):
3084 return TCGEN05_LD_RED_INST(16x32bx2, 64, I32);
3085 case TCGEN05_LD_RED_INTR(16x32bx2, 128, i32):
3086 return TCGEN05_LD_RED_INST(16x32bx2, 128, I32);
3087 default:
3088 llvm_unreachable("Invalid tcgen05.ld.red intrinsic ID");
3089 }
3090}
3091
3092// Lower vector return type of tcgen05.ld intrinsics
3093static std::optional<std::tuple<SDValue, SDValue, SDValue>>
3094lowerTcgen05LdRed(SDNode *N, SelectionDAG &DAG) {
3095 SDLoc DL(N);
3096 EVT ResVT = N->getValueType(ResNo: 0);
3097 if (!ResVT.isVector())
3098 return {}; // already legalized.
3099
3100 const unsigned NumElts = ResVT.getVectorNumElements();
3101
3102 // Create the return type of the instructions
3103 // +1 represents the reduction value
3104 SmallVector<EVT, 132> ListVTs{
3105 NumElts + 1,
3106 ResVT.getVectorElementType().isFloatingPoint() ? MVT::f32 : MVT::i32};
3107
3108 ListVTs.push_back(Elt: MVT::Other); // Chain
3109
3110 SDVTList ResVTs = DAG.getVTList(VTs: ListVTs);
3111
3112 // Prepare the Operands
3113 SmallVector<SDValue, 8> Ops{N->getOperand(Num: 0)}; // Chain
3114
3115 // skip IID at index 1
3116 for (unsigned i = 2; i < N->getNumOperands(); i++)
3117 Ops.push_back(Elt: N->getOperand(Num: i));
3118
3119 unsigned IID = cast<ConstantSDNode>(Val: N->getOperand(Num: 1))->getZExtValue();
3120 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
3121 SDValue NewNode =
3122 DAG.getMemIntrinsicNode(Opcode: getTcgen05LdRedID(IID), dl: DL, VTList: ResVTs, Ops,
3123 MemVT: MemSD->getMemoryVT(), MMO: MemSD->getMemOperand());
3124
3125 // Split vector result
3126 SmallVector<SDValue, 132> ScalarRes;
3127 for (unsigned i = 0; i < NumElts; ++i) {
3128 SDValue Res = NewNode.getValue(R: i);
3129 ScalarRes.push_back(Elt: Res);
3130 }
3131
3132 SDValue BuildVector = DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: ResVT, Ops: ScalarRes);
3133 SDValue RedResult = NewNode.getValue(R: NumElts);
3134 SDValue Chain = NewNode.getValue(R: NumElts + 1);
3135 return {{BuildVector, RedResult, Chain}};
3136}
3137
3138static SDValue lowerIntrinsicWChain(SDValue Op, SelectionDAG &DAG) {
3139 switch (Op->getConstantOperandVal(Num: 1)) {
3140 default:
3141 return Op;
3142
3143 // These tcgen05 intrinsics return a v2i32, which is legal, so we have to
3144 // lower them through LowerOperation() instead of ReplaceNodeResults().
3145 case Intrinsic::nvvm_tcgen05_ld_16x64b_x2:
3146 case Intrinsic::nvvm_tcgen05_ld_16x128b_x1:
3147 case Intrinsic::nvvm_tcgen05_ld_32x32b_x2:
3148 if (auto Res = lowerTcgen05Ld(N: Op.getNode(), DAG))
3149 return DAG.getMergeValues(Ops: {Res->first, Res->second}, dl: SDLoc(Op));
3150 return SDValue();
3151
3152 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x2:
3153 if (auto Res = lowerTcgen05Ld(N: Op.getNode(), DAG, /*HasOffset=*/true))
3154 return DAG.getMergeValues(Ops: {Res->first, Res->second}, dl: SDLoc(Op));
3155 return SDValue();
3156
3157 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x2_f32:
3158 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x2_i32:
3159 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x2_f32:
3160 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x2_i32:
3161 if (auto Res = lowerTcgen05LdRed(N: Op.getNode(), DAG))
3162 return DAG.getMergeValues(
3163 Ops: {std::get<0>(t&: *Res), std::get<1>(t&: *Res), std::get<2>(t&: *Res)}, dl: SDLoc(Op));
3164 return SDValue();
3165 }
3166}
3167
3168static SDValue lowerIntrinsicWOChain(SDValue Op, SelectionDAG &DAG) {
3169 switch (Op->getConstantOperandVal(Num: 0)) {
3170 default:
3171 return Op;
3172 case Intrinsic::nvvm_prmt:
3173 case Intrinsic::nvvm_prmt_b4e:
3174 case Intrinsic::nvvm_prmt_ecl:
3175 case Intrinsic::nvvm_prmt_ecr:
3176 case Intrinsic::nvvm_prmt_f4e:
3177 case Intrinsic::nvvm_prmt_rc16:
3178 case Intrinsic::nvvm_prmt_rc8:
3179 return lowerPrmtIntrinsic(Op, DAG);
3180 case Intrinsic::nvvm_internal_addrspace_wrap:
3181 return Op.getOperand(i: 1);
3182 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled:
3183 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x:
3184 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y:
3185 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z:
3186 return LowerClusterLaunchControlQueryCancel(Op, DAG);
3187 case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite:
3188 case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite:
3189 case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite:
3190 case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite:
3191 case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite:
3192 case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite:
3193 case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite:
3194 case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite:
3195 case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite:
3196 case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite:
3197 return lowerCvtRSIntrinsics(Op, DAG);
3198 }
3199}
3200
3201// In PTX 64-bit CTLZ and CTPOP are supported, but they return a 32-bit value.
3202// Lower these into a node returning the correct type which is zero-extended
3203// back to the correct size.
3204static SDValue lowerCTLZCTPOP(SDValue Op, SelectionDAG &DAG) {
3205 SDValue V = Op->getOperand(Num: 0);
3206 assert(V.getValueType() == MVT::i64 &&
3207 "Unexpected CTLZ/CTPOP type to legalize");
3208
3209 SDLoc DL(Op);
3210 SDValue CT = DAG.getNode(Opcode: Op->getOpcode(), DL, VT: MVT::i32, Operand: V);
3211 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: MVT::i64, Operand: CT, Flags: SDNodeFlags::NonNeg);
3212}
3213
3214static SDValue expandFSH64(SDValue A, SDValue B, SDValue ShiftAmount, SDLoc DL,
3215 unsigned Opcode, SelectionDAG &DAG) {
3216 assert(A.getValueType() == MVT::i64 && B.getValueType() == MVT::i64);
3217
3218 const auto *AmtConst = dyn_cast<ConstantSDNode>(Val&: ShiftAmount);
3219 if (!AmtConst)
3220 return SDValue();
3221 const auto Amt = AmtConst->getZExtValue() & 63;
3222
3223 SDValue UnpackA =
3224 DAG.getNode(Opcode: NVPTXISD::UNPACK_VECTOR, DL, ResultTys: {MVT::i32, MVT::i32}, Ops: A);
3225 SDValue UnpackB =
3226 DAG.getNode(Opcode: NVPTXISD::UNPACK_VECTOR, DL, ResultTys: {MVT::i32, MVT::i32}, Ops: B);
3227
3228 // Arch is Little endiain: 0 = low bits, 1 = high bits
3229 SDValue ALo = UnpackA.getValue(R: 0);
3230 SDValue AHi = UnpackA.getValue(R: 1);
3231 SDValue BLo = UnpackB.getValue(R: 0);
3232 SDValue BHi = UnpackB.getValue(R: 1);
3233
3234 // The bitfeild consists of { AHi : ALo : BHi : BLo }
3235 //
3236 // * FSHL, Amt < 32 - The window will contain { AHi : ALo : BHi }
3237 // * FSHL, Amt >= 32 - The window will contain { ALo : BHi : BLo }
3238 // * FSHR, Amt < 32 - The window will contain { ALo : BHi : BLo }
3239 // * FSHR, Amt >= 32 - The window will contain { AHi : ALo : BHi }
3240 //
3241 // Note that Amt = 0 and Amt = 32 are special cases where 32-bit funnel shifts
3242 // are not needed at all. Amt = 0 is a no-op producing either A or B depending
3243 // on the direction. Amt = 32 can be implemented by a packing and unpacking
3244 // move to select and arrange the 32bit values. For simplicity, these cases
3245 // are not handled here explicitly and instead we rely on DAGCombiner to
3246 // remove the no-op funnel shifts we insert.
3247 auto [High, Mid, Low] = ((Opcode == ISD::FSHL) == (Amt < 32))
3248 ? std::make_tuple(args&: AHi, args&: ALo, args&: BHi)
3249 : std::make_tuple(args&: ALo, args&: BHi, args&: BLo);
3250
3251 SDValue NewAmt = DAG.getConstant(Val: Amt & 31, DL, VT: MVT::i32);
3252 SDValue RHi = DAG.getNode(Opcode, DL, VT: MVT::i32, Ops: {High, Mid, NewAmt});
3253 SDValue RLo = DAG.getNode(Opcode, DL, VT: MVT::i32, Ops: {Mid, Low, NewAmt});
3254
3255 return DAG.getNode(Opcode: NVPTXISD::BUILD_VECTOR, DL, VT: MVT::i64, Ops: {RLo, RHi});
3256}
3257
3258static SDValue lowerFSH(SDValue Op, SelectionDAG &DAG) {
3259 return expandFSH64(A: Op->getOperand(Num: 0), B: Op->getOperand(Num: 1), ShiftAmount: Op->getOperand(Num: 2),
3260 DL: SDLoc(Op), Opcode: Op->getOpcode(), DAG);
3261}
3262
3263static SDValue lowerROT(SDValue Op, SelectionDAG &DAG) {
3264 unsigned Opcode = Op->getOpcode() == ISD::ROTL ? ISD::FSHL : ISD::FSHR;
3265 return expandFSH64(A: Op->getOperand(Num: 0), B: Op->getOperand(Num: 0), ShiftAmount: Op->getOperand(Num: 1),
3266 DL: SDLoc(Op), Opcode, DAG);
3267}
3268
3269static SDValue lowerFREM(SDValue Op, SelectionDAG &DAG) {
3270 // Lower (frem x, y) into (sub x, (mul (ftrunc (div x, y)) y)),
3271 // i.e. "poor man's fmod()". When y is infinite, x is returned. This matches
3272 // the semantics of LLVM's frem.
3273 SDLoc DL(Op);
3274 SDValue X = Op->getOperand(Num: 0);
3275 SDValue Y = Op->getOperand(Num: 1);
3276 EVT Ty = Op.getValueType();
3277 SDNodeFlags Flags = Op->getFlags();
3278
3279 SDValue Div = DAG.getNode(Opcode: ISD::FDIV, DL, VT: Ty, N1: X, N2: Y, Flags);
3280 SDValue Trunc = DAG.getNode(Opcode: ISD::FTRUNC, DL, VT: Ty, Operand: Div, Flags);
3281 SDValue Mul = DAG.getNode(Opcode: ISD::FMUL, DL, VT: Ty, N1: Trunc, N2: Y,
3282 Flags: Flags | SDNodeFlags::AllowContract);
3283 SDValue Sub = DAG.getNode(Opcode: ISD::FSUB, DL, VT: Ty, N1: X, N2: Mul,
3284 Flags: Flags | SDNodeFlags::AllowContract);
3285
3286 if (Flags.hasNoInfs())
3287 return Sub;
3288
3289 // If Y is infinite, return X
3290 SDValue AbsY = DAG.getNode(Opcode: ISD::FABS, DL, VT: Ty, Operand: Y);
3291 SDValue Inf =
3292 DAG.getConstantFP(Val: APFloat::getInf(Sem: Ty.getFltSemantics()), DL, VT: Ty);
3293 SDValue IsInf = DAG.getSetCC(DL, VT: MVT::i1, LHS: AbsY, RHS: Inf, Cond: ISD::SETEQ);
3294 return DAG.getSelect(DL, VT: Ty, Cond: IsInf, LHS: X, RHS: Sub);
3295}
3296
3297static SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG) {
3298 assert(Op.getValueType() == MVT::i1 && "Custom lowering enabled only for i1");
3299
3300 SDValue Cond = Op->getOperand(Num: 0);
3301 SDValue TrueVal = Op->getOperand(Num: 1);
3302 SDValue FalseVal = Op->getOperand(Num: 2);
3303 SDLoc DL(Op);
3304
3305 // If both operands are truncated, we push the select through the truncates.
3306 if (TrueVal.getOpcode() == ISD::TRUNCATE &&
3307 FalseVal.getOpcode() == ISD::TRUNCATE) {
3308 TrueVal = TrueVal.getOperand(i: 0);
3309 FalseVal = FalseVal.getOperand(i: 0);
3310
3311 EVT VT = TrueVal.getSimpleValueType().bitsLE(VT: FalseVal.getSimpleValueType())
3312 ? TrueVal.getValueType()
3313 : FalseVal.getValueType();
3314 TrueVal = DAG.getAnyExtOrTrunc(Op: TrueVal, DL, VT);
3315 FalseVal = DAG.getAnyExtOrTrunc(Op: FalseVal, DL, VT);
3316 SDValue Select = DAG.getSelect(DL, VT, Cond, LHS: TrueVal, RHS: FalseVal);
3317 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MVT::i1, Operand: Select);
3318 }
3319
3320 // Otherwise, expand the select into a series of logical operations. These
3321 // often can be folded into other operations either by us or ptxas.
3322 TrueVal = DAG.getFreeze(V: TrueVal);
3323 FalseVal = DAG.getFreeze(V: FalseVal);
3324 SDValue And1 = DAG.getNode(Opcode: ISD::AND, DL, VT: MVT::i1, N1: Cond, N2: TrueVal);
3325 SDValue NotCond = DAG.getNOT(DL, Val: Cond, VT: MVT::i1);
3326 SDValue And2 = DAG.getNode(Opcode: ISD::AND, DL, VT: MVT::i1, N1: NotCond, N2: FalseVal);
3327 SDValue Or = DAG.getNode(Opcode: ISD::OR, DL, VT: MVT::i1, N1: And1, N2: And2);
3328 return Or;
3329}
3330
3331static SDValue lowerMSTORE(SDValue Op, SelectionDAG &DAG) {
3332 SDNode *N = Op.getNode();
3333
3334 SDValue Chain = N->getOperand(Num: 0);
3335 SDValue Val = N->getOperand(Num: 1);
3336 SDValue BasePtr = N->getOperand(Num: 2);
3337 SDValue Offset = N->getOperand(Num: 3);
3338 SDValue Mask = N->getOperand(Num: 4);
3339
3340 SDLoc DL(N);
3341 EVT ValVT = Val.getValueType();
3342 MemSDNode *MemSD = cast<MemSDNode>(Val: N);
3343 assert(ValVT.isVector() && "Masked vector store must have vector type");
3344 assert(MemSD->getAlign() >= DAG.getEVTAlign(ValVT) &&
3345 "Unexpected alignment for masked store");
3346
3347 unsigned Opcode = 0;
3348 switch (ValVT.getSimpleVT().SimpleTy) {
3349 default:
3350 llvm_unreachable("Unexpected masked vector store type");
3351 case MVT::v4i64:
3352 case MVT::v4f64: {
3353 Opcode = NVPTXISD::StoreV4;
3354 break;
3355 }
3356 case MVT::v8i32:
3357 case MVT::v8f32: {
3358 Opcode = NVPTXISD::StoreV8;
3359 break;
3360 }
3361 }
3362
3363 SmallVector<SDValue, 8> Ops;
3364
3365 // Construct the new SDNode. First operand is the chain.
3366 Ops.push_back(Elt: Chain);
3367
3368 // The next N operands are the values to store. Encode the mask into the
3369 // values using the sentinel register 0 to represent a masked-off element.
3370 assert(Mask.getValueType().isVector() &&
3371 Mask.getValueType().getVectorElementType() == MVT::i1 &&
3372 "Mask must be a vector of i1");
3373 assert(Mask.getOpcode() == ISD::BUILD_VECTOR &&
3374 "Mask expected to be a BUILD_VECTOR");
3375 assert(Mask.getValueType().getVectorNumElements() ==
3376 ValVT.getVectorNumElements() &&
3377 "Mask size must be the same as the vector size");
3378 for (auto [I, Op] : enumerate(First: Mask->ops())) {
3379 // Mask elements must be constants.
3380 if (Op.getNode()->getAsZExtVal() == 0) {
3381 // Append a sentinel register 0 to the Ops vector to represent a masked
3382 // off element, this will be handled in tablegen
3383 Ops.push_back(Elt: DAG.getRegister(Reg: MCRegister::NoRegister,
3384 VT: ValVT.getVectorElementType()));
3385 } else {
3386 // Extract the element from the vector to store
3387 SDValue ExtVal =
3388 DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: ValVT.getVectorElementType(),
3389 N1: Val, N2: DAG.getIntPtrConstant(Val: I, DL));
3390 Ops.push_back(Elt: ExtVal);
3391 }
3392 }
3393
3394 // Next, the pointer operand.
3395 Ops.push_back(Elt: BasePtr);
3396
3397 // Finally, the offset operand. We expect this to always be undef, and it will
3398 // be ignored in lowering, but to mirror the handling of the other vector
3399 // store instructions we include it in the new SDNode.
3400 assert(Offset.getOpcode() == ISD::UNDEF &&
3401 "Offset operand expected to be undef");
3402 Ops.push_back(Elt: Offset);
3403
3404 SDValue NewSt =
3405 DAG.getMemIntrinsicNode(Opcode, dl: DL, VTList: DAG.getVTList(VT: MVT::Other), Ops,
3406 MemVT: MemSD->getMemoryVT(), MMO: MemSD->getMemOperand());
3407
3408 return NewSt;
3409}
3410
3411SDValue
3412NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
3413 switch (Op.getOpcode()) {
3414 case ISD::RETURNADDR:
3415 return SDValue();
3416 case ISD::FRAMEADDR:
3417 return SDValue();
3418 case ISD::ADDRSPACECAST:
3419 return LowerADDRSPACECAST(Op, DAG);
3420 case ISD::INTRINSIC_W_CHAIN:
3421 return lowerIntrinsicWChain(Op, DAG);
3422 case ISD::INTRINSIC_WO_CHAIN:
3423 return lowerIntrinsicWOChain(Op, DAG);
3424 case ISD::INTRINSIC_VOID:
3425 return lowerIntrinsicVoid(Op, DAG);
3426 case ISD::BUILD_VECTOR:
3427 return LowerBUILD_VECTOR(Op, DAG);
3428 case ISD::BITCAST:
3429 return LowerBITCAST(Op, DAG);
3430 case ISD::EXTRACT_SUBVECTOR:
3431 return Op;
3432 case ISD::EXTRACT_VECTOR_ELT:
3433 return LowerEXTRACT_VECTOR_ELT(Op, DAG);
3434 case ISD::INSERT_VECTOR_ELT:
3435 return LowerINSERT_VECTOR_ELT(Op, DAG);
3436 case ISD::VECTOR_SHUFFLE:
3437 return LowerVECTOR_SHUFFLE(Op, DAG);
3438 case ISD::CONCAT_VECTORS:
3439 return LowerCONCAT_VECTORS(Op, DAG);
3440 case ISD::VECREDUCE_FMAX:
3441 case ISD::VECREDUCE_FMIN:
3442 case ISD::VECREDUCE_FMAXIMUM:
3443 case ISD::VECREDUCE_FMINIMUM:
3444 return LowerVECREDUCE(Op, DAG);
3445 case ISD::STORE:
3446 return LowerSTORE(Op, DAG);
3447 case ISD::MSTORE: {
3448 assert(STI.has256BitVectorLoadStore(
3449 cast<MemSDNode>(Op.getNode())->getAddressSpace()) &&
3450 "Masked store vector not supported on subtarget.");
3451 return lowerMSTORE(Op, DAG);
3452 }
3453 case ISD::LOAD:
3454 return LowerLOAD(Op, DAG);
3455 case ISD::MLOAD:
3456 return LowerMLOAD(Op, DAG);
3457 case ISD::SHL_PARTS:
3458 return LowerShiftLeftParts(Op, DAG);
3459 case ISD::SRA_PARTS:
3460 case ISD::SRL_PARTS:
3461 return LowerShiftRightParts(Op, DAG);
3462 case ISD::SELECT:
3463 return lowerSELECT(Op, DAG);
3464 case ISD::FROUND:
3465 return LowerFROUND(Op, DAG);
3466 case ISD::FCOPYSIGN:
3467 return LowerFCOPYSIGN(Op, DAG);
3468 case ISD::SINT_TO_FP:
3469 case ISD::UINT_TO_FP:
3470 return LowerINT_TO_FP(Op, DAG);
3471 case ISD::FP_TO_SINT:
3472 case ISD::FP_TO_UINT:
3473 return LowerFP_TO_INT(Op, DAG);
3474 case ISD::FP_ROUND:
3475 return LowerFP_ROUND(Op, DAG);
3476 case ISD::FP_EXTEND:
3477 return LowerFP_EXTEND(Op, DAG);
3478 case ISD::VAARG:
3479 return LowerVAARG(Op, DAG);
3480 case ISD::VASTART:
3481 return LowerVASTART(Op, DAG);
3482 case ISD::FSHL:
3483 case ISD::FSHR:
3484 return lowerFSH(Op, DAG);
3485 case ISD::ROTL:
3486 case ISD::ROTR:
3487 return lowerROT(Op, DAG);
3488 case ISD::ABS:
3489 case ISD::SMIN:
3490 case ISD::SMAX:
3491 case ISD::UMIN:
3492 case ISD::UMAX:
3493 case ISD::ADD:
3494 case ISD::SUB:
3495 case ISD::MUL:
3496 case ISD::SHL:
3497 case ISD::SREM:
3498 case ISD::UREM:
3499 return LowerVectorArith(Op, DAG);
3500 case ISD::DYNAMIC_STACKALLOC:
3501 return LowerDYNAMIC_STACKALLOC(Op, DAG);
3502 case ISD::STACKRESTORE:
3503 return LowerSTACKRESTORE(Op, DAG);
3504 case ISD::STACKSAVE:
3505 return LowerSTACKSAVE(Op, DAG);
3506 case ISD::CopyToReg:
3507 return LowerCopyToReg_128(Op, DAG);
3508 case ISD::FADD:
3509 case ISD::FSUB:
3510 case ISD::FMUL:
3511 // Used only for bf16 on SM80, where we select fma for non-ftz operation
3512 return PromoteBinOpIfF32FTZ(Op, DAG);
3513 case ISD::CTPOP:
3514 case ISD::CTLZ:
3515 return lowerCTLZCTPOP(Op, DAG);
3516 case ISD::FREM:
3517 return lowerFREM(Op, DAG);
3518 case ISD::BSWAP:
3519 return lowerBSWAP(Op, DAG);
3520 default:
3521 llvm_unreachable("Custom lowering not defined for operation");
3522 }
3523}
3524
3525// This will prevent AsmPrinter from trying to print the jump tables itself.
3526unsigned NVPTXTargetLowering::getJumpTableEncoding() const {
3527 return MachineJumpTableInfo::EK_Inline;
3528}
3529
3530SDValue NVPTXTargetLowering::LowerADDRSPACECAST(SDValue Op,
3531 SelectionDAG &DAG) const {
3532 AddrSpaceCastSDNode *N = cast<AddrSpaceCastSDNode>(Val: Op.getNode());
3533 unsigned SrcAS = N->getSrcAddressSpace();
3534 unsigned DestAS = N->getDestAddressSpace();
3535 if (SrcAS != llvm::ADDRESS_SPACE_GENERIC &&
3536 DestAS != llvm::ADDRESS_SPACE_GENERIC) {
3537 // Shared and SharedCluster can be converted to each other through generic
3538 // space
3539 if ((SrcAS == llvm::ADDRESS_SPACE_SHARED &&
3540 DestAS == llvm::ADDRESS_SPACE_SHARED_CLUSTER) ||
3541 (SrcAS == llvm::ADDRESS_SPACE_SHARED_CLUSTER &&
3542 DestAS == llvm::ADDRESS_SPACE_SHARED)) {
3543 SDLoc DL(Op.getNode());
3544 const MVT GenerictVT =
3545 getPointerTy(DL: DAG.getDataLayout(), AS: ADDRESS_SPACE_GENERIC);
3546 SDValue GenericConversion = DAG.getAddrSpaceCast(
3547 dl: DL, VT: GenerictVT, Ptr: Op.getOperand(i: 0), SrcAS, DestAS: ADDRESS_SPACE_GENERIC);
3548 SDValue SharedClusterConversion =
3549 DAG.getAddrSpaceCast(dl: DL, VT: Op.getValueType(), Ptr: GenericConversion,
3550 SrcAS: ADDRESS_SPACE_GENERIC, DestAS);
3551 return SharedClusterConversion;
3552 }
3553
3554 return DAG.getUNDEF(VT: Op.getValueType());
3555 }
3556
3557 return Op;
3558}
3559
3560// This function is almost a copy of SelectionDAG::expandVAArg().
3561// The only diff is that this one produces loads from local address space.
3562SDValue NVPTXTargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {
3563 const TargetLowering *TLI = STI.getTargetLowering();
3564 SDLoc DL(Op);
3565
3566 SDNode *Node = Op.getNode();
3567 const Value *V = cast<SrcValueSDNode>(Val: Node->getOperand(Num: 2))->getValue();
3568 EVT VT = Node->getValueType(ResNo: 0);
3569 auto *Ty = VT.getTypeForEVT(Context&: *DAG.getContext());
3570 SDValue Tmp1 = Node->getOperand(Num: 0);
3571 SDValue Tmp2 = Node->getOperand(Num: 1);
3572 const MaybeAlign MA(Node->getConstantOperandVal(Num: 3));
3573
3574 SDValue VAListLoad = DAG.getLoad(VT: TLI->getPointerTy(DL: DAG.getDataLayout()), dl: DL,
3575 Chain: Tmp1, Ptr: Tmp2, PtrInfo: MachinePointerInfo(V));
3576 SDValue VAList = VAListLoad;
3577
3578 if (MA && *MA > TLI->getMinStackArgumentAlignment()) {
3579 VAList = DAG.getNode(
3580 Opcode: ISD::ADD, DL, VT: VAList.getValueType(), N1: VAList,
3581 N2: DAG.getConstant(Val: MA->value() - 1, DL, VT: VAList.getValueType()));
3582
3583 VAList = DAG.getNode(Opcode: ISD::AND, DL, VT: VAList.getValueType(), N1: VAList,
3584 N2: DAG.getSignedConstant(Val: -(int64_t)MA->value(), DL,
3585 VT: VAList.getValueType()));
3586 }
3587
3588 // Increment the pointer, VAList, to the next vaarg
3589 Tmp1 = DAG.getNode(Opcode: ISD::ADD, DL, VT: VAList.getValueType(), N1: VAList,
3590 N2: DAG.getConstant(Val: DAG.getDataLayout().getTypeAllocSize(Ty),
3591 DL, VT: VAList.getValueType()));
3592
3593 // Store the incremented VAList to the legalized pointer
3594 Tmp1 = DAG.getStore(Chain: VAListLoad.getValue(R: 1), dl: DL, Val: Tmp1, Ptr: Tmp2,
3595 PtrInfo: MachinePointerInfo(V));
3596
3597 const Value *SrcV = Constant::getNullValue(
3598 Ty: PointerType::get(C&: *DAG.getContext(), AddressSpace: ADDRESS_SPACE_LOCAL));
3599
3600 // Load the actual argument out of the pointer VAList
3601 return DAG.getLoad(VT, dl: DL, Chain: Tmp1, Ptr: VAList, PtrInfo: MachinePointerInfo(SrcV));
3602}
3603
3604SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
3605 const TargetLowering *TLI = STI.getTargetLowering();
3606 SDLoc DL(Op);
3607 EVT PtrVT = TLI->getPointerTy(DL: DAG.getDataLayout());
3608
3609 // Store the address of unsized array <function>_vararg[] in the ap object.
3610 SDValue VAReg = getParamSymbol(DAG, /* vararg */ I: -1, T: PtrVT);
3611
3612 const Value *SV = cast<SrcValueSDNode>(Val: Op.getOperand(i: 2))->getValue();
3613 return DAG.getStore(Chain: Op.getOperand(i: 0), dl: DL, Val: VAReg, Ptr: Op.getOperand(i: 1),
3614 PtrInfo: MachinePointerInfo(SV));
3615}
3616
3617static std::pair<MemSDNode *, uint32_t>
3618convertMLOADToLoadWithUsedBytesMask(MemSDNode *N, SelectionDAG &DAG,
3619 const NVPTXSubtarget &STI) {
3620 SDValue Chain = N->getOperand(Num: 0);
3621 SDValue BasePtr = N->getOperand(Num: 1);
3622 SDValue Mask = N->getOperand(Num: 3);
3623 [[maybe_unused]] SDValue Passthru = N->getOperand(Num: 4);
3624
3625 SDLoc DL(N);
3626 EVT ResVT = N->getValueType(ResNo: 0);
3627 assert(ResVT.isVector() && "Masked vector load must have vector type");
3628 // While we only expect poison passthru vectors as an input to the backend,
3629 // when the legalization framework splits a poison vector in half, it creates
3630 // two undef vectors, so we can technically expect those too.
3631 assert((Passthru.getOpcode() == ISD::POISON ||
3632 Passthru.getOpcode() == ISD::UNDEF) &&
3633 "Passthru operand expected to be poison or undef");
3634
3635 // Extract the mask and convert it to a uint32_t representing the used bytes
3636 // of the entire vector load
3637 uint32_t UsedBytesMask = 0;
3638 uint32_t ElementSizeInBits = ResVT.getVectorElementType().getSizeInBits();
3639 assert(ElementSizeInBits % 8 == 0 && "Unexpected element size");
3640 uint32_t ElementSizeInBytes = ElementSizeInBits / 8;
3641 uint32_t ElementMask = (1u << ElementSizeInBytes) - 1u;
3642
3643 for (SDValue Op : reverse(C: Mask->ops())) {
3644 // We technically only want to do this shift for every
3645 // iteration *but* the first, but in the first iteration UsedBytesMask is 0,
3646 // so this shift is a no-op.
3647 UsedBytesMask <<= ElementSizeInBytes;
3648
3649 // Mask elements must be constants.
3650 if (Op->getAsZExtVal() != 0)
3651 UsedBytesMask |= ElementMask;
3652 }
3653
3654 assert(UsedBytesMask != 0 && UsedBytesMask != UINT32_MAX &&
3655 "Unexpected masked load with elements masked all on or all off");
3656
3657 // Create a new load sd node to be handled normally by ReplaceLoadVector.
3658 MemSDNode *NewLD = cast<MemSDNode>(
3659 Val: DAG.getLoad(VT: ResVT, dl: DL, Chain, Ptr: BasePtr, MMO: N->getMemOperand()).getNode());
3660
3661 // If our subtarget does not support the used bytes mask pragma, "drop" the
3662 // mask by setting it to UINT32_MAX
3663 if (!STI.hasUsedBytesMaskPragma())
3664 UsedBytesMask = UINT32_MAX;
3665
3666 return {NewLD, UsedBytesMask};
3667}
3668
3669/// replaceLoadVector - Convert vector loads into multi-output scalar loads.
3670static std::optional<std::pair<SDValue, SDValue>>
3671replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
3672 MemSDNode *LD = cast<MemSDNode>(Val: N);
3673 const EVT ResVT = LD->getValueType(ResNo: 0);
3674 const EVT MemVT = LD->getMemoryVT();
3675
3676 // If we're doing sign/zero extension as part of the load, avoid lowering to
3677 // a LoadV node. TODO: consider relaxing this restriction.
3678 if (ResVT != MemVT)
3679 return std::nullopt;
3680
3681 const auto NumEltsAndEltVT =
3682 getVectorLoweringShape(VectorEVT: ResVT, STI, AddressSpace: LD->getAddressSpace());
3683 if (!NumEltsAndEltVT)
3684 return std::nullopt;
3685 const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
3686
3687 Align Alignment = LD->getAlign();
3688 const auto &TD = DAG.getDataLayout();
3689 Align PrefAlign = TD.getPrefTypeAlign(Ty: MemVT.getTypeForEVT(Context&: *DAG.getContext()));
3690 if (Alignment < PrefAlign) {
3691 // This load is not sufficiently aligned, so bail out and let this vector
3692 // load be scalarized. Note that we may still be able to emit smaller
3693 // vector loads. For example, if we are loading a <4 x float> with an
3694 // alignment of 8, this check will fail but the legalizer will try again
3695 // with 2 x <2 x float>, which will succeed with an alignment of 8.
3696 return std::nullopt;
3697 }
3698
3699 // If we have a masked load, convert it to a normal load now
3700 std::optional<uint32_t> UsedBytesMask = std::nullopt;
3701 if (LD->getOpcode() == ISD::MLOAD)
3702 std::tie(args&: LD, args&: UsedBytesMask) =
3703 convertMLOADToLoadWithUsedBytesMask(N: LD, DAG, STI);
3704
3705 // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
3706 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
3707 // loaded type to i16 and propagate the "real" type as the memory type.
3708 const MVT LoadEltVT = (EltVT.getSizeInBits() < 16) ? MVT::i16 : EltVT;
3709
3710 unsigned Opcode;
3711 switch (NumElts) {
3712 default:
3713 return std::nullopt;
3714 case 2:
3715 Opcode = NVPTXISD::LoadV2;
3716 break;
3717 case 4:
3718 Opcode = NVPTXISD::LoadV4;
3719 break;
3720 case 8:
3721 Opcode = NVPTXISD::LoadV8;
3722 break;
3723 }
3724 auto ListVTs = SmallVector<EVT, 9>(NumElts, LoadEltVT);
3725 ListVTs.push_back(Elt: MVT::Other);
3726 SDVTList LdResVTs = DAG.getVTList(VTs: ListVTs);
3727
3728 SDLoc DL(LD);
3729
3730 // Copy regular operands
3731 SmallVector<SDValue, 8> OtherOps(LD->ops());
3732
3733 OtherOps.push_back(
3734 Elt: DAG.getConstant(Val: UsedBytesMask.value_or(UINT32_MAX), DL, VT: MVT::i32));
3735
3736 // The select routine does not have access to the LoadSDNode instance, so
3737 // pass along the extension information
3738 OtherOps.push_back(
3739 Elt: DAG.getIntPtrConstant(Val: cast<LoadSDNode>(Val: LD)->getExtensionType(), DL));
3740
3741 SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, dl: DL, VTList: LdResVTs, Ops: OtherOps, MemVT,
3742 MMO: LD->getMemOperand());
3743
3744 SmallVector<SDValue> ScalarRes;
3745 if (EltVT.isVector()) {
3746 assert(EVT(EltVT.getVectorElementType()) == ResVT.getVectorElementType());
3747 assert(NumElts * EltVT.getVectorNumElements() ==
3748 ResVT.getVectorNumElements());
3749 // Generate EXTRACT_VECTOR_ELTs to split v2[i,f,bf]16/v4i8 subvectors back
3750 // into individual elements.
3751 for (const unsigned I : llvm::seq(Size: NumElts)) {
3752 SDValue SubVector = NewLD.getValue(R: I);
3753 DAG.ExtractVectorElements(Op: SubVector, Args&: ScalarRes);
3754 }
3755 } else {
3756 for (const unsigned I : llvm::seq(Size: NumElts)) {
3757 SDValue Res = NewLD.getValue(R: I);
3758 if (LoadEltVT != EltVT)
3759 Res = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: EltVT, Operand: Res);
3760 ScalarRes.push_back(Elt: Res);
3761 }
3762 }
3763
3764 SDValue LoadChain = NewLD.getValue(R: NumElts);
3765
3766 const MVT BuildVecVT =
3767 MVT::getVectorVT(VT: EltVT.getScalarType(), NumElements: ScalarRes.size());
3768 SDValue BuildVec = DAG.getBuildVector(VT: BuildVecVT, DL, Ops: ScalarRes);
3769 SDValue LoadValue = DAG.getBitcast(VT: ResVT, V: BuildVec);
3770
3771 return {{LoadValue, LoadChain}};
3772}
3773
3774static void replaceLoadVector(SDNode *N, SelectionDAG &DAG,
3775 SmallVectorImpl<SDValue> &Results,
3776 const NVPTXSubtarget &STI) {
3777 if (auto Res = replaceLoadVector(N, DAG, STI))
3778 Results.append(IL: {Res->first, Res->second});
3779}
3780
3781static SDValue lowerLoadVector(SDNode *N, SelectionDAG &DAG,
3782 const NVPTXSubtarget &STI) {
3783 if (auto Res = replaceLoadVector(N, DAG, STI))
3784 return DAG.getMergeValues(Ops: {Res->first, Res->second}, dl: SDLoc(N));
3785 return SDValue();
3786}
3787
3788// v = ld i1* addr
3789// =>
3790// v1 = ld i8* addr (-> i16)
3791// v = trunc i16 to i1
3792static SDValue lowerLOADi1(LoadSDNode *LD, SelectionDAG &DAG) {
3793 SDLoc dl(LD);
3794 assert(LD->getExtensionType() == ISD::NON_EXTLOAD);
3795 assert(LD->getValueType(0) == MVT::i1 && "Custom lowering for i1 load only");
3796 SDValue newLD = DAG.getExtLoad(ExtType: ISD::ZEXTLOAD, dl, VT: MVT::i16, Chain: LD->getChain(),
3797 Ptr: LD->getBasePtr(), PtrInfo: LD->getPointerInfo(),
3798 MemVT: MVT::i8, Alignment: LD->getAlign(),
3799 MMOFlags: LD->getMemOperand()->getFlags());
3800 SDValue result = DAG.getNode(Opcode: ISD::TRUNCATE, DL: dl, VT: MVT::i1, Operand: newLD);
3801 // The legalizer (the caller) is expecting two values from the legalized
3802 // load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
3803 // in LegalizeDAG.cpp which also uses MergeValues.
3804 return DAG.getMergeValues(Ops: {result, LD->getChain()}, dl);
3805}
3806
3807SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
3808 LoadSDNode *LD = cast<LoadSDNode>(Val&: Op);
3809
3810 if (Op.getValueType() == MVT::i1)
3811 return lowerLOADi1(LD, DAG);
3812
3813 // To improve CodeGen we'll legalize any-extend loads to zext loads. This is
3814 // how they'll be lowered in ISel anyway, and by doing this a little earlier
3815 // we allow for more DAG combine opportunities.
3816 if (LD->getExtensionType() == ISD::EXTLOAD) {
3817 assert(LD->getValueType(0).isInteger() && LD->getMemoryVT().isInteger() &&
3818 "Unexpected fpext-load");
3819 return DAG.getExtLoad(ExtType: ISD::ZEXTLOAD, dl: SDLoc(Op), VT: Op.getValueType(),
3820 Chain: LD->getChain(), Ptr: LD->getBasePtr(), MemVT: LD->getMemoryVT(),
3821 MMO: LD->getMemOperand());
3822 }
3823
3824 llvm_unreachable("Unexpected custom lowering for load");
3825}
3826
3827SDValue NVPTXTargetLowering::LowerMLOAD(SDValue Op, SelectionDAG &DAG) const {
3828 // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
3829 // masked loads of these types and have to handle them here.
3830 // v2f32 also needs to be handled here if the subtarget has f32x2
3831 // instructions, making it legal.
3832 //
3833 // Note: misaligned masked loads should never reach this point
3834 // because the override of isLegalMaskedLoad in NVPTXTargetTransformInfo.cpp
3835 // will validate alignment. Therefore, we do not need to special case handle
3836 // them here.
3837 EVT VT = Op.getValueType();
3838 if (NVPTX::isPackedVectorTy(VT)) {
3839 auto Result = convertMLOADToLoadWithUsedBytesMask(
3840 N: cast<MemSDNode>(Val: Op.getNode()), DAG, STI);
3841 MemSDNode *LD = std::get<0>(in&: Result);
3842 uint32_t UsedBytesMask = std::get<1>(in&: Result);
3843
3844 SDLoc DL(LD);
3845
3846 // Copy regular operands
3847 SmallVector<SDValue, 8> OtherOps(LD->ops());
3848
3849 OtherOps.push_back(Elt: DAG.getConstant(Val: UsedBytesMask, DL, VT: MVT::i32));
3850
3851 // We currently are not lowering extending loads, but pass the extension
3852 // type anyway as later handling expects it.
3853 OtherOps.push_back(
3854 Elt: DAG.getIntPtrConstant(Val: cast<LoadSDNode>(Val: LD)->getExtensionType(), DL));
3855 SDValue NewLD =
3856 DAG.getMemIntrinsicNode(Opcode: NVPTXISD::MLoad, dl: DL, VTList: LD->getVTList(), Ops: OtherOps,
3857 MemVT: LD->getMemoryVT(), MMO: LD->getMemOperand());
3858 return NewLD;
3859 }
3860 return SDValue();
3861}
3862
3863static SDValue lowerSTOREVector(SDValue Op, SelectionDAG &DAG,
3864 const NVPTXSubtarget &STI) {
3865 MemSDNode *N = cast<MemSDNode>(Val: Op.getNode());
3866 SDValue Val = N->getOperand(Num: 1);
3867 SDLoc DL(N);
3868 const EVT ValVT = Val.getValueType();
3869 const EVT MemVT = N->getMemoryVT();
3870
3871 // If we're truncating as part of the store, avoid lowering to a StoreV node.
3872 // TODO: consider relaxing this restriction.
3873 if (ValVT != MemVT)
3874 return SDValue();
3875
3876 const auto NumEltsAndEltVT =
3877 getVectorLoweringShape(VectorEVT: ValVT, STI, AddressSpace: N->getAddressSpace());
3878 if (!NumEltsAndEltVT)
3879 return SDValue();
3880 const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
3881
3882 const DataLayout &TD = DAG.getDataLayout();
3883
3884 Align Alignment = N->getAlign();
3885 Align PrefAlign = TD.getPrefTypeAlign(Ty: ValVT.getTypeForEVT(Context&: *DAG.getContext()));
3886 if (Alignment < PrefAlign) {
3887 // This store is not sufficiently aligned, so bail out and let this vector
3888 // store be scalarized. Note that we may still be able to emit smaller
3889 // vector stores. For example, if we are storing a <4 x float> with an
3890 // alignment of 8, this check will fail but the legalizer will try again
3891 // with 2 x <2 x float>, which will succeed with an alignment of 8.
3892 return SDValue();
3893 }
3894
3895 unsigned Opcode;
3896 switch (NumElts) {
3897 default:
3898 return SDValue();
3899 case 2:
3900 Opcode = NVPTXISD::StoreV2;
3901 break;
3902 case 4:
3903 Opcode = NVPTXISD::StoreV4;
3904 break;
3905 case 8:
3906 Opcode = NVPTXISD::StoreV8;
3907 break;
3908 }
3909
3910 SmallVector<SDValue, 8> Ops;
3911
3912 // First is the chain
3913 Ops.push_back(Elt: N->getOperand(Num: 0));
3914
3915 // Then the split values
3916 if (EltVT.isVector()) {
3917 assert(EVT(EltVT.getVectorElementType()) == ValVT.getVectorElementType());
3918 assert(NumElts * EltVT.getVectorNumElements() ==
3919 ValVT.getVectorNumElements());
3920 // Combine individual elements into v2[i,f,bf]16/v4i8 subvectors to be
3921 // stored as b32s
3922 const unsigned NumEltsPerSubVector = EltVT.getVectorNumElements();
3923 for (const unsigned I : llvm::seq(Size: NumElts)) {
3924 SmallVector<SDValue, 4> SubVectorElts;
3925 DAG.ExtractVectorElements(Op: Val, Args&: SubVectorElts, Start: I * NumEltsPerSubVector,
3926 Count: NumEltsPerSubVector);
3927 Ops.push_back(Elt: DAG.getBuildVector(VT: EltVT, DL, Ops: SubVectorElts));
3928 }
3929 } else {
3930 SDValue V = DAG.getBitcast(VT: MVT::getVectorVT(VT: EltVT, NumElements: NumElts), V: Val);
3931 for (const unsigned I : llvm::seq(Size: NumElts)) {
3932 SDValue ExtVal = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: V,
3933 N2: DAG.getIntPtrConstant(Val: I, DL));
3934
3935 // Since StoreV2 is a target node, we cannot rely on DAG type
3936 // legalization. Therefore, we must ensure the type is legal. For i1 and
3937 // i8, we set the stored type to i16 and propagate the "real" type as the
3938 // memory type.
3939 if (EltVT.getSizeInBits() < 16)
3940 ExtVal = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT: MVT::i16, Operand: ExtVal);
3941 Ops.push_back(Elt: ExtVal);
3942 }
3943 }
3944
3945 // Then any remaining arguments
3946 Ops.append(in_start: N->op_begin() + 2, in_end: N->op_end());
3947
3948 SDValue NewSt =
3949 DAG.getMemIntrinsicNode(Opcode, dl: DL, VTList: DAG.getVTList(VT: MVT::Other), Ops,
3950 MemVT: N->getMemoryVT(), MMO: N->getMemOperand());
3951
3952 // return DCI.CombineTo(N, NewSt, true);
3953 return NewSt;
3954}
3955
3956SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
3957 StoreSDNode *Store = cast<StoreSDNode>(Val&: Op);
3958 EVT VT = Store->getMemoryVT();
3959
3960 if (VT == MVT::i1)
3961 return LowerSTOREi1(Op, DAG);
3962
3963 // Lower store of any other vector type, including v2f32 as we want to break
3964 // it apart since this is not a widely-supported type.
3965 return lowerSTOREVector(Op, DAG, STI);
3966}
3967
3968// st i1 v, addr
3969// =>
3970// v1 = zxt v to i16
3971// st.u8 i16, addr
3972SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
3973 SDNode *Node = Op.getNode();
3974 SDLoc dl(Node);
3975 StoreSDNode *ST = cast<StoreSDNode>(Val: Node);
3976 SDValue Tmp1 = ST->getChain();
3977 SDValue Tmp2 = ST->getBasePtr();
3978 SDValue Tmp3 = ST->getValue();
3979 assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only");
3980 Tmp3 = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL: dl, VT: MVT::i16, Operand: Tmp3);
3981 SDValue Result =
3982 DAG.getTruncStore(Chain: Tmp1, dl, Val: Tmp3, Ptr: Tmp2, PtrInfo: ST->getPointerInfo(), SVT: MVT::i8,
3983 Alignment: ST->getAlign(), MMOFlags: ST->getMemOperand()->getFlags());
3984 return Result;
3985}
3986
3987SDValue NVPTXTargetLowering::LowerCopyToReg_128(SDValue Op,
3988 SelectionDAG &DAG) const {
3989 // Change the CopyToReg to take in two 64-bit operands instead of a 128-bit
3990 // operand so that it can pass the legalization.
3991
3992 assert(Op.getOperand(1).getValueType() == MVT::i128 &&
3993 "Custom lowering for 128-bit CopyToReg only");
3994
3995 SDNode *Node = Op.getNode();
3996 SDLoc DL(Node);
3997
3998 SDValue Cast = DAG.getBitcast(VT: MVT::v2i64, V: Op->getOperand(Num: 2));
3999 SDValue Lo = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i64, N1: Cast,
4000 N2: DAG.getIntPtrConstant(Val: 0, DL));
4001 SDValue Hi = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i64, N1: Cast,
4002 N2: DAG.getIntPtrConstant(Val: 1, DL));
4003
4004 SmallVector<SDValue, 5> NewOps(Op->getNumOperands() + 1);
4005 SmallVector<EVT, 3> ResultsType(Node->values());
4006
4007 NewOps[0] = Op->getOperand(Num: 0); // Chain
4008 NewOps[1] = Op->getOperand(Num: 1); // Dst Reg
4009 NewOps[2] = Lo; // Lower 64-bit
4010 NewOps[3] = Hi; // Higher 64-bit
4011 if (Op.getNumOperands() == 4)
4012 NewOps[4] = Op->getOperand(Num: 3); // Glue if exists
4013
4014 return DAG.getNode(Opcode: ISD::CopyToReg, DL, ResultTys: ResultsType, Ops: NewOps);
4015}
4016
4017unsigned NVPTXTargetLowering::getNumRegisters(
4018 LLVMContext &Context, EVT VT,
4019 std::optional<MVT> RegisterVT = std::nullopt) const {
4020 if (VT == MVT::i128 && RegisterVT == MVT::i128)
4021 return 1;
4022 return TargetLoweringBase::getNumRegisters(Context, VT, RegisterVT);
4023}
4024
4025bool NVPTXTargetLowering::splitValueIntoRegisterParts(
4026 SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
4027 unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
4028 if (Val.getValueType() == MVT::i128 && NumParts == 1) {
4029 Parts[0] = Val;
4030 return true;
4031 }
4032 return false;
4033}
4034
4035// This creates target external symbol for a function parameter.
4036// Name of the symbol is composed from its index and the function name.
4037// Negative index corresponds to special parameter (unsized array) used for
4038// passing variable arguments.
4039SDValue NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int I,
4040 EVT T) const {
4041 StringRef SavedStr = nvTM->getStrPool().save(
4042 S: getParamName(F: &DAG.getMachineFunction().getFunction(), Idx: I));
4043 return DAG.getExternalSymbol(Sym: SavedStr.data(), VT: T);
4044}
4045
4046SDValue NVPTXTargetLowering::getCallParamSymbol(SelectionDAG &DAG, int I,
4047 EVT T) const {
4048 const StringRef SavedStr = nvTM->getStrPool().save(S: "param" + Twine(I));
4049 return DAG.getExternalSymbol(Sym: SavedStr.data(), VT: T);
4050}
4051
4052SDValue NVPTXTargetLowering::LowerFormalArguments(
4053 SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
4054 const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &dl,
4055 SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
4056 const DataLayout &DL = DAG.getDataLayout();
4057 LLVMContext &Ctx = *DAG.getContext();
4058 auto PtrVT = getPointerTy(DL: DAG.getDataLayout());
4059
4060 const Function &F = DAG.getMachineFunction().getFunction();
4061 const bool IsKernel = isKernelFunction(F);
4062
4063 SDValue Root = DAG.getRoot();
4064 SmallVector<SDValue, 16> OutChains;
4065
4066 // argTypes.size() (or theArgs.size()) and Ins.size() need not match.
4067 // Ins.size() will be larger
4068 // * if there is an aggregate argument with multiple fields (each field
4069 // showing up separately in Ins)
4070 // * if there is a vector argument with more than typical vector-length
4071 // elements (generally if more than 4) where each vector element is
4072 // individually present in Ins.
4073 // So a different index should be used for indexing into Ins.
4074 // See similar issue in LowerCall.
4075
4076 auto AllIns = ArrayRef(Ins);
4077 for (const auto &Arg : F.args()) {
4078 const auto ArgIns = AllIns.take_while(
4079 Pred: [&](auto I) { return I.OrigArgIndex == Arg.getArgNo(); });
4080 AllIns = AllIns.drop_front(N: ArgIns.size());
4081
4082 Type *Ty = Arg.getType();
4083
4084 if (ArgIns.empty())
4085 report_fatal_error(reason: "Empty parameter types are not supported");
4086
4087 if (Arg.use_empty()) {
4088 // argument is dead
4089 for (const auto &In : ArgIns) {
4090 assert(!In.Used && "Arg.use_empty() is true but Arg is used?");
4091 InVals.push_back(Elt: DAG.getUNDEF(VT: In.VT));
4092 }
4093 continue;
4094 }
4095
4096 SDValue ArgSymbol = getParamSymbol(DAG, I: Arg.getArgNo(), T: PtrVT);
4097
4098 // In the following cases, assign a node order of "i+1"
4099 // to newly created nodes. The SDNodes for params have to
4100 // appear in the same order as their order of appearance
4101 // in the original function. "i+1" holds that order.
4102 if (Arg.hasByValAttr()) {
4103 // Param has ByVal attribute
4104 // Return MoveParam(param symbol).
4105 // Ideally, the param symbol can be returned directly,
4106 // but when SDNode builder decides to use it in a CopyToReg(),
4107 // machine instruction fails because TargetExternalSymbol
4108 // (not lowered) is target dependent, and CopyToReg assumes
4109 // the source is lowered.
4110 assert(ArgIns.size() == 1 && "ByVal argument must be a pointer");
4111 const auto &ByvalIn = ArgIns[0];
4112 assert(getValueType(DL, Ty) == ByvalIn.VT &&
4113 "Ins type did not match function type");
4114 assert(ByvalIn.VT == PtrVT && "ByVal argument must be a pointer");
4115
4116 SDValue P;
4117 if (IsKernel) {
4118 assert(isParamGridConstant(Arg) && "ByVal argument must be lowered to "
4119 "grid_constant by NVPTXLowerArgs");
4120 P = ArgSymbol;
4121 P.getNode()->setIROrder(Arg.getArgNo() + 1);
4122 } else {
4123 P = DAG.getNode(Opcode: NVPTXISD::MoveParam, DL: dl, VT: ByvalIn.VT, Operand: ArgSymbol);
4124 P.getNode()->setIROrder(Arg.getArgNo() + 1);
4125 P = DAG.getAddrSpaceCast(dl, VT: ByvalIn.VT, Ptr: P, SrcAS: ADDRESS_SPACE_LOCAL,
4126 DestAS: ADDRESS_SPACE_GENERIC);
4127 }
4128 InVals.push_back(Elt: P);
4129 } else {
4130 SmallVector<EVT, 16> VTs;
4131 SmallVector<uint64_t, 16> Offsets;
4132 ComputePTXValueVTs(TLI: *this, DL, Ctx, CallConv, Ty, ValueVTs&: VTs, Offsets);
4133 assert(VTs.size() == ArgIns.size() && "Size mismatch");
4134 assert(VTs.size() == Offsets.size() && "Size mismatch");
4135
4136 const Align ArgAlign = getFunctionArgumentAlignment(
4137 F: &F, Ty, Idx: Arg.getArgNo() + AttributeList::FirstArgIndex, DL);
4138
4139 unsigned I = 0;
4140 const auto VI = VectorizePTXValueVTs(ValueVTs: VTs, Offsets, ParamAlignment: ArgAlign);
4141 for (const unsigned NumElts : VI) {
4142 // i1 is loaded/stored as i8
4143 const EVT LoadVT = VTs[I] == MVT::i1 ? MVT::i8 : VTs[I];
4144 const EVT VecVT = getVectorizedVT(VT: LoadVT, N: NumElts, C&: Ctx);
4145
4146 SDValue VecAddr = DAG.getObjectPtrOffset(
4147 SL: dl, Ptr: ArgSymbol, Offset: TypeSize::getFixed(ExactSize: Offsets[I]));
4148
4149 const Align PartAlign = commonAlignment(A: ArgAlign, Offset: Offsets[I]);
4150 const unsigned AS = IsKernel ? NVPTX::AddressSpace::EntryParam
4151 : NVPTX::AddressSpace::DeviceParam;
4152 SDValue P = DAG.getLoad(VT: VecVT, dl, Chain: Root, Ptr: VecAddr,
4153 PtrInfo: MachinePointerInfo(AS), Alignment: PartAlign,
4154 MMOFlags: MachineMemOperand::MODereferenceable |
4155 MachineMemOperand::MOInvariant);
4156 P.getNode()->setIROrder(Arg.getArgNo() + 1);
4157 for (const unsigned J : llvm::seq(Size: NumElts)) {
4158 SDValue Elt = getExtractVectorizedValue(V: P, I: J, VT: LoadVT, dl, DAG);
4159
4160 Elt = correctParamType(V: Elt, ExpectedVT: ArgIns[I + J].VT, Flags: ArgIns[I + J].Flags,
4161 DAG, dl);
4162 InVals.push_back(Elt);
4163 }
4164 I += NumElts;
4165 }
4166 }
4167 }
4168
4169 if (!OutChains.empty())
4170 DAG.setRoot(DAG.getTokenFactor(DL: dl, Vals&: OutChains));
4171
4172 return Chain;
4173}
4174
4175SDValue
4176NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
4177 bool isVarArg,
4178 const SmallVectorImpl<ISD::OutputArg> &Outs,
4179 const SmallVectorImpl<SDValue> &OutVals,
4180 const SDLoc &dl, SelectionDAG &DAG) const {
4181 const Function &F = DAG.getMachineFunction().getFunction();
4182 Type *RetTy = F.getReturnType();
4183
4184 if (RetTy->isVoidTy()) {
4185 assert(OutVals.empty() && Outs.empty() && "Return value expected for void");
4186 return DAG.getNode(Opcode: NVPTXISD::RET_GLUE, DL: dl, VT: MVT::Other, Operand: Chain);
4187 }
4188
4189 const DataLayout &DL = DAG.getDataLayout();
4190 LLVMContext &Ctx = *DAG.getContext();
4191
4192 const SDValue RetSymbol = DAG.getExternalSymbol(Sym: "func_retval0", VT: MVT::i32);
4193 const auto RetAlign = getFunctionParamOptimizedAlign(F: &F, ArgTy: RetTy, DL);
4194
4195 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
4196 // 32-bits are sign extended or zero extended, depending on whether
4197 // they are signed or unsigned types.
4198 const bool ExtendIntegerRetVal =
4199 RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty: RetTy) < 32;
4200
4201 SmallVector<EVT, 16> VTs;
4202 SmallVector<uint64_t, 16> Offsets;
4203 ComputePTXValueVTs(TLI: *this, DL, Ctx, CallConv, Ty: RetTy, ValueVTs&: VTs, Offsets);
4204 assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
4205
4206 const auto GetRetVal = [&](unsigned I) -> SDValue {
4207 SDValue RetVal = OutVals[I];
4208 assert(promoteScalarIntegerPTX(RetVal.getValueType()) ==
4209 RetVal.getValueType() &&
4210 "OutVal type should always be legal");
4211
4212 const EVT VTI = promoteScalarIntegerPTX(VT: VTs[I]);
4213 const EVT StoreVT =
4214 ExtendIntegerRetVal ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
4215 return correctParamType(V: RetVal, ExpectedVT: StoreVT, Flags: Outs[I].Flags, DAG, dl);
4216 };
4217
4218 unsigned I = 0;
4219 const auto VI = VectorizePTXValueVTs(ValueVTs: VTs, Offsets, ParamAlignment: RetAlign);
4220 for (const unsigned NumElts : VI) {
4221 const MaybeAlign CurrentAlign = ExtendIntegerRetVal
4222 ? MaybeAlign(std::nullopt)
4223 : commonAlignment(A: RetAlign, Offset: Offsets[I]);
4224
4225 SDValue Val = getBuildVectorizedValue(
4226 N: NumElts, dl, DAG, GetElement: [&](unsigned K) { return GetRetVal(I + K); });
4227
4228 SDValue Ptr =
4229 DAG.getObjectPtrOffset(SL: dl, Ptr: RetSymbol, Offset: TypeSize::getFixed(ExactSize: Offsets[I]));
4230
4231 Chain = DAG.getStore(Chain, dl, Val, Ptr,
4232 PtrInfo: MachinePointerInfo(NVPTX::AddressSpace::DeviceParam),
4233 Alignment: CurrentAlign);
4234
4235 I += NumElts;
4236 }
4237
4238 return DAG.getNode(Opcode: NVPTXISD::RET_GLUE, DL: dl, VT: MVT::Other, Operand: Chain);
4239}
4240
4241void NVPTXTargetLowering::LowerAsmOperandForConstraint(
4242 SDValue Op, StringRef Constraint, std::vector<SDValue> &Ops,
4243 SelectionDAG &DAG) const {
4244 if (Constraint.size() > 1)
4245 return;
4246 TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
4247}
4248
4249// llvm.ptx.memcpy.const and llvm.ptx.memmove.const need to be modeled as
4250// TgtMemIntrinsic
4251// because we need the information that is only available in the "Value" type
4252// of destination
4253// pointer. In particular, the address space information.
4254void NVPTXTargetLowering::getTgtMemIntrinsic(
4255 SmallVectorImpl<IntrinsicInfo> &Infos, const CallBase &I,
4256 MachineFunction &MF, unsigned Intrinsic) const {
4257 IntrinsicInfo Info;
4258 switch (Intrinsic) {
4259 default:
4260 return;
4261 case Intrinsic::nvvm_match_all_sync_i32p:
4262 case Intrinsic::nvvm_match_all_sync_i64p:
4263 Info.opc = ISD::INTRINSIC_W_CHAIN;
4264 // memVT is bogus. These intrinsics have IntrInaccessibleMemOnly attribute
4265 // in order to model data exchange with other threads, but perform no real
4266 // memory accesses.
4267 Info.memVT = MVT::i1;
4268
4269 // Our result depends on both our and other thread's arguments.
4270 Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOStore;
4271 Infos.push_back(Elt: Info);
4272 return;
4273 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col:
4274 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row:
4275 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col_stride:
4276 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row_stride:
4277 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col:
4278 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row:
4279 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col_stride:
4280 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride:
4281 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col:
4282 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row:
4283 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col_stride:
4284 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row_stride:
4285 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col:
4286 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row:
4287 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col_stride:
4288 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row_stride:
4289 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col:
4290 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row:
4291 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col_stride:
4292 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row_stride:
4293 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col:
4294 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row:
4295 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col_stride:
4296 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row_stride: {
4297 Info.opc = ISD::INTRINSIC_W_CHAIN;
4298 Info.memVT = MVT::v8f16;
4299 Info.ptrVal = I.getArgOperand(i: 0);
4300 Info.offset = 0;
4301 Info.flags = MachineMemOperand::MOLoad;
4302 Info.align = Align(16);
4303 Infos.push_back(Elt: Info);
4304 return;
4305 }
4306 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col:
4307 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col_stride:
4308 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col_stride:
4309 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col:
4310 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row:
4311 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row_stride:
4312 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row_stride:
4313 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row:
4314 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col:
4315 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col_stride:
4316 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row:
4317 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row_stride:
4318 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col:
4319 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col_stride:
4320 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col_stride:
4321 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col:
4322 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row:
4323 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row_stride:
4324 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row_stride:
4325 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row:
4326 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col:
4327 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col_stride:
4328 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row:
4329 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row_stride: {
4330 Info.opc = ISD::INTRINSIC_W_CHAIN;
4331 Info.memVT = MVT::v2i32;
4332 Info.ptrVal = I.getArgOperand(i: 0);
4333 Info.offset = 0;
4334 Info.flags = MachineMemOperand::MOLoad;
4335 Info.align = Align(8);
4336 Infos.push_back(Elt: Info);
4337 return;
4338 }
4339
4340 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col:
4341 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col_stride:
4342 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col_stride:
4343 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col:
4344 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row:
4345 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row_stride:
4346 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row_stride:
4347 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row:
4348 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col:
4349 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col_stride:
4350 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row:
4351 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row_stride:
4352 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col:
4353 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col_stride:
4354 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row:
4355 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row_stride:
4356
4357 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col:
4358 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col_stride:
4359 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col_stride:
4360 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col:
4361 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row:
4362 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row_stride:
4363 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row_stride:
4364 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row:
4365 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col:
4366 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col_stride:
4367 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row:
4368 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row_stride:
4369 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col:
4370 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col_stride:
4371 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row:
4372 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride:
4373 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16:
4374 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16:
4375 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8:
4376 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64:
4377 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32:
4378 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64:
4379 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32: {
4380 Info.opc = ISD::INTRINSIC_W_CHAIN;
4381 Info.memVT = MVT::v4i32;
4382 Info.ptrVal = I.getArgOperand(i: 0);
4383 Info.offset = 0;
4384 Info.flags = MachineMemOperand::MOLoad;
4385 Info.align = Align(16);
4386 Infos.push_back(Elt: Info);
4387 return;
4388 }
4389
4390 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col:
4391 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col_stride:
4392 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col_stride:
4393 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col:
4394 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row:
4395 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row_stride:
4396 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row_stride:
4397 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row:
4398
4399 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col:
4400 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col_stride:
4401 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col_stride:
4402 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col:
4403 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row:
4404 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row_stride:
4405 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row_stride:
4406 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row:
4407 case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row:
4408 case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row_stride:
4409 case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col:
4410 case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col_stride:
4411 case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row:
4412 case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row_stride:
4413 case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row_stride:
4414 case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row:
4415 case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col:
4416 case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col_stride:
4417 case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col_stride:
4418 case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col:
4419 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16:
4420 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16:
4421 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64:
4422 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32: {
4423 Info.opc = ISD::INTRINSIC_W_CHAIN;
4424 Info.memVT = MVT::i32;
4425 Info.ptrVal = I.getArgOperand(i: 0);
4426 Info.offset = 0;
4427 Info.flags = MachineMemOperand::MOLoad;
4428 Info.align = Align(4);
4429 Infos.push_back(Elt: Info);
4430 return;
4431 }
4432
4433 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col:
4434 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row:
4435 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col_stride:
4436 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride:
4437 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col:
4438 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row:
4439 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col_stride:
4440 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row_stride:
4441 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col:
4442 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row:
4443 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col_stride:
4444 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row_stride: {
4445 Info.opc = ISD::INTRINSIC_W_CHAIN;
4446 Info.memVT = MVT::v4f16;
4447 Info.ptrVal = I.getArgOperand(i: 0);
4448 Info.offset = 0;
4449 Info.flags = MachineMemOperand::MOLoad;
4450 Info.align = Align(16);
4451 Infos.push_back(Elt: Info);
4452 return;
4453 }
4454
4455 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col:
4456 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row:
4457 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col_stride:
4458 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride:
4459 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col:
4460 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row:
4461 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col_stride:
4462 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row_stride:
4463 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col:
4464 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row:
4465 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col_stride:
4466 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride:
4467 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col:
4468 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row:
4469 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col_stride:
4470 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row_stride: {
4471 Info.opc = ISD::INTRINSIC_W_CHAIN;
4472 Info.memVT = MVT::v8f32;
4473 Info.ptrVal = I.getArgOperand(i: 0);
4474 Info.offset = 0;
4475 Info.flags = MachineMemOperand::MOLoad;
4476 Info.align = Align(16);
4477 Infos.push_back(Elt: Info);
4478 return;
4479 }
4480
4481 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col:
4482 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col_stride:
4483 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row:
4484 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row_stride:
4485
4486 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col:
4487 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col_stride:
4488 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row:
4489 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row_stride:
4490
4491 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col:
4492 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col_stride:
4493 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row:
4494 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row_stride:
4495 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col:
4496 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col_stride:
4497 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row:
4498 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row_stride:
4499 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col:
4500 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col_stride:
4501 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row:
4502 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row_stride: {
4503 Info.opc = ISD::INTRINSIC_W_CHAIN;
4504 Info.memVT = MVT::v8i32;
4505 Info.ptrVal = I.getArgOperand(i: 0);
4506 Info.offset = 0;
4507 Info.flags = MachineMemOperand::MOLoad;
4508 Info.align = Align(16);
4509 Infos.push_back(Elt: Info);
4510 return;
4511 }
4512
4513 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col:
4514 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col_stride:
4515 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row:
4516 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row_stride:
4517 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col:
4518 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col_stride:
4519 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row:
4520 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride:
4521 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16:
4522 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16:
4523 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8:
4524 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64:
4525 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32:
4526 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64:
4527 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32: {
4528 Info.opc = ISD::INTRINSIC_W_CHAIN;
4529 Info.memVT = MVT::v2i32;
4530 Info.ptrVal = I.getArgOperand(i: 0);
4531 Info.offset = 0;
4532 Info.flags = MachineMemOperand::MOLoad;
4533 Info.align = Align(8);
4534 Infos.push_back(Elt: Info);
4535 return;
4536 }
4537
4538 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col:
4539 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col_stride:
4540 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row:
4541 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row_stride:
4542
4543 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col:
4544 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col_stride:
4545 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row:
4546 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row_stride: {
4547 Info.opc = ISD::INTRINSIC_W_CHAIN;
4548 Info.memVT = MVT::f64;
4549 Info.ptrVal = I.getArgOperand(i: 0);
4550 Info.offset = 0;
4551 Info.flags = MachineMemOperand::MOLoad;
4552 Info.align = Align(8);
4553 Infos.push_back(Elt: Info);
4554 return;
4555 }
4556
4557 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col:
4558 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col_stride:
4559 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row:
4560 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row_stride: {
4561 Info.opc = ISD::INTRINSIC_W_CHAIN;
4562 Info.memVT = MVT::v2f64;
4563 Info.ptrVal = I.getArgOperand(i: 0);
4564 Info.offset = 0;
4565 Info.flags = MachineMemOperand::MOLoad;
4566 Info.align = Align(16);
4567 Infos.push_back(Elt: Info);
4568 return;
4569 }
4570
4571 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col:
4572 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row:
4573 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride:
4574 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride:
4575 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col:
4576 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row:
4577 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col_stride:
4578 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row_stride:
4579 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col:
4580 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row:
4581 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col_stride:
4582 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row_stride: {
4583 Info.opc = ISD::INTRINSIC_VOID;
4584 Info.memVT = MVT::v4f16;
4585 Info.ptrVal = I.getArgOperand(i: 0);
4586 Info.offset = 0;
4587 Info.flags = MachineMemOperand::MOStore;
4588 Info.align = Align(16);
4589 Infos.push_back(Elt: Info);
4590 return;
4591 }
4592
4593 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col:
4594 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row:
4595 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col_stride:
4596 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride:
4597 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col:
4598 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row:
4599 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col_stride:
4600 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row_stride:
4601 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col:
4602 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row:
4603 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col_stride:
4604 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride:
4605 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col:
4606 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row:
4607 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col_stride:
4608 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row_stride: {
4609 Info.opc = ISD::INTRINSIC_VOID;
4610 Info.memVT = MVT::v8f32;
4611 Info.ptrVal = I.getArgOperand(i: 0);
4612 Info.offset = 0;
4613 Info.flags = MachineMemOperand::MOStore;
4614 Info.align = Align(16);
4615 Infos.push_back(Elt: Info);
4616 return;
4617 }
4618
4619 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col:
4620 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col_stride:
4621 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row:
4622 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row_stride:
4623 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col:
4624 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col_stride:
4625 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row:
4626 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row_stride:
4627 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col:
4628 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col_stride:
4629 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row:
4630 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row_stride: {
4631 Info.opc = ISD::INTRINSIC_VOID;
4632 Info.memVT = MVT::v8i32;
4633 Info.ptrVal = I.getArgOperand(i: 0);
4634 Info.offset = 0;
4635 Info.flags = MachineMemOperand::MOStore;
4636 Info.align = Align(16);
4637 Infos.push_back(Elt: Info);
4638 return;
4639 }
4640
4641 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col:
4642 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col_stride:
4643 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row:
4644 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row_stride:
4645 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col:
4646 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride:
4647 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row:
4648 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride:
4649 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16:
4650 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16:
4651 case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8: {
4652 Info.opc = ISD::INTRINSIC_VOID;
4653 Info.memVT = MVT::v2i32;
4654 Info.ptrVal = I.getArgOperand(i: 0);
4655 Info.offset = 0;
4656 Info.flags = MachineMemOperand::MOStore;
4657 Info.align = Align(8);
4658 Infos.push_back(Elt: Info);
4659 return;
4660 }
4661
4662 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col:
4663 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col_stride:
4664 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row:
4665 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row_stride: {
4666 Info.opc = ISD::INTRINSIC_VOID;
4667 Info.memVT = MVT::v2f64;
4668 Info.ptrVal = I.getArgOperand(i: 0);
4669 Info.offset = 0;
4670 Info.flags = MachineMemOperand::MOStore;
4671 Info.align = Align(16);
4672 Infos.push_back(Elt: Info);
4673 return;
4674 }
4675
4676 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16:
4677 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16:
4678 case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8: {
4679 Info.opc = ISD::INTRINSIC_VOID;
4680 Info.memVT = MVT::i32;
4681 Info.ptrVal = I.getArgOperand(i: 0);
4682 Info.offset = 0;
4683 Info.flags = MachineMemOperand::MOStore;
4684 Info.align = Align(4);
4685 Infos.push_back(Elt: Info);
4686 return;
4687 }
4688
4689 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16:
4690 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16:
4691 case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8: {
4692 Info.opc = ISD::INTRINSIC_VOID;
4693 Info.memVT = MVT::v4i32;
4694 Info.ptrVal = I.getArgOperand(i: 0);
4695 Info.offset = 0;
4696 Info.flags = MachineMemOperand::MOStore;
4697 Info.align = Align(16);
4698 Infos.push_back(Elt: Info);
4699 return;
4700 }
4701
4702 case Intrinsic::nvvm_atomic_add_gen_f_cta:
4703 case Intrinsic::nvvm_atomic_add_gen_f_sys:
4704 case Intrinsic::nvvm_atomic_add_gen_i_cta:
4705 case Intrinsic::nvvm_atomic_add_gen_i_sys:
4706 case Intrinsic::nvvm_atomic_and_gen_i_cta:
4707 case Intrinsic::nvvm_atomic_and_gen_i_sys:
4708 case Intrinsic::nvvm_atomic_cas_gen_i_cta:
4709 case Intrinsic::nvvm_atomic_cas_gen_i_sys:
4710 case Intrinsic::nvvm_atomic_dec_gen_i_cta:
4711 case Intrinsic::nvvm_atomic_dec_gen_i_sys:
4712 case Intrinsic::nvvm_atomic_inc_gen_i_cta:
4713 case Intrinsic::nvvm_atomic_inc_gen_i_sys:
4714 case Intrinsic::nvvm_atomic_max_gen_i_cta:
4715 case Intrinsic::nvvm_atomic_max_gen_i_sys:
4716 case Intrinsic::nvvm_atomic_min_gen_i_cta:
4717 case Intrinsic::nvvm_atomic_min_gen_i_sys:
4718 case Intrinsic::nvvm_atomic_or_gen_i_cta:
4719 case Intrinsic::nvvm_atomic_or_gen_i_sys:
4720 case Intrinsic::nvvm_atomic_exch_gen_i_cta:
4721 case Intrinsic::nvvm_atomic_exch_gen_i_sys:
4722 case Intrinsic::nvvm_atomic_xor_gen_i_cta:
4723 case Intrinsic::nvvm_atomic_xor_gen_i_sys: {
4724 auto &DL = I.getDataLayout();
4725 Info.opc = ISD::INTRINSIC_W_CHAIN;
4726 Info.memVT = getValueType(DL, Ty: I.getType());
4727 Info.ptrVal = I.getArgOperand(i: 0);
4728 Info.offset = 0;
4729 Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOStore;
4730 Info.align.reset();
4731 Infos.push_back(Elt: Info);
4732 return;
4733 }
4734
4735 case Intrinsic::nvvm_prefetch_tensormap: {
4736 auto &DL = I.getDataLayout();
4737 Info.opc = ISD::INTRINSIC_VOID;
4738 Info.memVT = getPointerTy(DL);
4739 Info.ptrVal = I.getArgOperand(i: 0);
4740 Info.offset = 0;
4741 Info.flags =
4742 MachineMemOperand::MOLoad | MachineMemOperand::MODereferenceable;
4743 Info.align.reset();
4744 Infos.push_back(Elt: Info);
4745 return;
4746 }
4747
4748 case Intrinsic::nvvm_tensormap_replace_global_address:
4749 case Intrinsic::nvvm_tensormap_replace_global_stride: {
4750 Info.opc = ISD::INTRINSIC_VOID;
4751 Info.memVT = MVT::i64;
4752 Info.ptrVal = I.getArgOperand(i: 0);
4753 Info.offset = 0;
4754 Info.flags = MachineMemOperand::MOStore;
4755 Info.align.reset();
4756 Infos.push_back(Elt: Info);
4757 return;
4758 }
4759
4760 case Intrinsic::nvvm_tensormap_replace_rank:
4761 case Intrinsic::nvvm_tensormap_replace_box_dim:
4762 case Intrinsic::nvvm_tensormap_replace_global_dim:
4763 case Intrinsic::nvvm_tensormap_replace_element_stride:
4764 case Intrinsic::nvvm_tensormap_replace_elemtype:
4765 case Intrinsic::nvvm_tensormap_replace_interleave_layout:
4766 case Intrinsic::nvvm_tensormap_replace_swizzle_mode:
4767 case Intrinsic::nvvm_tensormap_replace_swizzle_atomicity:
4768 case Intrinsic::nvvm_tensormap_replace_fill_mode: {
4769 Info.opc = ISD::INTRINSIC_VOID;
4770 Info.memVT = MVT::i32;
4771 Info.ptrVal = I.getArgOperand(i: 0);
4772 Info.offset = 0;
4773 Info.flags = MachineMemOperand::MOStore;
4774 Info.align.reset();
4775 Infos.push_back(Elt: Info);
4776 return;
4777 }
4778
4779 case Intrinsic::nvvm_ldu_global_i:
4780 case Intrinsic::nvvm_ldu_global_f:
4781 case Intrinsic::nvvm_ldu_global_p: {
4782 Info.opc = ISD::INTRINSIC_W_CHAIN;
4783 Info.memVT = getValueType(DL: I.getDataLayout(), Ty: I.getType());
4784 Info.ptrVal = I.getArgOperand(i: 0);
4785 Info.offset = 0;
4786 Info.flags = MachineMemOperand::MOLoad;
4787 Info.align = cast<ConstantInt>(Val: I.getArgOperand(i: 1))->getMaybeAlignValue();
4788
4789 Infos.push_back(Elt: Info);
4790 return;
4791 }
4792 case Intrinsic::nvvm_tex_1d_v4f32_s32:
4793 case Intrinsic::nvvm_tex_1d_v4f32_f32:
4794 case Intrinsic::nvvm_tex_1d_level_v4f32_f32:
4795 case Intrinsic::nvvm_tex_1d_grad_v4f32_f32:
4796 case Intrinsic::nvvm_tex_1d_array_v4f32_s32:
4797 case Intrinsic::nvvm_tex_1d_array_v4f32_f32:
4798 case Intrinsic::nvvm_tex_1d_array_level_v4f32_f32:
4799 case Intrinsic::nvvm_tex_1d_array_grad_v4f32_f32:
4800 case Intrinsic::nvvm_tex_2d_v4f32_s32:
4801 case Intrinsic::nvvm_tex_2d_v4f32_f32:
4802 case Intrinsic::nvvm_tex_2d_level_v4f32_f32:
4803 case Intrinsic::nvvm_tex_2d_grad_v4f32_f32:
4804 case Intrinsic::nvvm_tex_2d_array_v4f32_s32:
4805 case Intrinsic::nvvm_tex_2d_array_v4f32_f32:
4806 case Intrinsic::nvvm_tex_2d_array_level_v4f32_f32:
4807 case Intrinsic::nvvm_tex_2d_array_grad_v4f32_f32:
4808 case Intrinsic::nvvm_tex_3d_v4f32_s32:
4809 case Intrinsic::nvvm_tex_3d_v4f32_f32:
4810 case Intrinsic::nvvm_tex_3d_level_v4f32_f32:
4811 case Intrinsic::nvvm_tex_3d_grad_v4f32_f32:
4812 case Intrinsic::nvvm_tex_cube_v4f32_f32:
4813 case Intrinsic::nvvm_tex_cube_level_v4f32_f32:
4814 case Intrinsic::nvvm_tex_cube_array_v4f32_f32:
4815 case Intrinsic::nvvm_tex_cube_array_level_v4f32_f32:
4816 case Intrinsic::nvvm_tld4_r_2d_v4f32_f32:
4817 case Intrinsic::nvvm_tld4_g_2d_v4f32_f32:
4818 case Intrinsic::nvvm_tld4_b_2d_v4f32_f32:
4819 case Intrinsic::nvvm_tld4_a_2d_v4f32_f32:
4820 case Intrinsic::nvvm_tex_unified_1d_v4f32_s32:
4821 case Intrinsic::nvvm_tex_unified_1d_v4f32_f32:
4822 case Intrinsic::nvvm_tex_unified_1d_level_v4f32_f32:
4823 case Intrinsic::nvvm_tex_unified_1d_grad_v4f32_f32:
4824 case Intrinsic::nvvm_tex_unified_1d_array_v4f32_s32:
4825 case Intrinsic::nvvm_tex_unified_1d_array_v4f32_f32:
4826 case Intrinsic::nvvm_tex_unified_1d_array_level_v4f32_f32:
4827 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4f32_f32:
4828 case Intrinsic::nvvm_tex_unified_2d_v4f32_s32:
4829 case Intrinsic::nvvm_tex_unified_2d_v4f32_f32:
4830 case Intrinsic::nvvm_tex_unified_2d_level_v4f32_f32:
4831 case Intrinsic::nvvm_tex_unified_2d_grad_v4f32_f32:
4832 case Intrinsic::nvvm_tex_unified_2d_array_v4f32_s32:
4833 case Intrinsic::nvvm_tex_unified_2d_array_v4f32_f32:
4834 case Intrinsic::nvvm_tex_unified_2d_array_level_v4f32_f32:
4835 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4f32_f32:
4836 case Intrinsic::nvvm_tex_unified_3d_v4f32_s32:
4837 case Intrinsic::nvvm_tex_unified_3d_v4f32_f32:
4838 case Intrinsic::nvvm_tex_unified_3d_level_v4f32_f32:
4839 case Intrinsic::nvvm_tex_unified_3d_grad_v4f32_f32:
4840 case Intrinsic::nvvm_tex_unified_cube_v4f32_f32:
4841 case Intrinsic::nvvm_tex_unified_cube_level_v4f32_f32:
4842 case Intrinsic::nvvm_tex_unified_cube_array_v4f32_f32:
4843 case Intrinsic::nvvm_tex_unified_cube_array_level_v4f32_f32:
4844 case Intrinsic::nvvm_tex_unified_cube_grad_v4f32_f32:
4845 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4f32_f32:
4846 case Intrinsic::nvvm_tld4_unified_r_2d_v4f32_f32:
4847 case Intrinsic::nvvm_tld4_unified_g_2d_v4f32_f32:
4848 case Intrinsic::nvvm_tld4_unified_b_2d_v4f32_f32:
4849 case Intrinsic::nvvm_tld4_unified_a_2d_v4f32_f32:
4850 Info.opc = ISD::INTRINSIC_W_CHAIN;
4851 Info.memVT = MVT::v4f32;
4852 Info.ptrVal = nullptr;
4853 Info.offset = 0;
4854 Info.flags = MachineMemOperand::MOLoad;
4855 Info.align = Align(16);
4856 Infos.push_back(Elt: Info);
4857 return;
4858
4859 case Intrinsic::nvvm_tex_1d_v4s32_s32:
4860 case Intrinsic::nvvm_tex_1d_v4s32_f32:
4861 case Intrinsic::nvvm_tex_1d_level_v4s32_f32:
4862 case Intrinsic::nvvm_tex_1d_grad_v4s32_f32:
4863 case Intrinsic::nvvm_tex_1d_array_v4s32_s32:
4864 case Intrinsic::nvvm_tex_1d_array_v4s32_f32:
4865 case Intrinsic::nvvm_tex_1d_array_level_v4s32_f32:
4866 case Intrinsic::nvvm_tex_1d_array_grad_v4s32_f32:
4867 case Intrinsic::nvvm_tex_2d_v4s32_s32:
4868 case Intrinsic::nvvm_tex_2d_v4s32_f32:
4869 case Intrinsic::nvvm_tex_2d_level_v4s32_f32:
4870 case Intrinsic::nvvm_tex_2d_grad_v4s32_f32:
4871 case Intrinsic::nvvm_tex_2d_array_v4s32_s32:
4872 case Intrinsic::nvvm_tex_2d_array_v4s32_f32:
4873 case Intrinsic::nvvm_tex_2d_array_level_v4s32_f32:
4874 case Intrinsic::nvvm_tex_2d_array_grad_v4s32_f32:
4875 case Intrinsic::nvvm_tex_3d_v4s32_s32:
4876 case Intrinsic::nvvm_tex_3d_v4s32_f32:
4877 case Intrinsic::nvvm_tex_3d_level_v4s32_f32:
4878 case Intrinsic::nvvm_tex_3d_grad_v4s32_f32:
4879 case Intrinsic::nvvm_tex_cube_v4s32_f32:
4880 case Intrinsic::nvvm_tex_cube_level_v4s32_f32:
4881 case Intrinsic::nvvm_tex_cube_array_v4s32_f32:
4882 case Intrinsic::nvvm_tex_cube_array_level_v4s32_f32:
4883 case Intrinsic::nvvm_tex_cube_v4u32_f32:
4884 case Intrinsic::nvvm_tex_cube_level_v4u32_f32:
4885 case Intrinsic::nvvm_tex_cube_array_v4u32_f32:
4886 case Intrinsic::nvvm_tex_cube_array_level_v4u32_f32:
4887 case Intrinsic::nvvm_tex_1d_v4u32_s32:
4888 case Intrinsic::nvvm_tex_1d_v4u32_f32:
4889 case Intrinsic::nvvm_tex_1d_level_v4u32_f32:
4890 case Intrinsic::nvvm_tex_1d_grad_v4u32_f32:
4891 case Intrinsic::nvvm_tex_1d_array_v4u32_s32:
4892 case Intrinsic::nvvm_tex_1d_array_v4u32_f32:
4893 case Intrinsic::nvvm_tex_1d_array_level_v4u32_f32:
4894 case Intrinsic::nvvm_tex_1d_array_grad_v4u32_f32:
4895 case Intrinsic::nvvm_tex_2d_v4u32_s32:
4896 case Intrinsic::nvvm_tex_2d_v4u32_f32:
4897 case Intrinsic::nvvm_tex_2d_level_v4u32_f32:
4898 case Intrinsic::nvvm_tex_2d_grad_v4u32_f32:
4899 case Intrinsic::nvvm_tex_2d_array_v4u32_s32:
4900 case Intrinsic::nvvm_tex_2d_array_v4u32_f32:
4901 case Intrinsic::nvvm_tex_2d_array_level_v4u32_f32:
4902 case Intrinsic::nvvm_tex_2d_array_grad_v4u32_f32:
4903 case Intrinsic::nvvm_tex_3d_v4u32_s32:
4904 case Intrinsic::nvvm_tex_3d_v4u32_f32:
4905 case Intrinsic::nvvm_tex_3d_level_v4u32_f32:
4906 case Intrinsic::nvvm_tex_3d_grad_v4u32_f32:
4907 case Intrinsic::nvvm_tld4_r_2d_v4s32_f32:
4908 case Intrinsic::nvvm_tld4_g_2d_v4s32_f32:
4909 case Intrinsic::nvvm_tld4_b_2d_v4s32_f32:
4910 case Intrinsic::nvvm_tld4_a_2d_v4s32_f32:
4911 case Intrinsic::nvvm_tld4_r_2d_v4u32_f32:
4912 case Intrinsic::nvvm_tld4_g_2d_v4u32_f32:
4913 case Intrinsic::nvvm_tld4_b_2d_v4u32_f32:
4914 case Intrinsic::nvvm_tld4_a_2d_v4u32_f32:
4915 case Intrinsic::nvvm_tex_unified_1d_v4s32_s32:
4916 case Intrinsic::nvvm_tex_unified_1d_v4s32_f32:
4917 case Intrinsic::nvvm_tex_unified_1d_level_v4s32_f32:
4918 case Intrinsic::nvvm_tex_unified_1d_grad_v4s32_f32:
4919 case Intrinsic::nvvm_tex_unified_1d_array_v4s32_s32:
4920 case Intrinsic::nvvm_tex_unified_1d_array_v4s32_f32:
4921 case Intrinsic::nvvm_tex_unified_1d_array_level_v4s32_f32:
4922 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4s32_f32:
4923 case Intrinsic::nvvm_tex_unified_2d_v4s32_s32:
4924 case Intrinsic::nvvm_tex_unified_2d_v4s32_f32:
4925 case Intrinsic::nvvm_tex_unified_2d_level_v4s32_f32:
4926 case Intrinsic::nvvm_tex_unified_2d_grad_v4s32_f32:
4927 case Intrinsic::nvvm_tex_unified_2d_array_v4s32_s32:
4928 case Intrinsic::nvvm_tex_unified_2d_array_v4s32_f32:
4929 case Intrinsic::nvvm_tex_unified_2d_array_level_v4s32_f32:
4930 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4s32_f32:
4931 case Intrinsic::nvvm_tex_unified_3d_v4s32_s32:
4932 case Intrinsic::nvvm_tex_unified_3d_v4s32_f32:
4933 case Intrinsic::nvvm_tex_unified_3d_level_v4s32_f32:
4934 case Intrinsic::nvvm_tex_unified_3d_grad_v4s32_f32:
4935 case Intrinsic::nvvm_tex_unified_1d_v4u32_s32:
4936 case Intrinsic::nvvm_tex_unified_1d_v4u32_f32:
4937 case Intrinsic::nvvm_tex_unified_1d_level_v4u32_f32:
4938 case Intrinsic::nvvm_tex_unified_1d_grad_v4u32_f32:
4939 case Intrinsic::nvvm_tex_unified_1d_array_v4u32_s32:
4940 case Intrinsic::nvvm_tex_unified_1d_array_v4u32_f32:
4941 case Intrinsic::nvvm_tex_unified_1d_array_level_v4u32_f32:
4942 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4u32_f32:
4943 case Intrinsic::nvvm_tex_unified_2d_v4u32_s32:
4944 case Intrinsic::nvvm_tex_unified_2d_v4u32_f32:
4945 case Intrinsic::nvvm_tex_unified_2d_level_v4u32_f32:
4946 case Intrinsic::nvvm_tex_unified_2d_grad_v4u32_f32:
4947 case Intrinsic::nvvm_tex_unified_2d_array_v4u32_s32:
4948 case Intrinsic::nvvm_tex_unified_2d_array_v4u32_f32:
4949 case Intrinsic::nvvm_tex_unified_2d_array_level_v4u32_f32:
4950 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4u32_f32:
4951 case Intrinsic::nvvm_tex_unified_3d_v4u32_s32:
4952 case Intrinsic::nvvm_tex_unified_3d_v4u32_f32:
4953 case Intrinsic::nvvm_tex_unified_3d_level_v4u32_f32:
4954 case Intrinsic::nvvm_tex_unified_3d_grad_v4u32_f32:
4955 case Intrinsic::nvvm_tex_unified_cube_v4s32_f32:
4956 case Intrinsic::nvvm_tex_unified_cube_level_v4s32_f32:
4957 case Intrinsic::nvvm_tex_unified_cube_array_v4s32_f32:
4958 case Intrinsic::nvvm_tex_unified_cube_array_level_v4s32_f32:
4959 case Intrinsic::nvvm_tex_unified_cube_v4u32_f32:
4960 case Intrinsic::nvvm_tex_unified_cube_level_v4u32_f32:
4961 case Intrinsic::nvvm_tex_unified_cube_array_v4u32_f32:
4962 case Intrinsic::nvvm_tex_unified_cube_array_level_v4u32_f32:
4963 case Intrinsic::nvvm_tex_unified_cube_grad_v4s32_f32:
4964 case Intrinsic::nvvm_tex_unified_cube_grad_v4u32_f32:
4965 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4s32_f32:
4966 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4u32_f32:
4967 case Intrinsic::nvvm_tld4_unified_r_2d_v4s32_f32:
4968 case Intrinsic::nvvm_tld4_unified_g_2d_v4s32_f32:
4969 case Intrinsic::nvvm_tld4_unified_b_2d_v4s32_f32:
4970 case Intrinsic::nvvm_tld4_unified_a_2d_v4s32_f32:
4971 case Intrinsic::nvvm_tld4_unified_r_2d_v4u32_f32:
4972 case Intrinsic::nvvm_tld4_unified_g_2d_v4u32_f32:
4973 case Intrinsic::nvvm_tld4_unified_b_2d_v4u32_f32:
4974 case Intrinsic::nvvm_tld4_unified_a_2d_v4u32_f32:
4975 Info.opc = ISD::INTRINSIC_W_CHAIN;
4976 Info.memVT = MVT::v4i32;
4977 Info.ptrVal = nullptr;
4978 Info.offset = 0;
4979 Info.flags = MachineMemOperand::MOLoad;
4980 Info.align = Align(16);
4981 Infos.push_back(Elt: Info);
4982 return;
4983
4984 case Intrinsic::nvvm_suld_1d_i8_clamp:
4985 case Intrinsic::nvvm_suld_1d_v2i8_clamp:
4986 case Intrinsic::nvvm_suld_1d_v4i8_clamp:
4987 case Intrinsic::nvvm_suld_1d_array_i8_clamp:
4988 case Intrinsic::nvvm_suld_1d_array_v2i8_clamp:
4989 case Intrinsic::nvvm_suld_1d_array_v4i8_clamp:
4990 case Intrinsic::nvvm_suld_2d_i8_clamp:
4991 case Intrinsic::nvvm_suld_2d_v2i8_clamp:
4992 case Intrinsic::nvvm_suld_2d_v4i8_clamp:
4993 case Intrinsic::nvvm_suld_2d_array_i8_clamp:
4994 case Intrinsic::nvvm_suld_2d_array_v2i8_clamp:
4995 case Intrinsic::nvvm_suld_2d_array_v4i8_clamp:
4996 case Intrinsic::nvvm_suld_3d_i8_clamp:
4997 case Intrinsic::nvvm_suld_3d_v2i8_clamp:
4998 case Intrinsic::nvvm_suld_3d_v4i8_clamp:
4999 case Intrinsic::nvvm_suld_1d_i8_trap:
5000 case Intrinsic::nvvm_suld_1d_v2i8_trap:
5001 case Intrinsic::nvvm_suld_1d_v4i8_trap:
5002 case Intrinsic::nvvm_suld_1d_array_i8_trap:
5003 case Intrinsic::nvvm_suld_1d_array_v2i8_trap:
5004 case Intrinsic::nvvm_suld_1d_array_v4i8_trap:
5005 case Intrinsic::nvvm_suld_2d_i8_trap:
5006 case Intrinsic::nvvm_suld_2d_v2i8_trap:
5007 case Intrinsic::nvvm_suld_2d_v4i8_trap:
5008 case Intrinsic::nvvm_suld_2d_array_i8_trap:
5009 case Intrinsic::nvvm_suld_2d_array_v2i8_trap:
5010 case Intrinsic::nvvm_suld_2d_array_v4i8_trap:
5011 case Intrinsic::nvvm_suld_3d_i8_trap:
5012 case Intrinsic::nvvm_suld_3d_v2i8_trap:
5013 case Intrinsic::nvvm_suld_3d_v4i8_trap:
5014 case Intrinsic::nvvm_suld_1d_i8_zero:
5015 case Intrinsic::nvvm_suld_1d_v2i8_zero:
5016 case Intrinsic::nvvm_suld_1d_v4i8_zero:
5017 case Intrinsic::nvvm_suld_1d_array_i8_zero:
5018 case Intrinsic::nvvm_suld_1d_array_v2i8_zero:
5019 case Intrinsic::nvvm_suld_1d_array_v4i8_zero:
5020 case Intrinsic::nvvm_suld_2d_i8_zero:
5021 case Intrinsic::nvvm_suld_2d_v2i8_zero:
5022 case Intrinsic::nvvm_suld_2d_v4i8_zero:
5023 case Intrinsic::nvvm_suld_2d_array_i8_zero:
5024 case Intrinsic::nvvm_suld_2d_array_v2i8_zero:
5025 case Intrinsic::nvvm_suld_2d_array_v4i8_zero:
5026 case Intrinsic::nvvm_suld_3d_i8_zero:
5027 case Intrinsic::nvvm_suld_3d_v2i8_zero:
5028 case Intrinsic::nvvm_suld_3d_v4i8_zero:
5029 Info.opc = ISD::INTRINSIC_W_CHAIN;
5030 Info.memVT = MVT::i8;
5031 Info.ptrVal = nullptr;
5032 Info.offset = 0;
5033 Info.flags = MachineMemOperand::MOLoad;
5034 Info.align = Align(16);
5035 Infos.push_back(Elt: Info);
5036 return;
5037
5038 case Intrinsic::nvvm_suld_1d_i16_clamp:
5039 case Intrinsic::nvvm_suld_1d_v2i16_clamp:
5040 case Intrinsic::nvvm_suld_1d_v4i16_clamp:
5041 case Intrinsic::nvvm_suld_1d_array_i16_clamp:
5042 case Intrinsic::nvvm_suld_1d_array_v2i16_clamp:
5043 case Intrinsic::nvvm_suld_1d_array_v4i16_clamp:
5044 case Intrinsic::nvvm_suld_2d_i16_clamp:
5045 case Intrinsic::nvvm_suld_2d_v2i16_clamp:
5046 case Intrinsic::nvvm_suld_2d_v4i16_clamp:
5047 case Intrinsic::nvvm_suld_2d_array_i16_clamp:
5048 case Intrinsic::nvvm_suld_2d_array_v2i16_clamp:
5049 case Intrinsic::nvvm_suld_2d_array_v4i16_clamp:
5050 case Intrinsic::nvvm_suld_3d_i16_clamp:
5051 case Intrinsic::nvvm_suld_3d_v2i16_clamp:
5052 case Intrinsic::nvvm_suld_3d_v4i16_clamp:
5053 case Intrinsic::nvvm_suld_1d_i16_trap:
5054 case Intrinsic::nvvm_suld_1d_v2i16_trap:
5055 case Intrinsic::nvvm_suld_1d_v4i16_trap:
5056 case Intrinsic::nvvm_suld_1d_array_i16_trap:
5057 case Intrinsic::nvvm_suld_1d_array_v2i16_trap:
5058 case Intrinsic::nvvm_suld_1d_array_v4i16_trap:
5059 case Intrinsic::nvvm_suld_2d_i16_trap:
5060 case Intrinsic::nvvm_suld_2d_v2i16_trap:
5061 case Intrinsic::nvvm_suld_2d_v4i16_trap:
5062 case Intrinsic::nvvm_suld_2d_array_i16_trap:
5063 case Intrinsic::nvvm_suld_2d_array_v2i16_trap:
5064 case Intrinsic::nvvm_suld_2d_array_v4i16_trap:
5065 case Intrinsic::nvvm_suld_3d_i16_trap:
5066 case Intrinsic::nvvm_suld_3d_v2i16_trap:
5067 case Intrinsic::nvvm_suld_3d_v4i16_trap:
5068 case Intrinsic::nvvm_suld_1d_i16_zero:
5069 case Intrinsic::nvvm_suld_1d_v2i16_zero:
5070 case Intrinsic::nvvm_suld_1d_v4i16_zero:
5071 case Intrinsic::nvvm_suld_1d_array_i16_zero:
5072 case Intrinsic::nvvm_suld_1d_array_v2i16_zero:
5073 case Intrinsic::nvvm_suld_1d_array_v4i16_zero:
5074 case Intrinsic::nvvm_suld_2d_i16_zero:
5075 case Intrinsic::nvvm_suld_2d_v2i16_zero:
5076 case Intrinsic::nvvm_suld_2d_v4i16_zero:
5077 case Intrinsic::nvvm_suld_2d_array_i16_zero:
5078 case Intrinsic::nvvm_suld_2d_array_v2i16_zero:
5079 case Intrinsic::nvvm_suld_2d_array_v4i16_zero:
5080 case Intrinsic::nvvm_suld_3d_i16_zero:
5081 case Intrinsic::nvvm_suld_3d_v2i16_zero:
5082 case Intrinsic::nvvm_suld_3d_v4i16_zero:
5083 Info.opc = ISD::INTRINSIC_W_CHAIN;
5084 Info.memVT = MVT::i16;
5085 Info.ptrVal = nullptr;
5086 Info.offset = 0;
5087 Info.flags = MachineMemOperand::MOLoad;
5088 Info.align = Align(16);
5089 Infos.push_back(Elt: Info);
5090 return;
5091
5092 case Intrinsic::nvvm_suld_1d_i32_clamp:
5093 case Intrinsic::nvvm_suld_1d_v2i32_clamp:
5094 case Intrinsic::nvvm_suld_1d_v4i32_clamp:
5095 case Intrinsic::nvvm_suld_1d_array_i32_clamp:
5096 case Intrinsic::nvvm_suld_1d_array_v2i32_clamp:
5097 case Intrinsic::nvvm_suld_1d_array_v4i32_clamp:
5098 case Intrinsic::nvvm_suld_2d_i32_clamp:
5099 case Intrinsic::nvvm_suld_2d_v2i32_clamp:
5100 case Intrinsic::nvvm_suld_2d_v4i32_clamp:
5101 case Intrinsic::nvvm_suld_2d_array_i32_clamp:
5102 case Intrinsic::nvvm_suld_2d_array_v2i32_clamp:
5103 case Intrinsic::nvvm_suld_2d_array_v4i32_clamp:
5104 case Intrinsic::nvvm_suld_3d_i32_clamp:
5105 case Intrinsic::nvvm_suld_3d_v2i32_clamp:
5106 case Intrinsic::nvvm_suld_3d_v4i32_clamp:
5107 case Intrinsic::nvvm_suld_1d_i32_trap:
5108 case Intrinsic::nvvm_suld_1d_v2i32_trap:
5109 case Intrinsic::nvvm_suld_1d_v4i32_trap:
5110 case Intrinsic::nvvm_suld_1d_array_i32_trap:
5111 case Intrinsic::nvvm_suld_1d_array_v2i32_trap:
5112 case Intrinsic::nvvm_suld_1d_array_v4i32_trap:
5113 case Intrinsic::nvvm_suld_2d_i32_trap:
5114 case Intrinsic::nvvm_suld_2d_v2i32_trap:
5115 case Intrinsic::nvvm_suld_2d_v4i32_trap:
5116 case Intrinsic::nvvm_suld_2d_array_i32_trap:
5117 case Intrinsic::nvvm_suld_2d_array_v2i32_trap:
5118 case Intrinsic::nvvm_suld_2d_array_v4i32_trap:
5119 case Intrinsic::nvvm_suld_3d_i32_trap:
5120 case Intrinsic::nvvm_suld_3d_v2i32_trap:
5121 case Intrinsic::nvvm_suld_3d_v4i32_trap:
5122 case Intrinsic::nvvm_suld_1d_i32_zero:
5123 case Intrinsic::nvvm_suld_1d_v2i32_zero:
5124 case Intrinsic::nvvm_suld_1d_v4i32_zero:
5125 case Intrinsic::nvvm_suld_1d_array_i32_zero:
5126 case Intrinsic::nvvm_suld_1d_array_v2i32_zero:
5127 case Intrinsic::nvvm_suld_1d_array_v4i32_zero:
5128 case Intrinsic::nvvm_suld_2d_i32_zero:
5129 case Intrinsic::nvvm_suld_2d_v2i32_zero:
5130 case Intrinsic::nvvm_suld_2d_v4i32_zero:
5131 case Intrinsic::nvvm_suld_2d_array_i32_zero:
5132 case Intrinsic::nvvm_suld_2d_array_v2i32_zero:
5133 case Intrinsic::nvvm_suld_2d_array_v4i32_zero:
5134 case Intrinsic::nvvm_suld_3d_i32_zero:
5135 case Intrinsic::nvvm_suld_3d_v2i32_zero:
5136 case Intrinsic::nvvm_suld_3d_v4i32_zero:
5137 Info.opc = ISD::INTRINSIC_W_CHAIN;
5138 Info.memVT = MVT::i32;
5139 Info.ptrVal = nullptr;
5140 Info.offset = 0;
5141 Info.flags = MachineMemOperand::MOLoad;
5142 Info.align = Align(16);
5143 Infos.push_back(Elt: Info);
5144 return;
5145
5146 case Intrinsic::nvvm_suld_1d_i64_clamp:
5147 case Intrinsic::nvvm_suld_1d_v2i64_clamp:
5148 case Intrinsic::nvvm_suld_1d_array_i64_clamp:
5149 case Intrinsic::nvvm_suld_1d_array_v2i64_clamp:
5150 case Intrinsic::nvvm_suld_2d_i64_clamp:
5151 case Intrinsic::nvvm_suld_2d_v2i64_clamp:
5152 case Intrinsic::nvvm_suld_2d_array_i64_clamp:
5153 case Intrinsic::nvvm_suld_2d_array_v2i64_clamp:
5154 case Intrinsic::nvvm_suld_3d_i64_clamp:
5155 case Intrinsic::nvvm_suld_3d_v2i64_clamp:
5156 case Intrinsic::nvvm_suld_1d_i64_trap:
5157 case Intrinsic::nvvm_suld_1d_v2i64_trap:
5158 case Intrinsic::nvvm_suld_1d_array_i64_trap:
5159 case Intrinsic::nvvm_suld_1d_array_v2i64_trap:
5160 case Intrinsic::nvvm_suld_2d_i64_trap:
5161 case Intrinsic::nvvm_suld_2d_v2i64_trap:
5162 case Intrinsic::nvvm_suld_2d_array_i64_trap:
5163 case Intrinsic::nvvm_suld_2d_array_v2i64_trap:
5164 case Intrinsic::nvvm_suld_3d_i64_trap:
5165 case Intrinsic::nvvm_suld_3d_v2i64_trap:
5166 case Intrinsic::nvvm_suld_1d_i64_zero:
5167 case Intrinsic::nvvm_suld_1d_v2i64_zero:
5168 case Intrinsic::nvvm_suld_1d_array_i64_zero:
5169 case Intrinsic::nvvm_suld_1d_array_v2i64_zero:
5170 case Intrinsic::nvvm_suld_2d_i64_zero:
5171 case Intrinsic::nvvm_suld_2d_v2i64_zero:
5172 case Intrinsic::nvvm_suld_2d_array_i64_zero:
5173 case Intrinsic::nvvm_suld_2d_array_v2i64_zero:
5174 case Intrinsic::nvvm_suld_3d_i64_zero:
5175 case Intrinsic::nvvm_suld_3d_v2i64_zero:
5176 Info.opc = ISD::INTRINSIC_W_CHAIN;
5177 Info.memVT = MVT::i64;
5178 Info.ptrVal = nullptr;
5179 Info.offset = 0;
5180 Info.flags = MachineMemOperand::MOLoad;
5181 Info.align = Align(16);
5182 Infos.push_back(Elt: Info);
5183 return;
5184
5185 case Intrinsic::nvvm_tcgen05_ld_16x64b_x1:
5186 case Intrinsic::nvvm_tcgen05_ld_32x32b_x1:
5187 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x1: {
5188 Info.opc = ISD::INTRINSIC_W_CHAIN;
5189 Info.memVT = MVT::v1i32;
5190 Info.ptrVal = I.getArgOperand(i: 0);
5191 Info.offset = 0;
5192 Info.flags = MachineMemOperand::MOLoad;
5193 Info.align.reset();
5194 Infos.push_back(Elt: Info);
5195 return;
5196 }
5197
5198 case Intrinsic::nvvm_tcgen05_ld_16x64b_x2:
5199 case Intrinsic::nvvm_tcgen05_ld_16x128b_x1:
5200 case Intrinsic::nvvm_tcgen05_ld_32x32b_x2:
5201 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x2:
5202 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x2_i32:
5203 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x2_i32: {
5204 Info.opc = ISD::INTRINSIC_W_CHAIN;
5205 Info.memVT = MVT::v2i32;
5206 Info.ptrVal = I.getArgOperand(i: 0);
5207 Info.offset = 0;
5208 Info.flags = MachineMemOperand::MOLoad;
5209 Info.align.reset();
5210 Infos.push_back(Elt: Info);
5211 return;
5212 }
5213
5214 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x2_f32:
5215 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x2_f32: {
5216 Info.opc = ISD::INTRINSIC_W_CHAIN;
5217 Info.memVT = MVT::v2f32;
5218 Info.ptrVal = I.getArgOperand(i: 0);
5219 Info.offset = 0;
5220 Info.flags = MachineMemOperand::MOLoad;
5221 Info.align.reset();
5222 Infos.push_back(Elt: Info);
5223 return;
5224 }
5225
5226 case Intrinsic::nvvm_tcgen05_ld_16x64b_x4:
5227 case Intrinsic::nvvm_tcgen05_ld_16x128b_x2:
5228 case Intrinsic::nvvm_tcgen05_ld_32x32b_x4:
5229 case Intrinsic::nvvm_tcgen05_ld_16x256b_x1:
5230 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x4:
5231 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x4_i32:
5232 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x4_i32: {
5233 Info.opc = ISD::INTRINSIC_W_CHAIN;
5234 Info.memVT = MVT::v4i32;
5235 Info.ptrVal = I.getArgOperand(i: 0);
5236 Info.offset = 0;
5237 Info.flags = MachineMemOperand::MOLoad;
5238 Info.align.reset();
5239 Infos.push_back(Elt: Info);
5240 return;
5241 }
5242
5243 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x4_f32:
5244 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x4_f32: {
5245 Info.opc = ISD::INTRINSIC_W_CHAIN;
5246 Info.memVT = MVT::v4f32;
5247 Info.ptrVal = I.getArgOperand(i: 0);
5248 Info.offset = 0;
5249 Info.flags = MachineMemOperand::MOLoad;
5250 Info.align.reset();
5251 Infos.push_back(Elt: Info);
5252 return;
5253 }
5254
5255 case Intrinsic::nvvm_tcgen05_ld_16x64b_x8:
5256 case Intrinsic::nvvm_tcgen05_ld_16x128b_x4:
5257 case Intrinsic::nvvm_tcgen05_ld_16x256b_x2:
5258 case Intrinsic::nvvm_tcgen05_ld_32x32b_x8:
5259 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x8:
5260 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x8_i32:
5261 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x8_i32: {
5262 Info.opc = ISD::INTRINSIC_W_CHAIN;
5263 Info.memVT = MVT::v8i32;
5264 Info.ptrVal = I.getArgOperand(i: 0);
5265 Info.offset = 0;
5266 Info.flags = MachineMemOperand::MOLoad;
5267 Info.align.reset();
5268 Infos.push_back(Elt: Info);
5269 return;
5270 }
5271
5272 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x8_f32:
5273 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x8_f32: {
5274 Info.opc = ISD::INTRINSIC_W_CHAIN;
5275 Info.memVT = MVT::v8f32;
5276 Info.ptrVal = I.getArgOperand(i: 0);
5277 Info.offset = 0;
5278 Info.flags = MachineMemOperand::MOLoad;
5279 Info.align.reset();
5280 Infos.push_back(Elt: Info);
5281 return;
5282 }
5283
5284 case Intrinsic::nvvm_tcgen05_ld_16x64b_x16:
5285 case Intrinsic::nvvm_tcgen05_ld_16x128b_x8:
5286 case Intrinsic::nvvm_tcgen05_ld_16x256b_x4:
5287 case Intrinsic::nvvm_tcgen05_ld_32x32b_x16:
5288 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x16:
5289 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x16_i32:
5290 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x16_i32: {
5291 Info.opc = ISD::INTRINSIC_W_CHAIN;
5292 Info.memVT = MVT::v16i32;
5293 Info.ptrVal = I.getArgOperand(i: 0);
5294 Info.offset = 0;
5295 Info.flags = MachineMemOperand::MOLoad;
5296 Info.align.reset();
5297 Infos.push_back(Elt: Info);
5298 return;
5299 }
5300
5301 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x16_f32:
5302 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x16_f32: {
5303 Info.opc = ISD::INTRINSIC_W_CHAIN;
5304 Info.memVT = MVT::v16f32;
5305 Info.ptrVal = I.getArgOperand(i: 0);
5306 Info.offset = 0;
5307 Info.flags = MachineMemOperand::MOLoad;
5308 Info.align.reset();
5309 Infos.push_back(Elt: Info);
5310 return;
5311 }
5312
5313 case Intrinsic::nvvm_tcgen05_ld_16x64b_x32:
5314 case Intrinsic::nvvm_tcgen05_ld_16x128b_x16:
5315 case Intrinsic::nvvm_tcgen05_ld_16x256b_x8:
5316 case Intrinsic::nvvm_tcgen05_ld_32x32b_x32:
5317 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x32:
5318 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x32_i32:
5319 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x32_i32: {
5320 Info.opc = ISD::INTRINSIC_W_CHAIN;
5321 Info.memVT = MVT::v32i32;
5322 Info.ptrVal = I.getArgOperand(i: 0);
5323 Info.offset = 0;
5324 Info.flags = MachineMemOperand::MOLoad;
5325 Info.align.reset();
5326 Infos.push_back(Elt: Info);
5327 return;
5328 }
5329
5330 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x32_f32:
5331 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x32_f32: {
5332 Info.opc = ISD::INTRINSIC_W_CHAIN;
5333 Info.memVT = MVT::v32f32;
5334 Info.ptrVal = I.getArgOperand(i: 0);
5335 Info.offset = 0;
5336 Info.flags = MachineMemOperand::MOLoad;
5337 Info.align.reset();
5338 Infos.push_back(Elt: Info);
5339 return;
5340 }
5341
5342 case Intrinsic::nvvm_tcgen05_ld_16x64b_x64:
5343 case Intrinsic::nvvm_tcgen05_ld_16x128b_x32:
5344 case Intrinsic::nvvm_tcgen05_ld_16x256b_x16:
5345 case Intrinsic::nvvm_tcgen05_ld_32x32b_x64:
5346 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x64:
5347 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x64_i32:
5348 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x64_i32: {
5349 Info.opc = ISD::INTRINSIC_W_CHAIN;
5350 Info.memVT = MVT::v64i32;
5351 Info.ptrVal = I.getArgOperand(i: 0);
5352 Info.offset = 0;
5353 Info.flags = MachineMemOperand::MOLoad;
5354 Info.align.reset();
5355 Infos.push_back(Elt: Info);
5356 return;
5357 }
5358
5359 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x64_f32:
5360 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x64_f32: {
5361 Info.opc = ISD::INTRINSIC_W_CHAIN;
5362 Info.memVT = MVT::v64f32;
5363 Info.ptrVal = I.getArgOperand(i: 0);
5364 Info.offset = 0;
5365 Info.flags = MachineMemOperand::MOLoad;
5366 Info.align.reset();
5367 Infos.push_back(Elt: Info);
5368 return;
5369 }
5370
5371 case Intrinsic::nvvm_tcgen05_ld_16x64b_x128:
5372 case Intrinsic::nvvm_tcgen05_ld_16x128b_x64:
5373 case Intrinsic::nvvm_tcgen05_ld_16x256b_x32:
5374 case Intrinsic::nvvm_tcgen05_ld_32x32b_x128:
5375 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x128:
5376 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x128_i32:
5377 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x128_i32: {
5378 Info.opc = ISD::INTRINSIC_W_CHAIN;
5379 Info.memVT = MVT::v128i32;
5380 Info.ptrVal = I.getArgOperand(i: 0);
5381 Info.offset = 0;
5382 Info.flags = MachineMemOperand::MOLoad;
5383 Info.align.reset();
5384 Infos.push_back(Elt: Info);
5385 return;
5386 }
5387
5388 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x128_f32:
5389 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x128_f32: {
5390 Info.opc = ISD::INTRINSIC_W_CHAIN;
5391 Info.memVT = MVT::v128f32;
5392 Info.ptrVal = I.getArgOperand(i: 0);
5393 Info.offset = 0;
5394 Info.flags = MachineMemOperand::MOLoad;
5395 Info.align.reset();
5396 Infos.push_back(Elt: Info);
5397 return;
5398 }
5399
5400 case Intrinsic::nvvm_tcgen05_st_16x64b_x1:
5401 case Intrinsic::nvvm_tcgen05_st_32x32b_x1:
5402 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x1: {
5403 Info.opc = ISD::INTRINSIC_VOID;
5404 Info.memVT = MVT::i32;
5405 Info.ptrVal = I.getArgOperand(i: 0);
5406 Info.offset = 0;
5407 Info.flags = MachineMemOperand::MOStore;
5408 Info.align.reset();
5409 Infos.push_back(Elt: Info);
5410 return;
5411 }
5412
5413 case Intrinsic::nvvm_tcgen05_st_16x64b_x2:
5414 case Intrinsic::nvvm_tcgen05_st_16x128b_x1:
5415 case Intrinsic::nvvm_tcgen05_st_32x32b_x2:
5416 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x2: {
5417 Info.opc = ISD::INTRINSIC_VOID;
5418 Info.memVT = MVT::v2i32;
5419 Info.ptrVal = I.getArgOperand(i: 0);
5420 Info.offset = 0;
5421 Info.flags = MachineMemOperand::MOStore;
5422 Info.align.reset();
5423 Infos.push_back(Elt: Info);
5424 return;
5425 }
5426
5427 case Intrinsic::nvvm_tcgen05_st_16x64b_x4:
5428 case Intrinsic::nvvm_tcgen05_st_16x128b_x2:
5429 case Intrinsic::nvvm_tcgen05_st_16x256b_x1:
5430 case Intrinsic::nvvm_tcgen05_st_32x32b_x4:
5431 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x4: {
5432 Info.opc = ISD::INTRINSIC_VOID;
5433 Info.memVT = MVT::v4i32;
5434 Info.ptrVal = I.getArgOperand(i: 0);
5435 Info.offset = 0;
5436 Info.flags = MachineMemOperand::MOStore;
5437 Info.align.reset();
5438 Infos.push_back(Elt: Info);
5439 return;
5440 }
5441
5442 case Intrinsic::nvvm_tcgen05_st_16x64b_x8:
5443 case Intrinsic::nvvm_tcgen05_st_16x128b_x4:
5444 case Intrinsic::nvvm_tcgen05_st_16x256b_x2:
5445 case Intrinsic::nvvm_tcgen05_st_32x32b_x8:
5446 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x8: {
5447 Info.opc = ISD::INTRINSIC_VOID;
5448 Info.memVT = MVT::v8i32;
5449 Info.ptrVal = I.getArgOperand(i: 0);
5450 Info.offset = 0;
5451 Info.flags = MachineMemOperand::MOStore;
5452 Info.align.reset();
5453 Infos.push_back(Elt: Info);
5454 return;
5455 }
5456
5457 case Intrinsic::nvvm_tcgen05_st_16x64b_x16:
5458 case Intrinsic::nvvm_tcgen05_st_16x128b_x8:
5459 case Intrinsic::nvvm_tcgen05_st_16x256b_x4:
5460 case Intrinsic::nvvm_tcgen05_st_32x32b_x16:
5461 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x16: {
5462 Info.opc = ISD::INTRINSIC_VOID;
5463 Info.memVT = MVT::v16i32;
5464 Info.ptrVal = I.getArgOperand(i: 0);
5465 Info.offset = 0;
5466 Info.flags = MachineMemOperand::MOStore;
5467 Info.align.reset();
5468 Infos.push_back(Elt: Info);
5469 return;
5470 }
5471
5472 case Intrinsic::nvvm_tcgen05_st_16x64b_x32:
5473 case Intrinsic::nvvm_tcgen05_st_16x128b_x16:
5474 case Intrinsic::nvvm_tcgen05_st_16x256b_x8:
5475 case Intrinsic::nvvm_tcgen05_st_32x32b_x32:
5476 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x32: {
5477 Info.opc = ISD::INTRINSIC_VOID;
5478 Info.memVT = MVT::v32i32;
5479 Info.ptrVal = I.getArgOperand(i: 0);
5480 Info.offset = 0;
5481 Info.flags = MachineMemOperand::MOStore;
5482 Info.align.reset();
5483 Infos.push_back(Elt: Info);
5484 return;
5485 }
5486
5487 case Intrinsic::nvvm_tcgen05_st_16x64b_x64:
5488 case Intrinsic::nvvm_tcgen05_st_16x128b_x32:
5489 case Intrinsic::nvvm_tcgen05_st_16x256b_x16:
5490 case Intrinsic::nvvm_tcgen05_st_32x32b_x64:
5491 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x64: {
5492 Info.opc = ISD::INTRINSIC_VOID;
5493 Info.memVT = MVT::v64i32;
5494 Info.ptrVal = I.getArgOperand(i: 0);
5495 Info.offset = 0;
5496 Info.flags = MachineMemOperand::MOStore;
5497 Info.align.reset();
5498 Infos.push_back(Elt: Info);
5499 return;
5500 }
5501
5502 case Intrinsic::nvvm_tcgen05_st_16x64b_x128:
5503 case Intrinsic::nvvm_tcgen05_st_16x128b_x64:
5504 case Intrinsic::nvvm_tcgen05_st_16x256b_x32:
5505 case Intrinsic::nvvm_tcgen05_st_32x32b_x128:
5506 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x128: {
5507 Info.opc = ISD::INTRINSIC_VOID;
5508 Info.memVT = MVT::v128i32;
5509 Info.ptrVal = I.getArgOperand(i: 0);
5510 Info.offset = 0;
5511 Info.flags = MachineMemOperand::MOStore;
5512 Info.align.reset();
5513 Infos.push_back(Elt: Info);
5514 return;
5515 }
5516 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1:
5517 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1:
5518 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1:
5519 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1:
5520 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1:
5521 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1:
5522 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift:
5523 case Intrinsic::
5524 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift:
5525 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1:
5526 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1:
5527 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift:
5528 case Intrinsic::
5529 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift: {
5530 // We are reading and writing back to TMem
5531 Info.opc = ISD::INTRINSIC_VOID;
5532 Info.memVT = MVT::v4i32;
5533 Info.ptrVal = I.getArgOperand(i: 0);
5534 Info.offset = 0;
5535 Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOStore;
5536 Info.align = Align(16);
5537 Infos.push_back(Elt: Info);
5538 return;
5539 }
5540
5541 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2:
5542 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2:
5543 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2:
5544 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2:
5545 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2:
5546 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2:
5547 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2:
5548 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2:
5549 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift:
5550 case Intrinsic::
5551 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift:
5552 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift:
5553 case Intrinsic::
5554 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift: {
5555 // We are reading and writing back to TMem
5556 Info.opc = ISD::INTRINSIC_VOID;
5557 Info.memVT = MVT::v8i32;
5558 Info.ptrVal = I.getArgOperand(i: 0);
5559 Info.offset = 0;
5560 Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOStore;
5561 Info.align = Align(16);
5562 Infos.push_back(Elt: Info);
5563 return;
5564 }
5565 }
5566}
5567
5568// Helper for getting a function parameter name. Name is composed from
5569// its index and the function name. Negative index corresponds to special
5570// parameter (unsized array) used for passing variable arguments.
5571std::string NVPTXTargetLowering::getParamName(const Function *F,
5572 int Idx) const {
5573 std::string ParamName;
5574 raw_string_ostream ParamStr(ParamName);
5575
5576 ParamStr << getTargetMachine().getSymbol(GV: F)->getName();
5577 if (Idx < 0)
5578 ParamStr << "_vararg";
5579 else
5580 ParamStr << "_param_" << Idx;
5581
5582 return ParamName;
5583}
5584
5585/// isLegalAddressingMode - Return true if the addressing mode represented
5586/// by AM is legal for this target, for a load/store of the specified type.
5587/// Used to guide target specific optimizations, like loop strength reduction
5588/// (LoopStrengthReduce.cpp) and memory optimization for address mode
5589/// (CodeGenPrepare.cpp)
5590bool NVPTXTargetLowering::isLegalAddressingMode(const DataLayout &DL,
5591 const AddrMode &AM, Type *Ty,
5592 unsigned AS, Instruction *I) const {
5593 // AddrMode - This represents an addressing mode of:
5594 // BaseGV + BaseOffs + BaseReg + Scale*ScaleReg
5595 //
5596 // The legal address modes are
5597 // - [avar]
5598 // - [areg]
5599 // - [areg+immoff]
5600 // - [immAddr]
5601
5602 // immoff must fit in a signed 32-bit int
5603 if (!APInt(64, AM.BaseOffs).isSignedIntN(N: 32))
5604 return false;
5605
5606 if (AM.BaseGV)
5607 return !AM.BaseOffs && !AM.HasBaseReg && !AM.Scale;
5608
5609 switch (AM.Scale) {
5610 case 0: // "r", "r+i" or "i" is allowed
5611 break;
5612 case 1:
5613 if (AM.HasBaseReg) // "r+r+i" or "r+r" is not allowed.
5614 return false;
5615 // Otherwise we have r+i.
5616 break;
5617 default:
5618 // No scale > 1 is allowed
5619 return false;
5620 }
5621 return true;
5622}
5623
5624//===----------------------------------------------------------------------===//
5625// NVPTX Inline Assembly Support
5626//===----------------------------------------------------------------------===//
5627
5628/// getConstraintType - Given a constraint letter, return the type of
5629/// constraint it is for this target.
5630NVPTXTargetLowering::ConstraintType
5631NVPTXTargetLowering::getConstraintType(StringRef Constraint) const {
5632 if (Constraint.size() == 1) {
5633 switch (Constraint[0]) {
5634 default:
5635 break;
5636 case 'b':
5637 case 'r':
5638 case 'h':
5639 case 'c':
5640 case 'l':
5641 case 'f':
5642 case 'd':
5643 case 'q':
5644 case '0':
5645 case 'N':
5646 return C_RegisterClass;
5647 }
5648 }
5649 return TargetLowering::getConstraintType(Constraint);
5650}
5651
5652std::pair<unsigned, const TargetRegisterClass *>
5653NVPTXTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
5654 StringRef Constraint,
5655 MVT VT) const {
5656 if (Constraint.size() == 1) {
5657 switch (Constraint[0]) {
5658 case 'b':
5659 return std::make_pair(x: 0U, y: &NVPTX::B1RegClass);
5660 case 'c':
5661 case 'h':
5662 return std::make_pair(x: 0U, y: &NVPTX::B16RegClass);
5663 case 'r':
5664 case 'f':
5665 return std::make_pair(x: 0U, y: &NVPTX::B32RegClass);
5666 case 'l':
5667 case 'N':
5668 case 'd':
5669 return std::make_pair(x: 0U, y: &NVPTX::B64RegClass);
5670 case 'q': {
5671 if (STI.getSmVersion() < 70)
5672 report_fatal_error(reason: "Inline asm with 128 bit operands is only "
5673 "supported for sm_70 and higher!");
5674 return std::make_pair(x: 0U, y: &NVPTX::B128RegClass);
5675 }
5676 }
5677 }
5678 return TargetLowering::getRegForInlineAsmConstraint(TRI, Constraint, VT);
5679}
5680
5681//===----------------------------------------------------------------------===//
5682// NVPTX DAG Combining
5683//===----------------------------------------------------------------------===//
5684
5685bool NVPTXTargetLowering::allowFMA(MachineFunction &MF,
5686 CodeGenOptLevel OptLevel) const {
5687 // Always honor command-line argument
5688 if (FMAContractLevelOpt.getNumOccurrences() > 0)
5689 return FMAContractLevelOpt > 0;
5690
5691 // Do not contract if we're not optimizing the code.
5692 if (OptLevel == CodeGenOptLevel::None)
5693 return false;
5694
5695 // Honor TargetOptions flags that explicitly say fusion is okay.
5696 if (MF.getTarget().Options.AllowFPOpFusion == FPOpFusion::Fast)
5697 return true;
5698
5699 return false;
5700}
5701
5702static bool isConstZero(const SDValue &Operand) {
5703 const auto *Const = dyn_cast<ConstantSDNode>(Val: Operand);
5704 return Const && Const->getZExtValue() == 0;
5705}
5706
5707/// PerformADDCombineWithOperands - Try DAG combinations for an ADD with
5708/// operands N0 and N1. This is a helper for PerformADDCombine that is
5709/// called with the default operands, and if that fails, with commuted
5710/// operands.
5711static SDValue
5712PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
5713 TargetLowering::DAGCombinerInfo &DCI) {
5714 EVT VT = N0.getValueType();
5715
5716 // Since integer multiply-add costs the same as integer multiply
5717 // but is more costly than integer add, do the fusion only when
5718 // the mul is only used in the add.
5719 // TODO: this may not be true for later architectures, consider relaxing this
5720 if (!N0.getNode()->hasOneUse())
5721 return SDValue();
5722
5723 // fold (add (select cond, 0, (mul a, b)), c)
5724 // -> (select cond, c, (add (mul a, b), c))
5725 //
5726 if (N0.getOpcode() == ISD::SELECT) {
5727 unsigned ZeroOpNum;
5728 if (isConstZero(Operand: N0->getOperand(Num: 1)))
5729 ZeroOpNum = 1;
5730 else if (isConstZero(Operand: N0->getOperand(Num: 2)))
5731 ZeroOpNum = 2;
5732 else
5733 return SDValue();
5734
5735 SDValue M = N0->getOperand(Num: (ZeroOpNum == 1) ? 2 : 1);
5736 if (M->getOpcode() != ISD::MUL || !M.getNode()->hasOneUse())
5737 return SDValue();
5738
5739 SDLoc DL(N);
5740 SDValue Mul =
5741 DCI.DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: M->getOperand(Num: 0), N2: M->getOperand(Num: 1));
5742 SDValue MAD = DCI.DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Mul, N2: N1);
5743 return DCI.DAG.getSelect(DL: SDLoc(N), VT, Cond: N0->getOperand(Num: 0),
5744 LHS: ((ZeroOpNum == 1) ? N1 : MAD),
5745 RHS: ((ZeroOpNum == 1) ? MAD : N1));
5746 }
5747
5748 return SDValue();
5749}
5750
5751static SDValue
5752PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
5753 TargetLowering::DAGCombinerInfo &DCI,
5754 CodeGenOptLevel OptLevel) {
5755 EVT VT = N0.getValueType();
5756 if (N0.getOpcode() == ISD::FMUL) {
5757 const auto *TLI = static_cast<const NVPTXTargetLowering *>(
5758 &DCI.DAG.getTargetLoweringInfo());
5759 if (!(TLI->allowFMA(MF&: DCI.DAG.getMachineFunction(), OptLevel) ||
5760 (N->getFlags().hasAllowContract() &&
5761 N0->getFlags().hasAllowContract())))
5762 return SDValue();
5763
5764 // For floating point:
5765 // Do the fusion only when the mul has less than 5 uses and all
5766 // are add.
5767 // The heuristic is that if a use is not an add, then that use
5768 // cannot be fused into fma, therefore mul is still needed anyway.
5769 // If there are more than 4 uses, even if they are all add, fusing
5770 // them will increase register pressue.
5771 //
5772 int numUses = 0;
5773 int nonAddCount = 0;
5774 for (const SDNode *User : N0.getNode()->users()) {
5775 numUses++;
5776 if (User->getOpcode() != ISD::FADD)
5777 ++nonAddCount;
5778 if (numUses >= 5)
5779 return SDValue();
5780 }
5781 if (nonAddCount) {
5782 int orderNo = N->getIROrder();
5783 int orderNo2 = N0.getNode()->getIROrder();
5784 // simple heuristics here for considering potential register
5785 // pressure, the logics here is that the differnce are used
5786 // to measure the distance between def and use, the longer distance
5787 // more likely cause register pressure.
5788 if (orderNo - orderNo2 < 500)
5789 return SDValue();
5790
5791 // Now, check if at least one of the FMUL's operands is live beyond the
5792 // node N, which guarantees that the FMA will not increase register
5793 // pressure at node N.
5794 bool opIsLive = false;
5795 const SDNode *left = N0.getOperand(i: 0).getNode();
5796 const SDNode *right = N0.getOperand(i: 1).getNode();
5797
5798 if (isa<ConstantSDNode>(Val: left) || isa<ConstantSDNode>(Val: right))
5799 opIsLive = true;
5800
5801 if (!opIsLive)
5802 for (const SDNode *User : left->users()) {
5803 int orderNo3 = User->getIROrder();
5804 if (orderNo3 > orderNo) {
5805 opIsLive = true;
5806 break;
5807 }
5808 }
5809
5810 if (!opIsLive)
5811 for (const SDNode *User : right->users()) {
5812 int orderNo3 = User->getIROrder();
5813 if (orderNo3 > orderNo) {
5814 opIsLive = true;
5815 break;
5816 }
5817 }
5818
5819 if (!opIsLive)
5820 return SDValue();
5821 }
5822
5823 return DCI.DAG.getNode(Opcode: ISD::FMA, DL: SDLoc(N), VT, N1: N0.getOperand(i: 0),
5824 N2: N0.getOperand(i: 1), N3: N1);
5825 }
5826
5827 return SDValue();
5828}
5829
5830/// Fold unpacking movs into a load by increasing the number of return values.
5831///
5832/// ex:
5833/// L: v2f16,ch = load <p>
5834/// a: f16 = extractelt L:0, 0
5835/// b: f16 = extractelt L:0, 1
5836/// use(a, b)
5837///
5838/// ...is turned into...
5839///
5840/// L: f16,f16,ch = LoadV2 <p>
5841/// use(L:0, L:1)
5842static SDValue
5843combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
5844 // Don't run this optimization before the legalizer
5845 if (!DCI.isAfterLegalizeDAG())
5846 return SDValue();
5847
5848 EVT ElementVT = N->getValueType(ResNo: 0);
5849 // Avoid non-packed types and v4i8
5850 if (!NVPTX::isPackedVectorTy(VT: ElementVT) || ElementVT == MVT::v4i8)
5851 return SDValue();
5852
5853 // Check whether all outputs are either used by an extractelt or are
5854 // glue/chain nodes
5855 if (!all_of(Range: N->uses(), P: [&](SDUse &U) {
5856 // Skip glue, chain nodes
5857 if (U.getValueType() == MVT::Glue || U.getValueType() == MVT::Other)
5858 return true;
5859 if (U.getUser()->getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
5860 if (N->getOpcode() != ISD::LOAD)
5861 return true;
5862 // Since this is an ISD::LOAD, check all extractelts are used. If
5863 // any are not used, we don't want to defeat another optimization that
5864 // will narrow the load.
5865 //
5866 // For example:
5867 //
5868 // L: v2f16,ch = load <p>
5869 // e0: f16 = extractelt L:0, 0
5870 // e1: f16 = extractelt L:0, 1 <-- unused
5871 // store e0
5872 //
5873 // Can be optimized by DAGCombiner to:
5874 //
5875 // L: f16,ch = load <p>
5876 // store L:0
5877 return !U.getUser()->use_empty();
5878 }
5879
5880 // Otherwise, this use prevents us from splitting a value.
5881 return false;
5882 }))
5883 return SDValue();
5884
5885 auto *LD = cast<MemSDNode>(Val: N);
5886 SDLoc DL(LD);
5887
5888 // the new opcode after we double the number of operands
5889 unsigned Opcode;
5890 SmallVector<SDValue> Operands(LD->ops());
5891 unsigned OldNumOutputs; // non-glue, non-chain outputs
5892 switch (LD->getOpcode()) {
5893 case ISD::LOAD:
5894 OldNumOutputs = 1;
5895 // Any packed type is legal, so the legalizer will not have lowered
5896 // ISD::LOAD -> NVPTXISD::Load (unless it's under-aligned). We have to do it
5897 // here.
5898 Opcode = NVPTXISD::LoadV2;
5899 // append a "full" used bytes mask operand right before the extension type
5900 // operand, signifying that all bytes are used.
5901 Operands.push_back(Elt: DCI.DAG.getConstant(UINT32_MAX, DL, VT: MVT::i32));
5902 Operands.push_back(Elt: DCI.DAG.getIntPtrConstant(
5903 Val: cast<LoadSDNode>(Val: LD)->getExtensionType(), DL));
5904 break;
5905 case NVPTXISD::LoadV2:
5906 OldNumOutputs = 2;
5907 Opcode = NVPTXISD::LoadV4;
5908 break;
5909 case NVPTXISD::LoadV4:
5910 // V8 is only supported for f32/i32. Don't forget, we're not changing the
5911 // load size here. This is already a 256-bit load.
5912 if (ElementVT != MVT::v2f32 && ElementVT != MVT::v2i32)
5913 return SDValue();
5914 OldNumOutputs = 4;
5915 Opcode = NVPTXISD::LoadV8;
5916 break;
5917 case NVPTXISD::LoadV8:
5918 // PTX doesn't support the next doubling of outputs
5919 return SDValue();
5920 }
5921
5922 // the non-glue, non-chain outputs in the new load
5923 const unsigned NewNumOutputs = OldNumOutputs * 2;
5924 SmallVector<EVT> NewVTs(NewNumOutputs, ElementVT.getVectorElementType());
5925 // add remaining chain and glue values
5926 NewVTs.append(in_start: LD->value_begin() + OldNumOutputs, in_end: LD->value_end());
5927
5928 // Create the new load
5929 SDValue NewLoad = DCI.DAG.getMemIntrinsicNode(
5930 Opcode, dl: DL, VTList: DCI.DAG.getVTList(VTs: NewVTs), Ops: Operands, MemVT: LD->getMemoryVT(),
5931 MMO: LD->getMemOperand());
5932
5933 // Now we use a combination of BUILD_VECTORs and a MERGE_VALUES node to keep
5934 // the outputs the same. These nodes will be optimized away in later
5935 // DAGCombiner iterations.
5936 SmallVector<SDValue> Results;
5937 for (unsigned I : seq(Size: OldNumOutputs))
5938 Results.push_back(Elt: DCI.DAG.getBuildVector(
5939 VT: ElementVT, DL, Ops: {NewLoad.getValue(R: I * 2), NewLoad.getValue(R: I * 2 + 1)}));
5940 // Add remaining chain and glue nodes
5941 for (unsigned I : seq(Size: NewLoad->getNumValues() - NewNumOutputs))
5942 Results.push_back(Elt: NewLoad.getValue(R: NewNumOutputs + I));
5943
5944 return DCI.DAG.getMergeValues(Ops: Results, dl: DL);
5945}
5946
5947/// Fold packing movs into a store.
5948///
5949/// ex:
5950/// v1: v2f16 = BUILD_VECTOR a:f16, b:f16
5951/// v2: v2f16 = BUILD_VECTOR c:f16, d:f16
5952/// StoreV2 v1, v2
5953///
5954/// ...is turned into...
5955///
5956/// StoreV4 a, b, c, d
5957static SDValue combinePackingMovIntoStore(SDNode *N,
5958 TargetLowering::DAGCombinerInfo &DCI,
5959 unsigned Front, unsigned Back) {
5960 // We want to run this as late as possible since other optimizations may
5961 // eliminate the BUILD_VECTORs.
5962 if (!DCI.isAfterLegalizeDAG())
5963 return SDValue();
5964
5965 // Get the type of the operands being stored.
5966 EVT ElementVT = N->getOperand(Num: Front).getValueType();
5967
5968 // Avoid non-packed types and v4i8
5969 if (!NVPTX::isPackedVectorTy(VT: ElementVT) || ElementVT == MVT::v4i8)
5970 return SDValue();
5971
5972 auto *ST = cast<MemSDNode>(Val: N);
5973
5974 // The new opcode after we double the number of operands.
5975 unsigned Opcode;
5976 switch (N->getOpcode()) {
5977 case ISD::STORE:
5978 // Any packed type is legal, so the legalizer will not have lowered
5979 // ISD::STORE -> NVPTXISD::Store (unless it's under-aligned). We have to do
5980 // it here.
5981 Opcode = NVPTXISD::StoreV2;
5982 break;
5983 case NVPTXISD::StoreV2:
5984 Opcode = NVPTXISD::StoreV4;
5985 break;
5986 case NVPTXISD::StoreV4:
5987 // V8 is only supported for f32/i32. Don't forget, we're not changing the
5988 // store size here. This is already a 256-bit store.
5989 if (ElementVT != MVT::v2f32 && ElementVT != MVT::v2i32)
5990 return SDValue();
5991 Opcode = NVPTXISD::StoreV8;
5992 break;
5993 case NVPTXISD::StoreV8:
5994 // PTX doesn't support the next doubling of operands
5995 return SDValue();
5996 default:
5997 llvm_unreachable("Unhandled store opcode");
5998 }
5999
6000 // Scan the operands and if they're all BUILD_VECTORs, we'll have gathered
6001 // their elements.
6002 SmallVector<SDValue, 4> Operands(N->ops().take_front(N: Front));
6003 for (SDValue BV : N->ops().drop_front(N: Front).drop_back(N: Back)) {
6004 if (BV.getOpcode() != ISD::BUILD_VECTOR)
6005 return SDValue();
6006
6007 // If the operand has multiple uses, this optimization can increase register
6008 // pressure.
6009 if (!BV.hasOneUse())
6010 return SDValue();
6011
6012 // DAGCombiner visits nodes bottom-up. Check the BUILD_VECTOR operands for
6013 // any signs they may be folded by some other pattern or rule.
6014 for (SDValue Op : BV->ops()) {
6015 // Peek through bitcasts
6016 if (Op.getOpcode() == ISD::BITCAST)
6017 Op = Op.getOperand(i: 0);
6018
6019 // This may be folded into a PRMT.
6020 if (Op.getValueType() == MVT::i16 && Op.getOpcode() == ISD::TRUNCATE &&
6021 Op->getOperand(Num: 0).getValueType() == MVT::i32)
6022 return SDValue();
6023
6024 // This may be folded into cvt.bf16x2
6025 if (Op.getOpcode() == ISD::FP_ROUND)
6026 return SDValue();
6027 }
6028 Operands.append(IL: {BV.getOperand(i: 0), BV.getOperand(i: 1)});
6029 }
6030 Operands.append(in_start: N->op_end() - Back, in_end: N->op_end());
6031
6032 // Now we replace the store
6033 return DCI.DAG.getMemIntrinsicNode(Opcode, dl: SDLoc(N), VTList: N->getVTList(), Ops: Operands,
6034 MemVT: ST->getMemoryVT(), MMO: ST->getMemOperand());
6035}
6036
6037static SDValue combineSTORE(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
6038 const NVPTXSubtarget &STI) {
6039
6040 if (DCI.isBeforeLegalize() && N->getOpcode() == ISD::STORE) {
6041 // Here is our chance to custom lower a store with a non-simple type.
6042 // Unfortunately, we can't do this in the legalizer because there is no
6043 // way to setOperationAction for an non-simple type.
6044 StoreSDNode *ST = cast<StoreSDNode>(Val: N);
6045 if (!ST->getValue().getValueType().isSimple())
6046 return lowerSTOREVector(Op: SDValue(ST, 0), DAG&: DCI.DAG, STI);
6047 }
6048
6049 return combinePackingMovIntoStore(N, DCI, Front: 1, Back: 2);
6050}
6051
6052static SDValue combineLOAD(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
6053 const NVPTXSubtarget &STI) {
6054 if (DCI.isBeforeLegalize() && N->getOpcode() == ISD::LOAD) {
6055 // Here is our chance to custom lower a load with a non-simple type.
6056 // Unfortunately, we can't do this in the legalizer because there is no
6057 // way to setOperationAction for an non-simple type.
6058 if (!N->getValueType(ResNo: 0).isSimple())
6059 return lowerLoadVector(N, DAG&: DCI.DAG, STI);
6060 }
6061
6062 return combineUnpackingMovIntoLoad(N, DCI);
6063}
6064
6065/// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
6066///
6067static SDValue PerformADDCombine(SDNode *N,
6068 TargetLowering::DAGCombinerInfo &DCI,
6069 CodeGenOptLevel OptLevel) {
6070 if (OptLevel == CodeGenOptLevel::None)
6071 return SDValue();
6072
6073 SDValue N0 = N->getOperand(Num: 0);
6074 SDValue N1 = N->getOperand(Num: 1);
6075
6076 // Skip non-integer, non-scalar case
6077 EVT VT = N0.getValueType();
6078 if (VT.isVector() || VT != MVT::i32)
6079 return SDValue();
6080
6081 // First try with the default operand order.
6082 if (SDValue Result = PerformADDCombineWithOperands(N, N0, N1, DCI))
6083 return Result;
6084
6085 // If that didn't work, try again with the operands commuted.
6086 return PerformADDCombineWithOperands(N, N0: N1, N1: N0, DCI);
6087}
6088
6089/// Check if a v2f32 BUILD_VECTOR provably packs values from non-adjacent
6090/// register pairs (non-coalescable).
6091static bool isNonCoalescableBuildVector(const SDValue &BV) {
6092 if (BV.getOpcode() != ISD::BUILD_VECTOR || BV.getValueType() != MVT::v2f32)
6093 return false;
6094
6095 SDValue Elt0 = BV.getOperand(i: 0);
6096 SDValue Elt1 = BV.getOperand(i: 1);
6097
6098 bool IsExt0 = Elt0.getOpcode() == ISD::EXTRACT_VECTOR_ELT;
6099 bool IsExt1 = Elt1.getOpcode() == ISD::EXTRACT_VECTOR_ELT;
6100
6101 // If neither element is an EXTRACT_VECTOR_ELT they are free-standing
6102 // scalars and the register allocator can still place them side-by-side.
6103 if (!IsExt0 && !IsExt1)
6104 return false;
6105
6106 // If exactly one element is an EXTRACT_VECTOR_ELT, the other is a scalar
6107 // that cannot generally occupy the adjacent register slot.
6108 if (IsExt0 != IsExt1)
6109 return true;
6110
6111 // At this point both sources are extracting from vectors. If they are from
6112 // different vectors, then the BUILD_VECTOR is non-coalescable.
6113 SDValue Src0 = Elt0.getOperand(i: 0);
6114 SDValue Src1 = Elt1.getOperand(i: 0);
6115 if (Src0 != Src1)
6116 return true;
6117
6118 auto *Idx0 = dyn_cast<ConstantSDNode>(Val: Elt0.getOperand(i: 1));
6119 auto *Idx1 = dyn_cast<ConstantSDNode>(Val: Elt1.getOperand(i: 1));
6120 // If both indices are dynamic they will be lowered to
6121 // loads and the vector will be spilled to local memory. The register
6122 // allocator can easily place the results in adjacent registers.
6123 if (!Idx0 && !Idx1)
6124 return false;
6125
6126 // If one index is dynamic and the other is constant, the value from the
6127 // constant load will result in an additional register to pair with the result
6128 // from the dynamic load. We consider this non-coalescable.
6129 if ((Idx0 && !Idx1) || (!Idx0 && Idx1))
6130 return true;
6131
6132 // Both are constant, adjacent pairs are coalescable
6133 return std::abs(i: Idx0->getSExtValue() - Idx1->getSExtValue()) != 1;
6134}
6135
6136/// Scalarize a v2f32 arithmetic node (FADD, FMUL, FSUB, FMA) when at least
6137/// one operand is a BUILD_VECTOR that repacks values from non-adjacent register
6138/// pairs. Without this combine the BUILD_VECTOR forces allocation of a
6139/// temporary 64-bit register, increasing register pressure.
6140///
6141/// Example - before:
6142/// t0: v2f32,v2f32,ch = LoadV2 ...
6143/// t1: f32 = extract_vector_elt t0, 0
6144/// t2: f32 = extract_vector_elt t0:1, 0
6145/// t3: v2f32 = BUILD_VECTOR t1, t2 ;; non-coalescable repack
6146/// t4: v2f32 = fma t_a, t3, t_c
6147///
6148/// After:
6149/// t0: v2f32,v2f32,ch = LoadV2 ...
6150/// t1: f32 = extract_vector_elt t0, 0
6151/// t2: f32 = extract_vector_elt t0:1, 0
6152/// a0: f32 = extract_vector_elt t_a, 0
6153/// a1: f32 = extract_vector_elt t_a, 1
6154/// c0: f32 = extract_vector_elt t_c, 0
6155/// c1: f32 = extract_vector_elt t_c, 1
6156/// r0: f32 = fma a0, t1, c0
6157/// r1: f32 = fma a1, t2, c1
6158/// t4: v2f32 = BUILD_VECTOR r0, r1
6159static SDValue PerformScalarizeV2F32Op(SDNode *N,
6160 TargetLowering::DAGCombinerInfo &DCI) {
6161 EVT VT = N->getValueType(ResNo: 0);
6162 if (VT != MVT::v2f32)
6163 return SDValue();
6164
6165 // Only scalarize when at least one operand is a BUILD_VECTOR whose elements
6166 // are guaranteed to reside in different register pairs.
6167 if (none_of(Range: N->ops(), P: isNonCoalescableBuildVector))
6168 return SDValue();
6169
6170 SelectionDAG &DAG = DCI.DAG;
6171 SDLoc DL(N);
6172 EVT EltVT = VT.getVectorElementType();
6173 unsigned Opc = N->getOpcode();
6174
6175 // For each operand, get the scalar element at the given index: if the operand
6176 // is a BUILD_VECTOR, grab the element directly; otherwise, emit an
6177 // EXTRACT_VECTOR_ELT.
6178 auto GetElement = [&](SDValue Op, unsigned Index) -> SDValue {
6179 if (Op.getOpcode() == ISD::BUILD_VECTOR)
6180 return Op.getOperand(i: Index);
6181 return DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: Op,
6182 N2: DAG.getVectorIdxConstant(Val: Index, DL));
6183 };
6184
6185 // Build scalar operand lists for element 0 and element 1.
6186 SmallVector<SDValue, 3> Ops0, Ops1;
6187 for (const SDValue &Op : N->ops()) {
6188 Ops0.push_back(Elt: GetElement(Op, 0));
6189 Ops1.push_back(Elt: GetElement(Op, 1));
6190 }
6191
6192 SDValue Res0 = DAG.getNode(Opcode: Opc, DL, VT: EltVT, Ops: Ops0, Flags: N->getFlags());
6193 SDValue Res1 = DAG.getNode(Opcode: Opc, DL, VT: EltVT, Ops: Ops1, Flags: N->getFlags());
6194
6195 return DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT, N1: Res0, N2: Res1);
6196}
6197
6198/// PerformFADDCombine - Target-specific dag combine xforms for ISD::FADD.
6199///
6200static SDValue PerformFADDCombine(SDNode *N,
6201 TargetLowering::DAGCombinerInfo &DCI,
6202 CodeGenOptLevel OptLevel) {
6203 SDValue N0 = N->getOperand(Num: 0);
6204 SDValue N1 = N->getOperand(Num: 1);
6205
6206 if (SDValue Result = PerformScalarizeV2F32Op(N, DCI))
6207 return Result;
6208
6209 EVT VT = N0.getValueType();
6210 if (VT.isVector() || !(VT == MVT::f32 || VT == MVT::f64))
6211 return SDValue();
6212
6213 // First try with the default operand order.
6214 if (SDValue Result = PerformFADDCombineWithOperands(N, N0, N1, DCI, OptLevel))
6215 return Result;
6216
6217 // If that didn't work, try again with the operands commuted.
6218 return PerformFADDCombineWithOperands(N, N0: N1, N1: N0, DCI, OptLevel);
6219}
6220
6221/// Get 3-input version of a 2-input min/max opcode
6222static unsigned getMinMax3Opcode(unsigned MinMax2Opcode) {
6223 switch (MinMax2Opcode) {
6224 case ISD::FMAXNUM:
6225 case ISD::FMAXIMUMNUM:
6226 return NVPTXISD::FMAXNUM3;
6227 case ISD::FMINNUM:
6228 case ISD::FMINIMUMNUM:
6229 return NVPTXISD::FMINNUM3;
6230 case ISD::FMAXIMUM:
6231 return NVPTXISD::FMAXIMUM3;
6232 case ISD::FMINIMUM:
6233 return NVPTXISD::FMINIMUM3;
6234 default:
6235 llvm_unreachable("Invalid 2-input min/max opcode");
6236 }
6237}
6238
6239/// PerformFMinMaxCombine - Combine (fmaxnum (fmaxnum a, b), c) into
6240/// (fmaxnum3 a, b, c). Also covers other llvm min/max intrinsics.
6241static SDValue PerformFMinMaxCombine(SDNode *N,
6242 TargetLowering::DAGCombinerInfo &DCI,
6243 unsigned PTXVersion, unsigned SmVersion) {
6244
6245 // 3-input min/max requires PTX 8.8+ and SM_100+, and only supports f32s
6246 EVT VT = N->getValueType(ResNo: 0);
6247 if (VT != MVT::f32 || PTXVersion < 88 || SmVersion < 100)
6248 return SDValue();
6249
6250 SDValue Op0 = N->getOperand(Num: 0);
6251 SDValue Op1 = N->getOperand(Num: 1);
6252 unsigned MinMaxOp2 = N->getOpcode();
6253 unsigned MinMaxOp3 = getMinMax3Opcode(MinMax2Opcode: MinMaxOp2);
6254
6255 if (Op0.getOpcode() == MinMaxOp2 && Op0.hasOneUse()) {
6256 // (maxnum (maxnum a, b), c) -> (maxnum3 a, b, c)
6257 SDValue A = Op0.getOperand(i: 0);
6258 SDValue B = Op0.getOperand(i: 1);
6259 SDValue C = Op1;
6260 return DCI.DAG.getNode(Opcode: MinMaxOp3, DL: SDLoc(N), VT, N1: A, N2: B, N3: C, Flags: N->getFlags());
6261 } else if (Op1.getOpcode() == MinMaxOp2 && Op1.hasOneUse()) {
6262 // (maxnum a, (maxnum b, c)) -> (maxnum3 a, b, c)
6263 SDValue A = Op0;
6264 SDValue B = Op1.getOperand(i: 0);
6265 SDValue C = Op1.getOperand(i: 1);
6266 return DCI.DAG.getNode(Opcode: MinMaxOp3, DL: SDLoc(N), VT, N1: A, N2: B, N3: C, Flags: N->getFlags());
6267 }
6268 return SDValue();
6269}
6270
6271static SDValue PerformREMCombine(SDNode *N,
6272 TargetLowering::DAGCombinerInfo &DCI,
6273 CodeGenOptLevel OptLevel) {
6274 assert(N->getOpcode() == ISD::SREM || N->getOpcode() == ISD::UREM);
6275
6276 // Don't do anything at less than -O2.
6277 if (OptLevel < CodeGenOptLevel::Default)
6278 return SDValue();
6279
6280 SelectionDAG &DAG = DCI.DAG;
6281 SDLoc DL(N);
6282 EVT VT = N->getValueType(ResNo: 0);
6283 bool IsSigned = N->getOpcode() == ISD::SREM;
6284 unsigned DivOpc = IsSigned ? ISD::SDIV : ISD::UDIV;
6285
6286 const SDValue &Num = N->getOperand(Num: 0);
6287 const SDValue &Den = N->getOperand(Num: 1);
6288
6289 for (const SDNode *U : Num->users()) {
6290 if (U->getOpcode() == DivOpc && U->getOperand(Num: 0) == Num &&
6291 U->getOperand(Num: 1) == Den) {
6292 // Num % Den -> Num - (Num / Den) * Den
6293 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Num,
6294 N2: DAG.getNode(Opcode: ISD::MUL, DL, VT,
6295 N1: DAG.getNode(Opcode: DivOpc, DL, VT, N1: Num, N2: Den),
6296 N2: Den));
6297 }
6298 }
6299 return SDValue();
6300}
6301
6302// (sign_extend|zero_extend (mul|shl) x, y) -> (mul.wide x, y)
6303static SDValue combineMulWide(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
6304 CodeGenOptLevel OptLevel) {
6305 if (OptLevel == CodeGenOptLevel::None)
6306 return SDValue();
6307
6308 SDValue Op = N->getOperand(Num: 0);
6309 if (!Op.hasOneUse())
6310 return SDValue();
6311 EVT ToVT = N->getValueType(ResNo: 0);
6312 EVT FromVT = Op.getValueType();
6313 if (!((ToVT == MVT::i32 && FromVT == MVT::i16) ||
6314 (ToVT == MVT::i64 && FromVT == MVT::i32)))
6315 return SDValue();
6316 if (!(Op.getOpcode() == ISD::MUL ||
6317 (Op.getOpcode() == ISD::SHL && isa<ConstantSDNode>(Val: Op.getOperand(i: 1)))))
6318 return SDValue();
6319
6320 SDLoc DL(N);
6321 unsigned ExtOpcode = N->getOpcode();
6322 unsigned Opcode = 0;
6323 if (ExtOpcode == ISD::SIGN_EXTEND && Op->getFlags().hasNoSignedWrap())
6324 Opcode = NVPTXISD::MUL_WIDE_SIGNED;
6325 else if (ExtOpcode == ISD::ZERO_EXTEND && Op->getFlags().hasNoUnsignedWrap())
6326 Opcode = NVPTXISD::MUL_WIDE_UNSIGNED;
6327 else
6328 return SDValue();
6329 SDValue RHS = Op.getOperand(i: 1);
6330 if (Op.getOpcode() == ISD::SHL) {
6331 const auto ShiftAmt = Op.getConstantOperandVal(i: 1);
6332 const auto MulVal = APInt(FromVT.getSizeInBits(), 1) << ShiftAmt;
6333 RHS = DCI.DAG.getConstant(Val: MulVal, DL, VT: FromVT);
6334 }
6335 return DCI.DAG.getNode(Opcode, DL, VT: ToVT, N1: Op.getOperand(i: 0), N2: RHS);
6336}
6337
6338enum OperandSignedness {
6339 Signed = 0,
6340 Unsigned,
6341 Unknown
6342};
6343
6344/// IsMulWideOperandDemotable - Checks if the provided DAG node is an operand
6345/// that can be demoted to \p OptSize bits without loss of information. The
6346/// signedness of the operand, if determinable, is placed in \p S.
6347static bool IsMulWideOperandDemotable(SDValue Op,
6348 unsigned OptSize,
6349 OperandSignedness &S) {
6350 S = Unknown;
6351
6352 if (Op.getOpcode() == ISD::SIGN_EXTEND ||
6353 Op.getOpcode() == ISD::SIGN_EXTEND_INREG) {
6354 EVT OrigVT = Op.getOperand(i: 0).getValueType();
6355 if (OrigVT.getFixedSizeInBits() <= OptSize) {
6356 S = Signed;
6357 return true;
6358 }
6359 } else if (Op.getOpcode() == ISD::ZERO_EXTEND) {
6360 EVT OrigVT = Op.getOperand(i: 0).getValueType();
6361 if (OrigVT.getFixedSizeInBits() <= OptSize) {
6362 S = Unsigned;
6363 return true;
6364 }
6365 }
6366
6367 return false;
6368}
6369
6370/// AreMulWideOperandsDemotable - Checks if the given LHS and RHS operands can
6371/// be demoted to \p OptSize bits without loss of information. If the operands
6372/// contain a constant, it should appear as the RHS operand. The signedness of
6373/// the operands is placed in \p IsSigned.
6374static bool AreMulWideOperandsDemotable(SDValue LHS, SDValue RHS,
6375 unsigned OptSize,
6376 bool &IsSigned) {
6377 OperandSignedness LHSSign;
6378
6379 // The LHS operand must be a demotable op
6380 if (!IsMulWideOperandDemotable(Op: LHS, OptSize, S&: LHSSign))
6381 return false;
6382
6383 // We should have been able to determine the signedness from the LHS
6384 if (LHSSign == Unknown)
6385 return false;
6386
6387 IsSigned = (LHSSign == Signed);
6388
6389 // The RHS can be a demotable op or a constant
6390 if (ConstantSDNode *CI = dyn_cast<ConstantSDNode>(Val&: RHS)) {
6391 const APInt &Val = CI->getAPIntValue();
6392 if (LHSSign == Unsigned) {
6393 return Val.isIntN(N: OptSize);
6394 } else {
6395 return Val.isSignedIntN(N: OptSize);
6396 }
6397 } else {
6398 OperandSignedness RHSSign;
6399 if (!IsMulWideOperandDemotable(Op: RHS, OptSize, S&: RHSSign))
6400 return false;
6401
6402 return LHSSign == RHSSign;
6403 }
6404}
6405
6406/// TryMULWIDECombine - Attempt to replace a multiply of M bits with a multiply
6407/// of M/2 bits that produces an M-bit result (i.e. mul.wide). This transform
6408/// works on both multiply DAG nodes and SHL DAG nodes with a constant shift
6409/// amount.
6410static SDValue TryMULWIDECombine(SDNode *N,
6411 TargetLowering::DAGCombinerInfo &DCI) {
6412 EVT MulType = N->getValueType(ResNo: 0);
6413 if (MulType != MVT::i32 && MulType != MVT::i64) {
6414 return SDValue();
6415 }
6416
6417 SDLoc DL(N);
6418 unsigned OptSize = MulType.getSizeInBits() >> 1;
6419 SDValue LHS = N->getOperand(Num: 0);
6420 SDValue RHS = N->getOperand(Num: 1);
6421
6422 // Canonicalize the multiply so the constant (if any) is on the right
6423 if (N->getOpcode() == ISD::MUL) {
6424 if (isa<ConstantSDNode>(Val: LHS)) {
6425 std::swap(a&: LHS, b&: RHS);
6426 }
6427 }
6428
6429 // If we have a SHL, determine the actual multiply amount
6430 if (N->getOpcode() == ISD::SHL) {
6431 ConstantSDNode *ShlRHS = dyn_cast<ConstantSDNode>(Val&: RHS);
6432 if (!ShlRHS) {
6433 return SDValue();
6434 }
6435
6436 APInt ShiftAmt = ShlRHS->getAPIntValue();
6437 unsigned BitWidth = MulType.getSizeInBits();
6438 if (ShiftAmt.sge(RHS: 0) && ShiftAmt.slt(RHS: BitWidth)) {
6439 APInt MulVal = APInt(BitWidth, 1) << ShiftAmt;
6440 RHS = DCI.DAG.getConstant(Val: MulVal, DL, VT: MulType);
6441 } else {
6442 return SDValue();
6443 }
6444 }
6445
6446 bool Signed;
6447 // Verify that our operands are demotable
6448 if (!AreMulWideOperandsDemotable(LHS, RHS, OptSize, IsSigned&: Signed)) {
6449 return SDValue();
6450 }
6451
6452 EVT DemotedVT;
6453 if (MulType == MVT::i32) {
6454 DemotedVT = MVT::i16;
6455 } else {
6456 DemotedVT = MVT::i32;
6457 }
6458
6459 // Truncate the operands to the correct size. Note that these are just for
6460 // type consistency and will (likely) be eliminated in later phases.
6461 SDValue TruncLHS =
6462 DCI.DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: DemotedVT, Operand: LHS);
6463 SDValue TruncRHS =
6464 DCI.DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: DemotedVT, Operand: RHS);
6465
6466 unsigned Opc;
6467 if (Signed) {
6468 Opc = NVPTXISD::MUL_WIDE_SIGNED;
6469 } else {
6470 Opc = NVPTXISD::MUL_WIDE_UNSIGNED;
6471 }
6472
6473 return DCI.DAG.getNode(Opcode: Opc, DL, VT: MulType, N1: TruncLHS, N2: TruncRHS);
6474}
6475
6476static bool isConstOne(const SDValue &Operand) {
6477 const auto *Const = dyn_cast<ConstantSDNode>(Val: Operand);
6478 return Const && Const->getZExtValue() == 1;
6479}
6480
6481static SDValue matchMADConstOnePattern(SDValue Add) {
6482 if (Add->getOpcode() != ISD::ADD)
6483 return SDValue();
6484
6485 if (isConstOne(Operand: Add->getOperand(Num: 0)))
6486 return Add->getOperand(Num: 1);
6487
6488 if (isConstOne(Operand: Add->getOperand(Num: 1)))
6489 return Add->getOperand(Num: 0);
6490
6491 return SDValue();
6492}
6493
6494static SDValue combineMADConstOne(SDValue X, SDValue Add, EVT VT, SDLoc DL,
6495 TargetLowering::DAGCombinerInfo &DCI) {
6496
6497 if (SDValue Y = matchMADConstOnePattern(Add)) {
6498 SDValue Mul = DCI.DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: X, N2: Y);
6499 return DCI.DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Mul, N2: X);
6500 }
6501
6502 return SDValue();
6503}
6504
6505static SDValue combineMulSelectConstOne(SDValue X, SDValue Select, EVT VT,
6506 SDLoc DL,
6507 TargetLowering::DAGCombinerInfo &DCI) {
6508 if (Select->getOpcode() != ISD::SELECT)
6509 return SDValue();
6510
6511 SDValue Cond = Select->getOperand(Num: 0);
6512
6513 unsigned ConstOpNo;
6514 if (isConstOne(Operand: Select->getOperand(Num: 1)))
6515 ConstOpNo = 1;
6516 else if (isConstOne(Operand: Select->getOperand(Num: 2)))
6517 ConstOpNo = 2;
6518 else
6519 return SDValue();
6520
6521 SDValue Y = Select->getOperand(Num: (ConstOpNo == 1) ? 2 : 1);
6522
6523 // Do not combine if the resulting sequence is not obviously profitable.
6524 if (!matchMADConstOnePattern(Add: Y))
6525 return SDValue();
6526
6527 SDValue NewMul = DCI.DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: X, N2: Y);
6528
6529 return DCI.DAG.getNode(Opcode: ISD::SELECT, DL, VT, N1: Cond,
6530 N2: (ConstOpNo == 1) ? X : NewMul,
6531 N3: (ConstOpNo == 1) ? NewMul : X);
6532}
6533
6534static SDValue
6535PerformMULCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
6536 TargetLowering::DAGCombinerInfo &DCI) {
6537
6538 EVT VT = N0.getValueType();
6539 if (VT.isVector())
6540 return SDValue();
6541
6542 if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
6543 return SDValue();
6544
6545 SDLoc DL(N);
6546
6547 // (mul x, (add y, 1)) -> (add (mul x, y), x)
6548 if (SDValue Res = combineMADConstOne(X: N0, Add: N1, VT, DL, DCI))
6549 return Res;
6550 if (SDValue Res = combineMADConstOne(X: N1, Add: N0, VT, DL, DCI))
6551 return Res;
6552
6553 // (mul x, (select y, 1)) -> (select (mul x, y), x)
6554 if (SDValue Res = combineMulSelectConstOne(X: N0, Select: N1, VT, DL, DCI))
6555 return Res;
6556 if (SDValue Res = combineMulSelectConstOne(X: N1, Select: N0, VT, DL, DCI))
6557 return Res;
6558
6559 return SDValue();
6560}
6561
6562/// PerformMULCombine - Runs PTX-specific DAG combine patterns on MUL nodes.
6563static SDValue PerformMULCombine(SDNode *N,
6564 TargetLowering::DAGCombinerInfo &DCI,
6565 CodeGenOptLevel OptLevel) {
6566 if (OptLevel == CodeGenOptLevel::None)
6567 return SDValue();
6568
6569 if (SDValue Ret = TryMULWIDECombine(N, DCI))
6570 return Ret;
6571
6572 SDValue N0 = N->getOperand(Num: 0);
6573 SDValue N1 = N->getOperand(Num: 1);
6574 return PerformMULCombineWithOperands(N, N0, N1, DCI);
6575}
6576
6577/// PerformSHLCombine - Runs PTX-specific DAG combine patterns on SHL nodes.
6578static SDValue PerformSHLCombine(SDNode *N,
6579 TargetLowering::DAGCombinerInfo &DCI,
6580 CodeGenOptLevel OptLevel) {
6581 if (OptLevel > CodeGenOptLevel::None) {
6582 // Try mul.wide combining at OptLevel > 0
6583 if (SDValue Ret = TryMULWIDECombine(N, DCI))
6584 return Ret;
6585 }
6586
6587 return SDValue();
6588}
6589
6590static SDValue PerformSETCCCombine(SDNode *N,
6591 TargetLowering::DAGCombinerInfo &DCI,
6592 unsigned int SmVersion) {
6593 EVT CCType = N->getValueType(ResNo: 0);
6594 SDValue A = N->getOperand(Num: 0);
6595 SDValue B = N->getOperand(Num: 1);
6596
6597 EVT AType = A.getValueType();
6598 if (!(CCType == MVT::v2i1 && (AType == MVT::v2f16 || AType == MVT::v2bf16)))
6599 return SDValue();
6600
6601 if (A.getValueType() == MVT::v2bf16 && SmVersion < 90)
6602 return SDValue();
6603
6604 SDLoc DL(N);
6605 // setp.f16x2 returns two scalar predicates, which we need to
6606 // convert back to v2i1. The returned result will be scalarized by
6607 // the legalizer, but the comparison will remain a single vector
6608 // instruction.
6609 SDValue CCNode = DCI.DAG.getNode(
6610 Opcode: A.getValueType() == MVT::v2f16 ? NVPTXISD::SETP_F16X2
6611 : NVPTXISD::SETP_BF16X2,
6612 DL, VTList: DCI.DAG.getVTList(VT1: MVT::i1, VT2: MVT::i1), Ops: {A, B, N->getOperand(Num: 2)});
6613 return DCI.DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: CCType, N1: CCNode.getValue(R: 0),
6614 N2: CCNode.getValue(R: 1));
6615}
6616
6617static SDValue PerformEXTRACTCombine(SDNode *N,
6618 TargetLowering::DAGCombinerInfo &DCI) {
6619 SDValue Vector = N->getOperand(Num: 0);
6620 if (Vector->getOpcode() == ISD::FREEZE)
6621 Vector = Vector->getOperand(Num: 0);
6622 SDLoc DL(N);
6623 EVT VectorVT = Vector.getValueType();
6624 if (Vector->getOpcode() == ISD::LOAD && VectorVT.isSimple() &&
6625 IsPTXVectorType(VT: VectorVT.getSimpleVT()))
6626 return SDValue(); // Native vector loads already combine nicely w/
6627 // extract_vector_elt.
6628 // Don't mess with singletons or packed types (v2*32, v2*16, v4i8 and v8i8),
6629 // we already handle them OK.
6630 if (VectorVT.getVectorNumElements() == 1 ||
6631 NVPTX::isPackedVectorTy(VT: VectorVT) || VectorVT == MVT::v8i8)
6632 return SDValue();
6633
6634 // Don't mess with undef values as sra may be simplified to 0, not undef.
6635 if (Vector->isUndef() || ISD::allOperandsUndef(N: Vector.getNode()))
6636 return SDValue();
6637
6638 uint64_t VectorBits = VectorVT.getSizeInBits();
6639 // We only handle the types we can extract in-register.
6640 if (!(VectorBits == 16 || VectorBits == 32 || VectorBits == 64))
6641 return SDValue();
6642
6643 ConstantSDNode *Index = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
6644 // Index == 0 is handled by generic DAG combiner.
6645 if (!Index || Index->getZExtValue() == 0)
6646 return SDValue();
6647
6648 MVT IVT = MVT::getIntegerVT(BitWidth: VectorBits);
6649 EVT EltVT = VectorVT.getVectorElementType();
6650 EVT EltIVT = EltVT.changeTypeToInteger();
6651 uint64_t EltBits = EltVT.getScalarSizeInBits();
6652
6653 SDValue Result = DCI.DAG.getNode(
6654 Opcode: ISD::TRUNCATE, DL, VT: EltIVT,
6655 Operand: DCI.DAG.getNode(
6656 Opcode: ISD::SRA, DL, VT: IVT, N1: DCI.DAG.getNode(Opcode: ISD::BITCAST, DL, VT: IVT, Operand: Vector),
6657 N2: DCI.DAG.getConstant(Val: Index->getZExtValue() * EltBits, DL, VT: IVT)));
6658
6659 // If element has non-integer type, bitcast it back to the expected type.
6660 if (EltVT != EltIVT)
6661 Result = DCI.DAG.getNode(Opcode: ISD::BITCAST, DL, VT: EltVT, Operand: Result);
6662 // Past legalizer, we may need to extent i8 -> i16 to match the register type.
6663 if (EltVT != N->getValueType(ResNo: 0))
6664 Result = DCI.DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT: N->getValueType(ResNo: 0), Operand: Result);
6665
6666 return Result;
6667}
6668
6669/// Transform patterns like:
6670/// (select (ugt shift_amt, BitWidth-1), 0, (srl/shl x, shift_amt))
6671/// (select (ult shift_amt, BitWidth), (srl/shl x, shift_amt), 0)
6672/// Into:
6673/// (NVPTXISD::SRL_CLAMP x, shift_amt) or (NVPTXISD::SHL_CLAMP x, shift_amt)
6674///
6675/// These patterns arise from C/C++ code like `shift >= 32 ? 0 : x >> shift`
6676/// which guards against undefined behavior. PTX shr/shl instructions clamp
6677/// shift amounts >= BitWidth to produce 0 for logical shifts, making the
6678/// guard redundant.
6679///
6680/// Note: We only handle SRL and SHL, not SRA, because arithmetic right
6681/// shifts could produce 0 or -1 when shift >= BitWidth.
6682/// Note: We don't handle uge or ule. These don't appear because of
6683/// canonicalization.
6684static SDValue PerformSELECTShiftCombine(SDNode *N,
6685 TargetLowering::DAGCombinerInfo &DCI) {
6686 if (!DCI.isAfterLegalizeDAG())
6687 return SDValue();
6688
6689 using namespace SDPatternMatch;
6690 unsigned BitWidth = N->getValueType(ResNo: 0).getSizeInBits();
6691 SDValue ShiftAmt, ShiftOp;
6692
6693 // Match logical shifts where the shift amount in the guard matches the shift
6694 // amount in the operation.
6695 auto LogicalShift =
6696 m_AllOf(preds: m_Value(N&: ShiftOp),
6697 preds: m_AnyOf(preds: m_Srl(L: m_Value(), R: m_TruncOrSelf(Op: m_Deferred(V&: ShiftAmt))),
6698 preds: m_Shl(L: m_Value(), R: m_TruncOrSelf(Op: m_Deferred(V&: ShiftAmt)))));
6699
6700 // shift_amt > BitWidth-1 ? 0 : shift_op
6701 bool MatchedUGT =
6702 sd_match(N, P: m_Select(Cond: m_SetCC(LHS: m_Value(N&: ShiftAmt),
6703 RHS: m_SpecificInt(V: APInt(BitWidth, BitWidth - 1)),
6704 CC: m_SpecificCondCode(CC: ISD::SETUGT)),
6705 T: m_Zero(), F: LogicalShift));
6706 // shift_amt < BitWidth ? shift_op : 0
6707 bool MatchedULT =
6708 !MatchedUGT &&
6709 sd_match(N, P: m_Select(Cond: m_SetCC(LHS: m_Value(N&: ShiftAmt),
6710 RHS: m_SpecificInt(V: APInt(BitWidth, BitWidth)),
6711 CC: m_SpecificCondCode(CC: ISD::SETULT)),
6712 T: LogicalShift, F: m_Zero()));
6713
6714 if (!MatchedUGT && !MatchedULT)
6715 return SDValue();
6716
6717 // Return a clamp shift operation, which has the same semantics as PTX shift.
6718 unsigned ClampOpc = ShiftOp.getOpcode() == ISD::SRL ? NVPTXISD::SRL_CLAMP
6719 : NVPTXISD::SHL_CLAMP;
6720 return DCI.DAG.getNode(Opcode: ClampOpc, DL: SDLoc(N), VT: ShiftOp.getValueType(),
6721 N1: ShiftOp.getOperand(i: 0), N2: ShiftOp.getOperand(i: 1));
6722}
6723
6724static SDValue PerformVSELECTCombine(SDNode *N,
6725 TargetLowering::DAGCombinerInfo &DCI) {
6726 SDValue VA = N->getOperand(Num: 1);
6727 EVT VectorVT = VA.getValueType();
6728 if (VectorVT != MVT::v4i8)
6729 return SDValue();
6730
6731 // We need to split vselect into individual per-element operations Because we
6732 // use BFE/BFI instruction for byte extraction/insertion, we do end up with
6733 // 32-bit values, so we may as well do comparison as i32 to avoid conversions
6734 // to/from i16 normally used for i8 values.
6735 SmallVector<SDValue, 4> E;
6736 SDLoc DL(N);
6737 SDValue VCond = N->getOperand(Num: 0);
6738 SDValue VB = N->getOperand(Num: 2);
6739 for (int I = 0; I < 4; ++I) {
6740 SDValue C = DCI.DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i1, N1: VCond,
6741 N2: DCI.DAG.getConstant(Val: I, DL, VT: MVT::i32));
6742 SDValue EA = DCI.DAG.getAnyExtOrTrunc(
6743 Op: DCI.DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i8, N1: VA,
6744 N2: DCI.DAG.getConstant(Val: I, DL, VT: MVT::i32)),
6745 DL, VT: MVT::i32);
6746 SDValue EB = DCI.DAG.getAnyExtOrTrunc(
6747 Op: DCI.DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i8, N1: VB,
6748 N2: DCI.DAG.getConstant(Val: I, DL, VT: MVT::i32)),
6749 DL, VT: MVT::i32);
6750 E.push_back(Elt: DCI.DAG.getAnyExtOrTrunc(
6751 Op: DCI.DAG.getNode(Opcode: ISD::SELECT, DL, VT: MVT::i32, N1: C, N2: EA, N3: EB), DL, VT: MVT::i8));
6752 }
6753 return DCI.DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: MVT::v4i8, Ops: E);
6754}
6755
6756static SDValue
6757PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
6758 auto VT = N->getValueType(ResNo: 0);
6759 if (!DCI.isAfterLegalizeDAG() ||
6760 // only process v2*16 types
6761 !(NVPTX::isPackedVectorTy(VT) && VT.is32BitVector() &&
6762 VT.getVectorNumElements() == 2))
6763 return SDValue();
6764
6765 auto Op0 = N->getOperand(Num: 0);
6766 auto Op1 = N->getOperand(Num: 1);
6767
6768 // Start out by assuming we want to take the lower 2 bytes of each i32
6769 // operand.
6770 uint64_t Op0Bytes = 0x10;
6771 uint64_t Op1Bytes = 0x54;
6772
6773 std::pair<SDValue *, uint64_t *> OpData[2] = {{&Op0, &Op0Bytes},
6774 {&Op1, &Op1Bytes}};
6775
6776 // Check that each operand is an i16, truncated from an i32 operand. We'll
6777 // select individual bytes from those original operands. Optionally, fold in a
6778 // shift right of that original operand.
6779 for (auto &[Op, OpBytes] : OpData) {
6780 // Eat up any bitcast
6781 if (Op->getOpcode() == ISD::BITCAST)
6782 *Op = Op->getOperand(i: 0);
6783
6784 if (!(Op->getValueType() == MVT::i16 && Op->getOpcode() == ISD::TRUNCATE &&
6785 Op->getOperand(i: 0).getValueType() == MVT::i32))
6786 return SDValue();
6787
6788 // If the truncate has multiple uses, this optimization can increase
6789 // register pressure
6790 if (!Op->hasOneUse())
6791 return SDValue();
6792
6793 *Op = Op->getOperand(i: 0);
6794
6795 // Optionally, fold in a shift-right of the original operand and let permute
6796 // pick the two higher bytes of the original value directly.
6797 if (Op->getOpcode() == ISD::SRL && isa<ConstantSDNode>(Val: Op->getOperand(i: 1))) {
6798 if (cast<ConstantSDNode>(Val: Op->getOperand(i: 1))->getZExtValue() == 16) {
6799 // Shift the PRMT byte selector to pick upper bytes from each respective
6800 // value, instead of the lower ones: 0x10 -> 0x32, 0x54 -> 0x76
6801 assert((*OpBytes == 0x10 || *OpBytes == 0x54) &&
6802 "PRMT selector values out of range");
6803 *OpBytes += 0x22;
6804 *Op = Op->getOperand(i: 0);
6805 }
6806 }
6807 }
6808
6809 SDLoc DL(N);
6810 auto &DAG = DCI.DAG;
6811
6812 auto PRMT =
6813 getPRMT(A: DAG.getBitcast(VT: MVT::i32, V: Op0), B: DAG.getBitcast(VT: MVT::i32, V: Op1),
6814 Selector: (Op1Bytes << 8) | Op0Bytes, DL, DAG);
6815 return DAG.getBitcast(VT, V: PRMT);
6816}
6817
6818static SDValue combineADDRSPACECAST(SDNode *N,
6819 TargetLowering::DAGCombinerInfo &DCI) {
6820 auto *ASCN1 = cast<AddrSpaceCastSDNode>(Val: N);
6821
6822 if (auto *ASCN2 = dyn_cast<AddrSpaceCastSDNode>(Val: ASCN1->getOperand(Num: 0))) {
6823 assert(ASCN2->getDestAddressSpace() == ASCN1->getSrcAddressSpace());
6824
6825 // Fold asc[B -> A](asc[A -> B](x)) -> x
6826 if (ASCN1->getDestAddressSpace() == ASCN2->getSrcAddressSpace())
6827 return ASCN2->getOperand(Num: 0);
6828 }
6829
6830 return SDValue();
6831}
6832
6833// Given a constant selector value and a prmt mode, return the selector value
6834// normalized to the generic prmt mode. See the PTX ISA documentation for more
6835// details:
6836// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt
6837static APInt getPRMTSelector(const APInt &Selector, unsigned Mode) {
6838 assert(Selector.getBitWidth() == 32 && "PRMT must have i32 operands");
6839
6840 if (Mode == NVPTX::PTXPrmtMode::NONE)
6841 return Selector;
6842
6843 const unsigned V = Selector.trunc(width: 2).getZExtValue();
6844
6845 const auto GetSelector = [](unsigned S0, unsigned S1, unsigned S2,
6846 unsigned S3) {
6847 return APInt(32, S0 | (S1 << 4) | (S2 << 8) | (S3 << 12));
6848 };
6849
6850 switch (Mode) {
6851 case NVPTX::PTXPrmtMode::F4E:
6852 return GetSelector(V, V + 1, V + 2, V + 3);
6853 case NVPTX::PTXPrmtMode::B4E:
6854 return GetSelector(V, (V - 1) & 7, (V - 2) & 7, (V - 3) & 7);
6855 case NVPTX::PTXPrmtMode::RC8:
6856 return GetSelector(V, V, V, V);
6857 case NVPTX::PTXPrmtMode::ECL:
6858 return GetSelector(V, std::max(a: V, b: 1U), std::max(a: V, b: 2U), 3U);
6859 case NVPTX::PTXPrmtMode::ECR:
6860 return GetSelector(0, std::min(a: V, b: 1U), std::min(a: V, b: 2U), V);
6861 case NVPTX::PTXPrmtMode::RC16: {
6862 unsigned V1 = (V & 1) << 1;
6863 return GetSelector(V1, V1 + 1, V1, V1 + 1);
6864 }
6865 default:
6866 llvm_unreachable("Invalid PRMT mode");
6867 }
6868}
6869
6870static APInt computePRMT(APInt A, APInt B, APInt Selector, unsigned Mode) {
6871 assert(A.getBitWidth() == 32 && B.getBitWidth() == 32 &&
6872 Selector.getBitWidth() == 32 && "PRMT must have i32 operands");
6873 // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
6874 APInt BitField = B.concat(NewLSB: A);
6875 APInt SelectorVal = getPRMTSelector(Selector, Mode);
6876 APInt Result(32, 0);
6877 for (unsigned I : llvm::seq(Size: 4U)) {
6878 APInt Sel = SelectorVal.extractBits(numBits: 4, bitPosition: I * 4);
6879 unsigned Idx = Sel.getLoBits(numBits: 3).getZExtValue();
6880 unsigned Sign = Sel.getHiBits(numBits: 1).getZExtValue();
6881 APInt Byte = BitField.extractBits(numBits: 8, bitPosition: Idx * 8);
6882 if (Sign)
6883 Byte = Byte.ashr(ShiftAmt: 8);
6884 Result.insertBits(SubBits: Byte, bitPosition: I * 8);
6885 }
6886 return Result;
6887}
6888
6889static SDValue combinePRMT(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
6890 CodeGenOptLevel OptLevel) {
6891 if (OptLevel == CodeGenOptLevel::None)
6892 return SDValue();
6893
6894 // Constant fold PRMT
6895 if (isa<ConstantSDNode>(Val: N->getOperand(Num: 0)) &&
6896 isa<ConstantSDNode>(Val: N->getOperand(Num: 1)) &&
6897 isa<ConstantSDNode>(Val: N->getOperand(Num: 2)))
6898 return DCI.DAG.getConstant(Val: computePRMT(A: N->getConstantOperandAPInt(Num: 0),
6899 B: N->getConstantOperandAPInt(Num: 1),
6900 Selector: N->getConstantOperandAPInt(Num: 2),
6901 Mode: N->getConstantOperandVal(Num: 3)),
6902 DL: SDLoc(N), VT: N->getValueType(ResNo: 0));
6903 return SDValue();
6904}
6905
6906// During call lowering we wrap the return values in a ProxyReg node which
6907// depend on the chain value produced by the completed call. This ensures that
6908// the full call is emitted in cases where libcalls are used to legalize
6909// operations. To improve the functioning of other DAG combines we pull all
6910// operations we can through one of these nodes, ensuring that the ProxyReg
6911// directly wraps a load. That is:
6912//
6913// (ProxyReg (zext (load retval0))) => (zext (ProxyReg (load retval0)))
6914//
6915static SDValue sinkProxyReg(SDValue R, SDValue Chain,
6916 TargetLowering::DAGCombinerInfo &DCI) {
6917 switch (R.getOpcode()) {
6918 case ISD::TRUNCATE:
6919 case ISD::ANY_EXTEND:
6920 case ISD::SIGN_EXTEND:
6921 case ISD::ZERO_EXTEND:
6922 case ISD::BITCAST: {
6923 if (SDValue V = sinkProxyReg(R: R.getOperand(i: 0), Chain, DCI))
6924 return DCI.DAG.getNode(Opcode: R.getOpcode(), DL: SDLoc(R), VT: R.getValueType(), Operand: V);
6925 return SDValue();
6926 }
6927 case ISD::SHL:
6928 case ISD::SRL:
6929 case ISD::SRA:
6930 case ISD::OR: {
6931 if (SDValue A = sinkProxyReg(R: R.getOperand(i: 0), Chain, DCI))
6932 if (SDValue B = sinkProxyReg(R: R.getOperand(i: 1), Chain, DCI))
6933 return DCI.DAG.getNode(Opcode: R.getOpcode(), DL: SDLoc(R), VT: R.getValueType(), N1: A, N2: B);
6934 return SDValue();
6935 }
6936 case ISD::Constant:
6937 return R;
6938 case ISD::LOAD:
6939 case NVPTXISD::LoadV2:
6940 case NVPTXISD::LoadV4: {
6941 return DCI.DAG.getNode(Opcode: NVPTXISD::ProxyReg, DL: SDLoc(R), VT: R.getValueType(),
6942 Ops: {Chain, R});
6943 }
6944 case ISD::BUILD_VECTOR: {
6945 if (DCI.isBeforeLegalize())
6946 return SDValue();
6947
6948 SmallVector<SDValue, 16> Ops;
6949 for (auto &Op : R->ops()) {
6950 SDValue V = sinkProxyReg(R: Op, Chain, DCI);
6951 if (!V)
6952 return SDValue();
6953 Ops.push_back(Elt: V);
6954 }
6955 return DCI.DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL: SDLoc(R), VT: R.getValueType(), Ops);
6956 }
6957 case ISD::EXTRACT_VECTOR_ELT: {
6958 if (DCI.isBeforeLegalize())
6959 return SDValue();
6960
6961 if (SDValue V = sinkProxyReg(R: R.getOperand(i: 0), Chain, DCI))
6962 return DCI.DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: SDLoc(R),
6963 VT: R.getValueType(), N1: V, N2: R.getOperand(i: 1));
6964 return SDValue();
6965 }
6966 default:
6967 return SDValue();
6968 }
6969}
6970
6971static unsigned getF16SubOpc(Intrinsic::ID AddIntrinsicID) {
6972 switch (AddIntrinsicID) {
6973 default:
6974 break;
6975 case Intrinsic::nvvm_add_rn_sat_f16:
6976 case Intrinsic::nvvm_add_rn_sat_v2f16:
6977 return NVPTXISD::SUB_RN_SAT;
6978 case Intrinsic::nvvm_add_rn_ftz_sat_f16:
6979 case Intrinsic::nvvm_add_rn_ftz_sat_v2f16:
6980 return NVPTXISD::SUB_RN_FTZ_SAT;
6981 }
6982 llvm_unreachable("Invalid F16 add intrinsic");
6983}
6984
6985static SDValue combineF16AddWithNeg(SDNode *N, SelectionDAG &DAG,
6986 Intrinsic::ID AddIntrinsicID) {
6987 SDValue Op1 = N->getOperand(Num: 1);
6988 SDValue Op2 = N->getOperand(Num: 2);
6989
6990 SDValue SubOp1, SubOp2;
6991
6992 if (Op1.getOpcode() == ISD::FNEG) {
6993 SubOp1 = Op2;
6994 SubOp2 = Op1.getOperand(i: 0);
6995 } else if (Op2.getOpcode() == ISD::FNEG) {
6996 SubOp1 = Op1;
6997 SubOp2 = Op2.getOperand(i: 0);
6998 } else {
6999 return SDValue();
7000 }
7001
7002 SDLoc DL(N);
7003 return DAG.getNode(Opcode: getF16SubOpc(AddIntrinsicID), DL, VT: N->getValueType(ResNo: 0),
7004 N1: SubOp1, N2: SubOp2);
7005}
7006
7007static SDValue combineIntrinsicWOChain(SDNode *N,
7008 TargetLowering::DAGCombinerInfo &DCI,
7009 const NVPTXSubtarget &STI) {
7010 unsigned IID = N->getConstantOperandVal(Num: 0);
7011
7012 switch (IID) {
7013 default:
7014 break;
7015 case Intrinsic::nvvm_add_rn_sat_f16:
7016 case Intrinsic::nvvm_add_rn_ftz_sat_f16:
7017 case Intrinsic::nvvm_add_rn_sat_v2f16:
7018 case Intrinsic::nvvm_add_rn_ftz_sat_v2f16:
7019 return combineF16AddWithNeg(N, DAG&: DCI.DAG, AddIntrinsicID: IID);
7020 }
7021 return SDValue();
7022}
7023
7024static SDValue combineProxyReg(SDNode *N,
7025 TargetLowering::DAGCombinerInfo &DCI) {
7026
7027 SDValue Chain = N->getOperand(Num: 0);
7028 SDValue Reg = N->getOperand(Num: 1);
7029
7030 // If the ProxyReg is not wrapping a load, try to pull the operations through
7031 // the ProxyReg.
7032 if (Reg.getOpcode() != ISD::LOAD) {
7033 if (SDValue V = sinkProxyReg(R: Reg, Chain, DCI))
7034 return V;
7035 }
7036
7037 return SDValue();
7038}
7039
7040SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
7041 DAGCombinerInfo &DCI) const {
7042 CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
7043 switch (N->getOpcode()) {
7044 default:
7045 break;
7046 case ISD::ADD:
7047 return PerformADDCombine(N, DCI, OptLevel);
7048 case ISD::ADDRSPACECAST:
7049 return combineADDRSPACECAST(N, DCI);
7050 case ISD::SIGN_EXTEND:
7051 case ISD::ZERO_EXTEND:
7052 return combineMulWide(N, DCI, OptLevel);
7053 case ISD::BUILD_VECTOR:
7054 return PerformBUILD_VECTORCombine(N, DCI);
7055 case ISD::EXTRACT_VECTOR_ELT:
7056 return PerformEXTRACTCombine(N, DCI);
7057 case ISD::FADD:
7058 return PerformFADDCombine(N, DCI, OptLevel);
7059 case ISD::FMA:
7060 case ISD::FMUL:
7061 case ISD::FSUB:
7062 return PerformScalarizeV2F32Op(N, DCI);
7063 case ISD::FMAXNUM:
7064 case ISD::FMINNUM:
7065 case ISD::FMAXIMUM:
7066 case ISD::FMINIMUM:
7067 case ISD::FMAXIMUMNUM:
7068 case ISD::FMINIMUMNUM:
7069 return PerformFMinMaxCombine(N, DCI, PTXVersion: STI.getPTXVersion(),
7070 SmVersion: STI.getSmVersion());
7071 case ISD::LOAD:
7072 case NVPTXISD::LoadV2:
7073 case NVPTXISD::LoadV4:
7074 return combineLOAD(N, DCI, STI);
7075 case ISD::MUL:
7076 return PerformMULCombine(N, DCI, OptLevel);
7077 case NVPTXISD::PRMT:
7078 return combinePRMT(N, DCI, OptLevel);
7079 case NVPTXISD::ProxyReg:
7080 return combineProxyReg(N, DCI);
7081 case ISD::SETCC:
7082 return PerformSETCCCombine(N, DCI, SmVersion: STI.getSmVersion());
7083 case ISD::SHL:
7084 return PerformSHLCombine(N, DCI, OptLevel);
7085 case ISD::SREM:
7086 case ISD::UREM:
7087 return PerformREMCombine(N, DCI, OptLevel);
7088 case ISD::STORE:
7089 case NVPTXISD::StoreV2:
7090 case NVPTXISD::StoreV4:
7091 return combineSTORE(N, DCI, STI);
7092 case ISD::SELECT:
7093 return PerformSELECTShiftCombine(N, DCI);
7094 case ISD::VSELECT:
7095 return PerformVSELECTCombine(N, DCI);
7096 case ISD::INTRINSIC_WO_CHAIN:
7097 return combineIntrinsicWOChain(N, DCI, STI);
7098 }
7099 return SDValue();
7100}
7101
7102static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG,
7103 SmallVectorImpl<SDValue> &Results) {
7104 // Handle bitcasting to v2i8 without hitting the default promotion
7105 // strategy which goes through stack memory.
7106 SDValue Op(Node, 0);
7107 EVT ToVT = Op->getValueType(ResNo: 0);
7108 if (ToVT != MVT::v2i8) {
7109 return;
7110 }
7111
7112 // Bitcast to i16 and unpack elements into a vector
7113 SDLoc DL(Node);
7114 SDValue AsInt = DAG.getBitcast(VT: MVT::i16, V: Op->getOperand(Num: 0));
7115 SDValue Vec0 = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MVT::i8, Operand: AsInt);
7116 SDValue Const8 = DAG.getConstant(Val: 8, DL, VT: MVT::i16);
7117 SDValue Vec1 =
7118 DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MVT::i8,
7119 Operand: DAG.getNode(Opcode: ISD::SRL, DL, VT: MVT::i16, Ops: {AsInt, Const8}));
7120 Results.push_back(
7121 Elt: DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: MVT::v2i8, Ops: {Vec0, Vec1}));
7122}
7123
7124static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
7125 SmallVectorImpl<SDValue> &Results) {
7126 SDValue Chain = N->getOperand(Num: 0);
7127 SDValue Intrin = N->getOperand(Num: 1);
7128 SDLoc DL(N);
7129
7130 // Get the intrinsic ID
7131 unsigned IntrinNo = Intrin.getNode()->getAsZExtVal();
7132 switch (IntrinNo) {
7133 default:
7134 return;
7135 case Intrinsic::nvvm_ldu_global_i:
7136 case Intrinsic::nvvm_ldu_global_f:
7137 case Intrinsic::nvvm_ldu_global_p: {
7138 EVT ResVT = N->getValueType(ResNo: 0);
7139
7140 if (ResVT.isVector()) {
7141 // Vector LDG/LDU
7142
7143 unsigned NumElts = ResVT.getVectorNumElements();
7144 EVT EltVT = ResVT.getVectorElementType();
7145
7146 // Since LDU/LDG are target nodes, we cannot rely on DAG type
7147 // legalization.
7148 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
7149 // loaded type to i16 and propagate the "real" type as the memory type.
7150 bool NeedTrunc = false;
7151 if (EltVT.getSizeInBits() < 16) {
7152 EltVT = MVT::i16;
7153 NeedTrunc = true;
7154 }
7155
7156 unsigned Opcode = 0;
7157 SDVTList LdResVTs;
7158
7159 switch (NumElts) {
7160 default:
7161 return;
7162 case 2:
7163 Opcode = NVPTXISD::LDUV2;
7164 LdResVTs = DAG.getVTList(VT1: EltVT, VT2: EltVT, VT3: MVT::Other);
7165 break;
7166 case 4: {
7167 Opcode = NVPTXISD::LDUV4;
7168 EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
7169 LdResVTs = DAG.getVTList(VTs: ListVTs);
7170 break;
7171 }
7172 }
7173
7174 SmallVector<SDValue, 8> OtherOps;
7175
7176 // Copy regular operands
7177
7178 OtherOps.push_back(Elt: Chain); // Chain
7179 // Skip operand 1 (intrinsic ID)
7180 // Others
7181 OtherOps.append(in_start: N->op_begin() + 2, in_end: N->op_end());
7182
7183 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
7184
7185 SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, dl: DL, VTList: LdResVTs, Ops: OtherOps,
7186 MemVT: MemSD->getMemoryVT(),
7187 MMO: MemSD->getMemOperand());
7188
7189 SmallVector<SDValue, 4> ScalarRes;
7190
7191 for (unsigned i = 0; i < NumElts; ++i) {
7192 SDValue Res = NewLD.getValue(R: i);
7193 if (NeedTrunc)
7194 Res =
7195 DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: ResVT.getVectorElementType(), Operand: Res);
7196 ScalarRes.push_back(Elt: Res);
7197 }
7198
7199 SDValue LoadChain = NewLD.getValue(R: NumElts);
7200
7201 SDValue BuildVec =
7202 DAG.getBuildVector(VT: ResVT, DL, Ops: ScalarRes);
7203
7204 Results.push_back(Elt: BuildVec);
7205 Results.push_back(Elt: LoadChain);
7206 } else {
7207 // i8 LDG/LDU
7208 assert(ResVT.isSimple() && ResVT.getSimpleVT().SimpleTy == MVT::i8 &&
7209 "Custom handling of non-i8 ldu/ldg?");
7210
7211 // Just copy all operands as-is
7212 SmallVector<SDValue, 4> Ops(N->ops());
7213
7214 // Force output to i16
7215 SDVTList LdResVTs = DAG.getVTList(VT1: MVT::i16, VT2: MVT::Other);
7216
7217 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
7218
7219 // We make sure the memory type is i8, which will be used during isel
7220 // to select the proper instruction.
7221 SDValue NewLD =
7222 DAG.getMemIntrinsicNode(Opcode: ISD::INTRINSIC_W_CHAIN, dl: DL, VTList: LdResVTs, Ops,
7223 MemVT: MVT::i8, MMO: MemSD->getMemOperand());
7224
7225 Results.push_back(Elt: DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MVT::i8,
7226 Operand: NewLD.getValue(R: 0)));
7227 Results.push_back(Elt: NewLD.getValue(R: 1));
7228 }
7229 return;
7230 }
7231
7232 case Intrinsic::nvvm_tcgen05_ld_16x64b_x4:
7233 case Intrinsic::nvvm_tcgen05_ld_16x64b_x8:
7234 case Intrinsic::nvvm_tcgen05_ld_16x64b_x16:
7235 case Intrinsic::nvvm_tcgen05_ld_16x64b_x32:
7236 case Intrinsic::nvvm_tcgen05_ld_16x64b_x64:
7237 case Intrinsic::nvvm_tcgen05_ld_16x64b_x128:
7238 case Intrinsic::nvvm_tcgen05_ld_32x32b_x4:
7239 case Intrinsic::nvvm_tcgen05_ld_32x32b_x8:
7240 case Intrinsic::nvvm_tcgen05_ld_32x32b_x16:
7241 case Intrinsic::nvvm_tcgen05_ld_32x32b_x32:
7242 case Intrinsic::nvvm_tcgen05_ld_32x32b_x64:
7243 case Intrinsic::nvvm_tcgen05_ld_32x32b_x128:
7244 case Intrinsic::nvvm_tcgen05_ld_16x128b_x2:
7245 case Intrinsic::nvvm_tcgen05_ld_16x128b_x4:
7246 case Intrinsic::nvvm_tcgen05_ld_16x128b_x8:
7247 case Intrinsic::nvvm_tcgen05_ld_16x128b_x16:
7248 case Intrinsic::nvvm_tcgen05_ld_16x128b_x32:
7249 case Intrinsic::nvvm_tcgen05_ld_16x128b_x64:
7250 case Intrinsic::nvvm_tcgen05_ld_16x256b_x1:
7251 case Intrinsic::nvvm_tcgen05_ld_16x256b_x2:
7252 case Intrinsic::nvvm_tcgen05_ld_16x256b_x4:
7253 case Intrinsic::nvvm_tcgen05_ld_16x256b_x8:
7254 case Intrinsic::nvvm_tcgen05_ld_16x256b_x16:
7255 case Intrinsic::nvvm_tcgen05_ld_16x256b_x32:
7256 if (auto Res = lowerTcgen05Ld(N, DAG)) {
7257 Results.push_back(Elt: Res->first);
7258 Results.push_back(Elt: Res->second);
7259 }
7260 return;
7261
7262 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x4:
7263 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x8:
7264 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x16:
7265 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x32:
7266 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x64:
7267 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x128:
7268 if (auto Res = lowerTcgen05Ld(N, DAG, /*HasOffset=*/true)) {
7269 Results.push_back(Elt: Res->first);
7270 Results.push_back(Elt: Res->second);
7271 }
7272 return;
7273
7274 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x8_i32:
7275 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x8_f32:
7276 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x64_i32:
7277 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x64_f32:
7278 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x4_i32:
7279 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x4_f32:
7280 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x32_i32:
7281 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x32_f32:
7282 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x16_i32:
7283 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x16_f32:
7284 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x128_i32:
7285 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x128_f32:
7286 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x8_i32:
7287 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x8_f32:
7288 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x64_i32:
7289 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x64_f32:
7290 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x4_i32:
7291 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x4_f32:
7292 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x32_i32:
7293 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x32_f32:
7294 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x16_i32:
7295 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x16_f32:
7296 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x128_i32:
7297 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x128_f32:
7298 if (auto Res = lowerTcgen05LdRed(N, DAG)) {
7299 Results.push_back(Elt: std::get<0>(t&: *Res));
7300 Results.push_back(Elt: std::get<1>(t&: *Res));
7301 Results.push_back(Elt: std::get<2>(t&: *Res));
7302 }
7303 return;
7304 }
7305}
7306
7307static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG,
7308 SmallVectorImpl<SDValue> &Results) {
7309 // Change the CopyFromReg to output 2 64-bit results instead of a 128-bit
7310 // result so that it can pass the legalization
7311 SDLoc DL(N);
7312 SDValue Chain = N->getOperand(Num: 0);
7313 SDValue Reg = N->getOperand(Num: 1);
7314 SDValue Glue = N->getOperand(Num: 2);
7315
7316 assert(Reg.getValueType() == MVT::i128 &&
7317 "Custom lowering for CopyFromReg with 128-bit reg only");
7318 SmallVector<EVT, 4> ResultsType = {MVT::i64, MVT::i64, N->getValueType(ResNo: 1),
7319 N->getValueType(ResNo: 2)};
7320 SmallVector<SDValue, 3> NewOps = {Chain, Reg, Glue};
7321
7322 SDValue NewValue = DAG.getNode(Opcode: ISD::CopyFromReg, DL, ResultTys: ResultsType, Ops: NewOps);
7323 SDValue Pair = DAG.getNode(Opcode: ISD::BUILD_PAIR, DL, VT: MVT::i128,
7324 Ops: {NewValue.getValue(R: 0), NewValue.getValue(R: 1)});
7325
7326 Results.push_back(Elt: Pair);
7327 Results.push_back(Elt: NewValue.getValue(R: 2));
7328 Results.push_back(Elt: NewValue.getValue(R: 3));
7329}
7330
7331static void replaceProxyReg(SDNode *N, SelectionDAG &DAG,
7332 const TargetLowering &TLI,
7333 SmallVectorImpl<SDValue> &Results) {
7334 SDValue Chain = N->getOperand(Num: 0);
7335 SDValue Reg = N->getOperand(Num: 1);
7336
7337 MVT VT = TLI.getRegisterType(Context&: *DAG.getContext(), VT: Reg.getValueType());
7338
7339 SDValue NewReg = DAG.getAnyExtOrTrunc(Op: Reg, DL: SDLoc(N), VT);
7340 SDValue NewProxy =
7341 DAG.getNode(Opcode: NVPTXISD::ProxyReg, DL: SDLoc(N), VT, Ops: {Chain, NewReg});
7342 SDValue Res = DAG.getAnyExtOrTrunc(Op: NewProxy, DL: SDLoc(N), VT: N->getValueType(ResNo: 0));
7343
7344 Results.push_back(Elt: Res);
7345}
7346
7347static void replaceAtomicSwap128(SDNode *N, SelectionDAG &DAG,
7348 const NVPTXSubtarget &STI,
7349 SmallVectorImpl<SDValue> &Results) {
7350 assert(N->getValueType(0) == MVT::i128 &&
7351 "Custom lowering for atomic128 only supports i128");
7352
7353 AtomicSDNode *AN = cast<AtomicSDNode>(Val: N);
7354 SDLoc dl(N);
7355
7356 if (!STI.hasAtomSwap128()) {
7357 DAG.getContext()->diagnose(DI: DiagnosticInfoUnsupported(
7358 DAG.getMachineFunction().getFunction(),
7359 "Support for b128 atomics introduced in PTX ISA version 8.3 and "
7360 "requires target sm_90.",
7361 dl.getDebugLoc()));
7362
7363 Results.push_back(Elt: DAG.getUNDEF(VT: MVT::i128));
7364 Results.push_back(Elt: AN->getOperand(Num: 0)); // Chain
7365 return;
7366 }
7367
7368 SmallVector<SDValue, 6> Ops;
7369 Ops.push_back(Elt: AN->getOperand(Num: 0)); // Chain
7370 Ops.push_back(Elt: AN->getOperand(Num: 1)); // Ptr
7371 for (const auto &Op : AN->ops().drop_front(N: 2)) {
7372 // Low part
7373 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::EXTRACT_ELEMENT, DL: dl, VT: MVT::i64, N1: Op,
7374 N2: DAG.getIntPtrConstant(Val: 0, DL: dl)));
7375 // High part
7376 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::EXTRACT_ELEMENT, DL: dl, VT: MVT::i64, N1: Op,
7377 N2: DAG.getIntPtrConstant(Val: 1, DL: dl)));
7378 }
7379 unsigned Opcode = N->getOpcode() == ISD::ATOMIC_SWAP
7380 ? NVPTXISD::ATOMIC_SWAP_B128
7381 : NVPTXISD::ATOMIC_CMP_SWAP_B128;
7382 SDVTList Tys = DAG.getVTList(VT1: MVT::i64, VT2: MVT::i64, VT3: MVT::Other);
7383 SDValue Result = DAG.getMemIntrinsicNode(Opcode, dl, VTList: Tys, Ops, MemVT: MVT::i128,
7384 MMO: AN->getMemOperand());
7385 Results.push_back(Elt: DAG.getNode(Opcode: ISD::BUILD_PAIR, DL: dl, VT: MVT::i128,
7386 Ops: {Result.getValue(R: 0), Result.getValue(R: 1)}));
7387 Results.push_back(Elt: Result.getValue(R: 2));
7388}
7389
7390void NVPTXTargetLowering::ReplaceNodeResults(
7391 SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
7392 switch (N->getOpcode()) {
7393 default:
7394 report_fatal_error(reason: "Unhandled custom legalization");
7395 case ISD::BITCAST:
7396 ReplaceBITCAST(Node: N, DAG, Results);
7397 return;
7398 case ISD::LOAD:
7399 case ISD::MLOAD:
7400 replaceLoadVector(N, DAG, Results, STI);
7401 return;
7402 case ISD::INTRINSIC_W_CHAIN:
7403 ReplaceINTRINSIC_W_CHAIN(N, DAG, Results);
7404 return;
7405 case ISD::CopyFromReg:
7406 ReplaceCopyFromReg_128(N, DAG, Results);
7407 return;
7408 case NVPTXISD::ProxyReg:
7409 replaceProxyReg(N, DAG, TLI: *this, Results);
7410 return;
7411 case ISD::ATOMIC_CMP_SWAP:
7412 case ISD::ATOMIC_SWAP:
7413 replaceAtomicSwap128(N, DAG, STI, Results);
7414 return;
7415 }
7416}
7417
7418NVPTXTargetLowering::AtomicExpansionKind
7419NVPTXTargetLowering::shouldExpandAtomicRMWInIR(const AtomicRMWInst *AI) const {
7420 Type *Ty = AI->getValOperand()->getType();
7421
7422 if (AI->isFloatingPointOperation()) {
7423 if (AI->getOperation() == AtomicRMWInst::BinOp::FAdd) {
7424 if (Ty->isHalfTy() && STI.getSmVersion() >= 70 &&
7425 STI.getPTXVersion() >= 63)
7426 return AtomicExpansionKind::None;
7427 if (Ty->isBFloatTy() && STI.getSmVersion() >= 90 &&
7428 STI.getPTXVersion() >= 78)
7429 return AtomicExpansionKind::None;
7430 if (Ty->isFloatTy())
7431 return AtomicExpansionKind::None;
7432 if (Ty->isDoubleTy() && STI.hasAtomAddF64())
7433 return AtomicExpansionKind::None;
7434 }
7435 return AtomicExpansionKind::CmpXChg;
7436 }
7437
7438 assert(Ty->isIntegerTy() && "Ty should be integer at this point");
7439 const unsigned BitWidth = cast<IntegerType>(Val: Ty)->getBitWidth();
7440
7441 switch (AI->getOperation()) {
7442 default:
7443 return AtomicExpansionKind::CmpXChg;
7444 case AtomicRMWInst::BinOp::Xchg:
7445 if (BitWidth == 128)
7446 return AtomicExpansionKind::None;
7447 [[fallthrough]];
7448 case AtomicRMWInst::BinOp::And:
7449 case AtomicRMWInst::BinOp::Or:
7450 case AtomicRMWInst::BinOp::Xor:
7451 switch (BitWidth) {
7452 case 8:
7453 case 16:
7454 return AtomicExpansionKind::CmpXChg;
7455 case 32:
7456 return AtomicExpansionKind::None;
7457 case 64:
7458 if (STI.hasAtomBitwise64())
7459 return AtomicExpansionKind::None;
7460 return AtomicExpansionKind::CmpXChg;
7461 case 128:
7462 return AtomicExpansionKind::CmpXChg;
7463 default:
7464 llvm_unreachable("unsupported width encountered");
7465 }
7466 case AtomicRMWInst::BinOp::Add:
7467 case AtomicRMWInst::BinOp::Sub:
7468 case AtomicRMWInst::BinOp::Max:
7469 case AtomicRMWInst::BinOp::Min:
7470 case AtomicRMWInst::BinOp::UMax:
7471 case AtomicRMWInst::BinOp::UMin:
7472 switch (BitWidth) {
7473 case 8:
7474 case 16:
7475 return AtomicExpansionKind::CmpXChg;
7476 case 32:
7477 return AtomicExpansionKind::None;
7478 case 64:
7479 if (STI.hasAtomMinMax64())
7480 return AtomicExpansionKind::None;
7481 return AtomicExpansionKind::CmpXChg;
7482 case 128:
7483 return AtomicExpansionKind::CmpXChg;
7484 default:
7485 llvm_unreachable("unsupported width encountered");
7486 }
7487 case AtomicRMWInst::BinOp::UIncWrap:
7488 case AtomicRMWInst::BinOp::UDecWrap:
7489 switch (BitWidth) {
7490 case 32:
7491 return AtomicExpansionKind::None;
7492 case 8:
7493 case 16:
7494 case 64:
7495 case 128:
7496 return AtomicExpansionKind::CmpXChg;
7497 default:
7498 llvm_unreachable("unsupported width encountered");
7499 }
7500 }
7501
7502 return AtomicExpansionKind::CmpXChg;
7503}
7504
7505bool NVPTXTargetLowering::shouldInsertFencesForAtomic(
7506 const Instruction *I) const {
7507 // This function returns true iff the operation is emulated using a CAS-loop,
7508 // or if it has the memory order seq_cst (which is not natively supported in
7509 // the PTX `atom` instruction).
7510 //
7511 // atomicrmw and cmpxchg instructions not efficiently supported by PTX
7512 // are lowered to CAS emulation loops that preserve their memory order,
7513 // syncscope, and volatile semantics. For PTX, it is more efficient to use
7514 // atom.cas.relaxed.sco instructions within the loop, and fences before and
7515 // after the loop to restore order.
7516 //
7517 // Atomic instructions efficiently supported by PTX are lowered to
7518 // `atom.<op>.<sem>.<scope` instruction with their corresponding memory order
7519 // and scope. Since PTX does not support seq_cst, we emulate it by lowering to
7520 // a fence.sc followed by an atom according to the PTX atomics ABI
7521 // https://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/atomic-abi.html
7522 if (auto *CI = dyn_cast<AtomicCmpXchgInst>(Val: I))
7523 return (cast<IntegerType>(Val: CI->getCompareOperand()->getType())
7524 ->getBitWidth() < STI.getMinCmpXchgSizeInBits()) ||
7525 CI->getMergedOrdering() == AtomicOrdering::SequentiallyConsistent;
7526 if (auto *RI = dyn_cast<AtomicRMWInst>(Val: I))
7527 return shouldExpandAtomicRMWInIR(AI: RI) == AtomicExpansionKind::CmpXChg ||
7528 RI->getOrdering() == AtomicOrdering::SequentiallyConsistent;
7529 return false;
7530}
7531
7532AtomicOrdering NVPTXTargetLowering::atomicOperationOrderAfterFenceSplit(
7533 const Instruction *I) const {
7534 // If the operation is emulated by a CAS-loop, we lower the instruction to
7535 // atom.<op>.relaxed, since AtomicExpandPass will insert fences for enforcing
7536 // the correct memory ordering around the CAS loop.
7537 //
7538 // When the operation is not emulated, but the memory order is seq_cst,
7539 // we must lower to "fence.sc.<scope>; atom.<op>.acquire.<scope>;" to conform
7540 // to the PTX atomics ABI.
7541 // https://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/atomic-abi.html
7542 // For such cases, emitLeadingFence() will separately insert the leading
7543 // "fence.sc.<scope>;". Here, we only set the memory order to acquire.
7544 //
7545 // Otherwise, the operation is not emulated, and the memory order is not
7546 // seq_cst. In this case, the LLVM memory order is natively supported by the
7547 // PTX `atom` instruction, and we just lower to the corresponding
7548 // `atom.<op>.relaxed|acquire|release|acq_rel". For such cases, this function
7549 // will NOT be called.
7550 // prerequisite: shouldInsertFencesForAtomic() should have returned `true` for
7551 // I before its memory order was modified.
7552 if (auto *CI = dyn_cast<AtomicCmpXchgInst>(Val: I);
7553 CI && CI->getMergedOrdering() == AtomicOrdering::SequentiallyConsistent &&
7554 cast<IntegerType>(Val: CI->getCompareOperand()->getType())->getBitWidth() >=
7555 STI.getMinCmpXchgSizeInBits())
7556 return AtomicOrdering::Acquire;
7557 else if (auto *RI = dyn_cast<AtomicRMWInst>(Val: I);
7558 RI && RI->getOrdering() == AtomicOrdering::SequentiallyConsistent &&
7559 shouldExpandAtomicRMWInIR(AI: RI) == AtomicExpansionKind::None)
7560 return AtomicOrdering::Acquire;
7561
7562 return AtomicOrdering::Monotonic;
7563}
7564
7565Instruction *NVPTXTargetLowering::emitLeadingFence(IRBuilderBase &Builder,
7566 Instruction *Inst,
7567 AtomicOrdering Ord) const {
7568 // prerequisite: shouldInsertFencesForAtomic() should have returned `true` for
7569 // `Inst` before its memory order was modified. We cannot enforce this with an
7570 // assert, because AtomicExpandPass will have modified the memory order
7571 // between the initial call to shouldInsertFencesForAtomic() and the call to
7572 // this function.
7573 if (!isa<AtomicCmpXchgInst>(Val: Inst) && !isa<AtomicRMWInst>(Val: Inst))
7574 return TargetLoweringBase::emitLeadingFence(Builder, Inst, Ord);
7575
7576 // Specialize for cmpxchg and atomicrmw
7577 auto SSID = getAtomicSyncScopeID(I: Inst);
7578 assert(SSID.has_value() && "Expected an atomic operation");
7579
7580 if (isReleaseOrStronger(AO: Ord))
7581 return Builder.CreateFence(Ordering: Ord == AtomicOrdering::SequentiallyConsistent
7582 ? AtomicOrdering::SequentiallyConsistent
7583 : AtomicOrdering::Release,
7584 SSID: SSID.value());
7585
7586 return nullptr;
7587}
7588
7589Instruction *NVPTXTargetLowering::emitTrailingFence(IRBuilderBase &Builder,
7590 Instruction *Inst,
7591 AtomicOrdering Ord) const {
7592 // prerequisite: shouldInsertFencesForAtomic() should have returned `true` for
7593 // `Inst` before its memory order was modified. See `emitLeadingFence` for why
7594 // this cannot be enforced with an assert. Specialize for cmpxchg and
7595 // atomicrmw
7596 auto *CI = dyn_cast<AtomicCmpXchgInst>(Val: Inst);
7597 auto *RI = dyn_cast<AtomicRMWInst>(Val: Inst);
7598 if (!CI && !RI)
7599 return TargetLoweringBase::emitTrailingFence(Builder, Inst, Ord);
7600
7601 auto SSID = getAtomicSyncScopeID(I: Inst);
7602 assert(SSID.has_value() && "Expected an atomic operation");
7603
7604 bool IsEmulated =
7605 CI ? cast<IntegerType>(Val: CI->getCompareOperand()->getType())
7606 ->getBitWidth() < STI.getMinCmpXchgSizeInBits()
7607 : shouldExpandAtomicRMWInIR(AI: RI) == AtomicExpansionKind::CmpXChg;
7608
7609 if (isAcquireOrStronger(AO: Ord) && IsEmulated)
7610 return Builder.CreateFence(Ordering: AtomicOrdering::Acquire, SSID: SSID.value());
7611
7612 return nullptr;
7613}
7614
7615// Rather than default to SINT when both UINT and SINT are custom, we only
7616// change the opcode when UINT is not legal and SINT is. UINT is preferred when
7617// both are custom since unsigned CVT instructions can lead to slightly better
7618// SASS code with fewer instructions.
7619unsigned NVPTXTargetLowering::getPreferredFPToIntOpcode(unsigned Op, EVT FromVT,
7620 EVT ToVT) const {
7621 if (isOperationLegal(Op, VT: ToVT))
7622 return Op;
7623 switch (Op) {
7624 case ISD::FP_TO_UINT:
7625 if (isOperationLegal(Op: ISD::FP_TO_SINT, VT: ToVT))
7626 return ISD::FP_TO_SINT;
7627 break;
7628 case ISD::STRICT_FP_TO_UINT:
7629 if (isOperationLegal(Op: ISD::STRICT_FP_TO_SINT, VT: ToVT))
7630 return ISD::STRICT_FP_TO_SINT;
7631 break;
7632 case ISD::VP_FP_TO_UINT:
7633 if (isOperationLegal(Op: ISD::VP_FP_TO_SINT, VT: ToVT))
7634 return ISD::VP_FP_TO_SINT;
7635 break;
7636 default:
7637 break;
7638 }
7639 return Op;
7640}
7641
7642// Pin NVPTXTargetObjectFile's vtables to this file.
7643NVPTXTargetObjectFile::~NVPTXTargetObjectFile() = default;
7644
7645MCSection *NVPTXTargetObjectFile::SelectSectionForGlobal(
7646 const GlobalObject *GO, SectionKind Kind, const TargetMachine &TM) const {
7647 return getDataSection();
7648}
7649
7650static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
7651 const SelectionDAG &DAG, unsigned Depth) {
7652 SDValue A = Op.getOperand(i: 0);
7653 SDValue B = Op.getOperand(i: 1);
7654 ConstantSDNode *Selector = dyn_cast<ConstantSDNode>(Val: Op.getOperand(i: 2));
7655 unsigned Mode = Op.getConstantOperandVal(i: 3);
7656
7657 if (!Selector)
7658 return;
7659
7660 KnownBits AKnown = DAG.computeKnownBits(Op: A, Depth);
7661 KnownBits BKnown = DAG.computeKnownBits(Op: B, Depth);
7662
7663 // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
7664 assert(AKnown.getBitWidth() == 32 && BKnown.getBitWidth() == 32 &&
7665 "PRMT must have i32 operands");
7666 assert(Known.getBitWidth() == 32 && "PRMT must have i32 result");
7667 KnownBits BitField = BKnown.concat(Lo: AKnown);
7668
7669 APInt SelectorVal = getPRMTSelector(Selector: Selector->getAPIntValue(), Mode);
7670 for (unsigned I : llvm::seq(Size: 4)) {
7671 APInt Sel = SelectorVal.extractBits(numBits: 4, bitPosition: I * 4);
7672 unsigned Idx = Sel.getLoBits(numBits: 3).getZExtValue();
7673 unsigned Sign = Sel.getHiBits(numBits: 1).getZExtValue();
7674 KnownBits Byte = BitField.extractBits(NumBits: 8, BitPosition: Idx * 8);
7675 if (Sign)
7676 Byte = KnownBits::ashr(LHS: Byte, RHS: 8);
7677 Known.insertBits(SubBits: Byte, BitPosition: I * 8);
7678 }
7679}
7680
7681static void computeKnownBitsForLoadV(const SDValue Op, KnownBits &Known) {
7682 MemSDNode *LD = cast<MemSDNode>(Val: Op);
7683
7684 // We can't do anything without knowing the sign bit.
7685 auto ExtType = LD->getConstantOperandVal(Num: LD->getNumOperands() - 1);
7686 if (ExtType == ISD::SEXTLOAD)
7687 return;
7688
7689 // ExtLoading to vector types is weird and may not work well with known bits.
7690 auto DestVT = LD->getValueType(ResNo: 0);
7691 if (DestVT.isVector())
7692 return;
7693
7694 assert(Known.getBitWidth() == DestVT.getSizeInBits());
7695 auto ElementBitWidth = NVPTXDAGToDAGISel::getFromTypeWidthForLoad(Mem: LD);
7696 Known.Zero.setHighBits(Known.getBitWidth() - ElementBitWidth);
7697}
7698
7699void NVPTXTargetLowering::computeKnownBitsForTargetNode(
7700 const SDValue Op, KnownBits &Known, const APInt &DemandedElts,
7701 const SelectionDAG &DAG, unsigned Depth) const {
7702 Known.resetAll();
7703
7704 switch (Op.getOpcode()) {
7705 case NVPTXISD::PRMT:
7706 computeKnownBitsForPRMT(Op, Known, DAG, Depth);
7707 break;
7708 case NVPTXISD::LoadV2:
7709 case NVPTXISD::LoadV4:
7710 case NVPTXISD::LoadV8:
7711 computeKnownBitsForLoadV(Op, Known);
7712 break;
7713 default:
7714 break;
7715 }
7716}
7717
7718static std::pair<APInt, APInt> getPRMTDemandedBits(const APInt &SelectorVal,
7719 const APInt &DemandedBits) {
7720 APInt DemandedLHS = APInt(32, 0);
7721 APInt DemandedRHS = APInt(32, 0);
7722
7723 for (unsigned I : llvm::seq(Size: 4)) {
7724 if (DemandedBits.extractBits(numBits: 8, bitPosition: I * 8).isZero())
7725 continue;
7726
7727 APInt Sel = SelectorVal.extractBits(numBits: 4, bitPosition: I * 4);
7728 unsigned Idx = Sel.getLoBits(numBits: 3).getZExtValue();
7729 unsigned Sign = Sel.getHiBits(numBits: 1).getZExtValue();
7730
7731 APInt &Src = Idx < 4 ? DemandedLHS : DemandedRHS;
7732 unsigned ByteStart = (Idx % 4) * 8;
7733 if (Sign)
7734 Src.setBit(ByteStart + 7);
7735 else
7736 Src.setBits(loBit: ByteStart, hiBit: ByteStart + 8);
7737 }
7738
7739 return {DemandedLHS, DemandedRHS};
7740}
7741
7742// Replace undef with 0 as this is easier for other optimizations such as
7743// known bits.
7744static SDValue canonicalizePRMTInput(SDValue Op, SelectionDAG &DAG) {
7745 if (!Op)
7746 return SDValue();
7747 if (Op.isUndef())
7748 return DAG.getConstant(Val: 0, DL: SDLoc(), VT: MVT::i32);
7749 return Op;
7750}
7751
7752static SDValue simplifyDemandedBitsForPRMT(SDValue PRMT,
7753 const APInt &DemandedBits,
7754 SelectionDAG &DAG,
7755 const TargetLowering &TLI,
7756 unsigned Depth) {
7757 assert(PRMT.getOpcode() == NVPTXISD::PRMT);
7758 SDValue Op0 = PRMT.getOperand(i: 0);
7759 SDValue Op1 = PRMT.getOperand(i: 1);
7760 auto *SelectorConst = dyn_cast<ConstantSDNode>(Val: PRMT.getOperand(i: 2));
7761 if (!SelectorConst)
7762 return SDValue();
7763
7764 unsigned Mode = PRMT.getConstantOperandVal(i: 3);
7765 const APInt Selector = getPRMTSelector(Selector: SelectorConst->getAPIntValue(), Mode);
7766
7767 // Try to simplify the PRMT to one of the inputs if the used bytes are all
7768 // from the same input in the correct order.
7769 const unsigned LeadingBytes = DemandedBits.countLeadingZeros() / 8;
7770 const unsigned SelBits = (4 - LeadingBytes) * 4;
7771 if (Selector.getLoBits(numBits: SelBits) == APInt(32, 0x3210).getLoBits(numBits: SelBits))
7772 return Op0;
7773 if (Selector.getLoBits(numBits: SelBits) == APInt(32, 0x7654).getLoBits(numBits: SelBits))
7774 return Op1;
7775
7776 auto [DemandedLHS, DemandedRHS] = getPRMTDemandedBits(SelectorVal: Selector, DemandedBits);
7777
7778 // Attempt to avoid multi-use ops if we don't need anything from them.
7779 SDValue DemandedOp0 =
7780 TLI.SimplifyMultipleUseDemandedBits(Op: Op0, DemandedBits: DemandedLHS, DAG, Depth: Depth + 1);
7781 SDValue DemandedOp1 =
7782 TLI.SimplifyMultipleUseDemandedBits(Op: Op1, DemandedBits: DemandedRHS, DAG, Depth: Depth + 1);
7783
7784 DemandedOp0 = canonicalizePRMTInput(Op: DemandedOp0, DAG);
7785 DemandedOp1 = canonicalizePRMTInput(Op: DemandedOp1, DAG);
7786 if ((DemandedOp0 && DemandedOp0 != Op0) ||
7787 (DemandedOp1 && DemandedOp1 != Op1)) {
7788 Op0 = DemandedOp0 ? DemandedOp0 : Op0;
7789 Op1 = DemandedOp1 ? DemandedOp1 : Op1;
7790 return getPRMT(A: Op0, B: Op1, Selector: Selector.getZExtValue(), DL: SDLoc(PRMT), DAG);
7791 }
7792
7793 return SDValue();
7794}
7795
7796bool NVPTXTargetLowering::SimplifyDemandedBitsForTargetNode(
7797 SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
7798 KnownBits &Known, TargetLoweringOpt &TLO, unsigned Depth) const {
7799 Known.resetAll();
7800
7801 switch (Op.getOpcode()) {
7802 case NVPTXISD::PRMT:
7803 if (SDValue Result = simplifyDemandedBitsForPRMT(PRMT: Op, DemandedBits, DAG&: TLO.DAG,
7804 TLI: *this, Depth)) {
7805 TLO.CombineTo(O: Op, N: Result);
7806 return true;
7807 }
7808 break;
7809 default:
7810 break;
7811 }
7812
7813 computeKnownBitsForTargetNode(Op, Known, DemandedElts, DAG: TLO.DAG, Depth);
7814 return false;
7815}
7816