1//===-- NVPTXISelLowering.cpp - NVPTX DAG Lowering Implementation ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the interfaces that NVPTX uses to lower LLVM code into a
10// selection DAG.
11//
12//===----------------------------------------------------------------------===//
13
14#include "NVPTXISelLowering.h"
15#include "MCTargetDesc/NVPTXBaseInfo.h"
16#include "NVPTX.h"
17#include "NVPTXISelDAGToDAG.h"
18#include "NVPTXSelectionDAGInfo.h"
19#include "NVPTXSubtarget.h"
20#include "NVPTXTargetMachine.h"
21#include "NVPTXTargetObjectFile.h"
22#include "NVPTXUtilities.h"
23#include "llvm/ADT/APFloat.h"
24#include "llvm/ADT/APInt.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallVector.h"
27#include "llvm/ADT/StringRef.h"
28#include "llvm/CodeGen/Analysis.h"
29#include "llvm/CodeGen/ISDOpcodes.h"
30#include "llvm/CodeGen/MachineFunction.h"
31#include "llvm/CodeGen/MachineJumpTableInfo.h"
32#include "llvm/CodeGen/MachineMemOperand.h"
33#include "llvm/CodeGen/SDPatternMatch.h"
34#include "llvm/CodeGen/SelectionDAG.h"
35#include "llvm/CodeGen/SelectionDAGNodes.h"
36#include "llvm/CodeGen/TargetCallingConv.h"
37#include "llvm/CodeGen/TargetLowering.h"
38#include "llvm/CodeGen/ValueTypes.h"
39#include "llvm/CodeGenTypes/MachineValueType.h"
40#include "llvm/IR/Argument.h"
41#include "llvm/IR/Attributes.h"
42#include "llvm/IR/Constants.h"
43#include "llvm/IR/DataLayout.h"
44#include "llvm/IR/DerivedTypes.h"
45#include "llvm/IR/DiagnosticInfo.h"
46#include "llvm/IR/FPEnv.h"
47#include "llvm/IR/Function.h"
48#include "llvm/IR/GlobalValue.h"
49#include "llvm/IR/IRBuilder.h"
50#include "llvm/IR/Instruction.h"
51#include "llvm/IR/Instructions.h"
52#include "llvm/IR/IntrinsicsNVPTX.h"
53#include "llvm/IR/Module.h"
54#include "llvm/IR/Type.h"
55#include "llvm/IR/Value.h"
56#include "llvm/Support/Alignment.h"
57#include "llvm/Support/AtomicOrdering.h"
58#include "llvm/Support/Casting.h"
59#include "llvm/Support/CodeGen.h"
60#include "llvm/Support/CommandLine.h"
61#include "llvm/Support/ErrorHandling.h"
62#include "llvm/Support/KnownBits.h"
63#include "llvm/Support/NVPTXAddrSpace.h"
64#include "llvm/Support/raw_ostream.h"
65#include "llvm/Target/TargetMachine.h"
66#include "llvm/Target/TargetOptions.h"
67#include <algorithm>
68#include <cassert>
69#include <cmath>
70#include <cstdint>
71#include <iterator>
72#include <optional>
73#include <string>
74#include <tuple>
75#include <utility>
76#include <vector>
77
78#define DEBUG_TYPE "nvptx-lower"
79
80using namespace llvm;
81
82static cl::opt<bool> sched4reg(
83 "nvptx-sched4reg",
84 cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(Val: false));
85
86static cl::opt<unsigned> FMAContractLevelOpt(
87 "nvptx-fma-level", cl::Hidden,
88 cl::desc("NVPTX Specific: FMA contraction (0: don't do it"
89 " 1: do it 2: do it aggressively"),
90 cl::init(Val: 2));
91
92static cl::opt<NVPTX::DivPrecisionLevel> UsePrecDivF32(
93 "nvptx-prec-divf32", cl::Hidden,
94 cl::desc(
95 "NVPTX Specific: Override the precision of the lowering for f32 fdiv"),
96 cl::values(
97 clEnumValN(NVPTX::DivPrecisionLevel::Approx, "0", "Use div.approx"),
98 clEnumValN(NVPTX::DivPrecisionLevel::Full, "1", "Use div.full"),
99 clEnumValN(NVPTX::DivPrecisionLevel::IEEE754, "2",
100 "Use IEEE Compliant F32 div.rnd if available (default)"),
101 clEnumValN(NVPTX::DivPrecisionLevel::IEEE754_NoFTZ, "3",
102 "Use IEEE Compliant F32 div.rnd if available, no FTZ")),
103 cl::init(Val: NVPTX::DivPrecisionLevel::IEEE754));
104
105static cl::opt<bool> UsePrecSqrtF32(
106 "nvptx-prec-sqrtf32", cl::Hidden,
107 cl::desc("NVPTX Specific: 0 use sqrt.approx, 1 use sqrt.rn."),
108 cl::init(Val: true));
109
110/// Whereas CUDA's implementation (see libdevice) uses ex2.approx for exp2(), it
111/// does NOT use lg2.approx for log2, so this is disabled by default.
112static cl::opt<bool> UseApproxLog2F32(
113 "nvptx-approx-log2f32",
114 cl::desc("NVPTX Specific: whether to use lg2.approx for log2"),
115 cl::init(Val: false));
116
117static cl::opt<bool> ForceMinByValParamAlign(
118 "nvptx-force-min-byval-param-align", cl::Hidden,
119 cl::desc("NVPTX Specific: force 4-byte minimal alignment for byval"
120 " params of device functions."),
121 cl::init(Val: false));
122
123NVPTX::DivPrecisionLevel
124NVPTXTargetLowering::getDivF32Level(const MachineFunction &MF,
125 const SDNode &N) const {
126 // If nvptx-prec-div32=N is used on the command-line, always honor it
127 if (UsePrecDivF32.getNumOccurrences() > 0)
128 return UsePrecDivF32;
129
130 const SDNodeFlags Flags = N.getFlags();
131 if (Flags.hasApproximateFuncs())
132 return NVPTX::DivPrecisionLevel::Approx;
133
134 return NVPTX::DivPrecisionLevel::IEEE754;
135}
136
137bool NVPTXTargetLowering::usePrecSqrtF32(const SDNode *N) const {
138 // If nvptx-prec-sqrtf32 is used on the command-line, always honor it
139 if (UsePrecSqrtF32.getNumOccurrences() > 0)
140 return UsePrecSqrtF32;
141
142 if (N) {
143 const SDNodeFlags Flags = N->getFlags();
144 if (Flags.hasApproximateFuncs())
145 return false;
146 }
147
148 return true;
149}
150
151bool NVPTXTargetLowering::useF32FTZ(const MachineFunction &MF) const {
152 return MF.getDenormalMode(FPType: APFloat::IEEEsingle()).Output ==
153 DenormalMode::PreserveSign;
154}
155
156static bool IsPTXVectorType(MVT VT) {
157 switch (VT.SimpleTy) {
158 default:
159 return false;
160 case MVT::v2i1:
161 case MVT::v4i1:
162 case MVT::v2i8:
163 case MVT::v4i8:
164 case MVT::v8i8: // <2 x i8x4>
165 case MVT::v16i8: // <4 x i8x4>
166 case MVT::v2i16:
167 case MVT::v4i16:
168 case MVT::v8i16: // <4 x i16x2>
169 case MVT::v2i32:
170 case MVT::v4i32:
171 case MVT::v2i64:
172 case MVT::v2f16:
173 case MVT::v4f16:
174 case MVT::v8f16: // <4 x f16x2>
175 case MVT::v2bf16:
176 case MVT::v4bf16:
177 case MVT::v8bf16: // <4 x bf16x2>
178 case MVT::v2f32:
179 case MVT::v4f32:
180 case MVT::v2f64:
181 case MVT::v4i64:
182 case MVT::v4f64:
183 case MVT::v8i32:
184 case MVT::v8f32:
185 case MVT::v16f16: // <8 x f16x2>
186 case MVT::v16bf16: // <8 x bf16x2>
187 case MVT::v16i16: // <8 x i16x2>
188 case MVT::v32i8: // <8 x i8x4>
189 return true;
190 }
191}
192
193// When legalizing vector loads/stores, this function is called, which does two
194// things:
195// 1. Determines Whether the vector is something we want to custom lower,
196// std::nullopt is returned if we do not want to custom lower it.
197// 2. If we do want to handle it, returns two parameters:
198// - unsigned int NumElts - The number of elements in the final vector
199// - EVT EltVT - The type of the elements in the final vector
200static std::optional<std::pair<unsigned int, MVT>>
201getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI,
202 unsigned AddressSpace) {
203 const bool CanLowerTo256Bit = STI.has256BitVectorLoadStore(AS: AddressSpace);
204
205 if (CanLowerTo256Bit && VectorEVT.isScalarInteger() &&
206 VectorEVT.getSizeInBits() == 256)
207 return {{4, MVT::i64}};
208
209 if (!VectorEVT.isSimple())
210 return std::nullopt;
211 const MVT VectorVT = VectorEVT.getSimpleVT();
212
213 if (!VectorVT.isVector()) {
214 if (VectorVT == MVT::i128 || VectorVT == MVT::f128)
215 return {{2, MVT::i64}};
216 return std::nullopt;
217 }
218
219 const MVT EltVT = VectorVT.getVectorElementType();
220 const unsigned NumElts = VectorVT.getVectorNumElements();
221
222 // The size of the PTX virtual register that holds a packed type.
223 unsigned PackRegSize;
224
225 // We only handle "native" vector sizes for now, e.g. <4 x double> is not
226 // legal. We can (and should) split that into 2 stores of <2 x double> here
227 // but I'm leaving that as a TODO for now.
228 switch (VectorVT.SimpleTy) {
229 default:
230 return std::nullopt;
231
232 case MVT::v4i64:
233 case MVT::v4f64:
234 // This is a "native" vector type iff the address space is global and the
235 // target supports 256-bit loads/stores
236 if (!CanLowerTo256Bit)
237 return std::nullopt;
238 [[fallthrough]];
239 case MVT::v2i8:
240 case MVT::v2i64:
241 case MVT::v2f64:
242 // This is a "native" vector type
243 return std::pair(NumElts, EltVT);
244
245 case MVT::v16f16: // <8 x f16x2>
246 case MVT::v16bf16: // <8 x bf16x2>
247 case MVT::v16i16: // <8 x i16x2>
248 case MVT::v32i8: // <8 x i8x4>
249 // This can be upsized into a "native" vector type iff the address space is
250 // global and the target supports 256-bit loads/stores.
251 if (!CanLowerTo256Bit)
252 return std::nullopt;
253 [[fallthrough]];
254 case MVT::v2i16: // <1 x i16x2>
255 case MVT::v2f16: // <1 x f16x2>
256 case MVT::v2bf16: // <1 x bf16x2>
257 case MVT::v4i8: // <1 x i8x4>
258 case MVT::v4i16: // <2 x i16x2>
259 case MVT::v4f16: // <2 x f16x2>
260 case MVT::v4bf16: // <2 x bf16x2>
261 case MVT::v8i8: // <2 x i8x4>
262 case MVT::v8f16: // <4 x f16x2>
263 case MVT::v8bf16: // <4 x bf16x2>
264 case MVT::v8i16: // <4 x i16x2>
265 case MVT::v16i8: // <4 x i8x4>
266 PackRegSize = 32;
267 break;
268
269 case MVT::v8f32: // <4 x f32x2>
270 case MVT::v8i32: // <4 x i32x2>
271 // This is a "native" vector type iff the address space is global and the
272 // target supports 256-bit loads/stores
273 if (!CanLowerTo256Bit)
274 return std::nullopt;
275 [[fallthrough]];
276 case MVT::v2f32: // <1 x f32x2>
277 case MVT::v4f32: // <2 x f32x2>
278 case MVT::v2i32: // <1 x i32x2>
279 case MVT::v4i32: // <2 x i32x2>
280 if (!STI.hasF32x2Instructions())
281 return std::pair(NumElts, EltVT);
282 PackRegSize = 64;
283 break;
284 }
285
286 // If we reach here, then we can pack 2 or more elements into a single 32-bit
287 // or 64-bit PTX register and treat the vector as a new vector containing
288 // packed elements.
289
290 // Number of elements to pack in one word.
291 const unsigned NPerReg = PackRegSize / EltVT.getSizeInBits();
292
293 return std::pair(NumElts / NPerReg, MVT::getVectorVT(VT: EltVT, NumElements: NPerReg));
294}
295
296/// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
297/// legal-ish MVTs that compose it. Unlike ComputeValueVTs, this will legalize
298/// the types as required by the calling convention (with special handling for
299/// i8s).
300/// NOTE: This is a band-aid for code that expects ComputeValueVTs to return the
301/// same number of types as the Ins/Outs arrays in LowerFormalArguments,
302/// LowerCall, and LowerReturn.
303static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
304 LLVMContext &Ctx, CallingConv::ID CallConv,
305 Type *Ty, SmallVectorImpl<EVT> &ValueVTs,
306 SmallVectorImpl<uint64_t> &Offsets,
307 uint64_t StartingOffset = 0) {
308 SmallVector<EVT, 16> TempVTs;
309 SmallVector<uint64_t, 16> TempOffsets;
310 ComputeValueVTs(TLI, DL, Ty, ValueVTs&: TempVTs, /*MemVTs=*/nullptr, FixedOffsets: &TempOffsets,
311 StartingOffset);
312
313 for (const auto [VT, Off] : zip(t&: TempVTs, u&: TempOffsets)) {
314 MVT RegisterVT = TLI.getRegisterTypeForCallingConv(Context&: Ctx, CC: CallConv, VT);
315 unsigned NumRegs = TLI.getNumRegistersForCallingConv(Context&: Ctx, CC: CallConv, VT);
316
317 // Since we actually can load/store b8, we need to ensure that we'll use
318 // the original sized type for any i8s or i8 vectors.
319 if (VT.getScalarType() == MVT::i8) {
320 if (RegisterVT == MVT::i16)
321 RegisterVT = MVT::i8;
322 else if (RegisterVT == MVT::v2i16)
323 RegisterVT = MVT::v2i8;
324 else
325 assert(RegisterVT == MVT::v4i8 &&
326 "Expected v4i8, v2i16, or i16 for i8 RegisterVT");
327 }
328
329 // TODO: This is horribly incorrect for cases where the vector elements are
330 // not a multiple of bytes (ex i1) and legal or i8. However, this problem
331 // has existed for as long as NVPTX has and no one has complained, so we'll
332 // leave it for now.
333 for (unsigned I : seq(Size: NumRegs)) {
334 ValueVTs.push_back(Elt: RegisterVT);
335 Offsets.push_back(Elt: Off + I * RegisterVT.getStoreSize());
336 }
337 }
338}
339
340// We return an EVT that can hold N VTs
341// If the VT is a vector, the resulting EVT is a flat vector with the same
342// element type as VT's element type.
343static EVT getVectorizedVT(EVT VT, unsigned N, LLVMContext &C) {
344 if (N == 1)
345 return VT;
346
347 return VT.isVector() ? EVT::getVectorVT(Context&: C, VT: VT.getScalarType(),
348 NumElements: VT.getVectorNumElements() * N)
349 : EVT::getVectorVT(Context&: C, VT, NumElements: N);
350}
351
352static SDValue getExtractVectorizedValue(SDValue V, unsigned I, EVT VT,
353 const SDLoc &dl, SelectionDAG &DAG) {
354 if (V.getValueType() == VT) {
355 assert(I == 0 && "Index must be 0 for scalar value");
356 return V;
357 }
358
359 if (!VT.isVector())
360 return DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: dl, VT, N1: V,
361 N2: DAG.getVectorIdxConstant(Val: I, DL: dl));
362
363 return DAG.getNode(
364 Opcode: ISD::EXTRACT_SUBVECTOR, DL: dl, VT, N1: V,
365 N2: DAG.getVectorIdxConstant(Val: I * VT.getVectorNumElements(), DL: dl));
366}
367
368template <typename T>
369static inline SDValue getBuildVectorizedValue(unsigned N, const SDLoc &dl,
370 SelectionDAG &DAG, T GetElement) {
371 if (N == 1)
372 return GetElement(0);
373
374 SmallVector<SDValue, 8> Values;
375 for (const unsigned I : llvm::seq(Size: N)) {
376 SDValue Val = GetElement(I);
377 if (Val.getValueType().isVector())
378 DAG.ExtractVectorElements(Op: Val, Args&: Values);
379 else
380 Values.push_back(Elt: Val);
381 }
382
383 EVT VT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: Values[0].getValueType(),
384 NumElements: Values.size());
385 return DAG.getBuildVector(VT, DL: dl, Ops: Values);
386}
387
388/// PromoteScalarIntegerPTX
389/// Used to make sure the arguments/returns are suitable for passing
390/// and promote them to a larger size if they're not.
391///
392/// The promoted type is placed in \p PromoteVT if the function returns true.
393static EVT promoteScalarIntegerPTX(const EVT VT) {
394 if (VT.isScalarInteger()) {
395 switch (PowerOf2Ceil(A: VT.getFixedSizeInBits())) {
396 default:
397 llvm_unreachable(
398 "Promotion is not suitable for scalars of size larger than 64-bits");
399 case 1:
400 return MVT::i1;
401 case 2:
402 case 4:
403 case 8:
404 return MVT::i8;
405 case 16:
406 return MVT::i16;
407 case 32:
408 return MVT::i32;
409 case 64:
410 return MVT::i64;
411 }
412 }
413 return VT;
414}
415
416// Check whether we can merge loads/stores of some of the pieces of a
417// flattened function parameter or return value into a single vector
418// load/store.
419//
420// The flattened parameter is represented as a list of EVTs and
421// offsets, and the whole structure is aligned to ParamAlignment. This
422// function determines whether we can load/store pieces of the
423// parameter starting at index Idx using a single vectorized op of
424// size AccessSize. If so, it returns the number of param pieces
425// covered by the vector op. Otherwise, it returns 1.
426template <typename T>
427static unsigned canMergeParamLoadStoresStartingAt(
428 unsigned Idx, uint32_t AccessSize, const SmallVectorImpl<EVT> &ValueVTs,
429 const SmallVectorImpl<T> &Offsets, Align ParamAlignment) {
430
431 // Can't vectorize if param alignment is not sufficient.
432 if (ParamAlignment < AccessSize)
433 return 1;
434 // Can't vectorize if offset is not aligned.
435 if (Offsets[Idx] & (AccessSize - 1))
436 return 1;
437
438 EVT EltVT = ValueVTs[Idx];
439 unsigned EltSize = EltVT.getStoreSize();
440
441 // Element is too large to vectorize.
442 if (EltSize >= AccessSize)
443 return 1;
444
445 unsigned NumElts = AccessSize / EltSize;
446 // Can't vectorize if AccessBytes if not a multiple of EltSize.
447 if (AccessSize != EltSize * NumElts)
448 return 1;
449
450 // We don't have enough elements to vectorize.
451 if (Idx + NumElts > ValueVTs.size())
452 return 1;
453
454 // PTX ISA can only deal with 2- and 4-element vector ops.
455 if (NumElts != 4 && NumElts != 2)
456 return 1;
457
458 for (unsigned j = Idx + 1; j < Idx + NumElts; ++j) {
459 // Types do not match.
460 if (ValueVTs[j] != EltVT)
461 return 1;
462
463 // Elements are not contiguous.
464 if (Offsets[j] - Offsets[j - 1] != EltSize)
465 return 1;
466 }
467 // OK. We can vectorize ValueVTs[i..i+NumElts)
468 return NumElts;
469}
470
471// Computes whether and how we can vectorize the loads/stores of a
472// flattened function parameter or return value.
473//
474// The flattened parameter is represented as the list of ValueVTs and
475// Offsets, and is aligned to ParamAlignment bytes. We return a vector
476// of the same size as ValueVTs indicating how each piece should be
477// loaded/stored (i.e. as a scalar, or as part of a vector
478// load/store).
479template <typename T>
480static SmallVector<unsigned, 16>
481VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
482 const SmallVectorImpl<T> &Offsets, Align ParamAlignment,
483 bool IsVAArg = false) {
484 // Set vector size to match ValueVTs and mark all elements as
485 // scalars by default.
486
487 if (IsVAArg)
488 return SmallVector<unsigned>(ValueVTs.size(), 1);
489
490 SmallVector<unsigned, 16> VectorInfo;
491
492 const auto GetNumElts = [&](unsigned I) -> unsigned {
493 for (const unsigned AccessSize : {16, 8, 4, 2}) {
494 const unsigned NumElts = canMergeParamLoadStoresStartingAt(
495 I, AccessSize, ValueVTs, Offsets, ParamAlignment);
496 assert((NumElts == 1 || NumElts == 2 || NumElts == 4) &&
497 "Unexpected vectorization size");
498 if (NumElts != 1)
499 return NumElts;
500 }
501 return 1;
502 };
503
504 // Check what we can vectorize using 128/64/32-bit accesses.
505 for (unsigned I = 0, E = ValueVTs.size(); I != E;) {
506 const unsigned NumElts = GetNumElts(I);
507 VectorInfo.push_back(Elt: NumElts);
508 I += NumElts;
509 }
510 assert(std::accumulate(VectorInfo.begin(), VectorInfo.end(), 0u) ==
511 ValueVTs.size());
512 return VectorInfo;
513}
514
515// NVPTXTargetLowering Constructor.
516NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
517 const NVPTXSubtarget &STI)
518 : TargetLowering(TM, STI), nvTM(&TM), STI(STI), GlobalUniqueCallSite(0) {
519 // always lower memset, memcpy, and memmove intrinsics to load/store
520 // instructions, rather
521 // then generating calls to memset, mempcy or memmove.
522 MaxStoresPerMemset = MaxStoresPerMemsetOptSize = (unsigned)0xFFFFFFFF;
523 MaxStoresPerMemcpy = MaxStoresPerMemcpyOptSize = (unsigned) 0xFFFFFFFF;
524 MaxStoresPerMemmove = MaxStoresPerMemmoveOptSize = (unsigned) 0xFFFFFFFF;
525
526 setBooleanContents(ZeroOrNegativeOneBooleanContent);
527 setBooleanVectorContents(ZeroOrNegativeOneBooleanContent);
528
529 // Jump is Expensive. Don't create extra control flow for 'and', 'or'
530 // condition branches.
531 setJumpIsExpensive(true);
532
533 // Wide divides are _very_ slow. Try to reduce the width of the divide if
534 // possible.
535 addBypassSlowDiv(SlowBitWidth: 64, FastBitWidth: 32);
536
537 // By default, use the Source scheduling
538 if (sched4reg)
539 setSchedulingPreference(Sched::RegPressure);
540 else
541 setSchedulingPreference(Sched::Source);
542
543 auto setFP16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
544 LegalizeAction NoF16Action) {
545 bool IsOpSupported = STI.allowFP16Math();
546 switch (Op) {
547 // Several FP16 instructions are available on sm_80 only.
548 case ISD::FMINNUM:
549 case ISD::FMAXNUM:
550 case ISD::FMAXNUM_IEEE:
551 case ISD::FMINNUM_IEEE:
552 case ISD::FMAXIMUM:
553 case ISD::FMINIMUM:
554 case ISD::FMAXIMUMNUM:
555 case ISD::FMINIMUMNUM:
556 IsOpSupported &= STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
557 break;
558 case ISD::FEXP2:
559 IsOpSupported &= STI.getSmVersion() >= 75 && STI.getPTXVersion() >= 70;
560 break;
561 }
562 setOperationAction(Op, VT, Action: IsOpSupported ? Action : NoF16Action);
563 };
564
565 auto setBF16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
566 LegalizeAction NoBF16Action) {
567 bool IsOpSupported = STI.hasNativeBF16Support(Opcode: Op);
568 setOperationAction(
569 Op, VT, Action: IsOpSupported ? Action : NoBF16Action);
570 };
571
572 auto setI16x2OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
573 LegalizeAction NoI16x2Action) {
574 bool IsOpSupported = false;
575 // instructions are available on sm_90 only
576 switch (Op) {
577 case ISD::ADD:
578 case ISD::SMAX:
579 case ISD::SMIN:
580 case ISD::UMIN:
581 case ISD::UMAX:
582 IsOpSupported = STI.getSmVersion() >= 90 && STI.getPTXVersion() >= 80;
583 break;
584 }
585 setOperationAction(Op, VT, Action: IsOpSupported ? Action : NoI16x2Action);
586 };
587
588 addRegisterClass(VT: MVT::i1, RC: &NVPTX::B1RegClass);
589 addRegisterClass(VT: MVT::i16, RC: &NVPTX::B16RegClass);
590 addRegisterClass(VT: MVT::v2i16, RC: &NVPTX::B32RegClass);
591 addRegisterClass(VT: MVT::v4i8, RC: &NVPTX::B32RegClass);
592 addRegisterClass(VT: MVT::i32, RC: &NVPTX::B32RegClass);
593 addRegisterClass(VT: MVT::i64, RC: &NVPTX::B64RegClass);
594 addRegisterClass(VT: MVT::f32, RC: &NVPTX::B32RegClass);
595 addRegisterClass(VT: MVT::f64, RC: &NVPTX::B64RegClass);
596 addRegisterClass(VT: MVT::f16, RC: &NVPTX::B16RegClass);
597 addRegisterClass(VT: MVT::v2f16, RC: &NVPTX::B32RegClass);
598 addRegisterClass(VT: MVT::bf16, RC: &NVPTX::B16RegClass);
599 addRegisterClass(VT: MVT::v2bf16, RC: &NVPTX::B32RegClass);
600
601 if (STI.hasF32x2Instructions()) {
602 addRegisterClass(VT: MVT::v2f32, RC: &NVPTX::B64RegClass);
603 addRegisterClass(VT: MVT::v2i32, RC: &NVPTX::B64RegClass);
604 }
605
606 // Conversion to/from FP16/FP16x2 is always legal.
607 setOperationAction(Op: ISD::BUILD_VECTOR, VT: MVT::v2f16, Action: Custom);
608 setOperationAction(Op: ISD::EXTRACT_VECTOR_ELT, VT: MVT::v2f16, Action: Custom);
609 setOperationAction(Op: ISD::INSERT_VECTOR_ELT, VT: MVT::v2f16, Action: Expand);
610 setOperationAction(Op: ISD::VECTOR_SHUFFLE, VT: MVT::v2f16, Action: Expand);
611
612 setOperationAction(Op: ISD::READCYCLECOUNTER, VT: MVT::i64, Action: Legal);
613 if (STI.getSmVersion() >= 30 && STI.getPTXVersion() > 31)
614 setOperationAction(Op: ISD::READSTEADYCOUNTER, VT: MVT::i64, Action: Legal);
615
616 setFP16OperationAction(ISD::SETCC, MVT::f16, Legal, Promote);
617 setFP16OperationAction(ISD::SETCC, MVT::v2f16, Legal, Expand);
618
619 // Conversion to/from BFP16/BFP16x2 is always legal.
620 setOperationAction(Op: ISD::BUILD_VECTOR, VT: MVT::v2bf16, Action: Custom);
621 setOperationAction(Op: ISD::EXTRACT_VECTOR_ELT, VT: MVT::v2bf16, Action: Custom);
622 setOperationAction(Op: ISD::INSERT_VECTOR_ELT, VT: MVT::v2bf16, Action: Expand);
623 setOperationAction(Op: ISD::VECTOR_SHUFFLE, VT: MVT::v2bf16, Action: Expand);
624
625 setBF16OperationAction(ISD::SETCC, MVT::v2bf16, Legal, Expand);
626 setBF16OperationAction(ISD::SETCC, MVT::bf16, Legal, Promote);
627 if (getOperationAction(Op: ISD::SETCC, VT: MVT::bf16) == Promote)
628 AddPromotedToType(Opc: ISD::SETCC, OrigVT: MVT::bf16, DestVT: MVT::f32);
629
630 // Conversion to/from i16/i16x2 is always legal.
631 setOperationAction(Op: ISD::BUILD_VECTOR, VT: MVT::v2i16, Action: Custom);
632 setOperationAction(Op: ISD::EXTRACT_VECTOR_ELT, VT: MVT::v2i16, Action: Custom);
633 setOperationAction(Op: ISD::INSERT_VECTOR_ELT, VT: MVT::v2i16, Action: Expand);
634 setOperationAction(Op: ISD::VECTOR_SHUFFLE, VT: MVT::v2i16, Action: Expand);
635
636 setOperationAction(Op: ISD::BUILD_VECTOR, VT: MVT::v4i8, Action: Custom);
637 setOperationAction(Op: ISD::EXTRACT_VECTOR_ELT, VT: MVT::v4i8, Action: Custom);
638 setOperationAction(Op: ISD::INSERT_VECTOR_ELT, VT: MVT::v4i8, Action: Custom);
639 setOperationAction(Op: ISD::VECTOR_SHUFFLE, VT: MVT::v4i8, Action: Custom);
640
641 // No support for these operations with v2f32/v2i32
642 setOperationAction(Ops: ISD::INSERT_VECTOR_ELT, VTs: {MVT::v2f32, MVT::v2i32}, Action: Expand);
643 setOperationAction(Ops: ISD::VECTOR_SHUFFLE, VTs: {MVT::v2f32, MVT::v2i32}, Action: Expand);
644
645 setOperationAction(Op: ISD::TRUNCATE, VT: MVT::v2i16, Action: Expand);
646 setOperationAction(Ops: {ISD::ANY_EXTEND, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND},
647 VT: MVT::v2i32, Action: Expand);
648
649 // Need custom lowering in case the index is dynamic.
650 if (STI.hasF32x2Instructions())
651 setOperationAction(Ops: ISD::EXTRACT_VECTOR_ELT, VTs: {MVT::v2f32, MVT::v2i32},
652 Action: Custom);
653
654 // Custom conversions to/from v2i8.
655 setOperationAction(Op: ISD::BITCAST, VT: MVT::v2i8, Action: Custom);
656
657 // Only logical ops can be done on v4i8/v2i32 directly, others must be done
658 // elementwise.
659 setOperationAction(
660 Ops: {ISD::ABS, ISD::ADD, ISD::ADDC, ISD::ADDE,
661 ISD::BITREVERSE, ISD::CTLZ, ISD::CTPOP, ISD::CTTZ,
662 ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::FSHL, ISD::FSHR,
663 ISD::MUL, ISD::MULHS, ISD::MULHU, ISD::PARITY,
664 ISD::ROTL, ISD::ROTR, ISD::SADDO, ISD::SADDO_CARRY,
665 ISD::SADDSAT, ISD::SDIV, ISD::SDIVREM, ISD::SELECT_CC,
666 ISD::SETCC, ISD::SHL, ISD::SINT_TO_FP, ISD::SMAX,
667 ISD::SMIN, ISD::SMULO, ISD::SMUL_LOHI, ISD::SRA,
668 ISD::SREM, ISD::SRL, ISD::SSHLSAT, ISD::SSUBO,
669 ISD::SSUBO_CARRY, ISD::SSUBSAT, ISD::SUB, ISD::SUBC,
670 ISD::SUBE, ISD::UADDO, ISD::UADDO_CARRY, ISD::UADDSAT,
671 ISD::UDIV, ISD::UDIVREM, ISD::UINT_TO_FP, ISD::UMAX,
672 ISD::UMIN, ISD::UMULO, ISD::UMUL_LOHI, ISD::UREM,
673 ISD::USHLSAT, ISD::USUBO, ISD::USUBO_CARRY, ISD::VSELECT,
674 ISD::USUBSAT},
675 VTs: {MVT::v4i8, MVT::v2i32}, Action: Expand);
676
677 // Operations not directly supported by NVPTX.
678 for (MVT VT : {MVT::bf16, MVT::f16, MVT::v2bf16, MVT::v2f16, MVT::f32,
679 MVT::v2f32, MVT::f64, MVT::i1, MVT::i8, MVT::i16, MVT::v2i16,
680 MVT::v4i8, MVT::i32, MVT::v2i32, MVT::i64}) {
681 setOperationAction(Op: ISD::SELECT_CC, VT, Action: Expand);
682 setOperationAction(Op: ISD::BR_CC, VT, Action: Expand);
683 }
684
685 // We don't want ops like FMINIMUM or UMAX to be lowered to SETCC+VSELECT.
686 setOperationAction(Ops: ISD::VSELECT, VTs: {MVT::v2f32, MVT::v2i32}, Action: Expand);
687
688 // Some SIGN_EXTEND_INREG can be done using cvt instruction.
689 // For others we will expand to a SHL/SRA pair.
690 setOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: MVT::i64, Action: Legal);
691 setOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: MVT::i32, Action: Legal);
692 setOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: MVT::i16, Action: Legal);
693 setOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: MVT::i8 , Action: Legal);
694 setOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: MVT::i1, Action: Expand);
695 setOperationAction(Ops: ISD::SIGN_EXTEND_INREG, VTs: {MVT::v2i16, MVT::v2i32}, Action: Expand);
696
697 setOperationAction(Op: ISD::SHL_PARTS, VT: MVT::i32 , Action: Custom);
698 setOperationAction(Op: ISD::SRA_PARTS, VT: MVT::i32 , Action: Custom);
699 setOperationAction(Op: ISD::SRL_PARTS, VT: MVT::i32 , Action: Custom);
700 setOperationAction(Op: ISD::SHL_PARTS, VT: MVT::i64 , Action: Custom);
701 setOperationAction(Op: ISD::SRA_PARTS, VT: MVT::i64 , Action: Custom);
702 setOperationAction(Op: ISD::SRL_PARTS, VT: MVT::i64 , Action: Custom);
703
704 setOperationAction(Op: ISD::BITREVERSE, VT: MVT::i32, Action: Legal);
705 setOperationAction(Op: ISD::BITREVERSE, VT: MVT::i64, Action: Legal);
706
707 setOperationAction(Ops: {ISD::ROTL, ISD::ROTR},
708 VTs: {MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64},
709 Action: Expand);
710
711 if (STI.hasHWROT32()) {
712 setOperationAction(Ops: {ISD::FSHL, ISD::FSHR}, VT: MVT::i32, Action: Legal);
713 setOperationAction(Ops: {ISD::ROTL, ISD::ROTR, ISD::FSHL, ISD::FSHR}, VT: MVT::i64,
714 Action: Custom);
715 }
716
717 setOperationAction(Op: ISD::BR_JT, VT: MVT::Other, Action: STI.hasBrx() ? Legal : Expand);
718 setOperationAction(Op: ISD::BRIND, VT: MVT::Other, Action: Expand);
719
720 // We want to legalize constant related memmove and memcopy
721 // intrinsics.
722 setOperationAction(Op: ISD::INTRINSIC_W_CHAIN, VT: MVT::Other, Action: Custom);
723
724 // FP extload/truncstore is not legal in PTX. We need to expand all these.
725 for (auto FloatVTs :
726 {MVT::fp_valuetypes(), MVT::fp_fixedlen_vector_valuetypes()}) {
727 for (MVT ValVT : FloatVTs) {
728 for (MVT MemVT : FloatVTs) {
729 setLoadExtAction(ExtType: ISD::EXTLOAD, ValVT, MemVT, Action: Expand);
730 setTruncStoreAction(ValVT, MemVT, Action: Expand);
731 }
732 }
733 }
734
735 // To improve CodeGen we'll legalize any-extend loads to zext loads. This is
736 // how they'll be lowered in ISel anyway, and by doing this a little earlier
737 // we allow for more DAG combine opportunities.
738 for (auto IntVTs :
739 {MVT::integer_valuetypes(), MVT::integer_fixedlen_vector_valuetypes()})
740 for (MVT ValVT : IntVTs)
741 for (MVT MemVT : IntVTs)
742 if (isTypeLegal(VT: ValVT))
743 setLoadExtAction(ExtType: ISD::EXTLOAD, ValVT, MemVT, Action: Custom);
744
745 // PTX does not support load / store predicate registers
746 setOperationAction(Ops: {ISD::LOAD, ISD::STORE}, VT: MVT::i1, Action: Custom);
747 for (MVT VT : MVT::integer_valuetypes()) {
748 setLoadExtAction(ExtTypes: {ISD::SEXTLOAD, ISD::ZEXTLOAD, ISD::EXTLOAD}, ValVT: VT, MemVT: MVT::i1,
749 Action: Promote);
750 setTruncStoreAction(ValVT: VT, MemVT: MVT::i1, Action: Expand);
751 }
752
753 // Disable generations of extload/truncstore for v2i32/v2i16/v2i8. The generic
754 // expansion for these nodes when they are unaligned is incorrect if the
755 // type is a vector.
756 //
757 // TODO: Fix the generic expansion for these nodes found in
758 // TargetLowering::expandUnalignedLoad/Store.
759 setLoadExtAction(ExtTypes: {ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, ValVT: MVT::v2i16,
760 MemVT: MVT::v2i8, Action: Expand);
761 setLoadExtAction(ExtTypes: {ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, ValVT: MVT::v2i32,
762 MemVTs: {MVT::v2i8, MVT::v2i16}, Action: Expand);
763 setTruncStoreAction(ValVT: MVT::v2i16, MemVT: MVT::v2i8, Action: Expand);
764 setTruncStoreAction(ValVT: MVT::v2i32, MemVT: MVT::v2i16, Action: Expand);
765 setTruncStoreAction(ValVT: MVT::v2i32, MemVT: MVT::v2i8, Action: Expand);
766
767 // Register custom handling for illegal type loads/stores. We'll try to custom
768 // lower almost all illegal types and logic in the lowering will discard cases
769 // we can't handle.
770 setOperationAction(Ops: {ISD::LOAD, ISD::STORE}, VTs: {MVT::i128, MVT::i256, MVT::f128},
771 Action: Custom);
772 for (MVT VT : MVT::fixedlen_vector_valuetypes())
773 if (!isTypeLegal(VT) && VT.getStoreSizeInBits() <= 256)
774 setOperationAction(Ops: {ISD::STORE, ISD::LOAD, ISD::MSTORE, ISD::MLOAD}, VT,
775 Action: Custom);
776
777 // Custom legalization for LDU intrinsics.
778 // TODO: The logic to lower these is not very robust and we should rewrite it.
779 // Perhaps LDU should not be represented as an intrinsic at all.
780 setOperationAction(Op: ISD::INTRINSIC_W_CHAIN, VT: MVT::i8, Action: Custom);
781 for (MVT VT : MVT::fixedlen_vector_valuetypes())
782 if (IsPTXVectorType(VT))
783 setOperationAction(Op: ISD::INTRINSIC_W_CHAIN, VT, Action: Custom);
784
785 setCondCodeAction(CCs: {ISD::SETNE, ISD::SETEQ, ISD::SETUGE, ISD::SETULE,
786 ISD::SETUGT, ISD::SETULT, ISD::SETGT, ISD::SETLT,
787 ISD::SETGE, ISD::SETLE},
788 VT: MVT::i1, Action: Expand);
789
790 // This is legal in NVPTX
791 setOperationAction(Op: ISD::ConstantFP, VT: MVT::f64, Action: Legal);
792 setOperationAction(Op: ISD::ConstantFP, VT: MVT::f32, Action: Legal);
793 setOperationAction(Op: ISD::ConstantFP, VT: MVT::f16, Action: Legal);
794 setOperationAction(Op: ISD::ConstantFP, VT: MVT::bf16, Action: Legal);
795
796 setOperationAction(Ops: ISD::DYNAMIC_STACKALLOC, VTs: {MVT::i32, MVT::i64}, Action: Custom);
797 setOperationAction(Ops: {ISD::STACKRESTORE, ISD::STACKSAVE}, VT: MVT::Other, Action: Custom);
798
799 // TRAP can be lowered to PTX trap
800 setOperationAction(Op: ISD::TRAP, VT: MVT::Other, Action: Legal);
801 // DEBUGTRAP can be lowered to PTX brkpt
802 setOperationAction(Op: ISD::DEBUGTRAP, VT: MVT::Other, Action: Legal);
803
804 // Support varargs.
805 setOperationAction(Op: ISD::VASTART, VT: MVT::Other, Action: Custom);
806 setOperationAction(Op: ISD::VAARG, VT: MVT::Other, Action: Custom);
807 setOperationAction(Op: ISD::VACOPY, VT: MVT::Other, Action: Expand);
808 setOperationAction(Op: ISD::VAEND, VT: MVT::Other, Action: Expand);
809
810 setOperationAction(Ops: {ISD::ABS, ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX},
811 VTs: {MVT::i16, MVT::i32, MVT::i64}, Action: Legal);
812
813 setOperationAction(Ops: {ISD::CTPOP, ISD::CTLZ, ISD::CTLZ_ZERO_UNDEF}, VT: MVT::i16,
814 Action: Promote);
815 setOperationAction(Ops: {ISD::CTPOP, ISD::CTLZ}, VT: MVT::i32, Action: Legal);
816 setOperationAction(Ops: {ISD::CTPOP, ISD::CTLZ}, VT: MVT::i64, Action: Custom);
817
818 setI16x2OperationAction(ISD::ABS, MVT::v2i16, Legal, Custom);
819 setI16x2OperationAction(ISD::SMIN, MVT::v2i16, Legal, Custom);
820 setI16x2OperationAction(ISD::SMAX, MVT::v2i16, Legal, Custom);
821 setI16x2OperationAction(ISD::UMIN, MVT::v2i16, Legal, Custom);
822 setI16x2OperationAction(ISD::UMAX, MVT::v2i16, Legal, Custom);
823 setI16x2OperationAction(ISD::CTPOP, MVT::v2i16, Legal, Expand);
824 setI16x2OperationAction(ISD::CTLZ, MVT::v2i16, Legal, Expand);
825
826 setI16x2OperationAction(ISD::ADD, MVT::v2i16, Legal, Custom);
827 setI16x2OperationAction(ISD::SUB, MVT::v2i16, Legal, Custom);
828 setI16x2OperationAction(ISD::MUL, MVT::v2i16, Legal, Custom);
829 setI16x2OperationAction(ISD::SHL, MVT::v2i16, Legal, Custom);
830 setI16x2OperationAction(ISD::SREM, MVT::v2i16, Legal, Custom);
831 setI16x2OperationAction(ISD::UREM, MVT::v2i16, Legal, Custom);
832
833 // Other arithmetic and logic ops are unsupported.
834 setOperationAction(Ops: {ISD::SDIV, ISD::UDIV, ISD::SRA, ISD::SRL, ISD::MULHS,
835 ISD::MULHU, ISD::FP_TO_SINT, ISD::FP_TO_UINT,
836 ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::SETCC},
837 VTs: {MVT::v2i16, MVT::v2i32}, Action: Expand);
838
839 // v2i32 is not supported for any arithmetic operations
840 setOperationAction(Ops: {ISD::ABS, ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX,
841 ISD::CTPOP, ISD::CTLZ, ISD::ADD, ISD::SUB, ISD::MUL,
842 ISD::SHL, ISD::SRA, ISD::SRL, ISD::OR, ISD::AND, ISD::XOR,
843 ISD::SREM, ISD::UREM},
844 VT: MVT::v2i32, Action: Expand);
845
846 setOperationAction(Op: ISD::ADDC, VT: MVT::i32, Action: Legal);
847 setOperationAction(Op: ISD::ADDE, VT: MVT::i32, Action: Legal);
848 setOperationAction(Op: ISD::SUBC, VT: MVT::i32, Action: Legal);
849 setOperationAction(Op: ISD::SUBE, VT: MVT::i32, Action: Legal);
850 if (STI.getPTXVersion() >= 43) {
851 setOperationAction(Op: ISD::ADDC, VT: MVT::i64, Action: Legal);
852 setOperationAction(Op: ISD::ADDE, VT: MVT::i64, Action: Legal);
853 setOperationAction(Op: ISD::SUBC, VT: MVT::i64, Action: Legal);
854 setOperationAction(Op: ISD::SUBE, VT: MVT::i64, Action: Legal);
855 }
856
857 setOperationAction(Op: ISD::CTTZ, VT: MVT::i16, Action: Expand);
858 setOperationAction(Ops: ISD::CTTZ, VTs: {MVT::v2i16, MVT::v2i32}, Action: Expand);
859 setOperationAction(Op: ISD::CTTZ, VT: MVT::i32, Action: Expand);
860 setOperationAction(Op: ISD::CTTZ, VT: MVT::i64, Action: Expand);
861
862 // PTX does not directly support SELP of i1, so promote to i32 first
863 setOperationAction(Op: ISD::SELECT, VT: MVT::i1, Action: Custom);
864
865 // PTX cannot multiply two i64s in a single instruction.
866 setOperationAction(Op: ISD::SMUL_LOHI, VT: MVT::i64, Action: Expand);
867 setOperationAction(Op: ISD::UMUL_LOHI, VT: MVT::i64, Action: Expand);
868
869 // We have some custom DAG combine patterns for these nodes
870 setTargetDAGCombine({ISD::ADD,
871 ISD::AND,
872 ISD::EXTRACT_VECTOR_ELT,
873 ISD::FADD,
874 ISD::FMAXNUM,
875 ISD::FMINNUM,
876 ISD::FMAXIMUM,
877 ISD::FMINIMUM,
878 ISD::FMAXIMUMNUM,
879 ISD::FMINIMUMNUM,
880 ISD::MUL,
881 ISD::SELECT,
882 ISD::SHL,
883 ISD::SREM,
884 ISD::UREM,
885 ISD::VSELECT,
886 ISD::BUILD_VECTOR,
887 ISD::ADDRSPACECAST,
888 ISD::LOAD,
889 ISD::STORE,
890 ISD::ZERO_EXTEND,
891 ISD::SIGN_EXTEND,
892 ISD::INTRINSIC_WO_CHAIN});
893
894 // setcc for f16x2 and bf16x2 needs special handling to prevent
895 // legalizer's attempt to scalarize it due to v2i1 not being legal.
896 if (STI.allowFP16Math() || STI.hasBF16Math())
897 setTargetDAGCombine(ISD::SETCC);
898
899 // Vector reduction operations. These may be turned into shuffle or tree
900 // reductions depending on what instructions are available for each type.
901 for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
902 MVT EltVT = VT.getVectorElementType();
903 if (EltVT == MVT::f32 || EltVT == MVT::f64) {
904 setOperationAction(Ops: {ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
905 ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
906 VT, Action: Custom);
907 }
908 }
909
910 // Promote fp16 arithmetic if fp16 hardware isn't available or the
911 // user passed --nvptx-no-fp16-math. The flag is useful because,
912 // although sm_53+ GPUs have some sort of FP16 support in
913 // hardware, only sm_53 and sm_60 have full implementation. Others
914 // only have token amount of hardware and are likely to run faster
915 // by using fp32 units instead.
916 for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) {
917 setFP16OperationAction(Op, MVT::f16, Legal, Promote);
918 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
919 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
920 // bf16 must be promoted to f32.
921 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
922 if (getOperationAction(Op, VT: MVT::bf16) == Promote)
923 AddPromotedToType(Opc: Op, OrigVT: MVT::bf16, DestVT: MVT::f32);
924 setOperationAction(Op, VT: MVT::v2f32,
925 Action: STI.hasF32x2Instructions() ? Legal : Expand);
926 }
927
928 // On SM80, we select add/mul/sub as fma to avoid promotion to float
929 for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB}) {
930 for (const auto &VT : {MVT::bf16, MVT::v2bf16}) {
931 if (!STI.hasNativeBF16Support(Opcode: Op) && STI.hasNativeBF16Support(Opcode: ISD::FMA)) {
932 setOperationAction(Op, VT, Action: Custom);
933 }
934 }
935 }
936
937 // f16/f16x2 neg was introduced in PTX 60, SM_53.
938 const bool IsFP16FP16x2NegAvailable = STI.getSmVersion() >= 53 &&
939 STI.getPTXVersion() >= 60 &&
940 STI.allowFP16Math();
941 for (const auto &VT : {MVT::f16, MVT::v2f16})
942 setOperationAction(Op: ISD::FNEG, VT,
943 Action: IsFP16FP16x2NegAvailable ? Legal : Expand);
944
945 setBF16OperationAction(ISD::FNEG, MVT::bf16, Legal, Expand);
946 setBF16OperationAction(ISD::FNEG, MVT::v2bf16, Legal, Expand);
947 setOperationAction(Op: ISD::FNEG, VT: MVT::v2f32, Action: Expand);
948 // (would be) Library functions.
949
950 // These map to conversion instructions for scalar FP types.
951 for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
952 ISD::FROUNDEVEN, ISD::FTRUNC}) {
953 setOperationAction(Op, VT: MVT::f16, Action: Legal);
954 setOperationAction(Op, VT: MVT::f32, Action: Legal);
955 setOperationAction(Op, VT: MVT::f64, Action: Legal);
956 setOperationAction(Op, VT: MVT::v2f16, Action: Expand);
957 setOperationAction(Op, VT: MVT::v2bf16, Action: Expand);
958 setOperationAction(Op, VT: MVT::v2f32, Action: Expand);
959 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
960 if (getOperationAction(Op, VT: MVT::bf16) == Promote)
961 AddPromotedToType(Opc: Op, OrigVT: MVT::bf16, DestVT: MVT::f32);
962 }
963
964 if (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71) {
965 setOperationAction(Op: ISD::BF16_TO_FP, VT: MVT::f32, Action: Expand);
966 }
967 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
968 for (MVT VT : {MVT::bf16, MVT::f32, MVT::f64}) {
969 setOperationAction(Op: ISD::FP_EXTEND, VT, Action: Custom);
970 setOperationAction(Op: ISD::FP_ROUND, VT, Action: Custom);
971 }
972 }
973
974 // Expand v2f32 = fp_extend
975 setOperationAction(Op: ISD::FP_EXTEND, VT: MVT::v2f32, Action: Expand);
976 // Expand v2[b]f16 = fp_round v2f32
977 setOperationAction(Ops: ISD::FP_ROUND, VTs: {MVT::v2bf16, MVT::v2f16}, Action: Expand);
978
979 // sm_80 only has conversions between f32 and bf16. Custom lower all other
980 // bf16 conversions.
981 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
982 for (MVT VT : {MVT::i1, MVT::i16, MVT::i32, MVT::i64}) {
983 setOperationAction(
984 Ops: {ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT},
985 VT, Action: Custom);
986 }
987 setOperationAction(
988 Ops: {ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT},
989 VT: MVT::bf16, Action: Custom);
990 }
991
992 setOperationAction(Op: ISD::FROUND, VT: MVT::f16, Action: Promote);
993 setOperationAction(Op: ISD::FROUND, VT: MVT::v2f16, Action: Expand);
994 setOperationAction(Op: ISD::FROUND, VT: MVT::v2bf16, Action: Expand);
995 setOperationAction(Op: ISD::FROUND, VT: MVT::f32, Action: Custom);
996 setOperationAction(Op: ISD::FROUND, VT: MVT::f64, Action: Custom);
997 setOperationAction(Op: ISD::FROUND, VT: MVT::bf16, Action: Promote);
998 AddPromotedToType(Opc: ISD::FROUND, OrigVT: MVT::bf16, DestVT: MVT::f32);
999
1000 // 'Expand' implements FCOPYSIGN without calling an external library.
1001 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::f16, Action: Expand);
1002 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::v2f16, Action: Expand);
1003 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::bf16, Action: Expand);
1004 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::v2bf16, Action: Expand);
1005 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::f32, Action: Custom);
1006 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::f64, Action: Custom);
1007
1008 // These map to corresponding instructions for f32/f64. f16 must be
1009 // promoted to f32. v2f16 is expanded to f16, which is then promoted
1010 // to f32.
1011 for (const auto &Op :
1012 {ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS, ISD::FTANH}) {
1013 setOperationAction(Op, VT: MVT::f16, Action: Promote);
1014 setOperationAction(Op, VT: MVT::f32, Action: Legal);
1015 // only div/rem/sqrt are legal for f64
1016 if (Op == ISD::FDIV || Op == ISD::FREM || Op == ISD::FSQRT) {
1017 setOperationAction(Op, VT: MVT::f64, Action: Legal);
1018 }
1019 setOperationAction(Ops: Op, VTs: {MVT::v2f16, MVT::v2bf16, MVT::v2f32}, Action: Expand);
1020 setOperationAction(Op, VT: MVT::bf16, Action: Promote);
1021 AddPromotedToType(Opc: Op, OrigVT: MVT::bf16, DestVT: MVT::f32);
1022 }
1023 setOperationAction(Ops: ISD::FREM, VTs: {MVT::f32, MVT::f64}, Action: Custom);
1024
1025 setOperationAction(Ops: ISD::FABS, VTs: {MVT::f32, MVT::f64}, Action: Legal);
1026 setOperationAction(Op: ISD::FABS, VT: MVT::v2f32, Action: Expand);
1027 if (STI.getPTXVersion() >= 65) {
1028 setFP16OperationAction(ISD::FABS, MVT::f16, Legal, Promote);
1029 setFP16OperationAction(ISD::FABS, MVT::v2f16, Legal, Expand);
1030 } else {
1031 setOperationAction(Op: ISD::FABS, VT: MVT::f16, Action: Promote);
1032 setOperationAction(Op: ISD::FABS, VT: MVT::v2f16, Action: Expand);
1033 }
1034 setBF16OperationAction(ISD::FABS, MVT::v2bf16, Legal, Expand);
1035 setBF16OperationAction(ISD::FABS, MVT::bf16, Legal, Promote);
1036 if (getOperationAction(Op: ISD::FABS, VT: MVT::bf16) == Promote)
1037 AddPromotedToType(Opc: ISD::FABS, OrigVT: MVT::bf16, DestVT: MVT::f32);
1038
1039 for (const auto &Op :
1040 {ISD::FMINNUM, ISD::FMAXNUM, ISD::FMINIMUMNUM, ISD::FMAXIMUMNUM}) {
1041 setOperationAction(Op, VT: MVT::f32, Action: Legal);
1042 setOperationAction(Op, VT: MVT::f64, Action: Legal);
1043 setFP16OperationAction(Op, MVT::f16, Legal, Promote);
1044 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
1045 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
1046 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
1047 if (getOperationAction(Op, VT: MVT::bf16) == Promote)
1048 AddPromotedToType(Opc: Op, OrigVT: MVT::bf16, DestVT: MVT::f32);
1049 setOperationAction(Op, VT: MVT::v2f32, Action: Expand);
1050 }
1051 bool SupportsF32MinMaxNaN =
1052 STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
1053 for (const auto &Op : {ISD::FMINIMUM, ISD::FMAXIMUM}) {
1054 setOperationAction(Op, VT: MVT::f32, Action: SupportsF32MinMaxNaN ? Legal : Expand);
1055 setFP16OperationAction(Op, MVT::f16, Legal, Expand);
1056 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
1057 setBF16OperationAction(Op, MVT::bf16, Legal, Expand);
1058 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
1059 setOperationAction(Op, VT: MVT::v2f32, Action: Expand);
1060 }
1061
1062 // Custom lowering for inline asm with 128-bit operands
1063 setOperationAction(Op: ISD::CopyToReg, VT: MVT::i128, Action: Custom);
1064 setOperationAction(Op: ISD::CopyFromReg, VT: MVT::i128, Action: Custom);
1065
1066 // FEXP2 support:
1067 // - f32
1068 // - f16/f16x2 (sm_70+, PTX 7.0+)
1069 // - bf16/bf16x2 (sm_90+, PTX 7.8+)
1070 // When f16/bf16 types aren't supported, they are promoted/expanded to f32.
1071 setOperationAction(Op: ISD::FEXP2, VT: MVT::f32, Action: Legal);
1072 setOperationAction(Op: ISD::FEXP2, VT: MVT::v2f32, Action: Expand);
1073 setFP16OperationAction(ISD::FEXP2, MVT::f16, Legal, Promote);
1074 setFP16OperationAction(ISD::FEXP2, MVT::v2f16, Legal, Expand);
1075 setBF16OperationAction(ISD::FEXP2, MVT::bf16, Legal, Promote);
1076 setBF16OperationAction(ISD::FEXP2, MVT::v2bf16, Legal, Expand);
1077
1078 // FLOG2 supports f32 only
1079 // f16/bf16 types aren't supported, but they are promoted/expanded to f32.
1080 if (UseApproxLog2F32) {
1081 setOperationAction(Op: ISD::FLOG2, VT: MVT::f32, Action: Legal);
1082 setOperationPromotedToType(Opc: ISD::FLOG2, OrigVT: MVT::f16, DestVT: MVT::f32);
1083 setOperationPromotedToType(Opc: ISD::FLOG2, OrigVT: MVT::bf16, DestVT: MVT::f32);
1084 setOperationAction(Ops: ISD::FLOG2, VTs: {MVT::v2f16, MVT::v2bf16, MVT::v2f32},
1085 Action: Expand);
1086 }
1087
1088 setOperationAction(Ops: ISD::ADDRSPACECAST, VTs: {MVT::i32, MVT::i64}, Action: Custom);
1089
1090 setOperationAction(Ops: ISD::ATOMIC_LOAD_SUB, VTs: {MVT::i32, MVT::i64}, Action: Expand);
1091
1092 // atom.b128 is legal in PTX but since we don't represent i128 as a legal
1093 // type, we need to custom lower it.
1094 setOperationAction(Ops: {ISD::ATOMIC_CMP_SWAP, ISD::ATOMIC_SWAP}, VT: MVT::i128,
1095 Action: Custom);
1096
1097 // Now deduce the information based on the above mentioned
1098 // actions
1099 computeRegisterProperties(TRI: STI.getRegisterInfo());
1100
1101 // PTX support for 16-bit CAS is emulated. Only use 32+
1102 setMinCmpXchgSizeInBits(STI.getMinCmpXchgSizeInBits());
1103 setMaxAtomicSizeInBitsSupported(STI.hasAtomSwap128() ? 128 : 64);
1104 setMaxDivRemBitWidthSupported(64);
1105
1106 // Custom lowering for tcgen05.ld vector operands
1107 setOperationAction(Ops: ISD::INTRINSIC_W_CHAIN,
1108 VTs: {MVT::v2i32, MVT::v4i32, MVT::v8i32, MVT::v16i32,
1109 MVT::v32i32, MVT::v64i32, MVT::v128i32, MVT::v2f32,
1110 MVT::v4f32, MVT::v8f32, MVT::v16f32, MVT::v32f32,
1111 MVT::v64f32, MVT::v128f32},
1112 Action: Custom);
1113
1114 // Custom lowering for tcgen05.st vector operands
1115 setOperationAction(Ops: ISD::INTRINSIC_VOID,
1116 VTs: {MVT::v2i32, MVT::v4i32, MVT::v8i32, MVT::v16i32,
1117 MVT::v32i32, MVT::v64i32, MVT::v128i32, MVT::Other},
1118 Action: Custom);
1119
1120 // Enable custom lowering for the following:
1121 // * MVT::i128 - clusterlaunchcontrol
1122 // * MVT::i32 - prmt
1123 // * MVT::v4f32 - cvt_rs fp{4/6/8}x4 intrinsics
1124 // * MVT::Other - internal.addrspace.wrap
1125 setOperationAction(Ops: ISD::INTRINSIC_WO_CHAIN,
1126 VTs: {MVT::i32, MVT::i128, MVT::v4f32, MVT::Other}, Action: Custom);
1127
1128 // Custom lowering for bswap
1129 setOperationAction(Ops: ISD::BSWAP, VTs: {MVT::i16, MVT::i32, MVT::i64, MVT::v2i16},
1130 Action: Custom);
1131}
1132
1133TargetLoweringBase::LegalizeTypeAction
1134NVPTXTargetLowering::getPreferredVectorAction(MVT VT) const {
1135 if (!VT.isScalableVector() && VT.getVectorNumElements() != 1 &&
1136 VT.getScalarType() == MVT::i1)
1137 return TypeSplitVector;
1138 return TargetLoweringBase::getPreferredVectorAction(VT);
1139}
1140
1141SDValue NVPTXTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG,
1142 int Enabled, int &ExtraSteps,
1143 bool &UseOneConst,
1144 bool Reciprocal) const {
1145 if (!(Enabled == ReciprocalEstimate::Enabled ||
1146 (Enabled == ReciprocalEstimate::Unspecified && !usePrecSqrtF32())))
1147 return SDValue();
1148
1149 if (ExtraSteps == ReciprocalEstimate::Unspecified)
1150 ExtraSteps = 0;
1151
1152 SDLoc DL(Operand);
1153 EVT VT = Operand.getValueType();
1154 bool Ftz = useF32FTZ(MF: DAG.getMachineFunction());
1155
1156 auto MakeIntrinsicCall = [&](Intrinsic::ID IID) {
1157 return DAG.getNode(Opcode: ISD::INTRINSIC_WO_CHAIN, DL, VT,
1158 N1: DAG.getConstant(Val: IID, DL, VT: MVT::i32), N2: Operand);
1159 };
1160
1161 // The sqrt and rsqrt refinement processes assume we always start out with an
1162 // approximation of the rsqrt. Therefore, if we're going to do any refinement
1163 // (i.e. ExtraSteps > 0), we must return an rsqrt. But if we're *not* doing
1164 // any refinement, we must return a regular sqrt.
1165 if (Reciprocal || ExtraSteps > 0) {
1166 if (VT == MVT::f32)
1167 return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_rsqrt_approx_ftz_f
1168 : Intrinsic::nvvm_rsqrt_approx_f);
1169 else if (VT == MVT::f64)
1170 return MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d);
1171 else
1172 return SDValue();
1173 } else {
1174 if (VT == MVT::f32)
1175 return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_sqrt_approx_ftz_f
1176 : Intrinsic::nvvm_sqrt_approx_f);
1177 else {
1178 // There's no sqrt.approx.f64 instruction, so we emit
1179 // reciprocal(rsqrt(x)). This is faster than
1180 // select(x == 0, 0, x * rsqrt(x)). (In fact, it's faster than plain
1181 // x * rsqrt(x).)
1182 return DAG.getNode(
1183 Opcode: ISD::INTRINSIC_WO_CHAIN, DL, VT,
1184 N1: DAG.getConstant(Val: Intrinsic::nvvm_rcp_approx_ftz_d, DL, VT: MVT::i32),
1185 N2: MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d));
1186 }
1187 }
1188}
1189
1190std::string NVPTXTargetLowering::getPrototype(
1191 const DataLayout &DL, Type *RetTy, const ArgListTy &Args,
1192 const SmallVectorImpl<ISD::OutputArg> &Outs,
1193 std::optional<unsigned> FirstVAArg, const CallBase &CB,
1194 unsigned UniqueCallSite) const {
1195 auto PtrVT = getPointerTy(DL);
1196
1197 std::string Prototype;
1198 raw_string_ostream O(Prototype);
1199 O << "prototype_" << UniqueCallSite << " : .callprototype ";
1200
1201 if (RetTy->isVoidTy()) {
1202 O << "()";
1203 } else {
1204 O << "(";
1205 if (shouldPassAsArray(Ty: RetTy)) {
1206 const Align RetAlign = getArgumentAlignment(CB: &CB, Ty: RetTy, Idx: 0, DL);
1207 O << ".param .align " << RetAlign.value() << " .b8 _["
1208 << DL.getTypeAllocSize(Ty: RetTy) << "]";
1209 } else if (RetTy->isFloatingPointTy() || RetTy->isIntegerTy()) {
1210 unsigned size = 0;
1211 if (auto *ITy = dyn_cast<IntegerType>(Val: RetTy)) {
1212 size = ITy->getBitWidth();
1213 } else {
1214 assert(RetTy->isFloatingPointTy() &&
1215 "Floating point type expected here");
1216 size = RetTy->getPrimitiveSizeInBits();
1217 }
1218 // PTX ABI requires all scalar return values to be at least 32
1219 // bits in size. fp16 normally uses .b16 as its storage type in
1220 // PTX, so its size must be adjusted here, too.
1221 size = promoteScalarArgumentSize(size);
1222
1223 O << ".param .b" << size << " _";
1224 } else if (isa<PointerType>(Val: RetTy)) {
1225 O << ".param .b" << PtrVT.getSizeInBits() << " _";
1226 } else {
1227 llvm_unreachable("Unknown return type");
1228 }
1229 O << ") ";
1230 }
1231 O << "_ (";
1232
1233 bool first = true;
1234
1235 const unsigned NumArgs = FirstVAArg.value_or(u: Args.size());
1236 auto AllOuts = ArrayRef(Outs);
1237 for (const unsigned I : llvm::seq(Size: NumArgs)) {
1238 const auto ArgOuts =
1239 AllOuts.take_while(Pred: [I](auto O) { return O.OrigArgIndex == I; });
1240 AllOuts = AllOuts.drop_front(N: ArgOuts.size());
1241
1242 Type *Ty = Args[I].Ty;
1243 if (!first) {
1244 O << ", ";
1245 }
1246 first = false;
1247
1248 if (ArgOuts[0].Flags.isByVal()) {
1249 // Indirect calls need strict ABI alignment so we disable optimizations by
1250 // not providing a function to optimize.
1251 Type *ETy = Args[I].IndirectType;
1252 Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
1253 Align ParamByValAlign =
1254 getFunctionByValParamAlign(/*F=*/nullptr, ArgTy: ETy, InitialAlign, DL);
1255
1256 O << ".param .align " << ParamByValAlign.value() << " .b8 _["
1257 << ArgOuts[0].Flags.getByValSize() << "]";
1258 } else {
1259 if (shouldPassAsArray(Ty)) {
1260 Align ParamAlign =
1261 getArgumentAlignment(CB: &CB, Ty, Idx: I + AttributeList::FirstArgIndex, DL);
1262 O << ".param .align " << ParamAlign.value() << " .b8 _["
1263 << DL.getTypeAllocSize(Ty) << "]";
1264 continue;
1265 }
1266 // i8 types in IR will be i16 types in SDAG
1267 assert((getValueType(DL, Ty) == ArgOuts[0].VT ||
1268 (getValueType(DL, Ty) == MVT::i8 && ArgOuts[0].VT == MVT::i16)) &&
1269 "type mismatch between callee prototype and arguments");
1270 // scalar type
1271 unsigned sz = 0;
1272 if (auto *ITy = dyn_cast<IntegerType>(Val: Ty)) {
1273 sz = promoteScalarArgumentSize(size: ITy->getBitWidth());
1274 } else if (isa<PointerType>(Val: Ty)) {
1275 sz = PtrVT.getSizeInBits();
1276 } else {
1277 sz = Ty->getPrimitiveSizeInBits();
1278 }
1279 O << ".param .b" << sz << " _";
1280 }
1281 }
1282
1283 if (FirstVAArg)
1284 O << (first ? "" : ",") << " .param .align "
1285 << STI.getMaxRequiredAlignment() << " .b8 _[]";
1286 O << ")";
1287 if (shouldEmitPTXNoReturn(V: &CB, TM: *nvTM))
1288 O << " .noreturn";
1289 O << ";";
1290
1291 return Prototype;
1292}
1293
1294Align NVPTXTargetLowering::getFunctionArgumentAlignment(
1295 const Function *F, Type *Ty, unsigned Idx, const DataLayout &DL) const {
1296 return getAlign(F: *F, Index: Idx).value_or(u: getFunctionParamOptimizedAlign(F, ArgTy: Ty, DL));
1297}
1298
1299Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty,
1300 unsigned Idx,
1301 const DataLayout &DL) const {
1302 if (!CB) {
1303 // CallSite is zero, fallback to ABI type alignment
1304 return DL.getABITypeAlign(Ty);
1305 }
1306
1307 const Function *DirectCallee = CB->getCalledFunction();
1308
1309 if (!DirectCallee) {
1310 // We don't have a direct function symbol, but that may be because of
1311 // constant cast instructions in the call.
1312
1313 // With bitcast'd call targets, the instruction will be the call
1314 if (const auto *CI = dyn_cast<CallInst>(Val: CB)) {
1315 // Check if we have call alignment metadata
1316 if (MaybeAlign StackAlign = getAlign(*CI, Idx))
1317 return StackAlign.value();
1318 }
1319 DirectCallee = getMaybeBitcastedCallee(CB);
1320 }
1321
1322 // Check for function alignment information if we found that the
1323 // ultimate target is a Function
1324 if (DirectCallee)
1325 return getFunctionArgumentAlignment(F: DirectCallee, Ty, Idx, DL);
1326
1327 // Call is indirect, fall back to the ABI type alignment
1328 return DL.getABITypeAlign(Ty);
1329}
1330
1331static bool shouldConvertToIndirectCall(const CallBase *CB,
1332 const GlobalAddressSDNode *Func) {
1333 if (!Func)
1334 return false;
1335 if (auto *CalleeFunc = dyn_cast<Function>(Val: Func->getGlobal()))
1336 return CB->getFunctionType() != CalleeFunc->getFunctionType();
1337 return false;
1338}
1339
1340static MachinePointerInfo refinePtrAS(SDValue &Ptr, SelectionDAG &DAG,
1341 const DataLayout &DL,
1342 const TargetLowering &TL) {
1343 if (Ptr->getOpcode() == ISD::FrameIndex) {
1344 auto Ty = TL.getPointerTy(DL, AS: ADDRESS_SPACE_LOCAL);
1345 Ptr = DAG.getAddrSpaceCast(dl: SDLoc(), VT: Ty, Ptr, SrcAS: ADDRESS_SPACE_GENERIC,
1346 DestAS: ADDRESS_SPACE_LOCAL);
1347
1348 return MachinePointerInfo(ADDRESS_SPACE_LOCAL);
1349 }
1350
1351 // Peel of an addrspacecast to generic and load directly from the specific
1352 // address space.
1353 if (Ptr->getOpcode() == ISD::ADDRSPACECAST) {
1354 const auto *ASC = cast<AddrSpaceCastSDNode>(Val&: Ptr);
1355 if (ASC->getDestAddressSpace() == ADDRESS_SPACE_GENERIC) {
1356 Ptr = ASC->getOperand(Num: 0);
1357 return MachinePointerInfo(ASC->getSrcAddressSpace());
1358 }
1359 }
1360
1361 return MachinePointerInfo();
1362}
1363
1364static ISD::NodeType getExtOpcode(const ISD::ArgFlagsTy &Flags) {
1365 if (Flags.isSExt())
1366 return ISD::SIGN_EXTEND;
1367 if (Flags.isZExt())
1368 return ISD::ZERO_EXTEND;
1369 return ISD::ANY_EXTEND;
1370}
1371
1372static SDValue correctParamType(SDValue V, EVT ExpectedVT,
1373 ISD::ArgFlagsTy Flags, SelectionDAG &DAG,
1374 SDLoc dl) {
1375 const EVT ActualVT = V.getValueType();
1376 assert((ActualVT == ExpectedVT ||
1377 (ExpectedVT.isInteger() && ActualVT.isInteger())) &&
1378 "Non-integer argument type size mismatch");
1379 if (ExpectedVT.bitsGT(VT: ActualVT))
1380 return DAG.getNode(Opcode: getExtOpcode(Flags), DL: dl, VT: ExpectedVT, Operand: V);
1381 if (ExpectedVT.bitsLT(VT: ActualVT))
1382 return DAG.getNode(Opcode: ISD::TRUNCATE, DL: dl, VT: ExpectedVT, Operand: V);
1383
1384 return V;
1385}
1386
1387SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
1388 SmallVectorImpl<SDValue> &InVals) const {
1389
1390 if (CLI.IsVarArg && (STI.getPTXVersion() < 60 || STI.getSmVersion() < 30))
1391 report_fatal_error(
1392 reason: "Support for variadic functions (unsized array parameter) introduced "
1393 "in PTX ISA version 6.0 and requires target sm_30.");
1394
1395 SelectionDAG &DAG = CLI.DAG;
1396 SDLoc dl = CLI.DL;
1397 const SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
1398 SDValue Callee = CLI.Callee;
1399 ArgListTy &Args = CLI.getArgs();
1400 Type *RetTy = CLI.RetTy;
1401 const CallBase *CB = CLI.CB;
1402 const DataLayout &DL = DAG.getDataLayout();
1403 LLVMContext &Ctx = *DAG.getContext();
1404
1405 const auto GetI32 = [&](const unsigned I) {
1406 return DAG.getConstant(Val: I, DL: dl, VT: MVT::i32);
1407 };
1408
1409 const unsigned UniqueCallSite = GlobalUniqueCallSite++;
1410 const SDValue CallChain = CLI.Chain;
1411 const SDValue StartChain =
1412 DAG.getCALLSEQ_START(Chain: CallChain, InSize: UniqueCallSite, OutSize: 0, DL: dl);
1413 SDValue DeclareGlue = StartChain.getValue(R: 1);
1414
1415 SmallVector<SDValue, 16> CallPrereqs{StartChain};
1416
1417 const auto MakeDeclareScalarParam = [&](SDValue Symbol, unsigned Size) {
1418 // PTX ABI requires integral types to be at least 32 bits in size. FP16 is
1419 // loaded/stored using i16, so it's handled here as well.
1420 const unsigned SizeBits = promoteScalarArgumentSize(size: Size * 8);
1421 SDValue Declare =
1422 DAG.getNode(Opcode: NVPTXISD::DeclareScalarParam, DL: dl, ResultTys: {MVT::Other, MVT::Glue},
1423 Ops: {StartChain, Symbol, GetI32(SizeBits), DeclareGlue});
1424 CallPrereqs.push_back(Elt: Declare);
1425 DeclareGlue = Declare.getValue(R: 1);
1426 return Declare;
1427 };
1428
1429 const auto MakeDeclareArrayParam = [&](SDValue Symbol, Align Align,
1430 unsigned Size) {
1431 SDValue Declare = DAG.getNode(
1432 Opcode: NVPTXISD::DeclareArrayParam, DL: dl, ResultTys: {MVT::Other, MVT::Glue},
1433 Ops: {StartChain, Symbol, GetI32(Align.value()), GetI32(Size), DeclareGlue});
1434 CallPrereqs.push_back(Elt: Declare);
1435 DeclareGlue = Declare.getValue(R: 1);
1436 return Declare;
1437 };
1438
1439 // Variadic arguments.
1440 //
1441 // Normally, for each argument, we declare a param scalar or a param
1442 // byte array in the .param space, and store the argument value to that
1443 // param scalar or array starting at offset 0.
1444 //
1445 // In the case of the first variadic argument, we declare a vararg byte array
1446 // with size 0. The exact size of this array isn't known at this point, so
1447 // it'll be patched later. All the variadic arguments will be stored to this
1448 // array at a certain offset (which gets tracked by 'VAOffset'). The offset is
1449 // initially set to 0, so it can be used for non-variadic arguments (which use
1450 // 0 offset) to simplify the code.
1451 //
1452 // After all vararg is processed, 'VAOffset' holds the size of the
1453 // vararg byte array.
1454 assert((CLI.IsVarArg || CLI.Args.size() == CLI.NumFixedArgs) &&
1455 "Non-VarArg function with extra arguments");
1456
1457 const unsigned FirstVAArg = CLI.NumFixedArgs; // position of first variadic
1458 unsigned VAOffset = 0; // current offset in the param array
1459
1460 const SDValue VADeclareParam =
1461 CLI.Args.size() > FirstVAArg
1462 ? MakeDeclareArrayParam(getCallParamSymbol(DAG, I: FirstVAArg, T: MVT::i32),
1463 Align(STI.getMaxRequiredAlignment()), 0)
1464 : SDValue();
1465
1466 // Args.size() and Outs.size() need not match.
1467 // Outs.size() will be larger
1468 // * if there is an aggregate argument with multiple fields (each field
1469 // showing up separately in Outs)
1470 // * if there is a vector argument with more than typical vector-length
1471 // elements (generally if more than 4) where each vector element is
1472 // individually present in Outs.
1473 // So a different index should be used for indexing into Outs/OutVals.
1474 // See similar issue in LowerFormalArguments.
1475 auto AllOuts = ArrayRef(CLI.Outs);
1476 auto AllOutVals = ArrayRef(CLI.OutVals);
1477 assert(AllOuts.size() == AllOutVals.size() &&
1478 "Outs and OutVals must be the same size");
1479 // Declare the .params or .reg need to pass values
1480 // to the function
1481 for (const auto E : llvm::enumerate(First&: Args)) {
1482 const auto ArgI = E.index();
1483 const auto Arg = E.value();
1484 const auto ArgOuts =
1485 AllOuts.take_while(Pred: [&](auto O) { return O.OrigArgIndex == ArgI; });
1486 const auto ArgOutVals = AllOutVals.take_front(N: ArgOuts.size());
1487 AllOuts = AllOuts.drop_front(N: ArgOuts.size());
1488 AllOutVals = AllOutVals.drop_front(N: ArgOuts.size());
1489
1490 const bool IsVAArg = (ArgI >= FirstVAArg);
1491 const bool IsByVal = Arg.IsByVal;
1492
1493 const SDValue ParamSymbol =
1494 getCallParamSymbol(DAG, I: IsVAArg ? FirstVAArg : ArgI, T: MVT::i32);
1495
1496 assert((!IsByVal || Arg.IndirectType) &&
1497 "byval arg must have indirect type");
1498 Type *ETy = (IsByVal ? Arg.IndirectType : Arg.Ty);
1499
1500 const Align ArgAlign = [&]() {
1501 if (IsByVal) {
1502 // The ByValAlign in the Outs[OIdx].Flags is always set at this point,
1503 // so we don't need to worry whether it's naturally aligned or not.
1504 // See TargetLowering::LowerCallTo().
1505 const Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
1506 return getFunctionByValParamAlign(F: CB->getCalledFunction(), ArgTy: ETy,
1507 InitialAlign, DL);
1508 }
1509 return getArgumentAlignment(CB, Ty: Arg.Ty, Idx: ArgI + 1, DL);
1510 }();
1511
1512 const unsigned TySize = DL.getTypeAllocSize(Ty: ETy);
1513 assert((!IsByVal || TySize == ArgOuts[0].Flags.getByValSize()) &&
1514 "type size mismatch");
1515
1516 const SDValue ArgDeclare = [&]() {
1517 if (IsVAArg)
1518 return VADeclareParam;
1519
1520 if (IsByVal || shouldPassAsArray(Ty: Arg.Ty))
1521 return MakeDeclareArrayParam(ParamSymbol, ArgAlign, TySize);
1522
1523 assert(ArgOuts.size() == 1 && "We must pass only one value as non-array");
1524 assert((ArgOuts[0].VT.isInteger() || ArgOuts[0].VT.isFloatingPoint()) &&
1525 "Only int and float types are supported as non-array arguments");
1526
1527 return MakeDeclareScalarParam(ParamSymbol, TySize);
1528 }();
1529
1530 if (IsByVal) {
1531 assert(ArgOutVals.size() == 1 && "We must pass only one value as byval");
1532 SDValue SrcPtr = ArgOutVals[0];
1533 const auto PointerInfo = refinePtrAS(Ptr&: SrcPtr, DAG, DL, TL: *this);
1534 const Align BaseSrcAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
1535
1536 if (IsVAArg)
1537 VAOffset = alignTo(Size: VAOffset, A: ArgAlign);
1538
1539 SmallVector<EVT, 4> ValueVTs, MemVTs;
1540 SmallVector<TypeSize, 4> Offsets;
1541 ComputeValueVTs(TLI: *this, DL, Ty: ETy, ValueVTs, MemVTs: &MemVTs, Offsets: &Offsets);
1542
1543 unsigned J = 0;
1544 const auto VI = VectorizePTXValueVTs(ValueVTs: MemVTs, Offsets, ParamAlignment: ArgAlign, IsVAArg);
1545 for (const unsigned NumElts : VI) {
1546 EVT LoadVT = getVectorizedVT(VT: MemVTs[J], N: NumElts, C&: Ctx);
1547 Align SrcAlign = commonAlignment(A: BaseSrcAlign, Offset: Offsets[J]);
1548 SDValue SrcAddr = DAG.getObjectPtrOffset(SL: dl, Ptr: SrcPtr, Offset: Offsets[J]);
1549 SDValue SrcLoad =
1550 DAG.getLoad(VT: LoadVT, dl, Chain: CallChain, Ptr: SrcAddr, PtrInfo: PointerInfo, Alignment: SrcAlign);
1551
1552 TypeSize ParamOffset = Offsets[J].getWithIncrement(RHS: VAOffset);
1553 Align ParamAlign = commonAlignment(A: ArgAlign, Offset: ParamOffset);
1554 SDValue ParamAddr =
1555 DAG.getObjectPtrOffset(SL: dl, Ptr: ParamSymbol, Offset: ParamOffset);
1556 SDValue StoreParam =
1557 DAG.getStore(Chain: ArgDeclare, dl, Val: SrcLoad, Ptr: ParamAddr,
1558 PtrInfo: MachinePointerInfo(ADDRESS_SPACE_PARAM), Alignment: ParamAlign);
1559 CallPrereqs.push_back(Elt: StoreParam);
1560
1561 J += NumElts;
1562 }
1563 if (IsVAArg)
1564 VAOffset += TySize;
1565 } else {
1566 SmallVector<EVT, 16> VTs;
1567 SmallVector<uint64_t, 16> Offsets;
1568 ComputePTXValueVTs(TLI: *this, DL, Ctx, CallConv: CLI.CallConv, Ty: Arg.Ty, ValueVTs&: VTs, Offsets,
1569 StartingOffset: VAOffset);
1570 assert(VTs.size() == Offsets.size() && "Size mismatch");
1571 assert(VTs.size() == ArgOuts.size() && "Size mismatch");
1572
1573 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
1574 // than 32-bits are sign extended or zero extended, depending on
1575 // whether they are signed or unsigned types. This case applies
1576 // only to scalar parameters and not to aggregate values.
1577 const bool ExtendIntegerParam =
1578 Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty: Arg.Ty) < 32;
1579
1580 const auto GetStoredValue = [&](const unsigned I) {
1581 SDValue StVal = ArgOutVals[I];
1582 assert(promoteScalarIntegerPTX(StVal.getValueType()) ==
1583 StVal.getValueType() &&
1584 "OutVal type should always be legal");
1585
1586 const EVT VTI = promoteScalarIntegerPTX(VT: VTs[I]);
1587 const EVT StoreVT =
1588 ExtendIntegerParam ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
1589
1590 return correctParamType(V: StVal, ExpectedVT: StoreVT, Flags: ArgOuts[I].Flags, DAG, dl);
1591 };
1592
1593 unsigned J = 0;
1594 const auto VI = VectorizePTXValueVTs(ValueVTs: VTs, Offsets, ParamAlignment: ArgAlign, IsVAArg);
1595 for (const unsigned NumElts : VI) {
1596 const EVT EltVT = promoteScalarIntegerPTX(VT: VTs[J]);
1597
1598 unsigned Offset;
1599 if (IsVAArg) {
1600 // TODO: We may need to support vector types that can be passed
1601 // as scalars in variadic arguments.
1602 assert(NumElts == 1 &&
1603 "Vectorization should be disabled for vaargs.");
1604
1605 // Align each part of the variadic argument to their type.
1606 VAOffset = alignTo(Size: VAOffset, A: DAG.getEVTAlign(MemoryVT: EltVT));
1607 Offset = VAOffset;
1608
1609 const EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
1610 VAOffset += DL.getTypeAllocSize(Ty: TheStoreType.getTypeForEVT(Context&: Ctx));
1611 } else {
1612 assert(VAOffset == 0 && "VAOffset must be 0 for non-VA args");
1613 Offset = Offsets[J];
1614 }
1615
1616 SDValue Ptr =
1617 DAG.getObjectPtrOffset(SL: dl, Ptr: ParamSymbol, Offset: TypeSize::getFixed(ExactSize: Offset));
1618
1619 const MaybeAlign CurrentAlign = ExtendIntegerParam
1620 ? MaybeAlign(std::nullopt)
1621 : commonAlignment(A: ArgAlign, Offset);
1622
1623 SDValue Val =
1624 getBuildVectorizedValue(N: NumElts, dl, DAG, GetElement: [&](unsigned K) {
1625 return GetStoredValue(J + K);
1626 });
1627
1628 SDValue StoreParam =
1629 DAG.getStore(Chain: ArgDeclare, dl, Val, Ptr,
1630 PtrInfo: MachinePointerInfo(ADDRESS_SPACE_PARAM), Alignment: CurrentAlign);
1631 CallPrereqs.push_back(Elt: StoreParam);
1632
1633 J += NumElts;
1634 }
1635 }
1636 }
1637
1638 // Handle Result
1639 if (!Ins.empty()) {
1640 const SDValue RetSymbol = DAG.getExternalSymbol(Sym: "retval0", VT: MVT::i32);
1641 const unsigned ResultSize = DL.getTypeAllocSize(Ty: RetTy);
1642 if (shouldPassAsArray(Ty: RetTy)) {
1643 const Align RetAlign = getArgumentAlignment(CB, Ty: RetTy, Idx: 0, DL);
1644 MakeDeclareArrayParam(RetSymbol, RetAlign, ResultSize);
1645 } else {
1646 MakeDeclareScalarParam(RetSymbol, ResultSize);
1647 }
1648 }
1649
1650 // Set the size of the vararg param byte array if the callee is a variadic
1651 // function and the variadic part is not empty.
1652 if (VADeclareParam) {
1653 SDValue DeclareParamOps[] = {VADeclareParam.getOperand(i: 0),
1654 VADeclareParam.getOperand(i: 1),
1655 VADeclareParam.getOperand(i: 2), GetI32(VAOffset),
1656 VADeclareParam.getOperand(i: 4)};
1657 DAG.MorphNodeTo(N: VADeclareParam.getNode(), Opc: VADeclareParam.getOpcode(),
1658 VTs: VADeclareParam->getVTList(), Ops: DeclareParamOps);
1659 }
1660
1661 const auto *Func = dyn_cast<GlobalAddressSDNode>(Val: Callee.getNode());
1662 // If the type of the callsite does not match that of the function, convert
1663 // the callsite to an indirect call.
1664 const bool ConvertToIndirectCall = shouldConvertToIndirectCall(CB, Func);
1665
1666 // Both indirect calls and libcalls have nullptr Func. In order to distinguish
1667 // between them we must rely on the call site value which is valid for
1668 // indirect calls but is always null for libcalls.
1669 const bool IsIndirectCall = (!Func && CB) || ConvertToIndirectCall;
1670
1671 if (isa<ExternalSymbolSDNode>(Val: Callee)) {
1672 Function* CalleeFunc = nullptr;
1673
1674 // Try to find the callee in the current module.
1675 Callee = DAG.getSymbolFunctionGlobalAddress(Op: Callee, TargetFunction: &CalleeFunc);
1676 assert(CalleeFunc != nullptr && "Libcall callee must be set.");
1677
1678 // Set the "libcall callee" attribute to indicate that the function
1679 // must always have a declaration.
1680 CalleeFunc->addFnAttr(Kind: "nvptx-libcall-callee", Val: "true");
1681 }
1682
1683 if (IsIndirectCall) {
1684 // This is indirect function call case : PTX requires a prototype of the
1685 // form
1686 // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
1687 // to be emitted, and the label has to used as the last arg of call
1688 // instruction.
1689 // The prototype is embedded in a string and put as the operand for a
1690 // CallPrototype SDNode which will print out to the value of the string.
1691 const bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs);
1692 std::string Proto =
1693 getPrototype(DL, RetTy, Args, Outs: CLI.Outs,
1694 FirstVAArg: HasVAArgs ? std::optional(FirstVAArg) : std::nullopt, CB: *CB,
1695 UniqueCallSite);
1696 const char *ProtoStr = nvTM->getStrPool().save(S: Proto).data();
1697 const SDValue PrototypeDeclare = DAG.getNode(
1698 Opcode: NVPTXISD::CallPrototype, DL: dl, VT: MVT::Other,
1699 Ops: {StartChain, DAG.getTargetExternalSymbol(Sym: ProtoStr, VT: MVT::i32)});
1700 CallPrereqs.push_back(Elt: PrototypeDeclare);
1701 }
1702
1703 const unsigned Proto = IsIndirectCall ? UniqueCallSite : 0;
1704 const unsigned NumArgs =
1705 std::min<unsigned>(a: CLI.NumFixedArgs + 1, b: Args.size());
1706 /// CALL(Chain, IsConvergent, IsIndirectCall/IsUniform, NumReturns,
1707 /// NumParams, Callee, Proto)
1708 const SDValue CallToken = DAG.getTokenFactor(DL: dl, Vals&: CallPrereqs);
1709 const SDValue Call = DAG.getNode(
1710 Opcode: NVPTXISD::CALL, DL: dl, VT: MVT::Other,
1711 Ops: {CallToken, GetI32(CLI.IsConvergent), GetI32(IsIndirectCall),
1712 GetI32(Ins.empty() ? 0 : 1), GetI32(NumArgs), Callee, GetI32(Proto)});
1713
1714 SmallVector<SDValue, 16> LoadChains{Call};
1715 SmallVector<SDValue, 16> ProxyRegOps;
1716 if (!Ins.empty()) {
1717 SmallVector<EVT, 16> VTs;
1718 SmallVector<uint64_t, 16> Offsets;
1719 ComputePTXValueVTs(TLI: *this, DL, Ctx, CallConv: CLI.CallConv, Ty: RetTy, ValueVTs&: VTs, Offsets);
1720 assert(VTs.size() == Ins.size() && "Bad value decomposition");
1721
1722 const Align RetAlign = getArgumentAlignment(CB, Ty: RetTy, Idx: 0, DL);
1723 const SDValue RetSymbol = DAG.getExternalSymbol(Sym: "retval0", VT: MVT::i32);
1724
1725 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
1726 // 32-bits are sign extended or zero extended, depending on whether
1727 // they are signed or unsigned types.
1728 const bool ExtendIntegerRetVal =
1729 RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty: RetTy) < 32;
1730
1731 unsigned I = 0;
1732 const auto VI = VectorizePTXValueVTs(ValueVTs: VTs, Offsets, ParamAlignment: RetAlign);
1733 for (const unsigned NumElts : VI) {
1734 const MaybeAlign CurrentAlign =
1735 ExtendIntegerRetVal ? MaybeAlign(std::nullopt)
1736 : commonAlignment(A: RetAlign, Offset: Offsets[I]);
1737
1738 const EVT VTI = promoteScalarIntegerPTX(VT: VTs[I]);
1739 const EVT LoadVT =
1740 ExtendIntegerRetVal ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
1741 const EVT VecVT = getVectorizedVT(VT: LoadVT, N: NumElts, C&: Ctx);
1742 SDValue Ptr =
1743 DAG.getObjectPtrOffset(SL: dl, Ptr: RetSymbol, Offset: TypeSize::getFixed(ExactSize: Offsets[I]));
1744
1745 SDValue R =
1746 DAG.getLoad(VT: VecVT, dl, Chain: Call, Ptr,
1747 PtrInfo: MachinePointerInfo(ADDRESS_SPACE_PARAM), Alignment: CurrentAlign);
1748
1749 LoadChains.push_back(Elt: R.getValue(R: 1));
1750 for (const unsigned J : llvm::seq(Size: NumElts))
1751 ProxyRegOps.push_back(Elt: getExtractVectorizedValue(V: R, I: J, VT: LoadVT, dl, DAG));
1752 I += NumElts;
1753 }
1754 }
1755
1756 const SDValue EndToken = DAG.getTokenFactor(DL: dl, Vals&: LoadChains);
1757 const SDValue CallEnd = DAG.getCALLSEQ_END(Chain: EndToken, Size1: UniqueCallSite,
1758 Size2: UniqueCallSite + 1, Glue: SDValue(), DL: dl);
1759
1760 // Append ProxyReg instructions to the chain to make sure that `callseq_end`
1761 // will not get lost. Otherwise, during libcalls expansion, the nodes can become
1762 // dangling.
1763 for (const auto [I, Reg] : llvm::enumerate(First&: ProxyRegOps)) {
1764 SDValue Proxy =
1765 DAG.getNode(Opcode: NVPTXISD::ProxyReg, DL: dl, VT: Reg.getValueType(), Ops: {CallEnd, Reg});
1766 SDValue Ret = correctParamType(V: Proxy, ExpectedVT: Ins[I].VT, Flags: Ins[I].Flags, DAG, dl);
1767 InVals.push_back(Elt: Ret);
1768 }
1769
1770 // set IsTailCall to false for now, until we figure out how to express
1771 // tail call optimization in PTX
1772 CLI.IsTailCall = false;
1773 return CallEnd;
1774}
1775
1776SDValue NVPTXTargetLowering::LowerDYNAMIC_STACKALLOC(SDValue Op,
1777 SelectionDAG &DAG) const {
1778
1779 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
1780 const Function &Fn = DAG.getMachineFunction().getFunction();
1781
1782 DAG.getContext()->diagnose(DI: DiagnosticInfoUnsupported(
1783 Fn,
1784 "Support for dynamic alloca introduced in PTX ISA version 7.3 and "
1785 "requires target sm_52.",
1786 SDLoc(Op).getDebugLoc()));
1787 auto Ops = {DAG.getConstant(Val: 0, DL: SDLoc(), VT: Op.getValueType()),
1788 Op.getOperand(i: 0)};
1789 return DAG.getMergeValues(Ops, dl: SDLoc());
1790 }
1791
1792 SDLoc DL(Op.getNode());
1793 SDValue Chain = Op.getOperand(i: 0);
1794 SDValue Size = Op.getOperand(i: 1);
1795 uint64_t Align = Op.getConstantOperandVal(i: 2);
1796
1797 // The alignment on a ISD::DYNAMIC_STACKALLOC node may be 0 to indicate that
1798 // the default stack alignment should be used.
1799 if (Align == 0)
1800 Align = DAG.getSubtarget().getFrameLowering()->getStackAlign().value();
1801
1802 // The size for ptx alloca instruction is 64-bit for m64 and 32-bit for m32.
1803 const MVT LocalVT = getPointerTy(DL: DAG.getDataLayout(), AS: ADDRESS_SPACE_LOCAL);
1804
1805 SDValue Alloc =
1806 DAG.getNode(Opcode: NVPTXISD::DYNAMIC_STACKALLOC, DL, ResultTys: {LocalVT, MVT::Other},
1807 Ops: {Chain, DAG.getZExtOrTrunc(Op: Size, DL, VT: LocalVT),
1808 DAG.getTargetConstant(Val: Align, DL, VT: MVT::i32)});
1809
1810 SDValue ASC = DAG.getAddrSpaceCast(
1811 dl: DL, VT: Op.getValueType(), Ptr: Alloc, SrcAS: ADDRESS_SPACE_LOCAL, DestAS: ADDRESS_SPACE_GENERIC);
1812
1813 return DAG.getMergeValues(Ops: {ASC, SDValue(Alloc.getNode(), 1)}, dl: DL);
1814}
1815
1816SDValue NVPTXTargetLowering::LowerSTACKRESTORE(SDValue Op,
1817 SelectionDAG &DAG) const {
1818 SDLoc DL(Op.getNode());
1819 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
1820 const Function &Fn = DAG.getMachineFunction().getFunction();
1821
1822 DAG.getContext()->diagnose(DI: DiagnosticInfoUnsupported(
1823 Fn,
1824 "Support for stackrestore requires PTX ISA version >= 7.3 and target "
1825 ">= sm_52.",
1826 DL.getDebugLoc()));
1827 return Op.getOperand(i: 0);
1828 }
1829
1830 const MVT LocalVT = getPointerTy(DL: DAG.getDataLayout(), AS: ADDRESS_SPACE_LOCAL);
1831 SDValue Chain = Op.getOperand(i: 0);
1832 SDValue Ptr = Op.getOperand(i: 1);
1833 SDValue ASC = DAG.getAddrSpaceCast(dl: DL, VT: LocalVT, Ptr, SrcAS: ADDRESS_SPACE_GENERIC,
1834 DestAS: ADDRESS_SPACE_LOCAL);
1835 return DAG.getNode(Opcode: NVPTXISD::STACKRESTORE, DL, VT: MVT::Other, Ops: {Chain, ASC});
1836}
1837
1838SDValue NVPTXTargetLowering::LowerSTACKSAVE(SDValue Op,
1839 SelectionDAG &DAG) const {
1840 SDLoc DL(Op.getNode());
1841 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
1842 const Function &Fn = DAG.getMachineFunction().getFunction();
1843
1844 DAG.getContext()->diagnose(DI: DiagnosticInfoUnsupported(
1845 Fn,
1846 "Support for stacksave requires PTX ISA version >= 7.3 and target >= "
1847 "sm_52.",
1848 DL.getDebugLoc()));
1849 auto Ops = {DAG.getConstant(Val: 0, DL, VT: Op.getValueType()), Op.getOperand(i: 0)};
1850 return DAG.getMergeValues(Ops, dl: DL);
1851 }
1852
1853 const MVT LocalVT = getPointerTy(DL: DAG.getDataLayout(), AS: ADDRESS_SPACE_LOCAL);
1854 SDValue Chain = Op.getOperand(i: 0);
1855 SDValue SS =
1856 DAG.getNode(Opcode: NVPTXISD::STACKSAVE, DL, ResultTys: {LocalVT, MVT::Other}, Ops: Chain);
1857 SDValue ASC = DAG.getAddrSpaceCast(
1858 dl: DL, VT: Op.getValueType(), Ptr: SS, SrcAS: ADDRESS_SPACE_LOCAL, DestAS: ADDRESS_SPACE_GENERIC);
1859 return DAG.getMergeValues(Ops: {ASC, SDValue(SS.getNode(), 1)}, dl: DL);
1860}
1861
1862// By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
1863// (see LegalizeDAG.cpp). This is slow and uses local memory.
1864// We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
1865SDValue
1866NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
1867 SDNode *Node = Op.getNode();
1868 SDLoc dl(Node);
1869 SmallVector<SDValue, 8> Ops;
1870 unsigned NumOperands = Node->getNumOperands();
1871 for (unsigned i = 0; i < NumOperands; ++i) {
1872 SDValue SubOp = Node->getOperand(Num: i);
1873 EVT VVT = SubOp.getNode()->getValueType(ResNo: 0);
1874 EVT EltVT = VVT.getVectorElementType();
1875 unsigned NumSubElem = VVT.getVectorNumElements();
1876 for (unsigned j = 0; j < NumSubElem; ++j) {
1877 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: dl, VT: EltVT, N1: SubOp,
1878 N2: DAG.getIntPtrConstant(Val: j, DL: dl)));
1879 }
1880 }
1881 return DAG.getBuildVector(VT: Node->getValueType(ResNo: 0), DL: dl, Ops);
1882}
1883
1884static SDValue getPRMT(SDValue A, SDValue B, SDValue Selector, SDLoc DL,
1885 SelectionDAG &DAG,
1886 unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
1887 assert(A.getValueType() == MVT::i32 && B.getValueType() == MVT::i32 &&
1888 Selector.getValueType() == MVT::i32 && "PRMT must have i32 operands");
1889 return DAG.getNode(Opcode: NVPTXISD::PRMT, DL, VT: MVT::i32,
1890 Ops: {A, B, Selector, DAG.getConstant(Val: Mode, DL, VT: MVT::i32)});
1891}
1892
1893static SDValue getPRMT(SDValue A, SDValue B, uint64_t Selector, SDLoc DL,
1894 SelectionDAG &DAG,
1895 unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
1896 return getPRMT(A, B, Selector: DAG.getConstant(Val: Selector, DL, VT: MVT::i32), DL, DAG, Mode);
1897}
1898
1899/// Reduces the elements using the scalar operations provided. The operations
1900/// are sorted descending in number of inputs they take. The flags on the
1901/// original reduction operation will be propagated to each scalar operation.
1902/// Nearby elements are grouped in tree reduction, unlike the shuffle reduction
1903/// used in ExpandReductions and SelectionDAG.
1904static SDValue buildTreeReduction(
1905 const SmallVector<SDValue> &Elements, EVT EltTy,
1906 ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
1907 const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
1908 // Build the reduction tree at each level, starting with all the elements.
1909 SmallVector<SDValue> Level = Elements;
1910
1911 unsigned OpIdx = 0;
1912 while (Level.size() > 1) {
1913 // Try to reduce this level using the current operator.
1914 const auto [Op, NumInputs] = Ops[OpIdx];
1915
1916 // Build the next level by partially reducing all elements.
1917 SmallVector<SDValue> ReducedLevel;
1918 unsigned I = 0, E = Level.size();
1919 for (; I + NumInputs <= E; I += NumInputs) {
1920 // Reduce elements in groups of [NumInputs], as much as possible.
1921 ReducedLevel.push_back(Elt: DAG.getNode(
1922 Opcode: Op, DL, VT: EltTy, Ops: ArrayRef<SDValue>(Level).slice(N: I, M: NumInputs), Flags));
1923 }
1924
1925 if (I < E) {
1926 // Handle leftover elements.
1927
1928 if (ReducedLevel.empty()) {
1929 // We didn't reduce anything at this level. We need to pick a smaller
1930 // operator.
1931 ++OpIdx;
1932 assert(OpIdx < Ops.size() && "no smaller operators for reduction");
1933 continue;
1934 }
1935
1936 // We reduced some things but there's still more left, meaning the
1937 // operator's number of inputs doesn't evenly divide this level size. Move
1938 // these elements to the next level.
1939 for (; I < E; ++I)
1940 ReducedLevel.push_back(Elt: Level[I]);
1941 }
1942
1943 // Process the next level.
1944 Level = ReducedLevel;
1945 }
1946
1947 return *Level.begin();
1948}
1949
1950// Get scalar reduction opcode
1951static ISD::NodeType getScalarOpcodeForReduction(unsigned ReductionOpcode) {
1952 switch (ReductionOpcode) {
1953 case ISD::VECREDUCE_FMAX:
1954 return ISD::FMAXNUM;
1955 case ISD::VECREDUCE_FMIN:
1956 return ISD::FMINNUM;
1957 case ISD::VECREDUCE_FMAXIMUM:
1958 return ISD::FMAXIMUM;
1959 case ISD::VECREDUCE_FMINIMUM:
1960 return ISD::FMINIMUM;
1961 default:
1962 llvm_unreachable("unhandled reduction opcode");
1963 }
1964}
1965
1966/// Get 3-input scalar reduction opcode
1967static std::optional<unsigned>
1968getScalar3OpcodeForReduction(unsigned ReductionOpcode) {
1969 switch (ReductionOpcode) {
1970 case ISD::VECREDUCE_FMAX:
1971 return NVPTXISD::FMAXNUM3;
1972 case ISD::VECREDUCE_FMIN:
1973 return NVPTXISD::FMINNUM3;
1974 case ISD::VECREDUCE_FMAXIMUM:
1975 return NVPTXISD::FMAXIMUM3;
1976 case ISD::VECREDUCE_FMINIMUM:
1977 return NVPTXISD::FMINIMUM3;
1978 default:
1979 return std::nullopt;
1980 }
1981}
1982
1983/// Lower reductions to either a sequence of operations or a tree if
1984/// reassociations are allowed. This method will use larger operations like
1985/// max3/min3 when the target supports them.
1986SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
1987 SelectionDAG &DAG) const {
1988 SDLoc DL(Op);
1989 const SDNodeFlags Flags = Op->getFlags();
1990 SDValue Vector = Op.getOperand(i: 0);
1991
1992 const unsigned Opcode = Op->getOpcode();
1993 const EVT EltTy = Vector.getValueType().getVectorElementType();
1994
1995 // Whether we can use 3-input min/max when expanding the reduction.
1996 const bool CanUseMinMax3 =
1997 EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
1998 STI.getPTXVersion() >= 88 &&
1999 (Opcode == ISD::VECREDUCE_FMAX || Opcode == ISD::VECREDUCE_FMIN ||
2000 Opcode == ISD::VECREDUCE_FMAXIMUM || Opcode == ISD::VECREDUCE_FMINIMUM);
2001
2002 // A list of SDNode opcodes with equivalent semantics, sorted descending by
2003 // number of inputs they take.
2004 SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> ScalarOps;
2005
2006 if (auto Opcode3Elem = getScalar3OpcodeForReduction(ReductionOpcode: Opcode);
2007 CanUseMinMax3 && Opcode3Elem)
2008 ScalarOps.push_back(Elt: {*Opcode3Elem, 3});
2009 ScalarOps.push_back(Elt: {getScalarOpcodeForReduction(ReductionOpcode: Opcode), 2});
2010
2011 SmallVector<SDValue> Elements;
2012 DAG.ExtractVectorElements(Op: Vector, Args&: Elements);
2013
2014 return buildTreeReduction(Elements, EltTy, Ops: ScalarOps, DL, Flags, DAG);
2015}
2016
2017SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
2018 // Handle bitcasting from v2i8 without hitting the default promotion
2019 // strategy which goes through stack memory.
2020 EVT FromVT = Op->getOperand(Num: 0)->getValueType(ResNo: 0);
2021 if (FromVT != MVT::v2i8) {
2022 return Op;
2023 }
2024
2025 // Pack vector elements into i16 and bitcast to final type
2026 SDLoc DL(Op);
2027 SDValue Vec0 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i8,
2028 N1: Op->getOperand(Num: 0), N2: DAG.getIntPtrConstant(Val: 0, DL));
2029 SDValue Vec1 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i8,
2030 N1: Op->getOperand(Num: 0), N2: DAG.getIntPtrConstant(Val: 1, DL));
2031 SDValue Extend0 = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: MVT::i16, Operand: Vec0);
2032 SDValue Extend1 = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: MVT::i16, Operand: Vec1);
2033 SDValue Const8 = DAG.getConstant(Val: 8, DL, VT: MVT::i16);
2034 SDValue AsInt = DAG.getNode(
2035 Opcode: ISD::OR, DL, VT: MVT::i16,
2036 Ops: {Extend0, DAG.getNode(Opcode: ISD::SHL, DL, VT: MVT::i16, Ops: {Extend1, Const8})});
2037 EVT ToVT = Op->getValueType(ResNo: 0);
2038 return DAG.getBitcast(VT: ToVT, V: AsInt);
2039}
2040
2041// We can init constant f16x2/v2i16/v4i8 with a single .b32 move. Normally it
2042// would get lowered as two constant loads and vector-packing move.
2043// Instead we want just a constant move:
2044// mov.b32 %r2, 0x40003C00
2045SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
2046 SelectionDAG &DAG) const {
2047 EVT VT = Op->getValueType(ResNo: 0);
2048 if (!(NVPTX::isPackedVectorTy(VT) && VT.is32BitVector()))
2049 return Op;
2050 SDLoc DL(Op);
2051
2052 if (!llvm::all_of(Range: Op->ops(), P: [](SDValue Operand) {
2053 return Operand->isUndef() || isa<ConstantSDNode>(Val: Operand) ||
2054 isa<ConstantFPSDNode>(Val: Operand);
2055 })) {
2056 if (VT != MVT::v4i8)
2057 return Op;
2058 // Lower non-const v4i8 vector as byte-wise constructed i32, which allows us
2059 // to optimize calculation of constant parts.
2060 auto GetPRMT = [&](const SDValue Left, const SDValue Right, bool Cast,
2061 uint64_t SelectionValue) -> SDValue {
2062 SDValue L = Left;
2063 SDValue R = Right;
2064 if (Cast) {
2065 L = DAG.getAnyExtOrTrunc(Op: L, DL, VT: MVT::i32);
2066 R = DAG.getAnyExtOrTrunc(Op: R, DL, VT: MVT::i32);
2067 }
2068 return getPRMT(A: L, B: R, Selector: SelectionValue, DL, DAG);
2069 };
2070 auto PRMT__10 = GetPRMT(Op->getOperand(Num: 0), Op->getOperand(Num: 1), true, 0x3340);
2071 auto PRMT__32 = GetPRMT(Op->getOperand(Num: 2), Op->getOperand(Num: 3), true, 0x3340);
2072 auto PRMT3210 = GetPRMT(PRMT__10, PRMT__32, false, 0x5410);
2073 return DAG.getBitcast(VT, V: PRMT3210);
2074 }
2075
2076 // Get value or the Nth operand as an APInt(32). Undef values treated as 0.
2077 auto GetOperand = [](SDValue Op, int N) -> APInt {
2078 const SDValue &Operand = Op->getOperand(Num: N);
2079 EVT VT = Op->getValueType(ResNo: 0);
2080 if (Operand->isUndef())
2081 return APInt(32, 0);
2082 APInt Value;
2083 if (VT == MVT::v2f16 || VT == MVT::v2bf16)
2084 Value = cast<ConstantFPSDNode>(Val: Operand)->getValueAPF().bitcastToAPInt();
2085 else if (VT == MVT::v2i16 || VT == MVT::v4i8)
2086 Value = Operand->getAsAPIntVal();
2087 else
2088 llvm_unreachable("Unsupported type");
2089 // i8 values are carried around as i16, so we need to zero out upper bits,
2090 // so they do not get in the way of combining individual byte values
2091 if (VT == MVT::v4i8)
2092 Value = Value.trunc(width: 8);
2093 return Value.zext(width: 32);
2094 };
2095
2096 // Construct a 32-bit constant by shifting into place smaller values
2097 // (elements of the vector type VT).
2098 // For example, if VT has 2 elements, then N == 2:
2099 // ShiftAmount = 32 / N = 16
2100 // Value |= Op0 (b16) << 0
2101 // Value |= Op1 (b16) << 16
2102 // If N == 4:
2103 // ShiftAmount = 32 / N = 8
2104 // Value |= Op0 (b8) << 0
2105 // Value |= Op1 (b8) << 8
2106 // Value |= Op2 (b8) << 16
2107 // Value |= Op3 (b8) << 24
2108 // ...etc
2109 APInt Value(32, 0);
2110 const unsigned NumElements = VT.getVectorNumElements();
2111 assert(32 % NumElements == 0 && "must evenly divide bit length");
2112 const unsigned ShiftAmount = 32 / NumElements;
2113 for (unsigned ElementNo : seq(Size: NumElements))
2114 Value |= GetOperand(Op, ElementNo).shl(shiftAmt: ElementNo * ShiftAmount);
2115 SDValue Const = DAG.getConstant(Val: Value, DL, VT: MVT::i32);
2116 return DAG.getNode(Opcode: ISD::BITCAST, DL, VT: Op->getValueType(ResNo: 0), Operand: Const);
2117}
2118
2119SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
2120 SelectionDAG &DAG) const {
2121 SDValue Index = Op->getOperand(Num: 1);
2122 SDValue Vector = Op->getOperand(Num: 0);
2123 SDLoc DL(Op);
2124 EVT VectorVT = Vector.getValueType();
2125
2126 if (VectorVT == MVT::v4i8) {
2127 SDValue Selector = DAG.getNode(Opcode: ISD::OR, DL, VT: MVT::i32,
2128 N1: DAG.getZExtOrTrunc(Op: Index, DL, VT: MVT::i32),
2129 N2: DAG.getConstant(Val: 0x7770, DL, VT: MVT::i32));
2130 SDValue PRMT = getPRMT(A: DAG.getBitcast(VT: MVT::i32, V: Vector),
2131 B: DAG.getConstant(Val: 0, DL, VT: MVT::i32), Selector, DL, DAG);
2132 SDValue Ext = DAG.getAnyExtOrTrunc(Op: PRMT, DL, VT: Op->getValueType(ResNo: 0));
2133 SDNodeFlags Flags;
2134 Flags.setNoSignedWrap(Ext.getScalarValueSizeInBits() > 8);
2135 Flags.setNoUnsignedWrap(Ext.getScalarValueSizeInBits() >= 8);
2136 Ext->setFlags(Flags);
2137 return Ext;
2138 }
2139
2140 // Constant index will be matched by tablegen.
2141 if (isa<ConstantSDNode>(Val: Index.getNode()))
2142 return Op;
2143
2144 // Extract individual elements and select one of them.
2145 assert(NVPTX::isPackedVectorTy(VectorVT) &&
2146 VectorVT.getVectorNumElements() == 2 && "Unexpected vector type.");
2147 EVT EltVT = VectorVT.getVectorElementType();
2148
2149 SDLoc dl(Op.getNode());
2150 SDValue E0 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: dl, VT: EltVT, N1: Vector,
2151 N2: DAG.getIntPtrConstant(Val: 0, DL: dl));
2152 SDValue E1 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: dl, VT: EltVT, N1: Vector,
2153 N2: DAG.getIntPtrConstant(Val: 1, DL: dl));
2154 return DAG.getSelectCC(DL: dl, LHS: Index, RHS: DAG.getIntPtrConstant(Val: 0, DL: dl), True: E0, False: E1,
2155 Cond: ISD::CondCode::SETEQ);
2156}
2157
2158SDValue NVPTXTargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op,
2159 SelectionDAG &DAG) const {
2160 SDValue Vector = Op->getOperand(Num: 0);
2161 EVT VectorVT = Vector.getValueType();
2162
2163 if (VectorVT != MVT::v4i8)
2164 return Op;
2165 SDLoc DL(Op);
2166 SDValue Value = Op->getOperand(Num: 1);
2167 if (Value->isUndef())
2168 return Vector;
2169
2170 SDValue Index = Op->getOperand(Num: 2);
2171
2172 SDValue BFI =
2173 DAG.getNode(Opcode: NVPTXISD::BFI, DL, VT: MVT::i32,
2174 Ops: {DAG.getZExtOrTrunc(Op: Value, DL, VT: MVT::i32), Vector,
2175 DAG.getNode(Opcode: ISD::MUL, DL, VT: MVT::i32,
2176 N1: DAG.getZExtOrTrunc(Op: Index, DL, VT: MVT::i32),
2177 N2: DAG.getConstant(Val: 8, DL, VT: MVT::i32)),
2178 DAG.getConstant(Val: 8, DL, VT: MVT::i32)});
2179 return DAG.getNode(Opcode: ISD::BITCAST, DL, VT: Op->getValueType(ResNo: 0), Operand: BFI);
2180}
2181
2182SDValue NVPTXTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
2183 SelectionDAG &DAG) const {
2184 SDValue V1 = Op.getOperand(i: 0);
2185 EVT VectorVT = V1.getValueType();
2186 if (VectorVT != MVT::v4i8 || Op.getValueType() != MVT::v4i8)
2187 return Op;
2188
2189 // Lower shuffle to PRMT instruction.
2190 const ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Val: Op.getNode());
2191 SDValue V2 = Op.getOperand(i: 1);
2192 uint32_t Selector = 0;
2193 for (auto I : llvm::enumerate(First: SVN->getMask())) {
2194 if (I.value() != -1) // -1 is a placeholder for undef.
2195 Selector |= (I.value() << (I.index() * 4));
2196 }
2197
2198 SDLoc DL(Op);
2199 SDValue PRMT = getPRMT(A: DAG.getBitcast(VT: MVT::i32, V: V1),
2200 B: DAG.getBitcast(VT: MVT::i32, V: V2), Selector, DL, DAG);
2201 return DAG.getBitcast(VT: Op.getValueType(), V: PRMT);
2202}
2203/// LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which
2204/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
2205/// amount, or
2206/// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift
2207/// amount.
2208SDValue NVPTXTargetLowering::LowerShiftRightParts(SDValue Op,
2209 SelectionDAG &DAG) const {
2210 assert(Op.getNumOperands() == 3 && "Not a double-shift!");
2211 assert(Op.getOpcode() == ISD::SRA_PARTS || Op.getOpcode() == ISD::SRL_PARTS);
2212
2213 EVT VT = Op.getValueType();
2214 unsigned VTBits = VT.getSizeInBits();
2215 SDLoc dl(Op);
2216 SDValue ShOpLo = Op.getOperand(i: 0);
2217 SDValue ShOpHi = Op.getOperand(i: 1);
2218 SDValue ShAmt = Op.getOperand(i: 2);
2219 unsigned Opc = (Op.getOpcode() == ISD::SRA_PARTS) ? ISD::SRA : ISD::SRL;
2220
2221 if (VTBits == 32 && STI.getSmVersion() >= 35) {
2222 // For 32bit and sm35, we can use the funnel shift 'shf' instruction.
2223 // {dHi, dLo} = {aHi, aLo} >> Amt
2224 // dHi = aHi >> Amt
2225 // dLo = shf.r.clamp aLo, aHi, Amt
2226
2227 SDValue Hi = DAG.getNode(Opcode: Opc, DL: dl, VT, N1: ShOpHi, N2: ShAmt);
2228 SDValue Lo =
2229 DAG.getNode(Opcode: NVPTXISD::FSHR_CLAMP, DL: dl, VT, N1: ShOpHi, N2: ShOpLo, N3: ShAmt);
2230
2231 SDValue Ops[2] = { Lo, Hi };
2232 return DAG.getMergeValues(Ops, dl);
2233 }
2234 else {
2235 // {dHi, dLo} = {aHi, aLo} >> Amt
2236 // - if (Amt>=size) then
2237 // dLo = aHi >> (Amt-size)
2238 // dHi = aHi >> Amt (this is either all 0 or all 1)
2239 // else
2240 // dLo = (aLo >>logic Amt) | (aHi << (size-Amt))
2241 // dHi = aHi >> Amt
2242
2243 SDValue RevShAmt = DAG.getNode(Opcode: ISD::SUB, DL: dl, VT: MVT::i32,
2244 N1: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32),
2245 N2: ShAmt);
2246 SDValue Tmp1 = DAG.getNode(Opcode: ISD::SRL, DL: dl, VT, N1: ShOpLo, N2: ShAmt);
2247 SDValue ExtraShAmt = DAG.getNode(Opcode: ISD::SUB, DL: dl, VT: MVT::i32, N1: ShAmt,
2248 N2: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32));
2249 SDValue Tmp2 = DAG.getNode(Opcode: ISD::SHL, DL: dl, VT, N1: ShOpHi, N2: RevShAmt);
2250 SDValue FalseVal = DAG.getNode(Opcode: ISD::OR, DL: dl, VT, N1: Tmp1, N2: Tmp2);
2251 SDValue TrueVal = DAG.getNode(Opcode: Opc, DL: dl, VT, N1: ShOpHi, N2: ExtraShAmt);
2252
2253 SDValue Cmp = DAG.getSetCC(DL: dl, VT: MVT::i1, LHS: ShAmt,
2254 RHS: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32),
2255 Cond: ISD::SETGE);
2256 SDValue Hi = DAG.getNode(Opcode: Opc, DL: dl, VT, N1: ShOpHi, N2: ShAmt);
2257 SDValue Lo = DAG.getNode(Opcode: ISD::SELECT, DL: dl, VT, N1: Cmp, N2: TrueVal, N3: FalseVal);
2258
2259 SDValue Ops[2] = { Lo, Hi };
2260 return DAG.getMergeValues(Ops, dl);
2261 }
2262}
2263
2264/// LowerShiftLeftParts - Lower SHL_PARTS, which
2265/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
2266/// amount, or
2267/// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift
2268/// amount.
2269SDValue NVPTXTargetLowering::LowerShiftLeftParts(SDValue Op,
2270 SelectionDAG &DAG) const {
2271 assert(Op.getNumOperands() == 3 && "Not a double-shift!");
2272 assert(Op.getOpcode() == ISD::SHL_PARTS);
2273
2274 EVT VT = Op.getValueType();
2275 unsigned VTBits = VT.getSizeInBits();
2276 SDLoc dl(Op);
2277 SDValue ShOpLo = Op.getOperand(i: 0);
2278 SDValue ShOpHi = Op.getOperand(i: 1);
2279 SDValue ShAmt = Op.getOperand(i: 2);
2280
2281 if (VTBits == 32 && STI.getSmVersion() >= 35) {
2282 // For 32bit and sm35, we can use the funnel shift 'shf' instruction.
2283 // {dHi, dLo} = {aHi, aLo} << Amt
2284 // dHi = shf.l.clamp aLo, aHi, Amt
2285 // dLo = aLo << Amt
2286
2287 SDValue Hi =
2288 DAG.getNode(Opcode: NVPTXISD::FSHL_CLAMP, DL: dl, VT, N1: ShOpHi, N2: ShOpLo, N3: ShAmt);
2289 SDValue Lo = DAG.getNode(Opcode: ISD::SHL, DL: dl, VT, N1: ShOpLo, N2: ShAmt);
2290
2291 SDValue Ops[2] = { Lo, Hi };
2292 return DAG.getMergeValues(Ops, dl);
2293 }
2294 else {
2295 // {dHi, dLo} = {aHi, aLo} << Amt
2296 // - if (Amt>=size) then
2297 // dLo = aLo << Amt (all 0)
2298 // dLo = aLo << (Amt-size)
2299 // else
2300 // dLo = aLo << Amt
2301 // dHi = (aHi << Amt) | (aLo >> (size-Amt))
2302
2303 SDValue RevShAmt = DAG.getNode(Opcode: ISD::SUB, DL: dl, VT: MVT::i32,
2304 N1: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32),
2305 N2: ShAmt);
2306 SDValue Tmp1 = DAG.getNode(Opcode: ISD::SHL, DL: dl, VT, N1: ShOpHi, N2: ShAmt);
2307 SDValue ExtraShAmt = DAG.getNode(Opcode: ISD::SUB, DL: dl, VT: MVT::i32, N1: ShAmt,
2308 N2: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32));
2309 SDValue Tmp2 = DAG.getNode(Opcode: ISD::SRL, DL: dl, VT, N1: ShOpLo, N2: RevShAmt);
2310 SDValue FalseVal = DAG.getNode(Opcode: ISD::OR, DL: dl, VT, N1: Tmp1, N2: Tmp2);
2311 SDValue TrueVal = DAG.getNode(Opcode: ISD::SHL, DL: dl, VT, N1: ShOpLo, N2: ExtraShAmt);
2312
2313 SDValue Cmp = DAG.getSetCC(DL: dl, VT: MVT::i1, LHS: ShAmt,
2314 RHS: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32),
2315 Cond: ISD::SETGE);
2316 SDValue Lo = DAG.getNode(Opcode: ISD::SHL, DL: dl, VT, N1: ShOpLo, N2: ShAmt);
2317 SDValue Hi = DAG.getNode(Opcode: ISD::SELECT, DL: dl, VT, N1: Cmp, N2: TrueVal, N3: FalseVal);
2318
2319 SDValue Ops[2] = { Lo, Hi };
2320 return DAG.getMergeValues(Ops, dl);
2321 }
2322}
2323
2324/// If the types match, convert the generic copysign to the NVPTXISD version,
2325/// otherwise bail ensuring that mismatched cases are properly expaned.
2326SDValue NVPTXTargetLowering::LowerFCOPYSIGN(SDValue Op,
2327 SelectionDAG &DAG) const {
2328 EVT VT = Op.getValueType();
2329 SDLoc DL(Op);
2330
2331 SDValue In1 = Op.getOperand(i: 0);
2332 SDValue In2 = Op.getOperand(i: 1);
2333 EVT SrcVT = In2.getValueType();
2334
2335 if (!SrcVT.bitsEq(VT))
2336 return SDValue();
2337
2338 return DAG.getNode(Opcode: NVPTXISD::FCOPYSIGN, DL, VT, N1: In1, N2: In2);
2339}
2340
2341SDValue NVPTXTargetLowering::LowerFROUND(SDValue Op, SelectionDAG &DAG) const {
2342 EVT VT = Op.getValueType();
2343
2344 if (VT == MVT::f32)
2345 return LowerFROUND32(Op, DAG);
2346
2347 if (VT == MVT::f64)
2348 return LowerFROUND64(Op, DAG);
2349
2350 llvm_unreachable("unhandled type");
2351}
2352
2353// This is the the rounding method used in CUDA libdevice in C like code:
2354// float roundf(float A)
2355// {
2356// float RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f));
2357// RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
2358// return abs(A) < 0.5 ? (float)(int)A : RoundedA;
2359// }
2360SDValue NVPTXTargetLowering::LowerFROUND32(SDValue Op,
2361 SelectionDAG &DAG) const {
2362 SDLoc SL(Op);
2363 SDValue A = Op.getOperand(i: 0);
2364 EVT VT = Op.getValueType();
2365
2366 SDValue AbsA = DAG.getNode(Opcode: ISD::FABS, DL: SL, VT, Operand: A);
2367
2368 // RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f))
2369 SDValue Bitcast = DAG.getNode(Opcode: ISD::BITCAST, DL: SL, VT: MVT::i32, Operand: A);
2370 const unsigned SignBitMask = 0x80000000;
2371 SDValue Sign = DAG.getNode(Opcode: ISD::AND, DL: SL, VT: MVT::i32, N1: Bitcast,
2372 N2: DAG.getConstant(Val: SignBitMask, DL: SL, VT: MVT::i32));
2373 const unsigned PointFiveInBits = 0x3F000000;
2374 SDValue PointFiveWithSignRaw =
2375 DAG.getNode(Opcode: ISD::OR, DL: SL, VT: MVT::i32, N1: Sign,
2376 N2: DAG.getConstant(Val: PointFiveInBits, DL: SL, VT: MVT::i32));
2377 SDValue PointFiveWithSign =
2378 DAG.getNode(Opcode: ISD::BITCAST, DL: SL, VT, Operand: PointFiveWithSignRaw);
2379 SDValue AdjustedA = DAG.getNode(Opcode: ISD::FADD, DL: SL, VT, N1: A, N2: PointFiveWithSign);
2380 SDValue RoundedA = DAG.getNode(Opcode: ISD::FTRUNC, DL: SL, VT, Operand: AdjustedA);
2381
2382 // RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
2383 EVT SetCCVT = getSetCCResultType(DL: DAG.getDataLayout(), Ctx&: *DAG.getContext(), VT);
2384 SDValue IsLarge =
2385 DAG.getSetCC(DL: SL, VT: SetCCVT, LHS: AbsA, RHS: DAG.getConstantFP(Val: pow(x: 2.0, y: 23.0), DL: SL, VT),
2386 Cond: ISD::SETOGT);
2387 RoundedA = DAG.getNode(Opcode: ISD::SELECT, DL: SL, VT, N1: IsLarge, N2: A, N3: RoundedA);
2388
2389 // return abs(A) < 0.5 ? (float)(int)A : RoundedA;
2390 SDValue IsSmall =DAG.getSetCC(DL: SL, VT: SetCCVT, LHS: AbsA,
2391 RHS: DAG.getConstantFP(Val: 0.5, DL: SL, VT), Cond: ISD::SETOLT);
2392 SDValue RoundedAForSmallA = DAG.getNode(Opcode: ISD::FTRUNC, DL: SL, VT, Operand: A);
2393 return DAG.getNode(Opcode: ISD::SELECT, DL: SL, VT, N1: IsSmall, N2: RoundedAForSmallA, N3: RoundedA);
2394}
2395
2396// The implementation of round(double) is similar to that of round(float) in
2397// that they both separate the value range into three regions and use a method
2398// specific to the region to round the values. However, round(double) first
2399// calculates the round of the absolute value and then adds the sign back while
2400// round(float) directly rounds the value with sign.
2401SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
2402 SelectionDAG &DAG) const {
2403 SDLoc SL(Op);
2404 SDValue A = Op.getOperand(i: 0);
2405 EVT VT = Op.getValueType();
2406
2407 SDValue AbsA = DAG.getNode(Opcode: ISD::FABS, DL: SL, VT, Operand: A);
2408
2409 // double RoundedA = (double) (int) (abs(A) + 0.5f);
2410 SDValue AdjustedA = DAG.getNode(Opcode: ISD::FADD, DL: SL, VT, N1: AbsA,
2411 N2: DAG.getConstantFP(Val: 0.5, DL: SL, VT));
2412 SDValue RoundedA = DAG.getNode(Opcode: ISD::FTRUNC, DL: SL, VT, Operand: AdjustedA);
2413
2414 // RoundedA = abs(A) < 0.5 ? (double)0 : RoundedA;
2415 EVT SetCCVT = getSetCCResultType(DL: DAG.getDataLayout(), Ctx&: *DAG.getContext(), VT);
2416 SDValue IsSmall =DAG.getSetCC(DL: SL, VT: SetCCVT, LHS: AbsA,
2417 RHS: DAG.getConstantFP(Val: 0.5, DL: SL, VT), Cond: ISD::SETOLT);
2418 RoundedA = DAG.getNode(Opcode: ISD::SELECT, DL: SL, VT, N1: IsSmall,
2419 N2: DAG.getConstantFP(Val: 0, DL: SL, VT),
2420 N3: RoundedA);
2421
2422 // Add sign to rounded_A
2423 RoundedA = DAG.getNode(Opcode: ISD::FCOPYSIGN, DL: SL, VT, N1: RoundedA, N2: A);
2424 DAG.getNode(Opcode: ISD::FTRUNC, DL: SL, VT, Operand: A);
2425
2426 // RoundedA = abs(A) > 0x1.0p52 ? A : RoundedA;
2427 SDValue IsLarge =
2428 DAG.getSetCC(DL: SL, VT: SetCCVT, LHS: AbsA, RHS: DAG.getConstantFP(Val: pow(x: 2.0, y: 52.0), DL: SL, VT),
2429 Cond: ISD::SETOGT);
2430 return DAG.getNode(Opcode: ISD::SELECT, DL: SL, VT, N1: IsLarge, N2: A, N3: RoundedA);
2431}
2432
2433static SDValue PromoteBinOpToF32(SDNode *N, SelectionDAG &DAG) {
2434 EVT VT = N->getValueType(ResNo: 0);
2435 EVT NVT = MVT::f32;
2436 if (VT.isVector()) {
2437 NVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: NVT, EC: VT.getVectorElementCount());
2438 }
2439 SDLoc DL(N);
2440 SDValue Tmp0 = DAG.getFPExtendOrRound(Op: N->getOperand(Num: 0), DL, VT: NVT);
2441 SDValue Tmp1 = DAG.getFPExtendOrRound(Op: N->getOperand(Num: 1), DL, VT: NVT);
2442 SDValue Res = DAG.getNode(Opcode: N->getOpcode(), DL, VT: NVT, N1: Tmp0, N2: Tmp1, Flags: N->getFlags());
2443 return DAG.getFPExtendOrRound(Op: Res, DL, VT);
2444}
2445
2446SDValue NVPTXTargetLowering::PromoteBinOpIfF32FTZ(SDValue Op,
2447 SelectionDAG &DAG) const {
2448 if (useF32FTZ(MF: DAG.getMachineFunction())) {
2449 return PromoteBinOpToF32(N: Op.getNode(), DAG);
2450 }
2451 return Op;
2452}
2453
2454SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
2455 SelectionDAG &DAG) const {
2456 assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
2457
2458 if (Op.getValueType() == MVT::bf16) {
2459 SDLoc Loc(Op);
2460 return DAG.getNode(
2461 Opcode: ISD::FP_ROUND, DL: Loc, VT: MVT::bf16,
2462 N1: DAG.getNode(Opcode: Op.getOpcode(), DL: Loc, VT: MVT::f32, Operand: Op.getOperand(i: 0)),
2463 N2: DAG.getIntPtrConstant(Val: 0, DL: Loc, /*isTarget=*/true));
2464 }
2465
2466 // Everything else is considered legal.
2467 return Op;
2468}
2469
2470SDValue NVPTXTargetLowering::LowerFP_TO_INT(SDValue Op,
2471 SelectionDAG &DAG) const {
2472 assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
2473
2474 if (Op.getOperand(i: 0).getValueType() == MVT::bf16) {
2475 SDLoc Loc(Op);
2476 return DAG.getNode(
2477 Opcode: Op.getOpcode(), DL: Loc, VT: Op.getValueType(),
2478 Operand: DAG.getNode(Opcode: ISD::FP_EXTEND, DL: Loc, VT: MVT::f32, Operand: Op.getOperand(i: 0)));
2479 }
2480
2481 // Everything else is considered legal.
2482 return Op;
2483}
2484
2485SDValue NVPTXTargetLowering::LowerFP_ROUND(SDValue Op,
2486 SelectionDAG &DAG) const {
2487 EVT NarrowVT = Op.getValueType();
2488 SDValue Wide = Op.getOperand(i: 0);
2489 EVT WideVT = Wide.getValueType();
2490 if (NarrowVT.getScalarType() == MVT::bf16) {
2491 const TargetLowering *TLI = STI.getTargetLowering();
2492 if (STI.getSmVersion() < 80 || STI.getPTXVersion() < 70) {
2493 return TLI->expandFP_ROUND(Node: Op.getNode(), DAG);
2494 }
2495 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
2496 // This combination was the first to support f32 -> bf16.
2497 if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70) {
2498 if (WideVT.getScalarType() == MVT::f32) {
2499 return Op;
2500 }
2501 if (WideVT.getScalarType() == MVT::f64) {
2502 SDLoc Loc(Op);
2503 // Round-inexact-to-odd f64 to f32, then do the final rounding using
2504 // the hardware f32 -> bf16 instruction.
2505 SDValue rod = TLI->expandRoundInexactToOdd(
2506 ResultVT: WideVT.changeElementType(Context&: *DAG.getContext(), EltVT: MVT::f32), Op: Wide, DL: Loc,
2507 DAG);
2508 return DAG.getFPExtendOrRound(Op: rod, DL: Loc, VT: NarrowVT);
2509 }
2510 }
2511 return TLI->expandFP_ROUND(Node: Op.getNode(), DAG);
2512 }
2513 }
2514
2515 // Everything else is considered legal.
2516 return Op;
2517}
2518
2519SDValue NVPTXTargetLowering::LowerFP_EXTEND(SDValue Op,
2520 SelectionDAG &DAG) const {
2521 SDValue Narrow = Op.getOperand(i: 0);
2522 EVT NarrowVT = Narrow.getValueType();
2523 EVT WideVT = Op.getValueType();
2524 if (NarrowVT.getScalarType() == MVT::bf16) {
2525 if (WideVT.getScalarType() == MVT::f32 &&
2526 (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71)) {
2527 SDLoc Loc(Op);
2528 return DAG.getNode(Opcode: ISD::BF16_TO_FP, DL: Loc, VT: WideVT, Operand: Narrow);
2529 }
2530 if (WideVT.getScalarType() == MVT::f64 &&
2531 (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78)) {
2532 EVT F32 = NarrowVT.changeElementType(Context&: *DAG.getContext(), EltVT: MVT::f32);
2533 SDLoc Loc(Op);
2534 if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 71) {
2535 Op = DAG.getNode(Opcode: ISD::FP_EXTEND, DL: Loc, VT: F32, Operand: Narrow);
2536 } else {
2537 Op = DAG.getNode(Opcode: ISD::BF16_TO_FP, DL: Loc, VT: F32, Operand: Narrow);
2538 }
2539 return DAG.getNode(Opcode: ISD::FP_EXTEND, DL: Loc, VT: WideVT, Operand: Op);
2540 }
2541 }
2542
2543 // Everything else is considered legal.
2544 return Op;
2545}
2546
2547static SDValue LowerVectorArith(SDValue Op, SelectionDAG &DAG) {
2548 SDLoc DL(Op);
2549 if (Op.getValueType() != MVT::v2i16)
2550 return Op;
2551 EVT EltVT = Op.getValueType().getVectorElementType();
2552 SmallVector<SDValue> VecElements;
2553 for (int I = 0, E = Op.getValueType().getVectorNumElements(); I < E; I++) {
2554 SmallVector<SDValue> ScalarArgs;
2555 llvm::transform(Range: Op->ops(), d_first: std::back_inserter(x&: ScalarArgs),
2556 F: [&](const SDUse &O) {
2557 return DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT,
2558 N1: O.get(), N2: DAG.getIntPtrConstant(Val: I, DL));
2559 });
2560 VecElements.push_back(Elt: DAG.getNode(Opcode: Op.getOpcode(), DL, VT: EltVT, Ops: ScalarArgs));
2561 }
2562 SDValue V =
2563 DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: Op.getValueType(), Ops: VecElements);
2564 return V;
2565}
2566
2567static SDValue lowerTcgen05St(SDValue Op, SelectionDAG &DAG) {
2568 SDNode *N = Op.getNode();
2569 SDLoc DL(N);
2570 SmallVector<SDValue, 32> Ops;
2571
2572 // split the vector argument
2573 for (size_t I = 0; I < N->getNumOperands(); I++) {
2574 SDValue Val = N->getOperand(Num: I);
2575 EVT ValVT = Val.getValueType();
2576 if (ValVT.isVector()) {
2577 EVT EltVT = ValVT.getVectorElementType();
2578 for (unsigned J = 0, NElts = ValVT.getVectorNumElements(); J < NElts; J++)
2579 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: Val,
2580 N2: DAG.getIntPtrConstant(Val: J, DL)));
2581 } else
2582 Ops.push_back(Elt: Val);
2583 }
2584
2585 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
2586 SDValue Tcgen05StNode =
2587 DAG.getMemIntrinsicNode(Opcode: ISD::INTRINSIC_VOID, dl: DL, VTList: N->getVTList(), Ops,
2588 MemVT: MemSD->getMemoryVT(), MMO: MemSD->getMemOperand());
2589
2590 return Tcgen05StNode;
2591}
2592
2593static SDValue lowerBSWAP(SDValue Op, SelectionDAG &DAG) {
2594 SDLoc DL(Op);
2595 SDValue Src = Op.getOperand(i: 0);
2596 EVT VT = Op.getValueType();
2597
2598 switch (VT.getSimpleVT().SimpleTy) {
2599 case MVT::i16: {
2600 SDValue Extended = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT: MVT::i32, Operand: Src);
2601 SDValue Swapped =
2602 getPRMT(A: Extended, B: DAG.getConstant(Val: 0, DL, VT: MVT::i32), Selector: 0x7701, DL, DAG);
2603 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MVT::i16, Operand: Swapped);
2604 }
2605 case MVT::i32: {
2606 return getPRMT(A: Src, B: DAG.getConstant(Val: 0, DL, VT: MVT::i32), Selector: 0x0123, DL, DAG);
2607 }
2608 case MVT::v2i16: {
2609 SDValue Converted = DAG.getBitcast(VT: MVT::i32, V: Src);
2610 SDValue Swapped =
2611 getPRMT(A: Converted, B: DAG.getConstant(Val: 0, DL, VT: MVT::i32), Selector: 0x2301, DL, DAG);
2612 return DAG.getNode(Opcode: ISD::BITCAST, DL, VT: MVT::v2i16, Operand: Swapped);
2613 }
2614 case MVT::i64: {
2615 SDValue UnpackSrc =
2616 DAG.getNode(Opcode: NVPTXISD::UNPACK_VECTOR, DL, ResultTys: {MVT::i32, MVT::i32}, Ops: Src);
2617 SDValue SwappedLow =
2618 getPRMT(A: UnpackSrc.getValue(R: 0), B: DAG.getConstant(Val: 0, DL, VT: MVT::i32), Selector: 0x0123,
2619 DL, DAG);
2620 SDValue SwappedHigh =
2621 getPRMT(A: UnpackSrc.getValue(R: 1), B: DAG.getConstant(Val: 0, DL, VT: MVT::i32), Selector: 0x0123,
2622 DL, DAG);
2623 return DAG.getNode(Opcode: NVPTXISD::BUILD_VECTOR, DL, VT: MVT::i64,
2624 Ops: {SwappedHigh, SwappedLow});
2625 }
2626 default:
2627 llvm_unreachable("unsupported type for bswap");
2628 }
2629}
2630
2631static unsigned getTcgen05MMADisableOutputLane(unsigned IID) {
2632 switch (IID) {
2633 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1:
2634 return NVPTXISD::TCGEN05_MMA_SHARED_DISABLE_OUTPUT_LANE_CG1;
2635 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2:
2636 return NVPTXISD::TCGEN05_MMA_SHARED_DISABLE_OUTPUT_LANE_CG2;
2637 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1:
2638 return NVPTXISD::TCGEN05_MMA_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG1;
2639 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2:
2640 return NVPTXISD::TCGEN05_MMA_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG2;
2641 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1:
2642 return NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG1;
2643 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2:
2644 return NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG2;
2645 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1:
2646 return NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1;
2647 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2:
2648 return NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2;
2649 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift:
2650 return NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG1_ASHIFT;
2651 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift:
2652 return NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG2_ASHIFT;
2653 case Intrinsic::
2654 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift:
2655 return NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1_ASHIFT;
2656 case Intrinsic::
2657 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift:
2658 return NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2_ASHIFT;
2659 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1:
2660 return NVPTXISD::TCGEN05_MMA_SP_SHARED_DISABLE_OUTPUT_LANE_CG1;
2661 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2:
2662 return NVPTXISD::TCGEN05_MMA_SP_SHARED_DISABLE_OUTPUT_LANE_CG2;
2663 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1:
2664 return NVPTXISD::TCGEN05_MMA_SP_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG1;
2665 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2:
2666 return NVPTXISD::TCGEN05_MMA_SP_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG2;
2667 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1:
2668 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG1;
2669 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2:
2670 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG2;
2671 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift:
2672 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG1_ASHIFT;
2673 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift:
2674 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG2_ASHIFT;
2675 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1:
2676 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1;
2677 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2:
2678 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2;
2679 case Intrinsic::
2680 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift:
2681 return NVPTXISD::
2682 TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1_ASHIFT;
2683 case Intrinsic::
2684 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift:
2685 return NVPTXISD::
2686 TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2_ASHIFT;
2687 };
2688 llvm_unreachable("unhandled tcgen05.mma.disable_output_lane intrinsic");
2689}
2690
2691static SDValue LowerTcgen05MMADisableOutputLane(SDValue Op, SelectionDAG &DAG) {
2692 SDNode *N = Op.getNode();
2693 SDLoc DL(N);
2694 unsigned IID = cast<ConstantSDNode>(Val: N->getOperand(Num: 1))->getZExtValue();
2695
2696 SmallVector<SDValue, 16> Ops;
2697 // split the vector argument
2698 for (size_t I = 0; I < N->getNumOperands(); I++) {
2699 if (I == 1)
2700 continue; // skip IID
2701 SDValue Val = N->getOperand(Num: I);
2702 EVT ValVT = Val.getValueType();
2703 if (ValVT.isVector()) {
2704 EVT EltVT = ValVT.getVectorElementType();
2705 for (unsigned J = 0, NElts = ValVT.getVectorNumElements(); J < NElts; J++)
2706 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: Val,
2707 N2: DAG.getIntPtrConstant(Val: J, DL)));
2708 } else
2709 Ops.push_back(Elt: Val);
2710 }
2711
2712 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
2713 SDValue Tcgen05MMANode = DAG.getMemIntrinsicNode(
2714 Opcode: getTcgen05MMADisableOutputLane(IID), dl: DL, VTList: N->getVTList(), Ops,
2715 MemVT: MemSD->getMemoryVT(), MMO: MemSD->getMemOperand());
2716
2717 return Tcgen05MMANode;
2718}
2719
2720// Lower vector return type of tcgen05.ld intrinsics
2721static std::optional<std::pair<SDValue, SDValue>>
2722lowerTcgen05Ld(SDNode *N, SelectionDAG &DAG, bool HasOffset = false) {
2723 SDLoc DL(N);
2724 EVT ResVT = N->getValueType(ResNo: 0);
2725 if (!ResVT.isVector())
2726 return {}; // already legalized.
2727
2728 const unsigned NumElts = ResVT.getVectorNumElements();
2729
2730 // Create the return type of the instructions
2731 SmallVector<EVT, 5> ListVTs;
2732 for (unsigned i = 0; i < NumElts; ++i)
2733 ListVTs.push_back(Elt: MVT::i32);
2734
2735 ListVTs.push_back(Elt: N->getValueType(ResNo: 1)); // Chain
2736
2737 SDVTList ResVTs = DAG.getVTList(VTs: ListVTs);
2738
2739 SmallVector<SDValue, 8> Ops{N->getOperand(Num: 0), N->getOperand(Num: 1),
2740 N->getOperand(Num: 2)};
2741
2742 if (HasOffset) {
2743 Ops.push_back(Elt: N->getOperand(Num: 3)); // offset
2744 Ops.push_back(Elt: N->getOperand(Num: 4)); // Pack flag
2745 } else
2746 Ops.push_back(Elt: N->getOperand(Num: 3)); // Pack flag
2747
2748 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
2749 SDValue NewNode =
2750 DAG.getMemIntrinsicNode(Opcode: ISD::INTRINSIC_W_CHAIN, dl: DL, VTList: ResVTs, Ops,
2751 MemVT: MemSD->getMemoryVT(), MMO: MemSD->getMemOperand());
2752
2753 // split the vector result
2754 SmallVector<SDValue, 4> ScalarRes;
2755 for (unsigned i = 0; i < NumElts; ++i) {
2756 SDValue Res = NewNode.getValue(R: i);
2757 ScalarRes.push_back(Elt: Res);
2758 }
2759
2760 SDValue Chain = NewNode.getValue(R: NumElts);
2761 SDValue BuildVector = DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: ResVT, Ops: ScalarRes);
2762 return {{BuildVector, Chain}};
2763}
2764
2765static SDValue reportInvalidTensormapReplaceUsage(SDValue Op, SelectionDAG &DAG,
2766 unsigned Val) {
2767 SDNode *N = Op.getNode();
2768 SDLoc DL(N);
2769
2770 const Function &Fn = DAG.getMachineFunction().getFunction();
2771
2772 unsigned AS = 0;
2773 if (auto *MemN = dyn_cast<MemIntrinsicSDNode>(Val: N))
2774 AS = MemN->getAddressSpace();
2775 Type *PtrTy = PointerType::get(C&: *DAG.getContext(), AddressSpace: AS);
2776 Module *M = DAG.getMachineFunction().getFunction().getParent();
2777
2778 DAG.getContext()->diagnose(DI: DiagnosticInfoUnsupported(
2779 Fn,
2780 "Intrinsic " +
2781 Intrinsic::getName(Id: N->getConstantOperandVal(Num: 1), Tys: {PtrTy}, M) +
2782 " with value " + Twine(Val) +
2783 " is not supported on the given target.",
2784 DL.getDebugLoc()));
2785 return Op.getOperand(i: 0);
2786}
2787
2788static SDValue lowerTensormapReplaceElemtype(SDValue Op, SelectionDAG &DAG) {
2789 SDNode *N = Op.getNode();
2790 SDLoc DL(N);
2791
2792 // immediate argument representing elemtype
2793 unsigned Val = N->getConstantOperandVal(Num: 3);
2794
2795 if (!DAG.getSubtarget<NVPTXSubtarget>().hasTensormapReplaceElemtypeSupport(
2796 value: Val))
2797 return reportInvalidTensormapReplaceUsage(Op, DAG, Val);
2798
2799 return Op;
2800}
2801
2802static SDValue lowerTensormapReplaceSwizzleMode(SDValue Op, SelectionDAG &DAG) {
2803 SDNode *N = Op.getNode();
2804 SDLoc DL(N);
2805
2806 // immediate argument representing swizzle mode
2807 unsigned Val = N->getConstantOperandVal(Num: 3);
2808
2809 if (!DAG.getSubtarget<NVPTXSubtarget>().hasTensormapReplaceSwizzleModeSupport(
2810 value: Val))
2811 return reportInvalidTensormapReplaceUsage(Op, DAG, Val);
2812
2813 return Op;
2814}
2815
2816static SDValue lowerIntrinsicVoid(SDValue Op, SelectionDAG &DAG) {
2817 SDNode *N = Op.getNode();
2818 SDValue Intrin = N->getOperand(Num: 1);
2819
2820 // Get the intrinsic ID
2821 unsigned IntrinNo = cast<ConstantSDNode>(Val: Intrin.getNode())->getZExtValue();
2822 switch (IntrinNo) {
2823 default:
2824 break;
2825 case Intrinsic::nvvm_tcgen05_st_16x64b_x1:
2826 case Intrinsic::nvvm_tcgen05_st_16x64b_x2:
2827 case Intrinsic::nvvm_tcgen05_st_16x64b_x4:
2828 case Intrinsic::nvvm_tcgen05_st_16x64b_x8:
2829 case Intrinsic::nvvm_tcgen05_st_16x64b_x16:
2830 case Intrinsic::nvvm_tcgen05_st_16x64b_x32:
2831 case Intrinsic::nvvm_tcgen05_st_16x64b_x128:
2832 case Intrinsic::nvvm_tcgen05_st_16x128b_x1:
2833 case Intrinsic::nvvm_tcgen05_st_16x128b_x2:
2834 case Intrinsic::nvvm_tcgen05_st_16x128b_x4:
2835 case Intrinsic::nvvm_tcgen05_st_16x128b_x8:
2836 case Intrinsic::nvvm_tcgen05_st_16x128b_x16:
2837 case Intrinsic::nvvm_tcgen05_st_16x128b_x32:
2838 case Intrinsic::nvvm_tcgen05_st_16x128b_x64:
2839 case Intrinsic::nvvm_tcgen05_st_16x256b_x1:
2840 case Intrinsic::nvvm_tcgen05_st_16x256b_x2:
2841 case Intrinsic::nvvm_tcgen05_st_16x256b_x4:
2842 case Intrinsic::nvvm_tcgen05_st_16x256b_x8:
2843 case Intrinsic::nvvm_tcgen05_st_16x256b_x16:
2844 case Intrinsic::nvvm_tcgen05_st_16x256b_x32:
2845 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x1:
2846 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x2:
2847 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x4:
2848 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x8:
2849 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x16:
2850 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x32:
2851 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x64:
2852 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x128:
2853 case Intrinsic::nvvm_tcgen05_st_32x32b_x1:
2854 case Intrinsic::nvvm_tcgen05_st_32x32b_x2:
2855 case Intrinsic::nvvm_tcgen05_st_32x32b_x4:
2856 case Intrinsic::nvvm_tcgen05_st_32x32b_x8:
2857 case Intrinsic::nvvm_tcgen05_st_32x32b_x16:
2858 case Intrinsic::nvvm_tcgen05_st_32x32b_x32:
2859 case Intrinsic::nvvm_tcgen05_st_16x64b_x64:
2860 case Intrinsic::nvvm_tcgen05_st_32x32b_x64:
2861 case Intrinsic::nvvm_tcgen05_st_32x32b_x128:
2862 return lowerTcgen05St(Op, DAG);
2863 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1:
2864 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2:
2865 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1:
2866 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2:
2867 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1:
2868 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2:
2869 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1:
2870 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2:
2871 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1:
2872 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2:
2873 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1:
2874 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2:
2875 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1:
2876 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2:
2877 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1:
2878 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2:
2879 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift:
2880 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift:
2881 case Intrinsic::
2882 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift:
2883 case Intrinsic::
2884 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift:
2885 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift:
2886 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift:
2887 case Intrinsic::
2888 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift:
2889 case Intrinsic::
2890 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift:
2891 return LowerTcgen05MMADisableOutputLane(Op, DAG);
2892 case Intrinsic::nvvm_tensormap_replace_elemtype:
2893 return lowerTensormapReplaceElemtype(Op, DAG);
2894 case Intrinsic::nvvm_tensormap_replace_swizzle_mode:
2895 return lowerTensormapReplaceSwizzleMode(Op, DAG);
2896 }
2897 return Op;
2898}
2899
2900static SDValue LowerClusterLaunchControlQueryCancel(SDValue Op,
2901 SelectionDAG &DAG) {
2902
2903 SDNode *N = Op.getNode();
2904 if (N->getOperand(Num: 1).getValueType() != MVT::i128) {
2905 // return, if the operand is already lowered
2906 return SDValue();
2907 }
2908
2909 unsigned IID =
2910 cast<ConstantSDNode>(Val: N->getOperand(Num: 0).getNode())->getZExtValue();
2911 auto Opcode = [&]() {
2912 switch (IID) {
2913 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled:
2914 return NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_IS_CANCELED;
2915 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x:
2916 return NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_X;
2917 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y:
2918 return NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Y;
2919 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z:
2920 return NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Z;
2921 default:
2922 llvm_unreachable("unsupported/unhandled intrinsic");
2923 }
2924 }();
2925
2926 SDLoc DL(N);
2927 SDValue TryCancelResponse = N->getOperand(Num: 1);
2928 SDValue Cast = DAG.getNode(Opcode: ISD::BITCAST, DL, VT: MVT::v2i64, Operand: TryCancelResponse);
2929 SDValue TryCancelResponse0 =
2930 DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i64, N1: Cast,
2931 N2: DAG.getIntPtrConstant(Val: 0, DL));
2932 SDValue TryCancelResponse1 =
2933 DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i64, N1: Cast,
2934 N2: DAG.getIntPtrConstant(Val: 1, DL));
2935
2936 return DAG.getNode(Opcode, DL, VTList: N->getVTList(),
2937 Ops: {TryCancelResponse0, TryCancelResponse1});
2938}
2939
2940static SDValue lowerCvtRSIntrinsics(SDValue Op, SelectionDAG &DAG) {
2941 SDNode *N = Op.getNode();
2942 SDLoc DL(N);
2943 SDValue F32Vec = N->getOperand(Num: 1);
2944 SDValue RBits = N->getOperand(Num: 2);
2945
2946 unsigned IntrinsicID = N->getConstantOperandVal(Num: 0);
2947
2948 // Extract the 4 float elements from the vector
2949 SmallVector<SDValue, 6> Ops;
2950 for (unsigned i = 0; i < 4; ++i)
2951 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::f32, N1: F32Vec,
2952 N2: DAG.getIntPtrConstant(Val: i, DL)));
2953
2954 using NVPTX::PTXCvtMode::CvtMode;
2955
2956 auto [OpCode, RetTy, CvtModeFlag] =
2957 [&]() -> std::tuple<unsigned, MVT::SimpleValueType, uint32_t> {
2958 switch (IntrinsicID) {
2959 case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite:
2960 return {NVPTXISD::CVT_E4M3X4_F32X4_RS_SF, MVT::v4i8,
2961 CvtMode::RS | CvtMode::RELU_FLAG};
2962 case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite:
2963 return {NVPTXISD::CVT_E4M3X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
2964 case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite:
2965 return {NVPTXISD::CVT_E5M2X4_F32X4_RS_SF, MVT::v4i8,
2966 CvtMode::RS | CvtMode::RELU_FLAG};
2967 case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite:
2968 return {NVPTXISD::CVT_E5M2X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
2969 case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite:
2970 return {NVPTXISD::CVT_E2M3X4_F32X4_RS_SF, MVT::v4i8,
2971 CvtMode::RS | CvtMode::RELU_FLAG};
2972 case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite:
2973 return {NVPTXISD::CVT_E2M3X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
2974 case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite:
2975 return {NVPTXISD::CVT_E3M2X4_F32X4_RS_SF, MVT::v4i8,
2976 CvtMode::RS | CvtMode::RELU_FLAG};
2977 case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite:
2978 return {NVPTXISD::CVT_E3M2X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
2979 case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite:
2980 return {NVPTXISD::CVT_E2M1X4_F32X4_RS_SF, MVT::i16,
2981 CvtMode::RS | CvtMode::RELU_FLAG};
2982 case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite:
2983 return {NVPTXISD::CVT_E2M1X4_F32X4_RS_SF, MVT::i16, CvtMode::RS};
2984 default:
2985 llvm_unreachable("unsupported/unhandled intrinsic");
2986 }
2987 }();
2988
2989 Ops.push_back(Elt: RBits);
2990 Ops.push_back(Elt: DAG.getConstant(Val: CvtModeFlag, DL, VT: MVT::i32));
2991
2992 return DAG.getNode(Opcode: OpCode, DL, VT: RetTy, Ops);
2993}
2994
2995static SDValue lowerPrmtIntrinsic(SDValue Op, SelectionDAG &DAG) {
2996 const unsigned Mode = [&]() {
2997 switch (Op->getConstantOperandVal(Num: 0)) {
2998 case Intrinsic::nvvm_prmt:
2999 return NVPTX::PTXPrmtMode::NONE;
3000 case Intrinsic::nvvm_prmt_b4e:
3001 return NVPTX::PTXPrmtMode::B4E;
3002 case Intrinsic::nvvm_prmt_ecl:
3003 return NVPTX::PTXPrmtMode::ECL;
3004 case Intrinsic::nvvm_prmt_ecr:
3005 return NVPTX::PTXPrmtMode::ECR;
3006 case Intrinsic::nvvm_prmt_f4e:
3007 return NVPTX::PTXPrmtMode::F4E;
3008 case Intrinsic::nvvm_prmt_rc16:
3009 return NVPTX::PTXPrmtMode::RC16;
3010 case Intrinsic::nvvm_prmt_rc8:
3011 return NVPTX::PTXPrmtMode::RC8;
3012 default:
3013 llvm_unreachable("unsupported/unhandled intrinsic");
3014 }
3015 }();
3016 SDLoc DL(Op);
3017 SDValue A = Op->getOperand(Num: 1);
3018 SDValue B = Op.getNumOperands() == 4 ? Op.getOperand(i: 2)
3019 : DAG.getConstant(Val: 0, DL, VT: MVT::i32);
3020 SDValue Selector = (Op->op_end() - 1)->get();
3021 return getPRMT(A, B, Selector, DL, DAG, Mode);
3022}
3023
3024#define TCGEN05_LD_RED_INTR(SHAPE, NUM, TYPE) \
3025 Intrinsic::nvvm_tcgen05_ld_red_##SHAPE##_x##NUM##_##TYPE
3026
3027#define TCGEN05_LD_RED_INST(SHAPE, NUM, TYPE) \
3028 NVPTXISD::TCGEN05_LD_RED_##SHAPE##_X##NUM##_##TYPE
3029
3030static unsigned getTcgen05LdRedID(Intrinsic::ID IID) {
3031 switch (IID) {
3032 case TCGEN05_LD_RED_INTR(32x32b, 2, f32):
3033 return TCGEN05_LD_RED_INST(32x32b, 2, F32);
3034 case TCGEN05_LD_RED_INTR(32x32b, 4, f32):
3035 return TCGEN05_LD_RED_INST(32x32b, 4, F32);
3036 case TCGEN05_LD_RED_INTR(32x32b, 8, f32):
3037 return TCGEN05_LD_RED_INST(32x32b, 8, F32);
3038 case TCGEN05_LD_RED_INTR(32x32b, 16, f32):
3039 return TCGEN05_LD_RED_INST(32x32b, 16, F32);
3040 case TCGEN05_LD_RED_INTR(32x32b, 32, f32):
3041 return TCGEN05_LD_RED_INST(32x32b, 32, F32);
3042 case TCGEN05_LD_RED_INTR(32x32b, 64, f32):
3043 return TCGEN05_LD_RED_INST(32x32b, 64, F32);
3044 case TCGEN05_LD_RED_INTR(32x32b, 128, f32):
3045 return TCGEN05_LD_RED_INST(32x32b, 128, F32);
3046 case TCGEN05_LD_RED_INTR(16x32bx2, 2, f32):
3047 return TCGEN05_LD_RED_INST(16x32bx2, 2, F32);
3048 case TCGEN05_LD_RED_INTR(16x32bx2, 4, f32):
3049 return TCGEN05_LD_RED_INST(16x32bx2, 4, F32);
3050 case TCGEN05_LD_RED_INTR(16x32bx2, 8, f32):
3051 return TCGEN05_LD_RED_INST(16x32bx2, 8, F32);
3052 case TCGEN05_LD_RED_INTR(16x32bx2, 16, f32):
3053 return TCGEN05_LD_RED_INST(16x32bx2, 16, F32);
3054 case TCGEN05_LD_RED_INTR(16x32bx2, 32, f32):
3055 return TCGEN05_LD_RED_INST(16x32bx2, 32, F32);
3056 case TCGEN05_LD_RED_INTR(16x32bx2, 64, f32):
3057 return TCGEN05_LD_RED_INST(16x32bx2, 64, F32);
3058 case TCGEN05_LD_RED_INTR(16x32bx2, 128, f32):
3059 return TCGEN05_LD_RED_INST(16x32bx2, 128, F32);
3060 case TCGEN05_LD_RED_INTR(32x32b, 2, i32):
3061 return TCGEN05_LD_RED_INST(32x32b, 2, I32);
3062 case TCGEN05_LD_RED_INTR(32x32b, 4, i32):
3063 return TCGEN05_LD_RED_INST(32x32b, 4, I32);
3064 case TCGEN05_LD_RED_INTR(32x32b, 8, i32):
3065 return TCGEN05_LD_RED_INST(32x32b, 8, I32);
3066 case TCGEN05_LD_RED_INTR(32x32b, 16, i32):
3067 return TCGEN05_LD_RED_INST(32x32b, 16, I32);
3068 case TCGEN05_LD_RED_INTR(32x32b, 32, i32):
3069 return TCGEN05_LD_RED_INST(32x32b, 32, I32);
3070 case TCGEN05_LD_RED_INTR(32x32b, 64, i32):
3071 return TCGEN05_LD_RED_INST(32x32b, 64, I32);
3072 case TCGEN05_LD_RED_INTR(32x32b, 128, i32):
3073 return TCGEN05_LD_RED_INST(32x32b, 128, I32);
3074 case TCGEN05_LD_RED_INTR(16x32bx2, 2, i32):
3075 return TCGEN05_LD_RED_INST(16x32bx2, 2, I32);
3076 case TCGEN05_LD_RED_INTR(16x32bx2, 4, i32):
3077 return TCGEN05_LD_RED_INST(16x32bx2, 4, I32);
3078 case TCGEN05_LD_RED_INTR(16x32bx2, 8, i32):
3079 return TCGEN05_LD_RED_INST(16x32bx2, 8, I32);
3080 case TCGEN05_LD_RED_INTR(16x32bx2, 16, i32):
3081 return TCGEN05_LD_RED_INST(16x32bx2, 16, I32);
3082 case TCGEN05_LD_RED_INTR(16x32bx2, 32, i32):
3083 return TCGEN05_LD_RED_INST(16x32bx2, 32, I32);
3084 case TCGEN05_LD_RED_INTR(16x32bx2, 64, i32):
3085 return TCGEN05_LD_RED_INST(16x32bx2, 64, I32);
3086 case TCGEN05_LD_RED_INTR(16x32bx2, 128, i32):
3087 return TCGEN05_LD_RED_INST(16x32bx2, 128, I32);
3088 default:
3089 llvm_unreachable("Invalid tcgen05.ld.red intrinsic ID");
3090 }
3091}
3092
3093// Lower vector return type of tcgen05.ld intrinsics
3094static std::optional<std::tuple<SDValue, SDValue, SDValue>>
3095lowerTcgen05LdRed(SDNode *N, SelectionDAG &DAG) {
3096 SDLoc DL(N);
3097 EVT ResVT = N->getValueType(ResNo: 0);
3098 if (!ResVT.isVector())
3099 return {}; // already legalized.
3100
3101 const unsigned NumElts = ResVT.getVectorNumElements();
3102
3103 // Create the return type of the instructions
3104 // +1 represents the reduction value
3105 SmallVector<EVT, 132> ListVTs{
3106 NumElts + 1,
3107 ResVT.getVectorElementType().isFloatingPoint() ? MVT::f32 : MVT::i32};
3108
3109 ListVTs.push_back(Elt: MVT::Other); // Chain
3110
3111 SDVTList ResVTs = DAG.getVTList(VTs: ListVTs);
3112
3113 // Prepare the Operands
3114 SmallVector<SDValue, 8> Ops{N->getOperand(Num: 0)}; // Chain
3115
3116 // skip IID at index 1
3117 for (unsigned i = 2; i < N->getNumOperands(); i++)
3118 Ops.push_back(Elt: N->getOperand(Num: i));
3119
3120 unsigned IID = cast<ConstantSDNode>(Val: N->getOperand(Num: 1))->getZExtValue();
3121 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
3122 SDValue NewNode =
3123 DAG.getMemIntrinsicNode(Opcode: getTcgen05LdRedID(IID), dl: DL, VTList: ResVTs, Ops,
3124 MemVT: MemSD->getMemoryVT(), MMO: MemSD->getMemOperand());
3125
3126 // Split vector result
3127 SmallVector<SDValue, 132> ScalarRes;
3128 for (unsigned i = 0; i < NumElts; ++i) {
3129 SDValue Res = NewNode.getValue(R: i);
3130 ScalarRes.push_back(Elt: Res);
3131 }
3132
3133 SDValue BuildVector = DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: ResVT, Ops: ScalarRes);
3134 SDValue RedResult = NewNode.getValue(R: NumElts);
3135 SDValue Chain = NewNode.getValue(R: NumElts + 1);
3136 return {{BuildVector, RedResult, Chain}};
3137}
3138
3139static SDValue lowerIntrinsicWChain(SDValue Op, SelectionDAG &DAG) {
3140 switch (Op->getConstantOperandVal(Num: 1)) {
3141 default:
3142 return Op;
3143
3144 // These tcgen05 intrinsics return a v2i32, which is legal, so we have to
3145 // lower them through LowerOperation() instead of ReplaceNodeResults().
3146 case Intrinsic::nvvm_tcgen05_ld_16x64b_x2:
3147 case Intrinsic::nvvm_tcgen05_ld_16x128b_x1:
3148 case Intrinsic::nvvm_tcgen05_ld_32x32b_x2:
3149 if (auto Res = lowerTcgen05Ld(N: Op.getNode(), DAG))
3150 return DAG.getMergeValues(Ops: {Res->first, Res->second}, dl: SDLoc(Op));
3151 return SDValue();
3152
3153 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x2:
3154 if (auto Res = lowerTcgen05Ld(N: Op.getNode(), DAG, /*HasOffset=*/true))
3155 return DAG.getMergeValues(Ops: {Res->first, Res->second}, dl: SDLoc(Op));
3156 return SDValue();
3157
3158 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x2_f32:
3159 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x2_i32:
3160 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x2_f32:
3161 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x2_i32:
3162 if (auto Res = lowerTcgen05LdRed(N: Op.getNode(), DAG))
3163 return DAG.getMergeValues(
3164 Ops: {std::get<0>(t&: *Res), std::get<1>(t&: *Res), std::get<2>(t&: *Res)}, dl: SDLoc(Op));
3165 return SDValue();
3166 }
3167}
3168
3169static SDValue lowerIntrinsicWOChain(SDValue Op, SelectionDAG &DAG) {
3170 switch (Op->getConstantOperandVal(Num: 0)) {
3171 default:
3172 return Op;
3173 case Intrinsic::nvvm_prmt:
3174 case Intrinsic::nvvm_prmt_b4e:
3175 case Intrinsic::nvvm_prmt_ecl:
3176 case Intrinsic::nvvm_prmt_ecr:
3177 case Intrinsic::nvvm_prmt_f4e:
3178 case Intrinsic::nvvm_prmt_rc16:
3179 case Intrinsic::nvvm_prmt_rc8:
3180 return lowerPrmtIntrinsic(Op, DAG);
3181 case Intrinsic::nvvm_internal_addrspace_wrap:
3182 return Op.getOperand(i: 1);
3183 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled:
3184 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x:
3185 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y:
3186 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z:
3187 return LowerClusterLaunchControlQueryCancel(Op, DAG);
3188 case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite:
3189 case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite:
3190 case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite:
3191 case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite:
3192 case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite:
3193 case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite:
3194 case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite:
3195 case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite:
3196 case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite:
3197 case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite:
3198 return lowerCvtRSIntrinsics(Op, DAG);
3199 }
3200}
3201
3202// In PTX 64-bit CTLZ and CTPOP are supported, but they return a 32-bit value.
3203// Lower these into a node returning the correct type which is zero-extended
3204// back to the correct size.
3205static SDValue lowerCTLZCTPOP(SDValue Op, SelectionDAG &DAG) {
3206 SDValue V = Op->getOperand(Num: 0);
3207 assert(V.getValueType() == MVT::i64 &&
3208 "Unexpected CTLZ/CTPOP type to legalize");
3209
3210 SDLoc DL(Op);
3211 SDValue CT = DAG.getNode(Opcode: Op->getOpcode(), DL, VT: MVT::i32, Operand: V);
3212 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: MVT::i64, Operand: CT, Flags: SDNodeFlags::NonNeg);
3213}
3214
3215static SDValue expandFSH64(SDValue A, SDValue B, SDValue ShiftAmount, SDLoc DL,
3216 unsigned Opcode, SelectionDAG &DAG) {
3217 assert(A.getValueType() == MVT::i64 && B.getValueType() == MVT::i64);
3218
3219 const auto *AmtConst = dyn_cast<ConstantSDNode>(Val&: ShiftAmount);
3220 if (!AmtConst)
3221 return SDValue();
3222 const auto Amt = AmtConst->getZExtValue() & 63;
3223
3224 SDValue UnpackA =
3225 DAG.getNode(Opcode: NVPTXISD::UNPACK_VECTOR, DL, ResultTys: {MVT::i32, MVT::i32}, Ops: A);
3226 SDValue UnpackB =
3227 DAG.getNode(Opcode: NVPTXISD::UNPACK_VECTOR, DL, ResultTys: {MVT::i32, MVT::i32}, Ops: B);
3228
3229 // Arch is Little endiain: 0 = low bits, 1 = high bits
3230 SDValue ALo = UnpackA.getValue(R: 0);
3231 SDValue AHi = UnpackA.getValue(R: 1);
3232 SDValue BLo = UnpackB.getValue(R: 0);
3233 SDValue BHi = UnpackB.getValue(R: 1);
3234
3235 // The bitfeild consists of { AHi : ALo : BHi : BLo }
3236 //
3237 // * FSHL, Amt < 32 - The window will contain { AHi : ALo : BHi }
3238 // * FSHL, Amt >= 32 - The window will contain { ALo : BHi : BLo }
3239 // * FSHR, Amt < 32 - The window will contain { ALo : BHi : BLo }
3240 // * FSHR, Amt >= 32 - The window will contain { AHi : ALo : BHi }
3241 //
3242 // Note that Amt = 0 and Amt = 32 are special cases where 32-bit funnel shifts
3243 // are not needed at all. Amt = 0 is a no-op producing either A or B depending
3244 // on the direction. Amt = 32 can be implemented by a packing and unpacking
3245 // move to select and arrange the 32bit values. For simplicity, these cases
3246 // are not handled here explicitly and instead we rely on DAGCombiner to
3247 // remove the no-op funnel shifts we insert.
3248 auto [High, Mid, Low] = ((Opcode == ISD::FSHL) == (Amt < 32))
3249 ? std::make_tuple(args&: AHi, args&: ALo, args&: BHi)
3250 : std::make_tuple(args&: ALo, args&: BHi, args&: BLo);
3251
3252 SDValue NewAmt = DAG.getConstant(Val: Amt & 31, DL, VT: MVT::i32);
3253 SDValue RHi = DAG.getNode(Opcode, DL, VT: MVT::i32, Ops: {High, Mid, NewAmt});
3254 SDValue RLo = DAG.getNode(Opcode, DL, VT: MVT::i32, Ops: {Mid, Low, NewAmt});
3255
3256 return DAG.getNode(Opcode: NVPTXISD::BUILD_VECTOR, DL, VT: MVT::i64, Ops: {RLo, RHi});
3257}
3258
3259static SDValue lowerFSH(SDValue Op, SelectionDAG &DAG) {
3260 return expandFSH64(A: Op->getOperand(Num: 0), B: Op->getOperand(Num: 1), ShiftAmount: Op->getOperand(Num: 2),
3261 DL: SDLoc(Op), Opcode: Op->getOpcode(), DAG);
3262}
3263
3264static SDValue lowerROT(SDValue Op, SelectionDAG &DAG) {
3265 unsigned Opcode = Op->getOpcode() == ISD::ROTL ? ISD::FSHL : ISD::FSHR;
3266 return expandFSH64(A: Op->getOperand(Num: 0), B: Op->getOperand(Num: 0), ShiftAmount: Op->getOperand(Num: 1),
3267 DL: SDLoc(Op), Opcode, DAG);
3268}
3269
3270static SDValue lowerFREM(SDValue Op, SelectionDAG &DAG) {
3271 // Lower (frem x, y) into (sub x, (mul (ftrunc (div x, y)) y)),
3272 // i.e. "poor man's fmod()". When y is infinite, x is returned. This matches
3273 // the semantics of LLVM's frem.
3274 SDLoc DL(Op);
3275 SDValue X = Op->getOperand(Num: 0);
3276 SDValue Y = Op->getOperand(Num: 1);
3277 EVT Ty = Op.getValueType();
3278 SDNodeFlags Flags = Op->getFlags();
3279
3280 SDValue Div = DAG.getNode(Opcode: ISD::FDIV, DL, VT: Ty, N1: X, N2: Y, Flags);
3281 SDValue Trunc = DAG.getNode(Opcode: ISD::FTRUNC, DL, VT: Ty, Operand: Div, Flags);
3282 SDValue Mul = DAG.getNode(Opcode: ISD::FMUL, DL, VT: Ty, N1: Trunc, N2: Y,
3283 Flags: Flags | SDNodeFlags::AllowContract);
3284 SDValue Sub = DAG.getNode(Opcode: ISD::FSUB, DL, VT: Ty, N1: X, N2: Mul,
3285 Flags: Flags | SDNodeFlags::AllowContract);
3286
3287 if (Flags.hasNoInfs())
3288 return Sub;
3289
3290 // If Y is infinite, return X
3291 SDValue AbsY = DAG.getNode(Opcode: ISD::FABS, DL, VT: Ty, Operand: Y);
3292 SDValue Inf =
3293 DAG.getConstantFP(Val: APFloat::getInf(Sem: Ty.getFltSemantics()), DL, VT: Ty);
3294 SDValue IsInf = DAG.getSetCC(DL, VT: MVT::i1, LHS: AbsY, RHS: Inf, Cond: ISD::SETEQ);
3295 return DAG.getSelect(DL, VT: Ty, Cond: IsInf, LHS: X, RHS: Sub);
3296}
3297
3298static SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG) {
3299 assert(Op.getValueType() == MVT::i1 && "Custom lowering enabled only for i1");
3300
3301 SDValue Cond = Op->getOperand(Num: 0);
3302 SDValue TrueVal = Op->getOperand(Num: 1);
3303 SDValue FalseVal = Op->getOperand(Num: 2);
3304 SDLoc DL(Op);
3305
3306 // If both operands are truncated, we push the select through the truncates.
3307 if (TrueVal.getOpcode() == ISD::TRUNCATE &&
3308 FalseVal.getOpcode() == ISD::TRUNCATE) {
3309 TrueVal = TrueVal.getOperand(i: 0);
3310 FalseVal = FalseVal.getOperand(i: 0);
3311
3312 EVT VT = TrueVal.getSimpleValueType().bitsLE(VT: FalseVal.getSimpleValueType())
3313 ? TrueVal.getValueType()
3314 : FalseVal.getValueType();
3315 TrueVal = DAG.getAnyExtOrTrunc(Op: TrueVal, DL, VT);
3316 FalseVal = DAG.getAnyExtOrTrunc(Op: FalseVal, DL, VT);
3317 SDValue Select = DAG.getSelect(DL, VT, Cond, LHS: TrueVal, RHS: FalseVal);
3318 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MVT::i1, Operand: Select);
3319 }
3320
3321 // Otherwise, expand the select into a series of logical operations. These
3322 // often can be folded into other operations either by us or ptxas.
3323 TrueVal = DAG.getFreeze(V: TrueVal);
3324 FalseVal = DAG.getFreeze(V: FalseVal);
3325 SDValue And1 = DAG.getNode(Opcode: ISD::AND, DL, VT: MVT::i1, N1: Cond, N2: TrueVal);
3326 SDValue NotCond = DAG.getNOT(DL, Val: Cond, VT: MVT::i1);
3327 SDValue And2 = DAG.getNode(Opcode: ISD::AND, DL, VT: MVT::i1, N1: NotCond, N2: FalseVal);
3328 SDValue Or = DAG.getNode(Opcode: ISD::OR, DL, VT: MVT::i1, N1: And1, N2: And2);
3329 return Or;
3330}
3331
3332static SDValue lowerMSTORE(SDValue Op, SelectionDAG &DAG) {
3333 SDNode *N = Op.getNode();
3334
3335 SDValue Chain = N->getOperand(Num: 0);
3336 SDValue Val = N->getOperand(Num: 1);
3337 SDValue BasePtr = N->getOperand(Num: 2);
3338 SDValue Offset = N->getOperand(Num: 3);
3339 SDValue Mask = N->getOperand(Num: 4);
3340
3341 SDLoc DL(N);
3342 EVT ValVT = Val.getValueType();
3343 MemSDNode *MemSD = cast<MemSDNode>(Val: N);
3344 assert(ValVT.isVector() && "Masked vector store must have vector type");
3345 assert(MemSD->getAlign() >= DAG.getEVTAlign(ValVT) &&
3346 "Unexpected alignment for masked store");
3347
3348 unsigned Opcode = 0;
3349 switch (ValVT.getSimpleVT().SimpleTy) {
3350 default:
3351 llvm_unreachable("Unexpected masked vector store type");
3352 case MVT::v4i64:
3353 case MVT::v4f64: {
3354 Opcode = NVPTXISD::StoreV4;
3355 break;
3356 }
3357 case MVT::v8i32:
3358 case MVT::v8f32: {
3359 Opcode = NVPTXISD::StoreV8;
3360 break;
3361 }
3362 }
3363
3364 SmallVector<SDValue, 8> Ops;
3365
3366 // Construct the new SDNode. First operand is the chain.
3367 Ops.push_back(Elt: Chain);
3368
3369 // The next N operands are the values to store. Encode the mask into the
3370 // values using the sentinel register 0 to represent a masked-off element.
3371 assert(Mask.getValueType().isVector() &&
3372 Mask.getValueType().getVectorElementType() == MVT::i1 &&
3373 "Mask must be a vector of i1");
3374 assert(Mask.getOpcode() == ISD::BUILD_VECTOR &&
3375 "Mask expected to be a BUILD_VECTOR");
3376 assert(Mask.getValueType().getVectorNumElements() ==
3377 ValVT.getVectorNumElements() &&
3378 "Mask size must be the same as the vector size");
3379 for (auto [I, Op] : enumerate(First: Mask->ops())) {
3380 // Mask elements must be constants.
3381 if (Op.getNode()->getAsZExtVal() == 0) {
3382 // Append a sentinel register 0 to the Ops vector to represent a masked
3383 // off element, this will be handled in tablegen
3384 Ops.push_back(Elt: DAG.getRegister(Reg: MCRegister::NoRegister,
3385 VT: ValVT.getVectorElementType()));
3386 } else {
3387 // Extract the element from the vector to store
3388 SDValue ExtVal =
3389 DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: ValVT.getVectorElementType(),
3390 N1: Val, N2: DAG.getIntPtrConstant(Val: I, DL));
3391 Ops.push_back(Elt: ExtVal);
3392 }
3393 }
3394
3395 // Next, the pointer operand.
3396 Ops.push_back(Elt: BasePtr);
3397
3398 // Finally, the offset operand. We expect this to always be undef, and it will
3399 // be ignored in lowering, but to mirror the handling of the other vector
3400 // store instructions we include it in the new SDNode.
3401 assert(Offset.getOpcode() == ISD::UNDEF &&
3402 "Offset operand expected to be undef");
3403 Ops.push_back(Elt: Offset);
3404
3405 SDValue NewSt =
3406 DAG.getMemIntrinsicNode(Opcode, dl: DL, VTList: DAG.getVTList(VT: MVT::Other), Ops,
3407 MemVT: MemSD->getMemoryVT(), MMO: MemSD->getMemOperand());
3408
3409 return NewSt;
3410}
3411
3412SDValue
3413NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
3414 switch (Op.getOpcode()) {
3415 case ISD::RETURNADDR:
3416 return SDValue();
3417 case ISD::FRAMEADDR:
3418 return SDValue();
3419 case ISD::ADDRSPACECAST:
3420 return LowerADDRSPACECAST(Op, DAG);
3421 case ISD::INTRINSIC_W_CHAIN:
3422 return lowerIntrinsicWChain(Op, DAG);
3423 case ISD::INTRINSIC_WO_CHAIN:
3424 return lowerIntrinsicWOChain(Op, DAG);
3425 case ISD::INTRINSIC_VOID:
3426 return lowerIntrinsicVoid(Op, DAG);
3427 case ISD::BUILD_VECTOR:
3428 return LowerBUILD_VECTOR(Op, DAG);
3429 case ISD::BITCAST:
3430 return LowerBITCAST(Op, DAG);
3431 case ISD::EXTRACT_SUBVECTOR:
3432 return Op;
3433 case ISD::EXTRACT_VECTOR_ELT:
3434 return LowerEXTRACT_VECTOR_ELT(Op, DAG);
3435 case ISD::INSERT_VECTOR_ELT:
3436 return LowerINSERT_VECTOR_ELT(Op, DAG);
3437 case ISD::VECTOR_SHUFFLE:
3438 return LowerVECTOR_SHUFFLE(Op, DAG);
3439 case ISD::CONCAT_VECTORS:
3440 return LowerCONCAT_VECTORS(Op, DAG);
3441 case ISD::VECREDUCE_FMAX:
3442 case ISD::VECREDUCE_FMIN:
3443 case ISD::VECREDUCE_FMAXIMUM:
3444 case ISD::VECREDUCE_FMINIMUM:
3445 return LowerVECREDUCE(Op, DAG);
3446 case ISD::STORE:
3447 return LowerSTORE(Op, DAG);
3448 case ISD::MSTORE: {
3449 assert(STI.has256BitVectorLoadStore(
3450 cast<MemSDNode>(Op.getNode())->getAddressSpace()) &&
3451 "Masked store vector not supported on subtarget.");
3452 return lowerMSTORE(Op, DAG);
3453 }
3454 case ISD::LOAD:
3455 return LowerLOAD(Op, DAG);
3456 case ISD::MLOAD:
3457 return LowerMLOAD(Op, DAG);
3458 case ISD::SHL_PARTS:
3459 return LowerShiftLeftParts(Op, DAG);
3460 case ISD::SRA_PARTS:
3461 case ISD::SRL_PARTS:
3462 return LowerShiftRightParts(Op, DAG);
3463 case ISD::SELECT:
3464 return lowerSELECT(Op, DAG);
3465 case ISD::FROUND:
3466 return LowerFROUND(Op, DAG);
3467 case ISD::FCOPYSIGN:
3468 return LowerFCOPYSIGN(Op, DAG);
3469 case ISD::SINT_TO_FP:
3470 case ISD::UINT_TO_FP:
3471 return LowerINT_TO_FP(Op, DAG);
3472 case ISD::FP_TO_SINT:
3473 case ISD::FP_TO_UINT:
3474 return LowerFP_TO_INT(Op, DAG);
3475 case ISD::FP_ROUND:
3476 return LowerFP_ROUND(Op, DAG);
3477 case ISD::FP_EXTEND:
3478 return LowerFP_EXTEND(Op, DAG);
3479 case ISD::VAARG:
3480 return LowerVAARG(Op, DAG);
3481 case ISD::VASTART:
3482 return LowerVASTART(Op, DAG);
3483 case ISD::FSHL:
3484 case ISD::FSHR:
3485 return lowerFSH(Op, DAG);
3486 case ISD::ROTL:
3487 case ISD::ROTR:
3488 return lowerROT(Op, DAG);
3489 case ISD::ABS:
3490 case ISD::SMIN:
3491 case ISD::SMAX:
3492 case ISD::UMIN:
3493 case ISD::UMAX:
3494 case ISD::ADD:
3495 case ISD::SUB:
3496 case ISD::MUL:
3497 case ISD::SHL:
3498 case ISD::SREM:
3499 case ISD::UREM:
3500 return LowerVectorArith(Op, DAG);
3501 case ISD::DYNAMIC_STACKALLOC:
3502 return LowerDYNAMIC_STACKALLOC(Op, DAG);
3503 case ISD::STACKRESTORE:
3504 return LowerSTACKRESTORE(Op, DAG);
3505 case ISD::STACKSAVE:
3506 return LowerSTACKSAVE(Op, DAG);
3507 case ISD::CopyToReg:
3508 return LowerCopyToReg_128(Op, DAG);
3509 case ISD::FADD:
3510 case ISD::FSUB:
3511 case ISD::FMUL:
3512 // Used only for bf16 on SM80, where we select fma for non-ftz operation
3513 return PromoteBinOpIfF32FTZ(Op, DAG);
3514 case ISD::CTPOP:
3515 case ISD::CTLZ:
3516 return lowerCTLZCTPOP(Op, DAG);
3517 case ISD::FREM:
3518 return lowerFREM(Op, DAG);
3519 case ISD::BSWAP:
3520 return lowerBSWAP(Op, DAG);
3521 default:
3522 llvm_unreachable("Custom lowering not defined for operation");
3523 }
3524}
3525
3526// This will prevent AsmPrinter from trying to print the jump tables itself.
3527unsigned NVPTXTargetLowering::getJumpTableEncoding() const {
3528 return MachineJumpTableInfo::EK_Inline;
3529}
3530
3531SDValue NVPTXTargetLowering::LowerADDRSPACECAST(SDValue Op,
3532 SelectionDAG &DAG) const {
3533 AddrSpaceCastSDNode *N = cast<AddrSpaceCastSDNode>(Val: Op.getNode());
3534 unsigned SrcAS = N->getSrcAddressSpace();
3535 unsigned DestAS = N->getDestAddressSpace();
3536 if (SrcAS != llvm::ADDRESS_SPACE_GENERIC &&
3537 DestAS != llvm::ADDRESS_SPACE_GENERIC) {
3538 // Shared and SharedCluster can be converted to each other through generic
3539 // space
3540 if ((SrcAS == llvm::ADDRESS_SPACE_SHARED &&
3541 DestAS == llvm::ADDRESS_SPACE_SHARED_CLUSTER) ||
3542 (SrcAS == llvm::ADDRESS_SPACE_SHARED_CLUSTER &&
3543 DestAS == llvm::ADDRESS_SPACE_SHARED)) {
3544 SDLoc DL(Op.getNode());
3545 const MVT GenerictVT =
3546 getPointerTy(DL: DAG.getDataLayout(), AS: ADDRESS_SPACE_GENERIC);
3547 SDValue GenericConversion = DAG.getAddrSpaceCast(
3548 dl: DL, VT: GenerictVT, Ptr: Op.getOperand(i: 0), SrcAS, DestAS: ADDRESS_SPACE_GENERIC);
3549 SDValue SharedClusterConversion =
3550 DAG.getAddrSpaceCast(dl: DL, VT: Op.getValueType(), Ptr: GenericConversion,
3551 SrcAS: ADDRESS_SPACE_GENERIC, DestAS);
3552 return SharedClusterConversion;
3553 }
3554
3555 return DAG.getUNDEF(VT: Op.getValueType());
3556 }
3557
3558 return Op;
3559}
3560
3561// This function is almost a copy of SelectionDAG::expandVAArg().
3562// The only diff is that this one produces loads from local address space.
3563SDValue NVPTXTargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {
3564 const TargetLowering *TLI = STI.getTargetLowering();
3565 SDLoc DL(Op);
3566
3567 SDNode *Node = Op.getNode();
3568 const Value *V = cast<SrcValueSDNode>(Val: Node->getOperand(Num: 2))->getValue();
3569 EVT VT = Node->getValueType(ResNo: 0);
3570 auto *Ty = VT.getTypeForEVT(Context&: *DAG.getContext());
3571 SDValue Tmp1 = Node->getOperand(Num: 0);
3572 SDValue Tmp2 = Node->getOperand(Num: 1);
3573 const MaybeAlign MA(Node->getConstantOperandVal(Num: 3));
3574
3575 SDValue VAListLoad = DAG.getLoad(VT: TLI->getPointerTy(DL: DAG.getDataLayout()), dl: DL,
3576 Chain: Tmp1, Ptr: Tmp2, PtrInfo: MachinePointerInfo(V));
3577 SDValue VAList = VAListLoad;
3578
3579 if (MA && *MA > TLI->getMinStackArgumentAlignment()) {
3580 VAList = DAG.getNode(
3581 Opcode: ISD::ADD, DL, VT: VAList.getValueType(), N1: VAList,
3582 N2: DAG.getConstant(Val: MA->value() - 1, DL, VT: VAList.getValueType()));
3583
3584 VAList = DAG.getNode(Opcode: ISD::AND, DL, VT: VAList.getValueType(), N1: VAList,
3585 N2: DAG.getSignedConstant(Val: -(int64_t)MA->value(), DL,
3586 VT: VAList.getValueType()));
3587 }
3588
3589 // Increment the pointer, VAList, to the next vaarg
3590 Tmp1 = DAG.getNode(Opcode: ISD::ADD, DL, VT: VAList.getValueType(), N1: VAList,
3591 N2: DAG.getConstant(Val: DAG.getDataLayout().getTypeAllocSize(Ty),
3592 DL, VT: VAList.getValueType()));
3593
3594 // Store the incremented VAList to the legalized pointer
3595 Tmp1 = DAG.getStore(Chain: VAListLoad.getValue(R: 1), dl: DL, Val: Tmp1, Ptr: Tmp2,
3596 PtrInfo: MachinePointerInfo(V));
3597
3598 const Value *SrcV = Constant::getNullValue(
3599 Ty: PointerType::get(C&: *DAG.getContext(), AddressSpace: ADDRESS_SPACE_LOCAL));
3600
3601 // Load the actual argument out of the pointer VAList
3602 return DAG.getLoad(VT, dl: DL, Chain: Tmp1, Ptr: VAList, PtrInfo: MachinePointerInfo(SrcV));
3603}
3604
3605SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
3606 const TargetLowering *TLI = STI.getTargetLowering();
3607 SDLoc DL(Op);
3608 EVT PtrVT = TLI->getPointerTy(DL: DAG.getDataLayout());
3609
3610 // Store the address of unsized array <function>_vararg[] in the ap object.
3611 SDValue VAReg = getParamSymbol(DAG, /* vararg */ I: -1, T: PtrVT);
3612
3613 const Value *SV = cast<SrcValueSDNode>(Val: Op.getOperand(i: 2))->getValue();
3614 return DAG.getStore(Chain: Op.getOperand(i: 0), dl: DL, Val: VAReg, Ptr: Op.getOperand(i: 1),
3615 PtrInfo: MachinePointerInfo(SV));
3616}
3617
3618static std::pair<MemSDNode *, uint32_t>
3619convertMLOADToLoadWithUsedBytesMask(MemSDNode *N, SelectionDAG &DAG,
3620 const NVPTXSubtarget &STI) {
3621 SDValue Chain = N->getOperand(Num: 0);
3622 SDValue BasePtr = N->getOperand(Num: 1);
3623 SDValue Mask = N->getOperand(Num: 3);
3624 [[maybe_unused]] SDValue Passthru = N->getOperand(Num: 4);
3625
3626 SDLoc DL(N);
3627 EVT ResVT = N->getValueType(ResNo: 0);
3628 assert(ResVT.isVector() && "Masked vector load must have vector type");
3629 // While we only expect poison passthru vectors as an input to the backend,
3630 // when the legalization framework splits a poison vector in half, it creates
3631 // two undef vectors, so we can technically expect those too.
3632 assert((Passthru.getOpcode() == ISD::POISON ||
3633 Passthru.getOpcode() == ISD::UNDEF) &&
3634 "Passthru operand expected to be poison or undef");
3635
3636 // Extract the mask and convert it to a uint32_t representing the used bytes
3637 // of the entire vector load
3638 uint32_t UsedBytesMask = 0;
3639 uint32_t ElementSizeInBits = ResVT.getVectorElementType().getSizeInBits();
3640 assert(ElementSizeInBits % 8 == 0 && "Unexpected element size");
3641 uint32_t ElementSizeInBytes = ElementSizeInBits / 8;
3642 uint32_t ElementMask = (1u << ElementSizeInBytes) - 1u;
3643
3644 for (SDValue Op : reverse(C: Mask->ops())) {
3645 // We technically only want to do this shift for every
3646 // iteration *but* the first, but in the first iteration UsedBytesMask is 0,
3647 // so this shift is a no-op.
3648 UsedBytesMask <<= ElementSizeInBytes;
3649
3650 // Mask elements must be constants.
3651 if (Op->getAsZExtVal() != 0)
3652 UsedBytesMask |= ElementMask;
3653 }
3654
3655 assert(UsedBytesMask != 0 && UsedBytesMask != UINT32_MAX &&
3656 "Unexpected masked load with elements masked all on or all off");
3657
3658 // Create a new load sd node to be handled normally by ReplaceLoadVector.
3659 MemSDNode *NewLD = cast<MemSDNode>(
3660 Val: DAG.getLoad(VT: ResVT, dl: DL, Chain, Ptr: BasePtr, MMO: N->getMemOperand()).getNode());
3661
3662 // If our subtarget does not support the used bytes mask pragma, "drop" the
3663 // mask by setting it to UINT32_MAX
3664 if (!STI.hasUsedBytesMaskPragma())
3665 UsedBytesMask = UINT32_MAX;
3666
3667 return {NewLD, UsedBytesMask};
3668}
3669
3670/// replaceLoadVector - Convert vector loads into multi-output scalar loads.
3671static std::optional<std::pair<SDValue, SDValue>>
3672replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
3673 MemSDNode *LD = cast<MemSDNode>(Val: N);
3674 const EVT ResVT = LD->getValueType(ResNo: 0);
3675 const EVT MemVT = LD->getMemoryVT();
3676
3677 // If we're doing sign/zero extension as part of the load, avoid lowering to
3678 // a LoadV node. TODO: consider relaxing this restriction.
3679 if (ResVT != MemVT)
3680 return std::nullopt;
3681
3682 const auto NumEltsAndEltVT =
3683 getVectorLoweringShape(VectorEVT: ResVT, STI, AddressSpace: LD->getAddressSpace());
3684 if (!NumEltsAndEltVT)
3685 return std::nullopt;
3686 const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
3687
3688 Align Alignment = LD->getAlign();
3689 const auto &TD = DAG.getDataLayout();
3690 Align PrefAlign = TD.getPrefTypeAlign(Ty: MemVT.getTypeForEVT(Context&: *DAG.getContext()));
3691 if (Alignment < PrefAlign) {
3692 // This load is not sufficiently aligned, so bail out and let this vector
3693 // load be scalarized. Note that we may still be able to emit smaller
3694 // vector loads. For example, if we are loading a <4 x float> with an
3695 // alignment of 8, this check will fail but the legalizer will try again
3696 // with 2 x <2 x float>, which will succeed with an alignment of 8.
3697 return std::nullopt;
3698 }
3699
3700 // If we have a masked load, convert it to a normal load now
3701 std::optional<uint32_t> UsedBytesMask = std::nullopt;
3702 if (LD->getOpcode() == ISD::MLOAD)
3703 std::tie(args&: LD, args&: UsedBytesMask) =
3704 convertMLOADToLoadWithUsedBytesMask(N: LD, DAG, STI);
3705
3706 // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
3707 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
3708 // loaded type to i16 and propagate the "real" type as the memory type.
3709 const MVT LoadEltVT = (EltVT.getSizeInBits() < 16) ? MVT::i16 : EltVT;
3710
3711 unsigned Opcode;
3712 switch (NumElts) {
3713 default:
3714 return std::nullopt;
3715 case 2:
3716 Opcode = NVPTXISD::LoadV2;
3717 break;
3718 case 4:
3719 Opcode = NVPTXISD::LoadV4;
3720 break;
3721 case 8:
3722 Opcode = NVPTXISD::LoadV8;
3723 break;
3724 }
3725 auto ListVTs = SmallVector<EVT, 9>(NumElts, LoadEltVT);
3726 ListVTs.push_back(Elt: MVT::Other);
3727 SDVTList LdResVTs = DAG.getVTList(VTs: ListVTs);
3728
3729 SDLoc DL(LD);
3730
3731 // Copy regular operands
3732 SmallVector<SDValue, 8> OtherOps(LD->ops());
3733
3734 OtherOps.push_back(
3735 Elt: DAG.getConstant(Val: UsedBytesMask.value_or(UINT32_MAX), DL, VT: MVT::i32));
3736
3737 // The select routine does not have access to the LoadSDNode instance, so
3738 // pass along the extension information
3739 OtherOps.push_back(
3740 Elt: DAG.getIntPtrConstant(Val: cast<LoadSDNode>(Val: LD)->getExtensionType(), DL));
3741
3742 SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, dl: DL, VTList: LdResVTs, Ops: OtherOps, MemVT,
3743 MMO: LD->getMemOperand());
3744
3745 SmallVector<SDValue> ScalarRes;
3746 if (EltVT.isVector()) {
3747 assert(EVT(EltVT.getVectorElementType()) == ResVT.getVectorElementType());
3748 assert(NumElts * EltVT.getVectorNumElements() ==
3749 ResVT.getVectorNumElements());
3750 // Generate EXTRACT_VECTOR_ELTs to split v2[i,f,bf]16/v4i8 subvectors back
3751 // into individual elements.
3752 for (const unsigned I : llvm::seq(Size: NumElts)) {
3753 SDValue SubVector = NewLD.getValue(R: I);
3754 DAG.ExtractVectorElements(Op: SubVector, Args&: ScalarRes);
3755 }
3756 } else {
3757 for (const unsigned I : llvm::seq(Size: NumElts)) {
3758 SDValue Res = NewLD.getValue(R: I);
3759 if (LoadEltVT != EltVT)
3760 Res = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: EltVT, Operand: Res);
3761 ScalarRes.push_back(Elt: Res);
3762 }
3763 }
3764
3765 SDValue LoadChain = NewLD.getValue(R: NumElts);
3766
3767 const MVT BuildVecVT =
3768 MVT::getVectorVT(VT: EltVT.getScalarType(), NumElements: ScalarRes.size());
3769 SDValue BuildVec = DAG.getBuildVector(VT: BuildVecVT, DL, Ops: ScalarRes);
3770 SDValue LoadValue = DAG.getBitcast(VT: ResVT, V: BuildVec);
3771
3772 return {{LoadValue, LoadChain}};
3773}
3774
3775static void replaceLoadVector(SDNode *N, SelectionDAG &DAG,
3776 SmallVectorImpl<SDValue> &Results,
3777 const NVPTXSubtarget &STI) {
3778 if (auto Res = replaceLoadVector(N, DAG, STI))
3779 Results.append(IL: {Res->first, Res->second});
3780}
3781
3782static SDValue lowerLoadVector(SDNode *N, SelectionDAG &DAG,
3783 const NVPTXSubtarget &STI) {
3784 if (auto Res = replaceLoadVector(N, DAG, STI))
3785 return DAG.getMergeValues(Ops: {Res->first, Res->second}, dl: SDLoc(N));
3786 return SDValue();
3787}
3788
3789// v = ld i1* addr
3790// =>
3791// v1 = ld i8* addr (-> i16)
3792// v = trunc i16 to i1
3793static SDValue lowerLOADi1(LoadSDNode *LD, SelectionDAG &DAG) {
3794 SDLoc dl(LD);
3795 assert(LD->getExtensionType() == ISD::NON_EXTLOAD);
3796 assert(LD->getValueType(0) == MVT::i1 && "Custom lowering for i1 load only");
3797 SDValue newLD = DAG.getExtLoad(ExtType: ISD::ZEXTLOAD, dl, VT: MVT::i16, Chain: LD->getChain(),
3798 Ptr: LD->getBasePtr(), PtrInfo: LD->getPointerInfo(),
3799 MemVT: MVT::i8, Alignment: LD->getAlign(),
3800 MMOFlags: LD->getMemOperand()->getFlags());
3801 SDValue result = DAG.getNode(Opcode: ISD::TRUNCATE, DL: dl, VT: MVT::i1, Operand: newLD);
3802 // The legalizer (the caller) is expecting two values from the legalized
3803 // load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
3804 // in LegalizeDAG.cpp which also uses MergeValues.
3805 return DAG.getMergeValues(Ops: {result, LD->getChain()}, dl);
3806}
3807
3808SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
3809 LoadSDNode *LD = cast<LoadSDNode>(Val&: Op);
3810
3811 if (Op.getValueType() == MVT::i1)
3812 return lowerLOADi1(LD, DAG);
3813
3814 // To improve CodeGen we'll legalize any-extend loads to zext loads. This is
3815 // how they'll be lowered in ISel anyway, and by doing this a little earlier
3816 // we allow for more DAG combine opportunities.
3817 if (LD->getExtensionType() == ISD::EXTLOAD) {
3818 assert(LD->getValueType(0).isInteger() && LD->getMemoryVT().isInteger() &&
3819 "Unexpected fpext-load");
3820 return DAG.getExtLoad(ExtType: ISD::ZEXTLOAD, dl: SDLoc(Op), VT: Op.getValueType(),
3821 Chain: LD->getChain(), Ptr: LD->getBasePtr(), MemVT: LD->getMemoryVT(),
3822 MMO: LD->getMemOperand());
3823 }
3824
3825 llvm_unreachable("Unexpected custom lowering for load");
3826}
3827
3828SDValue NVPTXTargetLowering::LowerMLOAD(SDValue Op, SelectionDAG &DAG) const {
3829 // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
3830 // masked loads of these types and have to handle them here.
3831 // v2f32 also needs to be handled here if the subtarget has f32x2
3832 // instructions, making it legal.
3833 //
3834 // Note: misaligned masked loads should never reach this point
3835 // because the override of isLegalMaskedLoad in NVPTXTargetTransformInfo.cpp
3836 // will validate alignment. Therefore, we do not need to special case handle
3837 // them here.
3838 EVT VT = Op.getValueType();
3839 if (NVPTX::isPackedVectorTy(VT)) {
3840 auto Result = convertMLOADToLoadWithUsedBytesMask(
3841 N: cast<MemSDNode>(Val: Op.getNode()), DAG, STI);
3842 MemSDNode *LD = std::get<0>(in&: Result);
3843 uint32_t UsedBytesMask = std::get<1>(in&: Result);
3844
3845 SDLoc DL(LD);
3846
3847 // Copy regular operands
3848 SmallVector<SDValue, 8> OtherOps(LD->ops());
3849
3850 OtherOps.push_back(Elt: DAG.getConstant(Val: UsedBytesMask, DL, VT: MVT::i32));
3851
3852 // We currently are not lowering extending loads, but pass the extension
3853 // type anyway as later handling expects it.
3854 OtherOps.push_back(
3855 Elt: DAG.getIntPtrConstant(Val: cast<LoadSDNode>(Val: LD)->getExtensionType(), DL));
3856 SDValue NewLD =
3857 DAG.getMemIntrinsicNode(Opcode: NVPTXISD::MLoad, dl: DL, VTList: LD->getVTList(), Ops: OtherOps,
3858 MemVT: LD->getMemoryVT(), MMO: LD->getMemOperand());
3859 return NewLD;
3860 }
3861 return SDValue();
3862}
3863
3864static SDValue lowerSTOREVector(SDValue Op, SelectionDAG &DAG,
3865 const NVPTXSubtarget &STI) {
3866 MemSDNode *N = cast<MemSDNode>(Val: Op.getNode());
3867 SDValue Val = N->getOperand(Num: 1);
3868 SDLoc DL(N);
3869 const EVT ValVT = Val.getValueType();
3870 const EVT MemVT = N->getMemoryVT();
3871
3872 // If we're truncating as part of the store, avoid lowering to a StoreV node.
3873 // TODO: consider relaxing this restriction.
3874 if (ValVT != MemVT)
3875 return SDValue();
3876
3877 const auto NumEltsAndEltVT =
3878 getVectorLoweringShape(VectorEVT: ValVT, STI, AddressSpace: N->getAddressSpace());
3879 if (!NumEltsAndEltVT)
3880 return SDValue();
3881 const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
3882
3883 const DataLayout &TD = DAG.getDataLayout();
3884
3885 Align Alignment = N->getAlign();
3886 Align PrefAlign = TD.getPrefTypeAlign(Ty: ValVT.getTypeForEVT(Context&: *DAG.getContext()));
3887 if (Alignment < PrefAlign) {
3888 // This store is not sufficiently aligned, so bail out and let this vector
3889 // store be scalarized. Note that we may still be able to emit smaller
3890 // vector stores. For example, if we are storing a <4 x float> with an
3891 // alignment of 8, this check will fail but the legalizer will try again
3892 // with 2 x <2 x float>, which will succeed with an alignment of 8.
3893 return SDValue();
3894 }
3895
3896 unsigned Opcode;
3897 switch (NumElts) {
3898 default:
3899 return SDValue();
3900 case 2:
3901 Opcode = NVPTXISD::StoreV2;
3902 break;
3903 case 4:
3904 Opcode = NVPTXISD::StoreV4;
3905 break;
3906 case 8:
3907 Opcode = NVPTXISD::StoreV8;
3908 break;
3909 }
3910
3911 SmallVector<SDValue, 8> Ops;
3912
3913 // First is the chain
3914 Ops.push_back(Elt: N->getOperand(Num: 0));
3915
3916 // Then the split values
3917 if (EltVT.isVector()) {
3918 assert(EVT(EltVT.getVectorElementType()) == ValVT.getVectorElementType());
3919 assert(NumElts * EltVT.getVectorNumElements() ==
3920 ValVT.getVectorNumElements());
3921 // Combine individual elements into v2[i,f,bf]16/v4i8 subvectors to be
3922 // stored as b32s
3923 const unsigned NumEltsPerSubVector = EltVT.getVectorNumElements();
3924 for (const unsigned I : llvm::seq(Size: NumElts)) {
3925 SmallVector<SDValue, 4> SubVectorElts;
3926 DAG.ExtractVectorElements(Op: Val, Args&: SubVectorElts, Start: I * NumEltsPerSubVector,
3927 Count: NumEltsPerSubVector);
3928 Ops.push_back(Elt: DAG.getBuildVector(VT: EltVT, DL, Ops: SubVectorElts));
3929 }
3930 } else {
3931 SDValue V = DAG.getBitcast(VT: MVT::getVectorVT(VT: EltVT, NumElements: NumElts), V: Val);
3932 for (const unsigned I : llvm::seq(Size: NumElts)) {
3933 SDValue ExtVal = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: V,
3934 N2: DAG.getIntPtrConstant(Val: I, DL));
3935
3936 // Since StoreV2 is a target node, we cannot rely on DAG type
3937 // legalization. Therefore, we must ensure the type is legal. For i1 and
3938 // i8, we set the stored type to i16 and propagate the "real" type as the
3939 // memory type.
3940 if (EltVT.getSizeInBits() < 16)
3941 ExtVal = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT: MVT::i16, Operand: ExtVal);
3942 Ops.push_back(Elt: ExtVal);
3943 }
3944 }
3945
3946 // Then any remaining arguments
3947 Ops.append(in_start: N->op_begin() + 2, in_end: N->op_end());
3948
3949 SDValue NewSt =
3950 DAG.getMemIntrinsicNode(Opcode, dl: DL, VTList: DAG.getVTList(VT: MVT::Other), Ops,
3951 MemVT: N->getMemoryVT(), MMO: N->getMemOperand());
3952
3953 // return DCI.CombineTo(N, NewSt, true);
3954 return NewSt;
3955}
3956
3957SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
3958 StoreSDNode *Store = cast<StoreSDNode>(Val&: Op);
3959 EVT VT = Store->getMemoryVT();
3960
3961 if (VT == MVT::i1)
3962 return LowerSTOREi1(Op, DAG);
3963
3964 // Lower store of any other vector type, including v2f32 as we want to break
3965 // it apart since this is not a widely-supported type.
3966 return lowerSTOREVector(Op, DAG, STI);
3967}
3968
3969// st i1 v, addr
3970// =>
3971// v1 = zxt v to i16
3972// st.u8 i16, addr
3973SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
3974 SDNode *Node = Op.getNode();
3975 SDLoc dl(Node);
3976 StoreSDNode *ST = cast<StoreSDNode>(Val: Node);
3977 SDValue Tmp1 = ST->getChain();
3978 SDValue Tmp2 = ST->getBasePtr();
3979 SDValue Tmp3 = ST->getValue();
3980 assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only");
3981 Tmp3 = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL: dl, VT: MVT::i16, Operand: Tmp3);
3982 SDValue Result =
3983 DAG.getTruncStore(Chain: Tmp1, dl, Val: Tmp3, Ptr: Tmp2, PtrInfo: ST->getPointerInfo(), SVT: MVT::i8,
3984 Alignment: ST->getAlign(), MMOFlags: ST->getMemOperand()->getFlags());
3985 return Result;
3986}
3987
3988SDValue NVPTXTargetLowering::LowerCopyToReg_128(SDValue Op,
3989 SelectionDAG &DAG) const {
3990 // Change the CopyToReg to take in two 64-bit operands instead of a 128-bit
3991 // operand so that it can pass the legalization.
3992
3993 assert(Op.getOperand(1).getValueType() == MVT::i128 &&
3994 "Custom lowering for 128-bit CopyToReg only");
3995
3996 SDNode *Node = Op.getNode();
3997 SDLoc DL(Node);
3998
3999 SDValue Cast = DAG.getBitcast(VT: MVT::v2i64, V: Op->getOperand(Num: 2));
4000 SDValue Lo = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i64, N1: Cast,
4001 N2: DAG.getIntPtrConstant(Val: 0, DL));
4002 SDValue Hi = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i64, N1: Cast,
4003 N2: DAG.getIntPtrConstant(Val: 1, DL));
4004
4005 SmallVector<SDValue, 5> NewOps(Op->getNumOperands() + 1);
4006 SmallVector<EVT, 3> ResultsType(Node->values());
4007
4008 NewOps[0] = Op->getOperand(Num: 0); // Chain
4009 NewOps[1] = Op->getOperand(Num: 1); // Dst Reg
4010 NewOps[2] = Lo; // Lower 64-bit
4011 NewOps[3] = Hi; // Higher 64-bit
4012 if (Op.getNumOperands() == 4)
4013 NewOps[4] = Op->getOperand(Num: 3); // Glue if exists
4014
4015 return DAG.getNode(Opcode: ISD::CopyToReg, DL, ResultTys: ResultsType, Ops: NewOps);
4016}
4017
4018unsigned NVPTXTargetLowering::getNumRegisters(
4019 LLVMContext &Context, EVT VT,
4020 std::optional<MVT> RegisterVT = std::nullopt) const {
4021 if (VT == MVT::i128 && RegisterVT == MVT::i128)
4022 return 1;
4023 return TargetLoweringBase::getNumRegisters(Context, VT, RegisterVT);
4024}
4025
4026bool NVPTXTargetLowering::splitValueIntoRegisterParts(
4027 SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
4028 unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
4029 if (Val.getValueType() == MVT::i128 && NumParts == 1) {
4030 Parts[0] = Val;
4031 return true;
4032 }
4033 return false;
4034}
4035
4036// This creates target external symbol for a function parameter.
4037// Name of the symbol is composed from its index and the function name.
4038// Negative index corresponds to special parameter (unsized array) used for
4039// passing variable arguments.
4040SDValue NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int I,
4041 EVT T) const {
4042 StringRef SavedStr = nvTM->getStrPool().save(
4043 S: getParamName(F: &DAG.getMachineFunction().getFunction(), Idx: I));
4044 return DAG.getExternalSymbol(Sym: SavedStr.data(), VT: T);
4045}
4046
4047SDValue NVPTXTargetLowering::getCallParamSymbol(SelectionDAG &DAG, int I,
4048 EVT T) const {
4049 const StringRef SavedStr = nvTM->getStrPool().save(S: "param" + Twine(I));
4050 return DAG.getExternalSymbol(Sym: SavedStr.data(), VT: T);
4051}
4052
4053SDValue NVPTXTargetLowering::LowerFormalArguments(
4054 SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
4055 const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &dl,
4056 SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
4057 const DataLayout &DL = DAG.getDataLayout();
4058 LLVMContext &Ctx = *DAG.getContext();
4059 auto PtrVT = getPointerTy(DL: DAG.getDataLayout());
4060
4061 const Function &F = DAG.getMachineFunction().getFunction();
4062
4063 SDValue Root = DAG.getRoot();
4064 SmallVector<SDValue, 16> OutChains;
4065
4066 // argTypes.size() (or theArgs.size()) and Ins.size() need not match.
4067 // Ins.size() will be larger
4068 // * if there is an aggregate argument with multiple fields (each field
4069 // showing up separately in Ins)
4070 // * if there is a vector argument with more than typical vector-length
4071 // elements (generally if more than 4) where each vector element is
4072 // individually present in Ins.
4073 // So a different index should be used for indexing into Ins.
4074 // See similar issue in LowerCall.
4075
4076 auto AllIns = ArrayRef(Ins);
4077 for (const auto &Arg : F.args()) {
4078 const auto ArgIns = AllIns.take_while(
4079 Pred: [&](auto I) { return I.OrigArgIndex == Arg.getArgNo(); });
4080 AllIns = AllIns.drop_front(N: ArgIns.size());
4081
4082 Type *Ty = Arg.getType();
4083
4084 if (ArgIns.empty())
4085 report_fatal_error(reason: "Empty parameter types are not supported");
4086
4087 if (Arg.use_empty()) {
4088 // argument is dead
4089 for (const auto &In : ArgIns) {
4090 assert(!In.Used && "Arg.use_empty() is true but Arg is used?");
4091 InVals.push_back(Elt: DAG.getUNDEF(VT: In.VT));
4092 }
4093 continue;
4094 }
4095
4096 SDValue ArgSymbol = getParamSymbol(DAG, I: Arg.getArgNo(), T: PtrVT);
4097
4098 // In the following cases, assign a node order of "i+1"
4099 // to newly created nodes. The SDNodes for params have to
4100 // appear in the same order as their order of appearance
4101 // in the original function. "i+1" holds that order.
4102 if (Arg.hasByValAttr()) {
4103 // Param has ByVal attribute
4104 // Return MoveParam(param symbol).
4105 // Ideally, the param symbol can be returned directly,
4106 // but when SDNode builder decides to use it in a CopyToReg(),
4107 // machine instruction fails because TargetExternalSymbol
4108 // (not lowered) is target dependent, and CopyToReg assumes
4109 // the source is lowered.
4110 assert(ArgIns.size() == 1 && "ByVal argument must be a pointer");
4111 const auto &ByvalIn = ArgIns[0];
4112 assert(getValueType(DL, Ty) == ByvalIn.VT &&
4113 "Ins type did not match function type");
4114 assert(ByvalIn.VT == PtrVT && "ByVal argument must be a pointer");
4115
4116 SDValue P;
4117 if (isKernelFunction(F)) {
4118 P = ArgSymbol;
4119 P.getNode()->setIROrder(Arg.getArgNo() + 1);
4120 } else {
4121 P = DAG.getNode(Opcode: NVPTXISD::MoveParam, DL: dl, VT: ByvalIn.VT, Operand: ArgSymbol);
4122 P.getNode()->setIROrder(Arg.getArgNo() + 1);
4123 P = DAG.getAddrSpaceCast(dl, VT: ByvalIn.VT, Ptr: P, SrcAS: ADDRESS_SPACE_LOCAL,
4124 DestAS: ADDRESS_SPACE_GENERIC);
4125 }
4126 InVals.push_back(Elt: P);
4127 } else {
4128 SmallVector<EVT, 16> VTs;
4129 SmallVector<uint64_t, 16> Offsets;
4130 ComputePTXValueVTs(TLI: *this, DL, Ctx, CallConv, Ty, ValueVTs&: VTs, Offsets);
4131 assert(VTs.size() == ArgIns.size() && "Size mismatch");
4132 assert(VTs.size() == Offsets.size() && "Size mismatch");
4133
4134 const Align ArgAlign = getFunctionArgumentAlignment(
4135 F: &F, Ty, Idx: Arg.getArgNo() + AttributeList::FirstArgIndex, DL);
4136
4137 unsigned I = 0;
4138 const auto VI = VectorizePTXValueVTs(ValueVTs: VTs, Offsets, ParamAlignment: ArgAlign);
4139 for (const unsigned NumElts : VI) {
4140 // i1 is loaded/stored as i8
4141 const EVT LoadVT = VTs[I] == MVT::i1 ? MVT::i8 : VTs[I];
4142 const EVT VecVT = getVectorizedVT(VT: LoadVT, N: NumElts, C&: Ctx);
4143
4144 SDValue VecAddr = DAG.getObjectPtrOffset(
4145 SL: dl, Ptr: ArgSymbol, Offset: TypeSize::getFixed(ExactSize: Offsets[I]));
4146
4147 const Align PartAlign = commonAlignment(A: ArgAlign, Offset: Offsets[I]);
4148 SDValue P =
4149 DAG.getLoad(VT: VecVT, dl, Chain: Root, Ptr: VecAddr,
4150 PtrInfo: MachinePointerInfo(ADDRESS_SPACE_PARAM), Alignment: PartAlign,
4151 MMOFlags: MachineMemOperand::MODereferenceable |
4152 MachineMemOperand::MOInvariant);
4153 P.getNode()->setIROrder(Arg.getArgNo() + 1);
4154 for (const unsigned J : llvm::seq(Size: NumElts)) {
4155 SDValue Elt = getExtractVectorizedValue(V: P, I: J, VT: LoadVT, dl, DAG);
4156
4157 Elt = correctParamType(V: Elt, ExpectedVT: ArgIns[I + J].VT, Flags: ArgIns[I + J].Flags,
4158 DAG, dl);
4159 InVals.push_back(Elt);
4160 }
4161 I += NumElts;
4162 }
4163 }
4164 }
4165
4166 if (!OutChains.empty())
4167 DAG.setRoot(DAG.getTokenFactor(DL: dl, Vals&: OutChains));
4168
4169 return Chain;
4170}
4171
4172SDValue
4173NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
4174 bool isVarArg,
4175 const SmallVectorImpl<ISD::OutputArg> &Outs,
4176 const SmallVectorImpl<SDValue> &OutVals,
4177 const SDLoc &dl, SelectionDAG &DAG) const {
4178 const Function &F = DAG.getMachineFunction().getFunction();
4179 Type *RetTy = F.getReturnType();
4180
4181 if (RetTy->isVoidTy()) {
4182 assert(OutVals.empty() && Outs.empty() && "Return value expected for void");
4183 return DAG.getNode(Opcode: NVPTXISD::RET_GLUE, DL: dl, VT: MVT::Other, Operand: Chain);
4184 }
4185
4186 const DataLayout &DL = DAG.getDataLayout();
4187 LLVMContext &Ctx = *DAG.getContext();
4188
4189 const SDValue RetSymbol = DAG.getExternalSymbol(Sym: "func_retval0", VT: MVT::i32);
4190 const auto RetAlign = getFunctionParamOptimizedAlign(F: &F, ArgTy: RetTy, DL);
4191
4192 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
4193 // 32-bits are sign extended or zero extended, depending on whether
4194 // they are signed or unsigned types.
4195 const bool ExtendIntegerRetVal =
4196 RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty: RetTy) < 32;
4197
4198 SmallVector<EVT, 16> VTs;
4199 SmallVector<uint64_t, 16> Offsets;
4200 ComputePTXValueVTs(TLI: *this, DL, Ctx, CallConv, Ty: RetTy, ValueVTs&: VTs, Offsets);
4201 assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
4202
4203 const auto GetRetVal = [&](unsigned I) -> SDValue {
4204 SDValue RetVal = OutVals[I];
4205 assert(promoteScalarIntegerPTX(RetVal.getValueType()) ==
4206 RetVal.getValueType() &&
4207 "OutVal type should always be legal");
4208
4209 const EVT VTI = promoteScalarIntegerPTX(VT: VTs[I]);
4210 const EVT StoreVT =
4211 ExtendIntegerRetVal ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
4212 return correctParamType(V: RetVal, ExpectedVT: StoreVT, Flags: Outs[I].Flags, DAG, dl);
4213 };
4214
4215 unsigned I = 0;
4216 const auto VI = VectorizePTXValueVTs(ValueVTs: VTs, Offsets, ParamAlignment: RetAlign);
4217 for (const unsigned NumElts : VI) {
4218 const MaybeAlign CurrentAlign = ExtendIntegerRetVal
4219 ? MaybeAlign(std::nullopt)
4220 : commonAlignment(A: RetAlign, Offset: Offsets[I]);
4221
4222 SDValue Val = getBuildVectorizedValue(
4223 N: NumElts, dl, DAG, GetElement: [&](unsigned K) { return GetRetVal(I + K); });
4224
4225 SDValue Ptr =
4226 DAG.getObjectPtrOffset(SL: dl, Ptr: RetSymbol, Offset: TypeSize::getFixed(ExactSize: Offsets[I]));
4227
4228 Chain = DAG.getStore(Chain, dl, Val, Ptr,
4229 PtrInfo: MachinePointerInfo(ADDRESS_SPACE_PARAM), Alignment: CurrentAlign);
4230
4231 I += NumElts;
4232 }
4233
4234 return DAG.getNode(Opcode: NVPTXISD::RET_GLUE, DL: dl, VT: MVT::Other, Operand: Chain);
4235}
4236
4237void NVPTXTargetLowering::LowerAsmOperandForConstraint(
4238 SDValue Op, StringRef Constraint, std::vector<SDValue> &Ops,
4239 SelectionDAG &DAG) const {
4240 if (Constraint.size() > 1)
4241 return;
4242 TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
4243}
4244
4245// llvm.ptx.memcpy.const and llvm.ptx.memmove.const need to be modeled as
4246// TgtMemIntrinsic
4247// because we need the information that is only available in the "Value" type
4248// of destination
4249// pointer. In particular, the address space information.
4250void NVPTXTargetLowering::getTgtMemIntrinsic(
4251 SmallVectorImpl<IntrinsicInfo> &Infos, const CallBase &I,
4252 MachineFunction &MF, unsigned Intrinsic) const {
4253 IntrinsicInfo Info;
4254 switch (Intrinsic) {
4255 default:
4256 return;
4257 case Intrinsic::nvvm_match_all_sync_i32p:
4258 case Intrinsic::nvvm_match_all_sync_i64p:
4259 Info.opc = ISD::INTRINSIC_W_CHAIN;
4260 // memVT is bogus. These intrinsics have IntrInaccessibleMemOnly attribute
4261 // in order to model data exchange with other threads, but perform no real
4262 // memory accesses.
4263 Info.memVT = MVT::i1;
4264
4265 // Our result depends on both our and other thread's arguments.
4266 Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOStore;
4267 Infos.push_back(Elt: Info);
4268 return;
4269 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col:
4270 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row:
4271 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col_stride:
4272 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row_stride:
4273 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col:
4274 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row:
4275 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col_stride:
4276 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride:
4277 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col:
4278 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row:
4279 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col_stride:
4280 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row_stride:
4281 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col:
4282 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row:
4283 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col_stride:
4284 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row_stride:
4285 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col:
4286 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row:
4287 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col_stride:
4288 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row_stride:
4289 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col:
4290 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row:
4291 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col_stride:
4292 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row_stride: {
4293 Info.opc = ISD::INTRINSIC_W_CHAIN;
4294 Info.memVT = MVT::v8f16;
4295 Info.ptrVal = I.getArgOperand(i: 0);
4296 Info.offset = 0;
4297 Info.flags = MachineMemOperand::MOLoad;
4298 Info.align = Align(16);
4299 Infos.push_back(Elt: Info);
4300 return;
4301 }
4302 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col:
4303 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col_stride:
4304 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col_stride:
4305 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col:
4306 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row:
4307 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row_stride:
4308 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row_stride:
4309 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row:
4310 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col:
4311 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col_stride:
4312 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row:
4313 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row_stride:
4314 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col:
4315 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col_stride:
4316 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col_stride:
4317 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col:
4318 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row:
4319 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row_stride:
4320 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row_stride:
4321 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row:
4322 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col:
4323 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col_stride:
4324 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row:
4325 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row_stride: {
4326 Info.opc = ISD::INTRINSIC_W_CHAIN;
4327 Info.memVT = MVT::v2i32;
4328 Info.ptrVal = I.getArgOperand(i: 0);
4329 Info.offset = 0;
4330 Info.flags = MachineMemOperand::MOLoad;
4331 Info.align = Align(8);
4332 Infos.push_back(Elt: Info);
4333 return;
4334 }
4335
4336 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col:
4337 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col_stride:
4338 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col_stride:
4339 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col:
4340 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row:
4341 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row_stride:
4342 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row_stride:
4343 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row:
4344 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col:
4345 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col_stride:
4346 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row:
4347 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row_stride:
4348 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col:
4349 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col_stride:
4350 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row:
4351 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row_stride:
4352
4353 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col:
4354 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col_stride:
4355 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col_stride:
4356 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col:
4357 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row:
4358 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row_stride:
4359 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row_stride:
4360 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row:
4361 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col:
4362 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col_stride:
4363 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row:
4364 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row_stride:
4365 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col:
4366 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col_stride:
4367 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row:
4368 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride:
4369 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16:
4370 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16:
4371 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8:
4372 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64:
4373 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32:
4374 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64:
4375 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32: {
4376 Info.opc = ISD::INTRINSIC_W_CHAIN;
4377 Info.memVT = MVT::v4i32;
4378 Info.ptrVal = I.getArgOperand(i: 0);
4379 Info.offset = 0;
4380 Info.flags = MachineMemOperand::MOLoad;
4381 Info.align = Align(16);
4382 Infos.push_back(Elt: Info);
4383 return;
4384 }
4385
4386 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col:
4387 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col_stride:
4388 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col_stride:
4389 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col:
4390 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row:
4391 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row_stride:
4392 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row_stride:
4393 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row:
4394
4395 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col:
4396 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col_stride:
4397 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col_stride:
4398 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col:
4399 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row:
4400 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row_stride:
4401 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row_stride:
4402 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row:
4403 case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row:
4404 case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row_stride:
4405 case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col:
4406 case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col_stride:
4407 case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row:
4408 case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row_stride:
4409 case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row_stride:
4410 case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row:
4411 case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col:
4412 case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col_stride:
4413 case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col_stride:
4414 case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col:
4415 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16:
4416 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16:
4417 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64:
4418 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32: {
4419 Info.opc = ISD::INTRINSIC_W_CHAIN;
4420 Info.memVT = MVT::i32;
4421 Info.ptrVal = I.getArgOperand(i: 0);
4422 Info.offset = 0;
4423 Info.flags = MachineMemOperand::MOLoad;
4424 Info.align = Align(4);
4425 Infos.push_back(Elt: Info);
4426 return;
4427 }
4428
4429 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col:
4430 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row:
4431 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col_stride:
4432 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride:
4433 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col:
4434 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row:
4435 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col_stride:
4436 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row_stride:
4437 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col:
4438 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row:
4439 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col_stride:
4440 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row_stride: {
4441 Info.opc = ISD::INTRINSIC_W_CHAIN;
4442 Info.memVT = MVT::v4f16;
4443 Info.ptrVal = I.getArgOperand(i: 0);
4444 Info.offset = 0;
4445 Info.flags = MachineMemOperand::MOLoad;
4446 Info.align = Align(16);
4447 Infos.push_back(Elt: Info);
4448 return;
4449 }
4450
4451 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col:
4452 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row:
4453 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col_stride:
4454 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride:
4455 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col:
4456 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row:
4457 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col_stride:
4458 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row_stride:
4459 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col:
4460 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row:
4461 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col_stride:
4462 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride:
4463 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col:
4464 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row:
4465 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col_stride:
4466 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row_stride: {
4467 Info.opc = ISD::INTRINSIC_W_CHAIN;
4468 Info.memVT = MVT::v8f32;
4469 Info.ptrVal = I.getArgOperand(i: 0);
4470 Info.offset = 0;
4471 Info.flags = MachineMemOperand::MOLoad;
4472 Info.align = Align(16);
4473 Infos.push_back(Elt: Info);
4474 return;
4475 }
4476
4477 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col:
4478 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col_stride:
4479 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row:
4480 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row_stride:
4481
4482 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col:
4483 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col_stride:
4484 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row:
4485 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row_stride:
4486
4487 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col:
4488 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col_stride:
4489 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row:
4490 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row_stride:
4491 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col:
4492 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col_stride:
4493 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row:
4494 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row_stride:
4495 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col:
4496 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col_stride:
4497 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row:
4498 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row_stride: {
4499 Info.opc = ISD::INTRINSIC_W_CHAIN;
4500 Info.memVT = MVT::v8i32;
4501 Info.ptrVal = I.getArgOperand(i: 0);
4502 Info.offset = 0;
4503 Info.flags = MachineMemOperand::MOLoad;
4504 Info.align = Align(16);
4505 Infos.push_back(Elt: Info);
4506 return;
4507 }
4508
4509 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col:
4510 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col_stride:
4511 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row:
4512 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row_stride:
4513 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col:
4514 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col_stride:
4515 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row:
4516 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride:
4517 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16:
4518 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16:
4519 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8:
4520 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64:
4521 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32:
4522 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64:
4523 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32: {
4524 Info.opc = ISD::INTRINSIC_W_CHAIN;
4525 Info.memVT = MVT::v2i32;
4526 Info.ptrVal = I.getArgOperand(i: 0);
4527 Info.offset = 0;
4528 Info.flags = MachineMemOperand::MOLoad;
4529 Info.align = Align(8);
4530 Infos.push_back(Elt: Info);
4531 return;
4532 }
4533
4534 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col:
4535 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col_stride:
4536 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row:
4537 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row_stride:
4538
4539 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col:
4540 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col_stride:
4541 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row:
4542 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row_stride: {
4543 Info.opc = ISD::INTRINSIC_W_CHAIN;
4544 Info.memVT = MVT::f64;
4545 Info.ptrVal = I.getArgOperand(i: 0);
4546 Info.offset = 0;
4547 Info.flags = MachineMemOperand::MOLoad;
4548 Info.align = Align(8);
4549 Infos.push_back(Elt: Info);
4550 return;
4551 }
4552
4553 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col:
4554 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col_stride:
4555 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row:
4556 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row_stride: {
4557 Info.opc = ISD::INTRINSIC_W_CHAIN;
4558 Info.memVT = MVT::v2f64;
4559 Info.ptrVal = I.getArgOperand(i: 0);
4560 Info.offset = 0;
4561 Info.flags = MachineMemOperand::MOLoad;
4562 Info.align = Align(16);
4563 Infos.push_back(Elt: Info);
4564 return;
4565 }
4566
4567 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col:
4568 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row:
4569 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride:
4570 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride:
4571 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col:
4572 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row:
4573 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col_stride:
4574 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row_stride:
4575 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col:
4576 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row:
4577 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col_stride:
4578 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row_stride: {
4579 Info.opc = ISD::INTRINSIC_VOID;
4580 Info.memVT = MVT::v4f16;
4581 Info.ptrVal = I.getArgOperand(i: 0);
4582 Info.offset = 0;
4583 Info.flags = MachineMemOperand::MOStore;
4584 Info.align = Align(16);
4585 Infos.push_back(Elt: Info);
4586 return;
4587 }
4588
4589 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col:
4590 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row:
4591 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col_stride:
4592 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride:
4593 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col:
4594 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row:
4595 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col_stride:
4596 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row_stride:
4597 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col:
4598 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row:
4599 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col_stride:
4600 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride:
4601 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col:
4602 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row:
4603 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col_stride:
4604 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row_stride: {
4605 Info.opc = ISD::INTRINSIC_VOID;
4606 Info.memVT = MVT::v8f32;
4607 Info.ptrVal = I.getArgOperand(i: 0);
4608 Info.offset = 0;
4609 Info.flags = MachineMemOperand::MOStore;
4610 Info.align = Align(16);
4611 Infos.push_back(Elt: Info);
4612 return;
4613 }
4614
4615 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col:
4616 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col_stride:
4617 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row:
4618 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row_stride:
4619 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col:
4620 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col_stride:
4621 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row:
4622 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row_stride:
4623 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col:
4624 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col_stride:
4625 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row:
4626 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row_stride: {
4627 Info.opc = ISD::INTRINSIC_VOID;
4628 Info.memVT = MVT::v8i32;
4629 Info.ptrVal = I.getArgOperand(i: 0);
4630 Info.offset = 0;
4631 Info.flags = MachineMemOperand::MOStore;
4632 Info.align = Align(16);
4633 Infos.push_back(Elt: Info);
4634 return;
4635 }
4636
4637 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col:
4638 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col_stride:
4639 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row:
4640 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row_stride:
4641 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col:
4642 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride:
4643 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row:
4644 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride:
4645 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16:
4646 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16:
4647 case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8: {
4648 Info.opc = ISD::INTRINSIC_VOID;
4649 Info.memVT = MVT::v2i32;
4650 Info.ptrVal = I.getArgOperand(i: 0);
4651 Info.offset = 0;
4652 Info.flags = MachineMemOperand::MOStore;
4653 Info.align = Align(8);
4654 Infos.push_back(Elt: Info);
4655 return;
4656 }
4657
4658 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col:
4659 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col_stride:
4660 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row:
4661 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row_stride: {
4662 Info.opc = ISD::INTRINSIC_VOID;
4663 Info.memVT = MVT::v2f64;
4664 Info.ptrVal = I.getArgOperand(i: 0);
4665 Info.offset = 0;
4666 Info.flags = MachineMemOperand::MOStore;
4667 Info.align = Align(16);
4668 Infos.push_back(Elt: Info);
4669 return;
4670 }
4671
4672 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16:
4673 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16:
4674 case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8: {
4675 Info.opc = ISD::INTRINSIC_VOID;
4676 Info.memVT = MVT::i32;
4677 Info.ptrVal = I.getArgOperand(i: 0);
4678 Info.offset = 0;
4679 Info.flags = MachineMemOperand::MOStore;
4680 Info.align = Align(4);
4681 Infos.push_back(Elt: Info);
4682 return;
4683 }
4684
4685 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16:
4686 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16:
4687 case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8: {
4688 Info.opc = ISD::INTRINSIC_VOID;
4689 Info.memVT = MVT::v4i32;
4690 Info.ptrVal = I.getArgOperand(i: 0);
4691 Info.offset = 0;
4692 Info.flags = MachineMemOperand::MOStore;
4693 Info.align = Align(16);
4694 Infos.push_back(Elt: Info);
4695 return;
4696 }
4697
4698 case Intrinsic::nvvm_atomic_add_gen_f_cta:
4699 case Intrinsic::nvvm_atomic_add_gen_f_sys:
4700 case Intrinsic::nvvm_atomic_add_gen_i_cta:
4701 case Intrinsic::nvvm_atomic_add_gen_i_sys:
4702 case Intrinsic::nvvm_atomic_and_gen_i_cta:
4703 case Intrinsic::nvvm_atomic_and_gen_i_sys:
4704 case Intrinsic::nvvm_atomic_cas_gen_i_cta:
4705 case Intrinsic::nvvm_atomic_cas_gen_i_sys:
4706 case Intrinsic::nvvm_atomic_dec_gen_i_cta:
4707 case Intrinsic::nvvm_atomic_dec_gen_i_sys:
4708 case Intrinsic::nvvm_atomic_inc_gen_i_cta:
4709 case Intrinsic::nvvm_atomic_inc_gen_i_sys:
4710 case Intrinsic::nvvm_atomic_max_gen_i_cta:
4711 case Intrinsic::nvvm_atomic_max_gen_i_sys:
4712 case Intrinsic::nvvm_atomic_min_gen_i_cta:
4713 case Intrinsic::nvvm_atomic_min_gen_i_sys:
4714 case Intrinsic::nvvm_atomic_or_gen_i_cta:
4715 case Intrinsic::nvvm_atomic_or_gen_i_sys:
4716 case Intrinsic::nvvm_atomic_exch_gen_i_cta:
4717 case Intrinsic::nvvm_atomic_exch_gen_i_sys:
4718 case Intrinsic::nvvm_atomic_xor_gen_i_cta:
4719 case Intrinsic::nvvm_atomic_xor_gen_i_sys: {
4720 auto &DL = I.getDataLayout();
4721 Info.opc = ISD::INTRINSIC_W_CHAIN;
4722 Info.memVT = getValueType(DL, Ty: I.getType());
4723 Info.ptrVal = I.getArgOperand(i: 0);
4724 Info.offset = 0;
4725 Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOStore;
4726 Info.align.reset();
4727 Infos.push_back(Elt: Info);
4728 return;
4729 }
4730
4731 case Intrinsic::nvvm_prefetch_tensormap: {
4732 auto &DL = I.getDataLayout();
4733 Info.opc = ISD::INTRINSIC_VOID;
4734 Info.memVT = getPointerTy(DL);
4735 Info.ptrVal = I.getArgOperand(i: 0);
4736 Info.offset = 0;
4737 Info.flags =
4738 MachineMemOperand::MOLoad | MachineMemOperand::MODereferenceable;
4739 Info.align.reset();
4740 Infos.push_back(Elt: Info);
4741 return;
4742 }
4743
4744 case Intrinsic::nvvm_tensormap_replace_global_address:
4745 case Intrinsic::nvvm_tensormap_replace_global_stride: {
4746 Info.opc = ISD::INTRINSIC_VOID;
4747 Info.memVT = MVT::i64;
4748 Info.ptrVal = I.getArgOperand(i: 0);
4749 Info.offset = 0;
4750 Info.flags = MachineMemOperand::MOStore;
4751 Info.align.reset();
4752 Infos.push_back(Elt: Info);
4753 return;
4754 }
4755
4756 case Intrinsic::nvvm_tensormap_replace_rank:
4757 case Intrinsic::nvvm_tensormap_replace_box_dim:
4758 case Intrinsic::nvvm_tensormap_replace_global_dim:
4759 case Intrinsic::nvvm_tensormap_replace_element_stride:
4760 case Intrinsic::nvvm_tensormap_replace_elemtype:
4761 case Intrinsic::nvvm_tensormap_replace_interleave_layout:
4762 case Intrinsic::nvvm_tensormap_replace_swizzle_mode:
4763 case Intrinsic::nvvm_tensormap_replace_swizzle_atomicity:
4764 case Intrinsic::nvvm_tensormap_replace_fill_mode: {
4765 Info.opc = ISD::INTRINSIC_VOID;
4766 Info.memVT = MVT::i32;
4767 Info.ptrVal = I.getArgOperand(i: 0);
4768 Info.offset = 0;
4769 Info.flags = MachineMemOperand::MOStore;
4770 Info.align.reset();
4771 Infos.push_back(Elt: Info);
4772 return;
4773 }
4774
4775 case Intrinsic::nvvm_ldu_global_i:
4776 case Intrinsic::nvvm_ldu_global_f:
4777 case Intrinsic::nvvm_ldu_global_p: {
4778 Info.opc = ISD::INTRINSIC_W_CHAIN;
4779 Info.memVT = getValueType(DL: I.getDataLayout(), Ty: I.getType());
4780 Info.ptrVal = I.getArgOperand(i: 0);
4781 Info.offset = 0;
4782 Info.flags = MachineMemOperand::MOLoad;
4783 Info.align = cast<ConstantInt>(Val: I.getArgOperand(i: 1))->getMaybeAlignValue();
4784
4785 Infos.push_back(Elt: Info);
4786 return;
4787 }
4788 case Intrinsic::nvvm_tex_1d_v4f32_s32:
4789 case Intrinsic::nvvm_tex_1d_v4f32_f32:
4790 case Intrinsic::nvvm_tex_1d_level_v4f32_f32:
4791 case Intrinsic::nvvm_tex_1d_grad_v4f32_f32:
4792 case Intrinsic::nvvm_tex_1d_array_v4f32_s32:
4793 case Intrinsic::nvvm_tex_1d_array_v4f32_f32:
4794 case Intrinsic::nvvm_tex_1d_array_level_v4f32_f32:
4795 case Intrinsic::nvvm_tex_1d_array_grad_v4f32_f32:
4796 case Intrinsic::nvvm_tex_2d_v4f32_s32:
4797 case Intrinsic::nvvm_tex_2d_v4f32_f32:
4798 case Intrinsic::nvvm_tex_2d_level_v4f32_f32:
4799 case Intrinsic::nvvm_tex_2d_grad_v4f32_f32:
4800 case Intrinsic::nvvm_tex_2d_array_v4f32_s32:
4801 case Intrinsic::nvvm_tex_2d_array_v4f32_f32:
4802 case Intrinsic::nvvm_tex_2d_array_level_v4f32_f32:
4803 case Intrinsic::nvvm_tex_2d_array_grad_v4f32_f32:
4804 case Intrinsic::nvvm_tex_3d_v4f32_s32:
4805 case Intrinsic::nvvm_tex_3d_v4f32_f32:
4806 case Intrinsic::nvvm_tex_3d_level_v4f32_f32:
4807 case Intrinsic::nvvm_tex_3d_grad_v4f32_f32:
4808 case Intrinsic::nvvm_tex_cube_v4f32_f32:
4809 case Intrinsic::nvvm_tex_cube_level_v4f32_f32:
4810 case Intrinsic::nvvm_tex_cube_array_v4f32_f32:
4811 case Intrinsic::nvvm_tex_cube_array_level_v4f32_f32:
4812 case Intrinsic::nvvm_tld4_r_2d_v4f32_f32:
4813 case Intrinsic::nvvm_tld4_g_2d_v4f32_f32:
4814 case Intrinsic::nvvm_tld4_b_2d_v4f32_f32:
4815 case Intrinsic::nvvm_tld4_a_2d_v4f32_f32:
4816 case Intrinsic::nvvm_tex_unified_1d_v4f32_s32:
4817 case Intrinsic::nvvm_tex_unified_1d_v4f32_f32:
4818 case Intrinsic::nvvm_tex_unified_1d_level_v4f32_f32:
4819 case Intrinsic::nvvm_tex_unified_1d_grad_v4f32_f32:
4820 case Intrinsic::nvvm_tex_unified_1d_array_v4f32_s32:
4821 case Intrinsic::nvvm_tex_unified_1d_array_v4f32_f32:
4822 case Intrinsic::nvvm_tex_unified_1d_array_level_v4f32_f32:
4823 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4f32_f32:
4824 case Intrinsic::nvvm_tex_unified_2d_v4f32_s32:
4825 case Intrinsic::nvvm_tex_unified_2d_v4f32_f32:
4826 case Intrinsic::nvvm_tex_unified_2d_level_v4f32_f32:
4827 case Intrinsic::nvvm_tex_unified_2d_grad_v4f32_f32:
4828 case Intrinsic::nvvm_tex_unified_2d_array_v4f32_s32:
4829 case Intrinsic::nvvm_tex_unified_2d_array_v4f32_f32:
4830 case Intrinsic::nvvm_tex_unified_2d_array_level_v4f32_f32:
4831 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4f32_f32:
4832 case Intrinsic::nvvm_tex_unified_3d_v4f32_s32:
4833 case Intrinsic::nvvm_tex_unified_3d_v4f32_f32:
4834 case Intrinsic::nvvm_tex_unified_3d_level_v4f32_f32:
4835 case Intrinsic::nvvm_tex_unified_3d_grad_v4f32_f32:
4836 case Intrinsic::nvvm_tex_unified_cube_v4f32_f32:
4837 case Intrinsic::nvvm_tex_unified_cube_level_v4f32_f32:
4838 case Intrinsic::nvvm_tex_unified_cube_array_v4f32_f32:
4839 case Intrinsic::nvvm_tex_unified_cube_array_level_v4f32_f32:
4840 case Intrinsic::nvvm_tex_unified_cube_grad_v4f32_f32:
4841 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4f32_f32:
4842 case Intrinsic::nvvm_tld4_unified_r_2d_v4f32_f32:
4843 case Intrinsic::nvvm_tld4_unified_g_2d_v4f32_f32:
4844 case Intrinsic::nvvm_tld4_unified_b_2d_v4f32_f32:
4845 case Intrinsic::nvvm_tld4_unified_a_2d_v4f32_f32:
4846 Info.opc = ISD::INTRINSIC_W_CHAIN;
4847 Info.memVT = MVT::v4f32;
4848 Info.ptrVal = nullptr;
4849 Info.offset = 0;
4850 Info.flags = MachineMemOperand::MOLoad;
4851 Info.align = Align(16);
4852 Infos.push_back(Elt: Info);
4853 return;
4854
4855 case Intrinsic::nvvm_tex_1d_v4s32_s32:
4856 case Intrinsic::nvvm_tex_1d_v4s32_f32:
4857 case Intrinsic::nvvm_tex_1d_level_v4s32_f32:
4858 case Intrinsic::nvvm_tex_1d_grad_v4s32_f32:
4859 case Intrinsic::nvvm_tex_1d_array_v4s32_s32:
4860 case Intrinsic::nvvm_tex_1d_array_v4s32_f32:
4861 case Intrinsic::nvvm_tex_1d_array_level_v4s32_f32:
4862 case Intrinsic::nvvm_tex_1d_array_grad_v4s32_f32:
4863 case Intrinsic::nvvm_tex_2d_v4s32_s32:
4864 case Intrinsic::nvvm_tex_2d_v4s32_f32:
4865 case Intrinsic::nvvm_tex_2d_level_v4s32_f32:
4866 case Intrinsic::nvvm_tex_2d_grad_v4s32_f32:
4867 case Intrinsic::nvvm_tex_2d_array_v4s32_s32:
4868 case Intrinsic::nvvm_tex_2d_array_v4s32_f32:
4869 case Intrinsic::nvvm_tex_2d_array_level_v4s32_f32:
4870 case Intrinsic::nvvm_tex_2d_array_grad_v4s32_f32:
4871 case Intrinsic::nvvm_tex_3d_v4s32_s32:
4872 case Intrinsic::nvvm_tex_3d_v4s32_f32:
4873 case Intrinsic::nvvm_tex_3d_level_v4s32_f32:
4874 case Intrinsic::nvvm_tex_3d_grad_v4s32_f32:
4875 case Intrinsic::nvvm_tex_cube_v4s32_f32:
4876 case Intrinsic::nvvm_tex_cube_level_v4s32_f32:
4877 case Intrinsic::nvvm_tex_cube_array_v4s32_f32:
4878 case Intrinsic::nvvm_tex_cube_array_level_v4s32_f32:
4879 case Intrinsic::nvvm_tex_cube_v4u32_f32:
4880 case Intrinsic::nvvm_tex_cube_level_v4u32_f32:
4881 case Intrinsic::nvvm_tex_cube_array_v4u32_f32:
4882 case Intrinsic::nvvm_tex_cube_array_level_v4u32_f32:
4883 case Intrinsic::nvvm_tex_1d_v4u32_s32:
4884 case Intrinsic::nvvm_tex_1d_v4u32_f32:
4885 case Intrinsic::nvvm_tex_1d_level_v4u32_f32:
4886 case Intrinsic::nvvm_tex_1d_grad_v4u32_f32:
4887 case Intrinsic::nvvm_tex_1d_array_v4u32_s32:
4888 case Intrinsic::nvvm_tex_1d_array_v4u32_f32:
4889 case Intrinsic::nvvm_tex_1d_array_level_v4u32_f32:
4890 case Intrinsic::nvvm_tex_1d_array_grad_v4u32_f32:
4891 case Intrinsic::nvvm_tex_2d_v4u32_s32:
4892 case Intrinsic::nvvm_tex_2d_v4u32_f32:
4893 case Intrinsic::nvvm_tex_2d_level_v4u32_f32:
4894 case Intrinsic::nvvm_tex_2d_grad_v4u32_f32:
4895 case Intrinsic::nvvm_tex_2d_array_v4u32_s32:
4896 case Intrinsic::nvvm_tex_2d_array_v4u32_f32:
4897 case Intrinsic::nvvm_tex_2d_array_level_v4u32_f32:
4898 case Intrinsic::nvvm_tex_2d_array_grad_v4u32_f32:
4899 case Intrinsic::nvvm_tex_3d_v4u32_s32:
4900 case Intrinsic::nvvm_tex_3d_v4u32_f32:
4901 case Intrinsic::nvvm_tex_3d_level_v4u32_f32:
4902 case Intrinsic::nvvm_tex_3d_grad_v4u32_f32:
4903 case Intrinsic::nvvm_tld4_r_2d_v4s32_f32:
4904 case Intrinsic::nvvm_tld4_g_2d_v4s32_f32:
4905 case Intrinsic::nvvm_tld4_b_2d_v4s32_f32:
4906 case Intrinsic::nvvm_tld4_a_2d_v4s32_f32:
4907 case Intrinsic::nvvm_tld4_r_2d_v4u32_f32:
4908 case Intrinsic::nvvm_tld4_g_2d_v4u32_f32:
4909 case Intrinsic::nvvm_tld4_b_2d_v4u32_f32:
4910 case Intrinsic::nvvm_tld4_a_2d_v4u32_f32:
4911 case Intrinsic::nvvm_tex_unified_1d_v4s32_s32:
4912 case Intrinsic::nvvm_tex_unified_1d_v4s32_f32:
4913 case Intrinsic::nvvm_tex_unified_1d_level_v4s32_f32:
4914 case Intrinsic::nvvm_tex_unified_1d_grad_v4s32_f32:
4915 case Intrinsic::nvvm_tex_unified_1d_array_v4s32_s32:
4916 case Intrinsic::nvvm_tex_unified_1d_array_v4s32_f32:
4917 case Intrinsic::nvvm_tex_unified_1d_array_level_v4s32_f32:
4918 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4s32_f32:
4919 case Intrinsic::nvvm_tex_unified_2d_v4s32_s32:
4920 case Intrinsic::nvvm_tex_unified_2d_v4s32_f32:
4921 case Intrinsic::nvvm_tex_unified_2d_level_v4s32_f32:
4922 case Intrinsic::nvvm_tex_unified_2d_grad_v4s32_f32:
4923 case Intrinsic::nvvm_tex_unified_2d_array_v4s32_s32:
4924 case Intrinsic::nvvm_tex_unified_2d_array_v4s32_f32:
4925 case Intrinsic::nvvm_tex_unified_2d_array_level_v4s32_f32:
4926 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4s32_f32:
4927 case Intrinsic::nvvm_tex_unified_3d_v4s32_s32:
4928 case Intrinsic::nvvm_tex_unified_3d_v4s32_f32:
4929 case Intrinsic::nvvm_tex_unified_3d_level_v4s32_f32:
4930 case Intrinsic::nvvm_tex_unified_3d_grad_v4s32_f32:
4931 case Intrinsic::nvvm_tex_unified_1d_v4u32_s32:
4932 case Intrinsic::nvvm_tex_unified_1d_v4u32_f32:
4933 case Intrinsic::nvvm_tex_unified_1d_level_v4u32_f32:
4934 case Intrinsic::nvvm_tex_unified_1d_grad_v4u32_f32:
4935 case Intrinsic::nvvm_tex_unified_1d_array_v4u32_s32:
4936 case Intrinsic::nvvm_tex_unified_1d_array_v4u32_f32:
4937 case Intrinsic::nvvm_tex_unified_1d_array_level_v4u32_f32:
4938 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4u32_f32:
4939 case Intrinsic::nvvm_tex_unified_2d_v4u32_s32:
4940 case Intrinsic::nvvm_tex_unified_2d_v4u32_f32:
4941 case Intrinsic::nvvm_tex_unified_2d_level_v4u32_f32:
4942 case Intrinsic::nvvm_tex_unified_2d_grad_v4u32_f32:
4943 case Intrinsic::nvvm_tex_unified_2d_array_v4u32_s32:
4944 case Intrinsic::nvvm_tex_unified_2d_array_v4u32_f32:
4945 case Intrinsic::nvvm_tex_unified_2d_array_level_v4u32_f32:
4946 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4u32_f32:
4947 case Intrinsic::nvvm_tex_unified_3d_v4u32_s32:
4948 case Intrinsic::nvvm_tex_unified_3d_v4u32_f32:
4949 case Intrinsic::nvvm_tex_unified_3d_level_v4u32_f32:
4950 case Intrinsic::nvvm_tex_unified_3d_grad_v4u32_f32:
4951 case Intrinsic::nvvm_tex_unified_cube_v4s32_f32:
4952 case Intrinsic::nvvm_tex_unified_cube_level_v4s32_f32:
4953 case Intrinsic::nvvm_tex_unified_cube_array_v4s32_f32:
4954 case Intrinsic::nvvm_tex_unified_cube_array_level_v4s32_f32:
4955 case Intrinsic::nvvm_tex_unified_cube_v4u32_f32:
4956 case Intrinsic::nvvm_tex_unified_cube_level_v4u32_f32:
4957 case Intrinsic::nvvm_tex_unified_cube_array_v4u32_f32:
4958 case Intrinsic::nvvm_tex_unified_cube_array_level_v4u32_f32:
4959 case Intrinsic::nvvm_tex_unified_cube_grad_v4s32_f32:
4960 case Intrinsic::nvvm_tex_unified_cube_grad_v4u32_f32:
4961 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4s32_f32:
4962 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4u32_f32:
4963 case Intrinsic::nvvm_tld4_unified_r_2d_v4s32_f32:
4964 case Intrinsic::nvvm_tld4_unified_g_2d_v4s32_f32:
4965 case Intrinsic::nvvm_tld4_unified_b_2d_v4s32_f32:
4966 case Intrinsic::nvvm_tld4_unified_a_2d_v4s32_f32:
4967 case Intrinsic::nvvm_tld4_unified_r_2d_v4u32_f32:
4968 case Intrinsic::nvvm_tld4_unified_g_2d_v4u32_f32:
4969 case Intrinsic::nvvm_tld4_unified_b_2d_v4u32_f32:
4970 case Intrinsic::nvvm_tld4_unified_a_2d_v4u32_f32:
4971 Info.opc = ISD::INTRINSIC_W_CHAIN;
4972 Info.memVT = MVT::v4i32;
4973 Info.ptrVal = nullptr;
4974 Info.offset = 0;
4975 Info.flags = MachineMemOperand::MOLoad;
4976 Info.align = Align(16);
4977 Infos.push_back(Elt: Info);
4978 return;
4979
4980 case Intrinsic::nvvm_suld_1d_i8_clamp:
4981 case Intrinsic::nvvm_suld_1d_v2i8_clamp:
4982 case Intrinsic::nvvm_suld_1d_v4i8_clamp:
4983 case Intrinsic::nvvm_suld_1d_array_i8_clamp:
4984 case Intrinsic::nvvm_suld_1d_array_v2i8_clamp:
4985 case Intrinsic::nvvm_suld_1d_array_v4i8_clamp:
4986 case Intrinsic::nvvm_suld_2d_i8_clamp:
4987 case Intrinsic::nvvm_suld_2d_v2i8_clamp:
4988 case Intrinsic::nvvm_suld_2d_v4i8_clamp:
4989 case Intrinsic::nvvm_suld_2d_array_i8_clamp:
4990 case Intrinsic::nvvm_suld_2d_array_v2i8_clamp:
4991 case Intrinsic::nvvm_suld_2d_array_v4i8_clamp:
4992 case Intrinsic::nvvm_suld_3d_i8_clamp:
4993 case Intrinsic::nvvm_suld_3d_v2i8_clamp:
4994 case Intrinsic::nvvm_suld_3d_v4i8_clamp:
4995 case Intrinsic::nvvm_suld_1d_i8_trap:
4996 case Intrinsic::nvvm_suld_1d_v2i8_trap:
4997 case Intrinsic::nvvm_suld_1d_v4i8_trap:
4998 case Intrinsic::nvvm_suld_1d_array_i8_trap:
4999 case Intrinsic::nvvm_suld_1d_array_v2i8_trap:
5000 case Intrinsic::nvvm_suld_1d_array_v4i8_trap:
5001 case Intrinsic::nvvm_suld_2d_i8_trap:
5002 case Intrinsic::nvvm_suld_2d_v2i8_trap:
5003 case Intrinsic::nvvm_suld_2d_v4i8_trap:
5004 case Intrinsic::nvvm_suld_2d_array_i8_trap:
5005 case Intrinsic::nvvm_suld_2d_array_v2i8_trap:
5006 case Intrinsic::nvvm_suld_2d_array_v4i8_trap:
5007 case Intrinsic::nvvm_suld_3d_i8_trap:
5008 case Intrinsic::nvvm_suld_3d_v2i8_trap:
5009 case Intrinsic::nvvm_suld_3d_v4i8_trap:
5010 case Intrinsic::nvvm_suld_1d_i8_zero:
5011 case Intrinsic::nvvm_suld_1d_v2i8_zero:
5012 case Intrinsic::nvvm_suld_1d_v4i8_zero:
5013 case Intrinsic::nvvm_suld_1d_array_i8_zero:
5014 case Intrinsic::nvvm_suld_1d_array_v2i8_zero:
5015 case Intrinsic::nvvm_suld_1d_array_v4i8_zero:
5016 case Intrinsic::nvvm_suld_2d_i8_zero:
5017 case Intrinsic::nvvm_suld_2d_v2i8_zero:
5018 case Intrinsic::nvvm_suld_2d_v4i8_zero:
5019 case Intrinsic::nvvm_suld_2d_array_i8_zero:
5020 case Intrinsic::nvvm_suld_2d_array_v2i8_zero:
5021 case Intrinsic::nvvm_suld_2d_array_v4i8_zero:
5022 case Intrinsic::nvvm_suld_3d_i8_zero:
5023 case Intrinsic::nvvm_suld_3d_v2i8_zero:
5024 case Intrinsic::nvvm_suld_3d_v4i8_zero:
5025 Info.opc = ISD::INTRINSIC_W_CHAIN;
5026 Info.memVT = MVT::i8;
5027 Info.ptrVal = nullptr;
5028 Info.offset = 0;
5029 Info.flags = MachineMemOperand::MOLoad;
5030 Info.align = Align(16);
5031 Infos.push_back(Elt: Info);
5032 return;
5033
5034 case Intrinsic::nvvm_suld_1d_i16_clamp:
5035 case Intrinsic::nvvm_suld_1d_v2i16_clamp:
5036 case Intrinsic::nvvm_suld_1d_v4i16_clamp:
5037 case Intrinsic::nvvm_suld_1d_array_i16_clamp:
5038 case Intrinsic::nvvm_suld_1d_array_v2i16_clamp:
5039 case Intrinsic::nvvm_suld_1d_array_v4i16_clamp:
5040 case Intrinsic::nvvm_suld_2d_i16_clamp:
5041 case Intrinsic::nvvm_suld_2d_v2i16_clamp:
5042 case Intrinsic::nvvm_suld_2d_v4i16_clamp:
5043 case Intrinsic::nvvm_suld_2d_array_i16_clamp:
5044 case Intrinsic::nvvm_suld_2d_array_v2i16_clamp:
5045 case Intrinsic::nvvm_suld_2d_array_v4i16_clamp:
5046 case Intrinsic::nvvm_suld_3d_i16_clamp:
5047 case Intrinsic::nvvm_suld_3d_v2i16_clamp:
5048 case Intrinsic::nvvm_suld_3d_v4i16_clamp:
5049 case Intrinsic::nvvm_suld_1d_i16_trap:
5050 case Intrinsic::nvvm_suld_1d_v2i16_trap:
5051 case Intrinsic::nvvm_suld_1d_v4i16_trap:
5052 case Intrinsic::nvvm_suld_1d_array_i16_trap:
5053 case Intrinsic::nvvm_suld_1d_array_v2i16_trap:
5054 case Intrinsic::nvvm_suld_1d_array_v4i16_trap:
5055 case Intrinsic::nvvm_suld_2d_i16_trap:
5056 case Intrinsic::nvvm_suld_2d_v2i16_trap:
5057 case Intrinsic::nvvm_suld_2d_v4i16_trap:
5058 case Intrinsic::nvvm_suld_2d_array_i16_trap:
5059 case Intrinsic::nvvm_suld_2d_array_v2i16_trap:
5060 case Intrinsic::nvvm_suld_2d_array_v4i16_trap:
5061 case Intrinsic::nvvm_suld_3d_i16_trap:
5062 case Intrinsic::nvvm_suld_3d_v2i16_trap:
5063 case Intrinsic::nvvm_suld_3d_v4i16_trap:
5064 case Intrinsic::nvvm_suld_1d_i16_zero:
5065 case Intrinsic::nvvm_suld_1d_v2i16_zero:
5066 case Intrinsic::nvvm_suld_1d_v4i16_zero:
5067 case Intrinsic::nvvm_suld_1d_array_i16_zero:
5068 case Intrinsic::nvvm_suld_1d_array_v2i16_zero:
5069 case Intrinsic::nvvm_suld_1d_array_v4i16_zero:
5070 case Intrinsic::nvvm_suld_2d_i16_zero:
5071 case Intrinsic::nvvm_suld_2d_v2i16_zero:
5072 case Intrinsic::nvvm_suld_2d_v4i16_zero:
5073 case Intrinsic::nvvm_suld_2d_array_i16_zero:
5074 case Intrinsic::nvvm_suld_2d_array_v2i16_zero:
5075 case Intrinsic::nvvm_suld_2d_array_v4i16_zero:
5076 case Intrinsic::nvvm_suld_3d_i16_zero:
5077 case Intrinsic::nvvm_suld_3d_v2i16_zero:
5078 case Intrinsic::nvvm_suld_3d_v4i16_zero:
5079 Info.opc = ISD::INTRINSIC_W_CHAIN;
5080 Info.memVT = MVT::i16;
5081 Info.ptrVal = nullptr;
5082 Info.offset = 0;
5083 Info.flags = MachineMemOperand::MOLoad;
5084 Info.align = Align(16);
5085 Infos.push_back(Elt: Info);
5086 return;
5087
5088 case Intrinsic::nvvm_suld_1d_i32_clamp:
5089 case Intrinsic::nvvm_suld_1d_v2i32_clamp:
5090 case Intrinsic::nvvm_suld_1d_v4i32_clamp:
5091 case Intrinsic::nvvm_suld_1d_array_i32_clamp:
5092 case Intrinsic::nvvm_suld_1d_array_v2i32_clamp:
5093 case Intrinsic::nvvm_suld_1d_array_v4i32_clamp:
5094 case Intrinsic::nvvm_suld_2d_i32_clamp:
5095 case Intrinsic::nvvm_suld_2d_v2i32_clamp:
5096 case Intrinsic::nvvm_suld_2d_v4i32_clamp:
5097 case Intrinsic::nvvm_suld_2d_array_i32_clamp:
5098 case Intrinsic::nvvm_suld_2d_array_v2i32_clamp:
5099 case Intrinsic::nvvm_suld_2d_array_v4i32_clamp:
5100 case Intrinsic::nvvm_suld_3d_i32_clamp:
5101 case Intrinsic::nvvm_suld_3d_v2i32_clamp:
5102 case Intrinsic::nvvm_suld_3d_v4i32_clamp:
5103 case Intrinsic::nvvm_suld_1d_i32_trap:
5104 case Intrinsic::nvvm_suld_1d_v2i32_trap:
5105 case Intrinsic::nvvm_suld_1d_v4i32_trap:
5106 case Intrinsic::nvvm_suld_1d_array_i32_trap:
5107 case Intrinsic::nvvm_suld_1d_array_v2i32_trap:
5108 case Intrinsic::nvvm_suld_1d_array_v4i32_trap:
5109 case Intrinsic::nvvm_suld_2d_i32_trap:
5110 case Intrinsic::nvvm_suld_2d_v2i32_trap:
5111 case Intrinsic::nvvm_suld_2d_v4i32_trap:
5112 case Intrinsic::nvvm_suld_2d_array_i32_trap:
5113 case Intrinsic::nvvm_suld_2d_array_v2i32_trap:
5114 case Intrinsic::nvvm_suld_2d_array_v4i32_trap:
5115 case Intrinsic::nvvm_suld_3d_i32_trap:
5116 case Intrinsic::nvvm_suld_3d_v2i32_trap:
5117 case Intrinsic::nvvm_suld_3d_v4i32_trap:
5118 case Intrinsic::nvvm_suld_1d_i32_zero:
5119 case Intrinsic::nvvm_suld_1d_v2i32_zero:
5120 case Intrinsic::nvvm_suld_1d_v4i32_zero:
5121 case Intrinsic::nvvm_suld_1d_array_i32_zero:
5122 case Intrinsic::nvvm_suld_1d_array_v2i32_zero:
5123 case Intrinsic::nvvm_suld_1d_array_v4i32_zero:
5124 case Intrinsic::nvvm_suld_2d_i32_zero:
5125 case Intrinsic::nvvm_suld_2d_v2i32_zero:
5126 case Intrinsic::nvvm_suld_2d_v4i32_zero:
5127 case Intrinsic::nvvm_suld_2d_array_i32_zero:
5128 case Intrinsic::nvvm_suld_2d_array_v2i32_zero:
5129 case Intrinsic::nvvm_suld_2d_array_v4i32_zero:
5130 case Intrinsic::nvvm_suld_3d_i32_zero:
5131 case Intrinsic::nvvm_suld_3d_v2i32_zero:
5132 case Intrinsic::nvvm_suld_3d_v4i32_zero:
5133 Info.opc = ISD::INTRINSIC_W_CHAIN;
5134 Info.memVT = MVT::i32;
5135 Info.ptrVal = nullptr;
5136 Info.offset = 0;
5137 Info.flags = MachineMemOperand::MOLoad;
5138 Info.align = Align(16);
5139 Infos.push_back(Elt: Info);
5140 return;
5141
5142 case Intrinsic::nvvm_suld_1d_i64_clamp:
5143 case Intrinsic::nvvm_suld_1d_v2i64_clamp:
5144 case Intrinsic::nvvm_suld_1d_array_i64_clamp:
5145 case Intrinsic::nvvm_suld_1d_array_v2i64_clamp:
5146 case Intrinsic::nvvm_suld_2d_i64_clamp:
5147 case Intrinsic::nvvm_suld_2d_v2i64_clamp:
5148 case Intrinsic::nvvm_suld_2d_array_i64_clamp:
5149 case Intrinsic::nvvm_suld_2d_array_v2i64_clamp:
5150 case Intrinsic::nvvm_suld_3d_i64_clamp:
5151 case Intrinsic::nvvm_suld_3d_v2i64_clamp:
5152 case Intrinsic::nvvm_suld_1d_i64_trap:
5153 case Intrinsic::nvvm_suld_1d_v2i64_trap:
5154 case Intrinsic::nvvm_suld_1d_array_i64_trap:
5155 case Intrinsic::nvvm_suld_1d_array_v2i64_trap:
5156 case Intrinsic::nvvm_suld_2d_i64_trap:
5157 case Intrinsic::nvvm_suld_2d_v2i64_trap:
5158 case Intrinsic::nvvm_suld_2d_array_i64_trap:
5159 case Intrinsic::nvvm_suld_2d_array_v2i64_trap:
5160 case Intrinsic::nvvm_suld_3d_i64_trap:
5161 case Intrinsic::nvvm_suld_3d_v2i64_trap:
5162 case Intrinsic::nvvm_suld_1d_i64_zero:
5163 case Intrinsic::nvvm_suld_1d_v2i64_zero:
5164 case Intrinsic::nvvm_suld_1d_array_i64_zero:
5165 case Intrinsic::nvvm_suld_1d_array_v2i64_zero:
5166 case Intrinsic::nvvm_suld_2d_i64_zero:
5167 case Intrinsic::nvvm_suld_2d_v2i64_zero:
5168 case Intrinsic::nvvm_suld_2d_array_i64_zero:
5169 case Intrinsic::nvvm_suld_2d_array_v2i64_zero:
5170 case Intrinsic::nvvm_suld_3d_i64_zero:
5171 case Intrinsic::nvvm_suld_3d_v2i64_zero:
5172 Info.opc = ISD::INTRINSIC_W_CHAIN;
5173 Info.memVT = MVT::i64;
5174 Info.ptrVal = nullptr;
5175 Info.offset = 0;
5176 Info.flags = MachineMemOperand::MOLoad;
5177 Info.align = Align(16);
5178 Infos.push_back(Elt: Info);
5179 return;
5180
5181 case Intrinsic::nvvm_tcgen05_ld_16x64b_x1:
5182 case Intrinsic::nvvm_tcgen05_ld_32x32b_x1:
5183 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x1: {
5184 Info.opc = ISD::INTRINSIC_W_CHAIN;
5185 Info.memVT = MVT::v1i32;
5186 Info.ptrVal = I.getArgOperand(i: 0);
5187 Info.offset = 0;
5188 Info.flags = MachineMemOperand::MOLoad;
5189 Info.align.reset();
5190 Infos.push_back(Elt: Info);
5191 return;
5192 }
5193
5194 case Intrinsic::nvvm_tcgen05_ld_16x64b_x2:
5195 case Intrinsic::nvvm_tcgen05_ld_16x128b_x1:
5196 case Intrinsic::nvvm_tcgen05_ld_32x32b_x2:
5197 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x2:
5198 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x2_i32:
5199 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x2_i32: {
5200 Info.opc = ISD::INTRINSIC_W_CHAIN;
5201 Info.memVT = MVT::v2i32;
5202 Info.ptrVal = I.getArgOperand(i: 0);
5203 Info.offset = 0;
5204 Info.flags = MachineMemOperand::MOLoad;
5205 Info.align.reset();
5206 Infos.push_back(Elt: Info);
5207 return;
5208 }
5209
5210 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x2_f32:
5211 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x2_f32: {
5212 Info.opc = ISD::INTRINSIC_W_CHAIN;
5213 Info.memVT = MVT::v2f32;
5214 Info.ptrVal = I.getArgOperand(i: 0);
5215 Info.offset = 0;
5216 Info.flags = MachineMemOperand::MOLoad;
5217 Info.align.reset();
5218 Infos.push_back(Elt: Info);
5219 return;
5220 }
5221
5222 case Intrinsic::nvvm_tcgen05_ld_16x64b_x4:
5223 case Intrinsic::nvvm_tcgen05_ld_16x128b_x2:
5224 case Intrinsic::nvvm_tcgen05_ld_32x32b_x4:
5225 case Intrinsic::nvvm_tcgen05_ld_16x256b_x1:
5226 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x4:
5227 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x4_i32:
5228 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x4_i32: {
5229 Info.opc = ISD::INTRINSIC_W_CHAIN;
5230 Info.memVT = MVT::v4i32;
5231 Info.ptrVal = I.getArgOperand(i: 0);
5232 Info.offset = 0;
5233 Info.flags = MachineMemOperand::MOLoad;
5234 Info.align.reset();
5235 Infos.push_back(Elt: Info);
5236 return;
5237 }
5238
5239 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x4_f32:
5240 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x4_f32: {
5241 Info.opc = ISD::INTRINSIC_W_CHAIN;
5242 Info.memVT = MVT::v4f32;
5243 Info.ptrVal = I.getArgOperand(i: 0);
5244 Info.offset = 0;
5245 Info.flags = MachineMemOperand::MOLoad;
5246 Info.align.reset();
5247 Infos.push_back(Elt: Info);
5248 return;
5249 }
5250
5251 case Intrinsic::nvvm_tcgen05_ld_16x64b_x8:
5252 case Intrinsic::nvvm_tcgen05_ld_16x128b_x4:
5253 case Intrinsic::nvvm_tcgen05_ld_16x256b_x2:
5254 case Intrinsic::nvvm_tcgen05_ld_32x32b_x8:
5255 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x8:
5256 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x8_i32:
5257 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x8_i32: {
5258 Info.opc = ISD::INTRINSIC_W_CHAIN;
5259 Info.memVT = MVT::v8i32;
5260 Info.ptrVal = I.getArgOperand(i: 0);
5261 Info.offset = 0;
5262 Info.flags = MachineMemOperand::MOLoad;
5263 Info.align.reset();
5264 Infos.push_back(Elt: Info);
5265 return;
5266 }
5267
5268 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x8_f32:
5269 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x8_f32: {
5270 Info.opc = ISD::INTRINSIC_W_CHAIN;
5271 Info.memVT = MVT::v8f32;
5272 Info.ptrVal = I.getArgOperand(i: 0);
5273 Info.offset = 0;
5274 Info.flags = MachineMemOperand::MOLoad;
5275 Info.align.reset();
5276 Infos.push_back(Elt: Info);
5277 return;
5278 }
5279
5280 case Intrinsic::nvvm_tcgen05_ld_16x64b_x16:
5281 case Intrinsic::nvvm_tcgen05_ld_16x128b_x8:
5282 case Intrinsic::nvvm_tcgen05_ld_16x256b_x4:
5283 case Intrinsic::nvvm_tcgen05_ld_32x32b_x16:
5284 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x16:
5285 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x16_i32:
5286 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x16_i32: {
5287 Info.opc = ISD::INTRINSIC_W_CHAIN;
5288 Info.memVT = MVT::v16i32;
5289 Info.ptrVal = I.getArgOperand(i: 0);
5290 Info.offset = 0;
5291 Info.flags = MachineMemOperand::MOLoad;
5292 Info.align.reset();
5293 Infos.push_back(Elt: Info);
5294 return;
5295 }
5296
5297 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x16_f32:
5298 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x16_f32: {
5299 Info.opc = ISD::INTRINSIC_W_CHAIN;
5300 Info.memVT = MVT::v16f32;
5301 Info.ptrVal = I.getArgOperand(i: 0);
5302 Info.offset = 0;
5303 Info.flags = MachineMemOperand::MOLoad;
5304 Info.align.reset();
5305 Infos.push_back(Elt: Info);
5306 return;
5307 }
5308
5309 case Intrinsic::nvvm_tcgen05_ld_16x64b_x32:
5310 case Intrinsic::nvvm_tcgen05_ld_16x128b_x16:
5311 case Intrinsic::nvvm_tcgen05_ld_16x256b_x8:
5312 case Intrinsic::nvvm_tcgen05_ld_32x32b_x32:
5313 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x32:
5314 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x32_i32:
5315 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x32_i32: {
5316 Info.opc = ISD::INTRINSIC_W_CHAIN;
5317 Info.memVT = MVT::v32i32;
5318 Info.ptrVal = I.getArgOperand(i: 0);
5319 Info.offset = 0;
5320 Info.flags = MachineMemOperand::MOLoad;
5321 Info.align.reset();
5322 Infos.push_back(Elt: Info);
5323 return;
5324 }
5325
5326 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x32_f32:
5327 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x32_f32: {
5328 Info.opc = ISD::INTRINSIC_W_CHAIN;
5329 Info.memVT = MVT::v32f32;
5330 Info.ptrVal = I.getArgOperand(i: 0);
5331 Info.offset = 0;
5332 Info.flags = MachineMemOperand::MOLoad;
5333 Info.align.reset();
5334 Infos.push_back(Elt: Info);
5335 return;
5336 }
5337
5338 case Intrinsic::nvvm_tcgen05_ld_16x64b_x64:
5339 case Intrinsic::nvvm_tcgen05_ld_16x128b_x32:
5340 case Intrinsic::nvvm_tcgen05_ld_16x256b_x16:
5341 case Intrinsic::nvvm_tcgen05_ld_32x32b_x64:
5342 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x64:
5343 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x64_i32:
5344 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x64_i32: {
5345 Info.opc = ISD::INTRINSIC_W_CHAIN;
5346 Info.memVT = MVT::v64i32;
5347 Info.ptrVal = I.getArgOperand(i: 0);
5348 Info.offset = 0;
5349 Info.flags = MachineMemOperand::MOLoad;
5350 Info.align.reset();
5351 Infos.push_back(Elt: Info);
5352 return;
5353 }
5354
5355 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x64_f32:
5356 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x64_f32: {
5357 Info.opc = ISD::INTRINSIC_W_CHAIN;
5358 Info.memVT = MVT::v64f32;
5359 Info.ptrVal = I.getArgOperand(i: 0);
5360 Info.offset = 0;
5361 Info.flags = MachineMemOperand::MOLoad;
5362 Info.align.reset();
5363 Infos.push_back(Elt: Info);
5364 return;
5365 }
5366
5367 case Intrinsic::nvvm_tcgen05_ld_16x64b_x128:
5368 case Intrinsic::nvvm_tcgen05_ld_16x128b_x64:
5369 case Intrinsic::nvvm_tcgen05_ld_16x256b_x32:
5370 case Intrinsic::nvvm_tcgen05_ld_32x32b_x128:
5371 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x128:
5372 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x128_i32:
5373 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x128_i32: {
5374 Info.opc = ISD::INTRINSIC_W_CHAIN;
5375 Info.memVT = MVT::v128i32;
5376 Info.ptrVal = I.getArgOperand(i: 0);
5377 Info.offset = 0;
5378 Info.flags = MachineMemOperand::MOLoad;
5379 Info.align.reset();
5380 Infos.push_back(Elt: Info);
5381 return;
5382 }
5383
5384 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x128_f32:
5385 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x128_f32: {
5386 Info.opc = ISD::INTRINSIC_W_CHAIN;
5387 Info.memVT = MVT::v128f32;
5388 Info.ptrVal = I.getArgOperand(i: 0);
5389 Info.offset = 0;
5390 Info.flags = MachineMemOperand::MOLoad;
5391 Info.align.reset();
5392 Infos.push_back(Elt: Info);
5393 return;
5394 }
5395
5396 case Intrinsic::nvvm_tcgen05_st_16x64b_x1:
5397 case Intrinsic::nvvm_tcgen05_st_32x32b_x1:
5398 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x1: {
5399 Info.opc = ISD::INTRINSIC_VOID;
5400 Info.memVT = MVT::i32;
5401 Info.ptrVal = I.getArgOperand(i: 0);
5402 Info.offset = 0;
5403 Info.flags = MachineMemOperand::MOStore;
5404 Info.align.reset();
5405 Infos.push_back(Elt: Info);
5406 return;
5407 }
5408
5409 case Intrinsic::nvvm_tcgen05_st_16x64b_x2:
5410 case Intrinsic::nvvm_tcgen05_st_16x128b_x1:
5411 case Intrinsic::nvvm_tcgen05_st_32x32b_x2:
5412 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x2: {
5413 Info.opc = ISD::INTRINSIC_VOID;
5414 Info.memVT = MVT::v2i32;
5415 Info.ptrVal = I.getArgOperand(i: 0);
5416 Info.offset = 0;
5417 Info.flags = MachineMemOperand::MOStore;
5418 Info.align.reset();
5419 Infos.push_back(Elt: Info);
5420 return;
5421 }
5422
5423 case Intrinsic::nvvm_tcgen05_st_16x64b_x4:
5424 case Intrinsic::nvvm_tcgen05_st_16x128b_x2:
5425 case Intrinsic::nvvm_tcgen05_st_16x256b_x1:
5426 case Intrinsic::nvvm_tcgen05_st_32x32b_x4:
5427 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x4: {
5428 Info.opc = ISD::INTRINSIC_VOID;
5429 Info.memVT = MVT::v4i32;
5430 Info.ptrVal = I.getArgOperand(i: 0);
5431 Info.offset = 0;
5432 Info.flags = MachineMemOperand::MOStore;
5433 Info.align.reset();
5434 Infos.push_back(Elt: Info);
5435 return;
5436 }
5437
5438 case Intrinsic::nvvm_tcgen05_st_16x64b_x8:
5439 case Intrinsic::nvvm_tcgen05_st_16x128b_x4:
5440 case Intrinsic::nvvm_tcgen05_st_16x256b_x2:
5441 case Intrinsic::nvvm_tcgen05_st_32x32b_x8:
5442 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x8: {
5443 Info.opc = ISD::INTRINSIC_VOID;
5444 Info.memVT = MVT::v8i32;
5445 Info.ptrVal = I.getArgOperand(i: 0);
5446 Info.offset = 0;
5447 Info.flags = MachineMemOperand::MOStore;
5448 Info.align.reset();
5449 Infos.push_back(Elt: Info);
5450 return;
5451 }
5452
5453 case Intrinsic::nvvm_tcgen05_st_16x64b_x16:
5454 case Intrinsic::nvvm_tcgen05_st_16x128b_x8:
5455 case Intrinsic::nvvm_tcgen05_st_16x256b_x4:
5456 case Intrinsic::nvvm_tcgen05_st_32x32b_x16:
5457 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x16: {
5458 Info.opc = ISD::INTRINSIC_VOID;
5459 Info.memVT = MVT::v16i32;
5460 Info.ptrVal = I.getArgOperand(i: 0);
5461 Info.offset = 0;
5462 Info.flags = MachineMemOperand::MOStore;
5463 Info.align.reset();
5464 Infos.push_back(Elt: Info);
5465 return;
5466 }
5467
5468 case Intrinsic::nvvm_tcgen05_st_16x64b_x32:
5469 case Intrinsic::nvvm_tcgen05_st_16x128b_x16:
5470 case Intrinsic::nvvm_tcgen05_st_16x256b_x8:
5471 case Intrinsic::nvvm_tcgen05_st_32x32b_x32:
5472 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x32: {
5473 Info.opc = ISD::INTRINSIC_VOID;
5474 Info.memVT = MVT::v32i32;
5475 Info.ptrVal = I.getArgOperand(i: 0);
5476 Info.offset = 0;
5477 Info.flags = MachineMemOperand::MOStore;
5478 Info.align.reset();
5479 Infos.push_back(Elt: Info);
5480 return;
5481 }
5482
5483 case Intrinsic::nvvm_tcgen05_st_16x64b_x64:
5484 case Intrinsic::nvvm_tcgen05_st_16x128b_x32:
5485 case Intrinsic::nvvm_tcgen05_st_16x256b_x16:
5486 case Intrinsic::nvvm_tcgen05_st_32x32b_x64:
5487 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x64: {
5488 Info.opc = ISD::INTRINSIC_VOID;
5489 Info.memVT = MVT::v64i32;
5490 Info.ptrVal = I.getArgOperand(i: 0);
5491 Info.offset = 0;
5492 Info.flags = MachineMemOperand::MOStore;
5493 Info.align.reset();
5494 Infos.push_back(Elt: Info);
5495 return;
5496 }
5497
5498 case Intrinsic::nvvm_tcgen05_st_16x64b_x128:
5499 case Intrinsic::nvvm_tcgen05_st_16x128b_x64:
5500 case Intrinsic::nvvm_tcgen05_st_16x256b_x32:
5501 case Intrinsic::nvvm_tcgen05_st_32x32b_x128:
5502 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x128: {
5503 Info.opc = ISD::INTRINSIC_VOID;
5504 Info.memVT = MVT::v128i32;
5505 Info.ptrVal = I.getArgOperand(i: 0);
5506 Info.offset = 0;
5507 Info.flags = MachineMemOperand::MOStore;
5508 Info.align.reset();
5509 Infos.push_back(Elt: Info);
5510 return;
5511 }
5512 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1:
5513 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1:
5514 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1:
5515 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1:
5516 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1:
5517 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1:
5518 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift:
5519 case Intrinsic::
5520 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift:
5521 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1:
5522 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1:
5523 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift:
5524 case Intrinsic::
5525 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift: {
5526 // We are reading and writing back to TMem
5527 Info.opc = ISD::INTRINSIC_VOID;
5528 Info.memVT = MVT::v4i32;
5529 Info.ptrVal = I.getArgOperand(i: 0);
5530 Info.offset = 0;
5531 Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOStore;
5532 Info.align = Align(16);
5533 Infos.push_back(Elt: Info);
5534 return;
5535 }
5536
5537 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2:
5538 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2:
5539 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2:
5540 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2:
5541 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2:
5542 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2:
5543 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2:
5544 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2:
5545 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift:
5546 case Intrinsic::
5547 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift:
5548 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift:
5549 case Intrinsic::
5550 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift: {
5551 // We are reading and writing back to TMem
5552 Info.opc = ISD::INTRINSIC_VOID;
5553 Info.memVT = MVT::v8i32;
5554 Info.ptrVal = I.getArgOperand(i: 0);
5555 Info.offset = 0;
5556 Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOStore;
5557 Info.align = Align(16);
5558 Infos.push_back(Elt: Info);
5559 return;
5560 }
5561 }
5562}
5563
5564/// getFunctionParamOptimizedAlign - since function arguments are passed via
5565/// .param space, we may want to increase their alignment in a way that
5566/// ensures that we can effectively vectorize their loads & stores. We can
5567/// increase alignment only if the function has internal or has private
5568/// linkage as for other linkage types callers may already rely on default
5569/// alignment. To allow using 128-bit vectorized loads/stores, this function
5570/// ensures that alignment is 16 or greater.
5571Align NVPTXTargetLowering::getFunctionParamOptimizedAlign(
5572 const Function *F, Type *ArgTy, const DataLayout &DL) const {
5573 // Capping the alignment to 128 bytes as that is the maximum alignment
5574 // supported by PTX.
5575 const Align ABITypeAlign = std::min(a: Align(128), b: DL.getABITypeAlign(Ty: ArgTy));
5576
5577 // If a function has linkage different from internal or private, we
5578 // must use default ABI alignment as external users rely on it. Same
5579 // for a function that may be called from a function pointer.
5580 if (!F || !F->hasLocalLinkage() ||
5581 F->hasAddressTaken(/*Users=*/nullptr,
5582 /*IgnoreCallbackUses=*/false,
5583 /*IgnoreAssumeLikeCalls=*/true,
5584 /*IgnoreLLVMUsed=*/IngoreLLVMUsed: true))
5585 return ABITypeAlign;
5586
5587 assert(!isKernelFunction(*F) && "Expect kernels to have non-local linkage");
5588 return std::max(a: Align(16), b: ABITypeAlign);
5589}
5590
5591/// Helper for computing alignment of a device function byval parameter.
5592Align NVPTXTargetLowering::getFunctionByValParamAlign(
5593 const Function *F, Type *ArgTy, Align InitialAlign,
5594 const DataLayout &DL) const {
5595 Align ArgAlign = InitialAlign;
5596 // Try to increase alignment to enhance vectorization options.
5597 if (F)
5598 ArgAlign = std::max(a: ArgAlign, b: getFunctionParamOptimizedAlign(F, ArgTy, DL));
5599
5600 // Old ptx versions have a bug. When PTX code takes address of
5601 // byval parameter with alignment < 4, ptxas generates code to
5602 // spill argument into memory. Alas on sm_50+ ptxas generates
5603 // SASS code that fails with misaligned access. To work around
5604 // the problem, make sure that we align byval parameters by at
5605 // least 4. This bug seems to be fixed at least starting from
5606 // ptxas > 9.0.
5607 // TODO: remove this after verifying the bug is not reproduced
5608 // on non-deprecated ptxas versions.
5609 if (ForceMinByValParamAlign)
5610 ArgAlign = std::max(a: ArgAlign, b: Align(4));
5611
5612 return ArgAlign;
5613}
5614
5615// Helper for getting a function parameter name. Name is composed from
5616// its index and the function name. Negative index corresponds to special
5617// parameter (unsized array) used for passing variable arguments.
5618std::string NVPTXTargetLowering::getParamName(const Function *F,
5619 int Idx) const {
5620 std::string ParamName;
5621 raw_string_ostream ParamStr(ParamName);
5622
5623 ParamStr << getTargetMachine().getSymbol(GV: F)->getName();
5624 if (Idx < 0)
5625 ParamStr << "_vararg";
5626 else
5627 ParamStr << "_param_" << Idx;
5628
5629 return ParamName;
5630}
5631
5632/// isLegalAddressingMode - Return true if the addressing mode represented
5633/// by AM is legal for this target, for a load/store of the specified type.
5634/// Used to guide target specific optimizations, like loop strength reduction
5635/// (LoopStrengthReduce.cpp) and memory optimization for address mode
5636/// (CodeGenPrepare.cpp)
5637bool NVPTXTargetLowering::isLegalAddressingMode(const DataLayout &DL,
5638 const AddrMode &AM, Type *Ty,
5639 unsigned AS, Instruction *I) const {
5640 // AddrMode - This represents an addressing mode of:
5641 // BaseGV + BaseOffs + BaseReg + Scale*ScaleReg
5642 //
5643 // The legal address modes are
5644 // - [avar]
5645 // - [areg]
5646 // - [areg+immoff]
5647 // - [immAddr]
5648
5649 // immoff must fit in a signed 32-bit int
5650 if (!APInt(64, AM.BaseOffs).isSignedIntN(N: 32))
5651 return false;
5652
5653 if (AM.BaseGV)
5654 return !AM.BaseOffs && !AM.HasBaseReg && !AM.Scale;
5655
5656 switch (AM.Scale) {
5657 case 0: // "r", "r+i" or "i" is allowed
5658 break;
5659 case 1:
5660 if (AM.HasBaseReg) // "r+r+i" or "r+r" is not allowed.
5661 return false;
5662 // Otherwise we have r+i.
5663 break;
5664 default:
5665 // No scale > 1 is allowed
5666 return false;
5667 }
5668 return true;
5669}
5670
5671//===----------------------------------------------------------------------===//
5672// NVPTX Inline Assembly Support
5673//===----------------------------------------------------------------------===//
5674
5675/// getConstraintType - Given a constraint letter, return the type of
5676/// constraint it is for this target.
5677NVPTXTargetLowering::ConstraintType
5678NVPTXTargetLowering::getConstraintType(StringRef Constraint) const {
5679 if (Constraint.size() == 1) {
5680 switch (Constraint[0]) {
5681 default:
5682 break;
5683 case 'b':
5684 case 'r':
5685 case 'h':
5686 case 'c':
5687 case 'l':
5688 case 'f':
5689 case 'd':
5690 case 'q':
5691 case '0':
5692 case 'N':
5693 return C_RegisterClass;
5694 }
5695 }
5696 return TargetLowering::getConstraintType(Constraint);
5697}
5698
5699std::pair<unsigned, const TargetRegisterClass *>
5700NVPTXTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
5701 StringRef Constraint,
5702 MVT VT) const {
5703 if (Constraint.size() == 1) {
5704 switch (Constraint[0]) {
5705 case 'b':
5706 return std::make_pair(x: 0U, y: &NVPTX::B1RegClass);
5707 case 'c':
5708 case 'h':
5709 return std::make_pair(x: 0U, y: &NVPTX::B16RegClass);
5710 case 'r':
5711 case 'f':
5712 return std::make_pair(x: 0U, y: &NVPTX::B32RegClass);
5713 case 'l':
5714 case 'N':
5715 case 'd':
5716 return std::make_pair(x: 0U, y: &NVPTX::B64RegClass);
5717 case 'q': {
5718 if (STI.getSmVersion() < 70)
5719 report_fatal_error(reason: "Inline asm with 128 bit operands is only "
5720 "supported for sm_70 and higher!");
5721 return std::make_pair(x: 0U, y: &NVPTX::B128RegClass);
5722 }
5723 }
5724 }
5725 return TargetLowering::getRegForInlineAsmConstraint(TRI, Constraint, VT);
5726}
5727
5728//===----------------------------------------------------------------------===//
5729// NVPTX DAG Combining
5730//===----------------------------------------------------------------------===//
5731
5732bool NVPTXTargetLowering::allowFMA(MachineFunction &MF,
5733 CodeGenOptLevel OptLevel) const {
5734 // Always honor command-line argument
5735 if (FMAContractLevelOpt.getNumOccurrences() > 0)
5736 return FMAContractLevelOpt > 0;
5737
5738 // Do not contract if we're not optimizing the code.
5739 if (OptLevel == CodeGenOptLevel::None)
5740 return false;
5741
5742 // Honor TargetOptions flags that explicitly say fusion is okay.
5743 if (MF.getTarget().Options.AllowFPOpFusion == FPOpFusion::Fast)
5744 return true;
5745
5746 return false;
5747}
5748
5749static bool isConstZero(const SDValue &Operand) {
5750 const auto *Const = dyn_cast<ConstantSDNode>(Val: Operand);
5751 return Const && Const->getZExtValue() == 0;
5752}
5753
5754/// PerformADDCombineWithOperands - Try DAG combinations for an ADD with
5755/// operands N0 and N1. This is a helper for PerformADDCombine that is
5756/// called with the default operands, and if that fails, with commuted
5757/// operands.
5758static SDValue
5759PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
5760 TargetLowering::DAGCombinerInfo &DCI) {
5761 EVT VT = N0.getValueType();
5762
5763 // Since integer multiply-add costs the same as integer multiply
5764 // but is more costly than integer add, do the fusion only when
5765 // the mul is only used in the add.
5766 // TODO: this may not be true for later architectures, consider relaxing this
5767 if (!N0.getNode()->hasOneUse())
5768 return SDValue();
5769
5770 // fold (add (select cond, 0, (mul a, b)), c)
5771 // -> (select cond, c, (add (mul a, b), c))
5772 //
5773 if (N0.getOpcode() == ISD::SELECT) {
5774 unsigned ZeroOpNum;
5775 if (isConstZero(Operand: N0->getOperand(Num: 1)))
5776 ZeroOpNum = 1;
5777 else if (isConstZero(Operand: N0->getOperand(Num: 2)))
5778 ZeroOpNum = 2;
5779 else
5780 return SDValue();
5781
5782 SDValue M = N0->getOperand(Num: (ZeroOpNum == 1) ? 2 : 1);
5783 if (M->getOpcode() != ISD::MUL || !M.getNode()->hasOneUse())
5784 return SDValue();
5785
5786 SDLoc DL(N);
5787 SDValue Mul =
5788 DCI.DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: M->getOperand(Num: 0), N2: M->getOperand(Num: 1));
5789 SDValue MAD = DCI.DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Mul, N2: N1);
5790 return DCI.DAG.getSelect(DL: SDLoc(N), VT, Cond: N0->getOperand(Num: 0),
5791 LHS: ((ZeroOpNum == 1) ? N1 : MAD),
5792 RHS: ((ZeroOpNum == 1) ? MAD : N1));
5793 }
5794
5795 return SDValue();
5796}
5797
5798static SDValue
5799PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
5800 TargetLowering::DAGCombinerInfo &DCI,
5801 CodeGenOptLevel OptLevel) {
5802 EVT VT = N0.getValueType();
5803 if (N0.getOpcode() == ISD::FMUL) {
5804 const auto *TLI = static_cast<const NVPTXTargetLowering *>(
5805 &DCI.DAG.getTargetLoweringInfo());
5806 if (!(TLI->allowFMA(MF&: DCI.DAG.getMachineFunction(), OptLevel) ||
5807 (N->getFlags().hasAllowContract() &&
5808 N0->getFlags().hasAllowContract())))
5809 return SDValue();
5810
5811 // For floating point:
5812 // Do the fusion only when the mul has less than 5 uses and all
5813 // are add.
5814 // The heuristic is that if a use is not an add, then that use
5815 // cannot be fused into fma, therefore mul is still needed anyway.
5816 // If there are more than 4 uses, even if they are all add, fusing
5817 // them will increase register pressue.
5818 //
5819 int numUses = 0;
5820 int nonAddCount = 0;
5821 for (const SDNode *User : N0.getNode()->users()) {
5822 numUses++;
5823 if (User->getOpcode() != ISD::FADD)
5824 ++nonAddCount;
5825 if (numUses >= 5)
5826 return SDValue();
5827 }
5828 if (nonAddCount) {
5829 int orderNo = N->getIROrder();
5830 int orderNo2 = N0.getNode()->getIROrder();
5831 // simple heuristics here for considering potential register
5832 // pressure, the logics here is that the differnce are used
5833 // to measure the distance between def and use, the longer distance
5834 // more likely cause register pressure.
5835 if (orderNo - orderNo2 < 500)
5836 return SDValue();
5837
5838 // Now, check if at least one of the FMUL's operands is live beyond the
5839 // node N, which guarantees that the FMA will not increase register
5840 // pressure at node N.
5841 bool opIsLive = false;
5842 const SDNode *left = N0.getOperand(i: 0).getNode();
5843 const SDNode *right = N0.getOperand(i: 1).getNode();
5844
5845 if (isa<ConstantSDNode>(Val: left) || isa<ConstantSDNode>(Val: right))
5846 opIsLive = true;
5847
5848 if (!opIsLive)
5849 for (const SDNode *User : left->users()) {
5850 int orderNo3 = User->getIROrder();
5851 if (orderNo3 > orderNo) {
5852 opIsLive = true;
5853 break;
5854 }
5855 }
5856
5857 if (!opIsLive)
5858 for (const SDNode *User : right->users()) {
5859 int orderNo3 = User->getIROrder();
5860 if (orderNo3 > orderNo) {
5861 opIsLive = true;
5862 break;
5863 }
5864 }
5865
5866 if (!opIsLive)
5867 return SDValue();
5868 }
5869
5870 return DCI.DAG.getNode(Opcode: ISD::FMA, DL: SDLoc(N), VT, N1: N0.getOperand(i: 0),
5871 N2: N0.getOperand(i: 1), N3: N1);
5872 }
5873
5874 return SDValue();
5875}
5876
5877/// Fold unpacking movs into a load by increasing the number of return values.
5878///
5879/// ex:
5880/// L: v2f16,ch = load <p>
5881/// a: f16 = extractelt L:0, 0
5882/// b: f16 = extractelt L:0, 1
5883/// use(a, b)
5884///
5885/// ...is turned into...
5886///
5887/// L: f16,f16,ch = LoadV2 <p>
5888/// use(L:0, L:1)
5889static SDValue
5890combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
5891 // Don't run this optimization before the legalizer
5892 if (!DCI.isAfterLegalizeDAG())
5893 return SDValue();
5894
5895 EVT ElementVT = N->getValueType(ResNo: 0);
5896 // Avoid non-packed types and v4i8
5897 if (!NVPTX::isPackedVectorTy(VT: ElementVT) || ElementVT == MVT::v4i8)
5898 return SDValue();
5899
5900 // Check whether all outputs are either used by an extractelt or are
5901 // glue/chain nodes
5902 if (!all_of(Range: N->uses(), P: [&](SDUse &U) {
5903 // Skip glue, chain nodes
5904 if (U.getValueType() == MVT::Glue || U.getValueType() == MVT::Other)
5905 return true;
5906 if (U.getUser()->getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
5907 if (N->getOpcode() != ISD::LOAD)
5908 return true;
5909 // Since this is an ISD::LOAD, check all extractelts are used. If
5910 // any are not used, we don't want to defeat another optimization that
5911 // will narrow the load.
5912 //
5913 // For example:
5914 //
5915 // L: v2f16,ch = load <p>
5916 // e0: f16 = extractelt L:0, 0
5917 // e1: f16 = extractelt L:0, 1 <-- unused
5918 // store e0
5919 //
5920 // Can be optimized by DAGCombiner to:
5921 //
5922 // L: f16,ch = load <p>
5923 // store L:0
5924 return !U.getUser()->use_empty();
5925 }
5926
5927 // Otherwise, this use prevents us from splitting a value.
5928 return false;
5929 }))
5930 return SDValue();
5931
5932 auto *LD = cast<MemSDNode>(Val: N);
5933 SDLoc DL(LD);
5934
5935 // the new opcode after we double the number of operands
5936 unsigned Opcode;
5937 SmallVector<SDValue> Operands(LD->ops());
5938 unsigned OldNumOutputs; // non-glue, non-chain outputs
5939 switch (LD->getOpcode()) {
5940 case ISD::LOAD:
5941 OldNumOutputs = 1;
5942 // Any packed type is legal, so the legalizer will not have lowered
5943 // ISD::LOAD -> NVPTXISD::Load (unless it's under-aligned). We have to do it
5944 // here.
5945 Opcode = NVPTXISD::LoadV2;
5946 // append a "full" used bytes mask operand right before the extension type
5947 // operand, signifying that all bytes are used.
5948 Operands.push_back(Elt: DCI.DAG.getConstant(UINT32_MAX, DL, VT: MVT::i32));
5949 Operands.push_back(Elt: DCI.DAG.getIntPtrConstant(
5950 Val: cast<LoadSDNode>(Val: LD)->getExtensionType(), DL));
5951 break;
5952 case NVPTXISD::LoadV2:
5953 OldNumOutputs = 2;
5954 Opcode = NVPTXISD::LoadV4;
5955 break;
5956 case NVPTXISD::LoadV4:
5957 // V8 is only supported for f32/i32. Don't forget, we're not changing the
5958 // load size here. This is already a 256-bit load.
5959 if (ElementVT != MVT::v2f32 && ElementVT != MVT::v2i32)
5960 return SDValue();
5961 OldNumOutputs = 4;
5962 Opcode = NVPTXISD::LoadV8;
5963 break;
5964 case NVPTXISD::LoadV8:
5965 // PTX doesn't support the next doubling of outputs
5966 return SDValue();
5967 }
5968
5969 // the non-glue, non-chain outputs in the new load
5970 const unsigned NewNumOutputs = OldNumOutputs * 2;
5971 SmallVector<EVT> NewVTs(NewNumOutputs, ElementVT.getVectorElementType());
5972 // add remaining chain and glue values
5973 NewVTs.append(in_start: LD->value_begin() + OldNumOutputs, in_end: LD->value_end());
5974
5975 // Create the new load
5976 SDValue NewLoad = DCI.DAG.getMemIntrinsicNode(
5977 Opcode, dl: DL, VTList: DCI.DAG.getVTList(VTs: NewVTs), Ops: Operands, MemVT: LD->getMemoryVT(),
5978 MMO: LD->getMemOperand());
5979
5980 // Now we use a combination of BUILD_VECTORs and a MERGE_VALUES node to keep
5981 // the outputs the same. These nodes will be optimized away in later
5982 // DAGCombiner iterations.
5983 SmallVector<SDValue> Results;
5984 for (unsigned I : seq(Size: OldNumOutputs))
5985 Results.push_back(Elt: DCI.DAG.getBuildVector(
5986 VT: ElementVT, DL, Ops: {NewLoad.getValue(R: I * 2), NewLoad.getValue(R: I * 2 + 1)}));
5987 // Add remaining chain and glue nodes
5988 for (unsigned I : seq(Size: NewLoad->getNumValues() - NewNumOutputs))
5989 Results.push_back(Elt: NewLoad.getValue(R: NewNumOutputs + I));
5990
5991 return DCI.DAG.getMergeValues(Ops: Results, dl: DL);
5992}
5993
5994/// Fold packing movs into a store.
5995///
5996/// ex:
5997/// v1: v2f16 = BUILD_VECTOR a:f16, b:f16
5998/// v2: v2f16 = BUILD_VECTOR c:f16, d:f16
5999/// StoreV2 v1, v2
6000///
6001/// ...is turned into...
6002///
6003/// StoreV4 a, b, c, d
6004static SDValue combinePackingMovIntoStore(SDNode *N,
6005 TargetLowering::DAGCombinerInfo &DCI,
6006 unsigned Front, unsigned Back) {
6007 // We want to run this as late as possible since other optimizations may
6008 // eliminate the BUILD_VECTORs.
6009 if (!DCI.isAfterLegalizeDAG())
6010 return SDValue();
6011
6012 // Get the type of the operands being stored.
6013 EVT ElementVT = N->getOperand(Num: Front).getValueType();
6014
6015 // Avoid non-packed types and v4i8
6016 if (!NVPTX::isPackedVectorTy(VT: ElementVT) || ElementVT == MVT::v4i8)
6017 return SDValue();
6018
6019 auto *ST = cast<MemSDNode>(Val: N);
6020
6021 // The new opcode after we double the number of operands.
6022 unsigned Opcode;
6023 switch (N->getOpcode()) {
6024 case ISD::STORE:
6025 // Any packed type is legal, so the legalizer will not have lowered
6026 // ISD::STORE -> NVPTXISD::Store (unless it's under-aligned). We have to do
6027 // it here.
6028 Opcode = NVPTXISD::StoreV2;
6029 break;
6030 case NVPTXISD::StoreV2:
6031 Opcode = NVPTXISD::StoreV4;
6032 break;
6033 case NVPTXISD::StoreV4:
6034 // V8 is only supported for f32/i32. Don't forget, we're not changing the
6035 // store size here. This is already a 256-bit store.
6036 if (ElementVT != MVT::v2f32 && ElementVT != MVT::v2i32)
6037 return SDValue();
6038 Opcode = NVPTXISD::StoreV8;
6039 break;
6040 case NVPTXISD::StoreV8:
6041 // PTX doesn't support the next doubling of operands
6042 return SDValue();
6043 default:
6044 llvm_unreachable("Unhandled store opcode");
6045 }
6046
6047 // Scan the operands and if they're all BUILD_VECTORs, we'll have gathered
6048 // their elements.
6049 SmallVector<SDValue, 4> Operands(N->ops().take_front(N: Front));
6050 for (SDValue BV : N->ops().drop_front(N: Front).drop_back(N: Back)) {
6051 if (BV.getOpcode() != ISD::BUILD_VECTOR)
6052 return SDValue();
6053
6054 // If the operand has multiple uses, this optimization can increase register
6055 // pressure.
6056 if (!BV.hasOneUse())
6057 return SDValue();
6058
6059 // DAGCombiner visits nodes bottom-up. Check the BUILD_VECTOR operands for
6060 // any signs they may be folded by some other pattern or rule.
6061 for (SDValue Op : BV->ops()) {
6062 // Peek through bitcasts
6063 if (Op.getOpcode() == ISD::BITCAST)
6064 Op = Op.getOperand(i: 0);
6065
6066 // This may be folded into a PRMT.
6067 if (Op.getValueType() == MVT::i16 && Op.getOpcode() == ISD::TRUNCATE &&
6068 Op->getOperand(Num: 0).getValueType() == MVT::i32)
6069 return SDValue();
6070
6071 // This may be folded into cvt.bf16x2
6072 if (Op.getOpcode() == ISD::FP_ROUND)
6073 return SDValue();
6074 }
6075 Operands.append(IL: {BV.getOperand(i: 0), BV.getOperand(i: 1)});
6076 }
6077 Operands.append(in_start: N->op_end() - Back, in_end: N->op_end());
6078
6079 // Now we replace the store
6080 return DCI.DAG.getMemIntrinsicNode(Opcode, dl: SDLoc(N), VTList: N->getVTList(), Ops: Operands,
6081 MemVT: ST->getMemoryVT(), MMO: ST->getMemOperand());
6082}
6083
6084static SDValue combineSTORE(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
6085 const NVPTXSubtarget &STI) {
6086
6087 if (DCI.isBeforeLegalize() && N->getOpcode() == ISD::STORE) {
6088 // Here is our chance to custom lower a store with a non-simple type.
6089 // Unfortunately, we can't do this in the legalizer because there is no
6090 // way to setOperationAction for an non-simple type.
6091 StoreSDNode *ST = cast<StoreSDNode>(Val: N);
6092 if (!ST->getValue().getValueType().isSimple())
6093 return lowerSTOREVector(Op: SDValue(ST, 0), DAG&: DCI.DAG, STI);
6094 }
6095
6096 return combinePackingMovIntoStore(N, DCI, Front: 1, Back: 2);
6097}
6098
6099static SDValue combineLOAD(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
6100 const NVPTXSubtarget &STI) {
6101 if (DCI.isBeforeLegalize() && N->getOpcode() == ISD::LOAD) {
6102 // Here is our chance to custom lower a load with a non-simple type.
6103 // Unfortunately, we can't do this in the legalizer because there is no
6104 // way to setOperationAction for an non-simple type.
6105 if (!N->getValueType(ResNo: 0).isSimple())
6106 return lowerLoadVector(N, DAG&: DCI.DAG, STI);
6107 }
6108
6109 return combineUnpackingMovIntoLoad(N, DCI);
6110}
6111
6112/// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
6113///
6114static SDValue PerformADDCombine(SDNode *N,
6115 TargetLowering::DAGCombinerInfo &DCI,
6116 CodeGenOptLevel OptLevel) {
6117 if (OptLevel == CodeGenOptLevel::None)
6118 return SDValue();
6119
6120 SDValue N0 = N->getOperand(Num: 0);
6121 SDValue N1 = N->getOperand(Num: 1);
6122
6123 // Skip non-integer, non-scalar case
6124 EVT VT = N0.getValueType();
6125 if (VT.isVector() || VT != MVT::i32)
6126 return SDValue();
6127
6128 // First try with the default operand order.
6129 if (SDValue Result = PerformADDCombineWithOperands(N, N0, N1, DCI))
6130 return Result;
6131
6132 // If that didn't work, try again with the operands commuted.
6133 return PerformADDCombineWithOperands(N, N0: N1, N1: N0, DCI);
6134}
6135
6136/// PerformFADDCombine - Target-specific dag combine xforms for ISD::FADD.
6137///
6138static SDValue PerformFADDCombine(SDNode *N,
6139 TargetLowering::DAGCombinerInfo &DCI,
6140 CodeGenOptLevel OptLevel) {
6141 SDValue N0 = N->getOperand(Num: 0);
6142 SDValue N1 = N->getOperand(Num: 1);
6143
6144 EVT VT = N0.getValueType();
6145 if (VT.isVector() || !(VT == MVT::f32 || VT == MVT::f64))
6146 return SDValue();
6147
6148 // First try with the default operand order.
6149 if (SDValue Result = PerformFADDCombineWithOperands(N, N0, N1, DCI, OptLevel))
6150 return Result;
6151
6152 // If that didn't work, try again with the operands commuted.
6153 return PerformFADDCombineWithOperands(N, N0: N1, N1: N0, DCI, OptLevel);
6154}
6155
6156/// Get 3-input version of a 2-input min/max opcode
6157static unsigned getMinMax3Opcode(unsigned MinMax2Opcode) {
6158 switch (MinMax2Opcode) {
6159 case ISD::FMAXNUM:
6160 case ISD::FMAXIMUMNUM:
6161 return NVPTXISD::FMAXNUM3;
6162 case ISD::FMINNUM:
6163 case ISD::FMINIMUMNUM:
6164 return NVPTXISD::FMINNUM3;
6165 case ISD::FMAXIMUM:
6166 return NVPTXISD::FMAXIMUM3;
6167 case ISD::FMINIMUM:
6168 return NVPTXISD::FMINIMUM3;
6169 default:
6170 llvm_unreachable("Invalid 2-input min/max opcode");
6171 }
6172}
6173
6174/// PerformFMinMaxCombine - Combine (fmaxnum (fmaxnum a, b), c) into
6175/// (fmaxnum3 a, b, c). Also covers other llvm min/max intrinsics.
6176static SDValue PerformFMinMaxCombine(SDNode *N,
6177 TargetLowering::DAGCombinerInfo &DCI,
6178 unsigned PTXVersion, unsigned SmVersion) {
6179
6180 // 3-input min/max requires PTX 8.8+ and SM_100+, and only supports f32s
6181 EVT VT = N->getValueType(ResNo: 0);
6182 if (VT != MVT::f32 || PTXVersion < 88 || SmVersion < 100)
6183 return SDValue();
6184
6185 SDValue Op0 = N->getOperand(Num: 0);
6186 SDValue Op1 = N->getOperand(Num: 1);
6187 unsigned MinMaxOp2 = N->getOpcode();
6188 unsigned MinMaxOp3 = getMinMax3Opcode(MinMax2Opcode: MinMaxOp2);
6189
6190 if (Op0.getOpcode() == MinMaxOp2 && Op0.hasOneUse()) {
6191 // (maxnum (maxnum a, b), c) -> (maxnum3 a, b, c)
6192 SDValue A = Op0.getOperand(i: 0);
6193 SDValue B = Op0.getOperand(i: 1);
6194 SDValue C = Op1;
6195 return DCI.DAG.getNode(Opcode: MinMaxOp3, DL: SDLoc(N), VT, N1: A, N2: B, N3: C, Flags: N->getFlags());
6196 } else if (Op1.getOpcode() == MinMaxOp2 && Op1.hasOneUse()) {
6197 // (maxnum a, (maxnum b, c)) -> (maxnum3 a, b, c)
6198 SDValue A = Op0;
6199 SDValue B = Op1.getOperand(i: 0);
6200 SDValue C = Op1.getOperand(i: 1);
6201 return DCI.DAG.getNode(Opcode: MinMaxOp3, DL: SDLoc(N), VT, N1: A, N2: B, N3: C, Flags: N->getFlags());
6202 }
6203 return SDValue();
6204}
6205
6206static SDValue PerformREMCombine(SDNode *N,
6207 TargetLowering::DAGCombinerInfo &DCI,
6208 CodeGenOptLevel OptLevel) {
6209 assert(N->getOpcode() == ISD::SREM || N->getOpcode() == ISD::UREM);
6210
6211 // Don't do anything at less than -O2.
6212 if (OptLevel < CodeGenOptLevel::Default)
6213 return SDValue();
6214
6215 SelectionDAG &DAG = DCI.DAG;
6216 SDLoc DL(N);
6217 EVT VT = N->getValueType(ResNo: 0);
6218 bool IsSigned = N->getOpcode() == ISD::SREM;
6219 unsigned DivOpc = IsSigned ? ISD::SDIV : ISD::UDIV;
6220
6221 const SDValue &Num = N->getOperand(Num: 0);
6222 const SDValue &Den = N->getOperand(Num: 1);
6223
6224 for (const SDNode *U : Num->users()) {
6225 if (U->getOpcode() == DivOpc && U->getOperand(Num: 0) == Num &&
6226 U->getOperand(Num: 1) == Den) {
6227 // Num % Den -> Num - (Num / Den) * Den
6228 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Num,
6229 N2: DAG.getNode(Opcode: ISD::MUL, DL, VT,
6230 N1: DAG.getNode(Opcode: DivOpc, DL, VT, N1: Num, N2: Den),
6231 N2: Den));
6232 }
6233 }
6234 return SDValue();
6235}
6236
6237// (sign_extend|zero_extend (mul|shl) x, y) -> (mul.wide x, y)
6238static SDValue combineMulWide(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
6239 CodeGenOptLevel OptLevel) {
6240 if (OptLevel == CodeGenOptLevel::None)
6241 return SDValue();
6242
6243 SDValue Op = N->getOperand(Num: 0);
6244 if (!Op.hasOneUse())
6245 return SDValue();
6246 EVT ToVT = N->getValueType(ResNo: 0);
6247 EVT FromVT = Op.getValueType();
6248 if (!((ToVT == MVT::i32 && FromVT == MVT::i16) ||
6249 (ToVT == MVT::i64 && FromVT == MVT::i32)))
6250 return SDValue();
6251 if (!(Op.getOpcode() == ISD::MUL ||
6252 (Op.getOpcode() == ISD::SHL && isa<ConstantSDNode>(Val: Op.getOperand(i: 1)))))
6253 return SDValue();
6254
6255 SDLoc DL(N);
6256 unsigned ExtOpcode = N->getOpcode();
6257 unsigned Opcode = 0;
6258 if (ExtOpcode == ISD::SIGN_EXTEND && Op->getFlags().hasNoSignedWrap())
6259 Opcode = NVPTXISD::MUL_WIDE_SIGNED;
6260 else if (ExtOpcode == ISD::ZERO_EXTEND && Op->getFlags().hasNoUnsignedWrap())
6261 Opcode = NVPTXISD::MUL_WIDE_UNSIGNED;
6262 else
6263 return SDValue();
6264 SDValue RHS = Op.getOperand(i: 1);
6265 if (Op.getOpcode() == ISD::SHL) {
6266 const auto ShiftAmt = Op.getConstantOperandVal(i: 1);
6267 const auto MulVal = APInt(FromVT.getSizeInBits(), 1) << ShiftAmt;
6268 RHS = DCI.DAG.getConstant(Val: MulVal, DL, VT: FromVT);
6269 }
6270 return DCI.DAG.getNode(Opcode, DL, VT: ToVT, N1: Op.getOperand(i: 0), N2: RHS);
6271}
6272
6273enum OperandSignedness {
6274 Signed = 0,
6275 Unsigned,
6276 Unknown
6277};
6278
6279/// IsMulWideOperandDemotable - Checks if the provided DAG node is an operand
6280/// that can be demoted to \p OptSize bits without loss of information. The
6281/// signedness of the operand, if determinable, is placed in \p S.
6282static bool IsMulWideOperandDemotable(SDValue Op,
6283 unsigned OptSize,
6284 OperandSignedness &S) {
6285 S = Unknown;
6286
6287 if (Op.getOpcode() == ISD::SIGN_EXTEND ||
6288 Op.getOpcode() == ISD::SIGN_EXTEND_INREG) {
6289 EVT OrigVT = Op.getOperand(i: 0).getValueType();
6290 if (OrigVT.getFixedSizeInBits() <= OptSize) {
6291 S = Signed;
6292 return true;
6293 }
6294 } else if (Op.getOpcode() == ISD::ZERO_EXTEND) {
6295 EVT OrigVT = Op.getOperand(i: 0).getValueType();
6296 if (OrigVT.getFixedSizeInBits() <= OptSize) {
6297 S = Unsigned;
6298 return true;
6299 }
6300 }
6301
6302 return false;
6303}
6304
6305/// AreMulWideOperandsDemotable - Checks if the given LHS and RHS operands can
6306/// be demoted to \p OptSize bits without loss of information. If the operands
6307/// contain a constant, it should appear as the RHS operand. The signedness of
6308/// the operands is placed in \p IsSigned.
6309static bool AreMulWideOperandsDemotable(SDValue LHS, SDValue RHS,
6310 unsigned OptSize,
6311 bool &IsSigned) {
6312 OperandSignedness LHSSign;
6313
6314 // The LHS operand must be a demotable op
6315 if (!IsMulWideOperandDemotable(Op: LHS, OptSize, S&: LHSSign))
6316 return false;
6317
6318 // We should have been able to determine the signedness from the LHS
6319 if (LHSSign == Unknown)
6320 return false;
6321
6322 IsSigned = (LHSSign == Signed);
6323
6324 // The RHS can be a demotable op or a constant
6325 if (ConstantSDNode *CI = dyn_cast<ConstantSDNode>(Val&: RHS)) {
6326 const APInt &Val = CI->getAPIntValue();
6327 if (LHSSign == Unsigned) {
6328 return Val.isIntN(N: OptSize);
6329 } else {
6330 return Val.isSignedIntN(N: OptSize);
6331 }
6332 } else {
6333 OperandSignedness RHSSign;
6334 if (!IsMulWideOperandDemotable(Op: RHS, OptSize, S&: RHSSign))
6335 return false;
6336
6337 return LHSSign == RHSSign;
6338 }
6339}
6340
6341/// TryMULWIDECombine - Attempt to replace a multiply of M bits with a multiply
6342/// of M/2 bits that produces an M-bit result (i.e. mul.wide). This transform
6343/// works on both multiply DAG nodes and SHL DAG nodes with a constant shift
6344/// amount.
6345static SDValue TryMULWIDECombine(SDNode *N,
6346 TargetLowering::DAGCombinerInfo &DCI) {
6347 EVT MulType = N->getValueType(ResNo: 0);
6348 if (MulType != MVT::i32 && MulType != MVT::i64) {
6349 return SDValue();
6350 }
6351
6352 SDLoc DL(N);
6353 unsigned OptSize = MulType.getSizeInBits() >> 1;
6354 SDValue LHS = N->getOperand(Num: 0);
6355 SDValue RHS = N->getOperand(Num: 1);
6356
6357 // Canonicalize the multiply so the constant (if any) is on the right
6358 if (N->getOpcode() == ISD::MUL) {
6359 if (isa<ConstantSDNode>(Val: LHS)) {
6360 std::swap(a&: LHS, b&: RHS);
6361 }
6362 }
6363
6364 // If we have a SHL, determine the actual multiply amount
6365 if (N->getOpcode() == ISD::SHL) {
6366 ConstantSDNode *ShlRHS = dyn_cast<ConstantSDNode>(Val&: RHS);
6367 if (!ShlRHS) {
6368 return SDValue();
6369 }
6370
6371 APInt ShiftAmt = ShlRHS->getAPIntValue();
6372 unsigned BitWidth = MulType.getSizeInBits();
6373 if (ShiftAmt.sge(RHS: 0) && ShiftAmt.slt(RHS: BitWidth)) {
6374 APInt MulVal = APInt(BitWidth, 1) << ShiftAmt;
6375 RHS = DCI.DAG.getConstant(Val: MulVal, DL, VT: MulType);
6376 } else {
6377 return SDValue();
6378 }
6379 }
6380
6381 bool Signed;
6382 // Verify that our operands are demotable
6383 if (!AreMulWideOperandsDemotable(LHS, RHS, OptSize, IsSigned&: Signed)) {
6384 return SDValue();
6385 }
6386
6387 EVT DemotedVT;
6388 if (MulType == MVT::i32) {
6389 DemotedVT = MVT::i16;
6390 } else {
6391 DemotedVT = MVT::i32;
6392 }
6393
6394 // Truncate the operands to the correct size. Note that these are just for
6395 // type consistency and will (likely) be eliminated in later phases.
6396 SDValue TruncLHS =
6397 DCI.DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: DemotedVT, Operand: LHS);
6398 SDValue TruncRHS =
6399 DCI.DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: DemotedVT, Operand: RHS);
6400
6401 unsigned Opc;
6402 if (Signed) {
6403 Opc = NVPTXISD::MUL_WIDE_SIGNED;
6404 } else {
6405 Opc = NVPTXISD::MUL_WIDE_UNSIGNED;
6406 }
6407
6408 return DCI.DAG.getNode(Opcode: Opc, DL, VT: MulType, N1: TruncLHS, N2: TruncRHS);
6409}
6410
6411static bool isConstOne(const SDValue &Operand) {
6412 const auto *Const = dyn_cast<ConstantSDNode>(Val: Operand);
6413 return Const && Const->getZExtValue() == 1;
6414}
6415
6416static SDValue matchMADConstOnePattern(SDValue Add) {
6417 if (Add->getOpcode() != ISD::ADD)
6418 return SDValue();
6419
6420 if (isConstOne(Operand: Add->getOperand(Num: 0)))
6421 return Add->getOperand(Num: 1);
6422
6423 if (isConstOne(Operand: Add->getOperand(Num: 1)))
6424 return Add->getOperand(Num: 0);
6425
6426 return SDValue();
6427}
6428
6429static SDValue combineMADConstOne(SDValue X, SDValue Add, EVT VT, SDLoc DL,
6430 TargetLowering::DAGCombinerInfo &DCI) {
6431
6432 if (SDValue Y = matchMADConstOnePattern(Add)) {
6433 SDValue Mul = DCI.DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: X, N2: Y);
6434 return DCI.DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Mul, N2: X);
6435 }
6436
6437 return SDValue();
6438}
6439
6440static SDValue combineMulSelectConstOne(SDValue X, SDValue Select, EVT VT,
6441 SDLoc DL,
6442 TargetLowering::DAGCombinerInfo &DCI) {
6443 if (Select->getOpcode() != ISD::SELECT)
6444 return SDValue();
6445
6446 SDValue Cond = Select->getOperand(Num: 0);
6447
6448 unsigned ConstOpNo;
6449 if (isConstOne(Operand: Select->getOperand(Num: 1)))
6450 ConstOpNo = 1;
6451 else if (isConstOne(Operand: Select->getOperand(Num: 2)))
6452 ConstOpNo = 2;
6453 else
6454 return SDValue();
6455
6456 SDValue Y = Select->getOperand(Num: (ConstOpNo == 1) ? 2 : 1);
6457
6458 // Do not combine if the resulting sequence is not obviously profitable.
6459 if (!matchMADConstOnePattern(Add: Y))
6460 return SDValue();
6461
6462 SDValue NewMul = DCI.DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: X, N2: Y);
6463
6464 return DCI.DAG.getNode(Opcode: ISD::SELECT, DL, VT, N1: Cond,
6465 N2: (ConstOpNo == 1) ? X : NewMul,
6466 N3: (ConstOpNo == 1) ? NewMul : X);
6467}
6468
6469static SDValue
6470PerformMULCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
6471 TargetLowering::DAGCombinerInfo &DCI) {
6472
6473 EVT VT = N0.getValueType();
6474 if (VT.isVector())
6475 return SDValue();
6476
6477 if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
6478 return SDValue();
6479
6480 SDLoc DL(N);
6481
6482 // (mul x, (add y, 1)) -> (add (mul x, y), x)
6483 if (SDValue Res = combineMADConstOne(X: N0, Add: N1, VT, DL, DCI))
6484 return Res;
6485 if (SDValue Res = combineMADConstOne(X: N1, Add: N0, VT, DL, DCI))
6486 return Res;
6487
6488 // (mul x, (select y, 1)) -> (select (mul x, y), x)
6489 if (SDValue Res = combineMulSelectConstOne(X: N0, Select: N1, VT, DL, DCI))
6490 return Res;
6491 if (SDValue Res = combineMulSelectConstOne(X: N1, Select: N0, VT, DL, DCI))
6492 return Res;
6493
6494 return SDValue();
6495}
6496
6497/// PerformMULCombine - Runs PTX-specific DAG combine patterns on MUL nodes.
6498static SDValue PerformMULCombine(SDNode *N,
6499 TargetLowering::DAGCombinerInfo &DCI,
6500 CodeGenOptLevel OptLevel) {
6501 if (OptLevel == CodeGenOptLevel::None)
6502 return SDValue();
6503
6504 if (SDValue Ret = TryMULWIDECombine(N, DCI))
6505 return Ret;
6506
6507 SDValue N0 = N->getOperand(Num: 0);
6508 SDValue N1 = N->getOperand(Num: 1);
6509 return PerformMULCombineWithOperands(N, N0, N1, DCI);
6510}
6511
6512/// PerformSHLCombine - Runs PTX-specific DAG combine patterns on SHL nodes.
6513static SDValue PerformSHLCombine(SDNode *N,
6514 TargetLowering::DAGCombinerInfo &DCI,
6515 CodeGenOptLevel OptLevel) {
6516 if (OptLevel > CodeGenOptLevel::None) {
6517 // Try mul.wide combining at OptLevel > 0
6518 if (SDValue Ret = TryMULWIDECombine(N, DCI))
6519 return Ret;
6520 }
6521
6522 return SDValue();
6523}
6524
6525static SDValue PerformSETCCCombine(SDNode *N,
6526 TargetLowering::DAGCombinerInfo &DCI,
6527 unsigned int SmVersion) {
6528 EVT CCType = N->getValueType(ResNo: 0);
6529 SDValue A = N->getOperand(Num: 0);
6530 SDValue B = N->getOperand(Num: 1);
6531
6532 EVT AType = A.getValueType();
6533 if (!(CCType == MVT::v2i1 && (AType == MVT::v2f16 || AType == MVT::v2bf16)))
6534 return SDValue();
6535
6536 if (A.getValueType() == MVT::v2bf16 && SmVersion < 90)
6537 return SDValue();
6538
6539 SDLoc DL(N);
6540 // setp.f16x2 returns two scalar predicates, which we need to
6541 // convert back to v2i1. The returned result will be scalarized by
6542 // the legalizer, but the comparison will remain a single vector
6543 // instruction.
6544 SDValue CCNode = DCI.DAG.getNode(
6545 Opcode: A.getValueType() == MVT::v2f16 ? NVPTXISD::SETP_F16X2
6546 : NVPTXISD::SETP_BF16X2,
6547 DL, VTList: DCI.DAG.getVTList(VT1: MVT::i1, VT2: MVT::i1), Ops: {A, B, N->getOperand(Num: 2)});
6548 return DCI.DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: CCType, N1: CCNode.getValue(R: 0),
6549 N2: CCNode.getValue(R: 1));
6550}
6551
6552static SDValue PerformEXTRACTCombine(SDNode *N,
6553 TargetLowering::DAGCombinerInfo &DCI) {
6554 SDValue Vector = N->getOperand(Num: 0);
6555 if (Vector->getOpcode() == ISD::FREEZE)
6556 Vector = Vector->getOperand(Num: 0);
6557 SDLoc DL(N);
6558 EVT VectorVT = Vector.getValueType();
6559 if (Vector->getOpcode() == ISD::LOAD && VectorVT.isSimple() &&
6560 IsPTXVectorType(VT: VectorVT.getSimpleVT()))
6561 return SDValue(); // Native vector loads already combine nicely w/
6562 // extract_vector_elt.
6563 // Don't mess with singletons or packed types (v2*32, v2*16, v4i8 and v8i8),
6564 // we already handle them OK.
6565 if (VectorVT.getVectorNumElements() == 1 ||
6566 NVPTX::isPackedVectorTy(VT: VectorVT) || VectorVT == MVT::v8i8)
6567 return SDValue();
6568
6569 // Don't mess with undef values as sra may be simplified to 0, not undef.
6570 if (Vector->isUndef() || ISD::allOperandsUndef(N: Vector.getNode()))
6571 return SDValue();
6572
6573 uint64_t VectorBits = VectorVT.getSizeInBits();
6574 // We only handle the types we can extract in-register.
6575 if (!(VectorBits == 16 || VectorBits == 32 || VectorBits == 64))
6576 return SDValue();
6577
6578 ConstantSDNode *Index = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
6579 // Index == 0 is handled by generic DAG combiner.
6580 if (!Index || Index->getZExtValue() == 0)
6581 return SDValue();
6582
6583 MVT IVT = MVT::getIntegerVT(BitWidth: VectorBits);
6584 EVT EltVT = VectorVT.getVectorElementType();
6585 EVT EltIVT = EltVT.changeTypeToInteger();
6586 uint64_t EltBits = EltVT.getScalarSizeInBits();
6587
6588 SDValue Result = DCI.DAG.getNode(
6589 Opcode: ISD::TRUNCATE, DL, VT: EltIVT,
6590 Operand: DCI.DAG.getNode(
6591 Opcode: ISD::SRA, DL, VT: IVT, N1: DCI.DAG.getNode(Opcode: ISD::BITCAST, DL, VT: IVT, Operand: Vector),
6592 N2: DCI.DAG.getConstant(Val: Index->getZExtValue() * EltBits, DL, VT: IVT)));
6593
6594 // If element has non-integer type, bitcast it back to the expected type.
6595 if (EltVT != EltIVT)
6596 Result = DCI.DAG.getNode(Opcode: ISD::BITCAST, DL, VT: EltVT, Operand: Result);
6597 // Past legalizer, we may need to extent i8 -> i16 to match the register type.
6598 if (EltVT != N->getValueType(ResNo: 0))
6599 Result = DCI.DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT: N->getValueType(ResNo: 0), Operand: Result);
6600
6601 return Result;
6602}
6603
6604/// Transform patterns like:
6605/// (select (ugt shift_amt, BitWidth-1), 0, (srl/shl x, shift_amt))
6606/// (select (ult shift_amt, BitWidth), (srl/shl x, shift_amt), 0)
6607/// Into:
6608/// (NVPTXISD::SRL_CLAMP x, shift_amt) or (NVPTXISD::SHL_CLAMP x, shift_amt)
6609///
6610/// These patterns arise from C/C++ code like `shift >= 32 ? 0 : x >> shift`
6611/// which guards against undefined behavior. PTX shr/shl instructions clamp
6612/// shift amounts >= BitWidth to produce 0 for logical shifts, making the
6613/// guard redundant.
6614///
6615/// Note: We only handle SRL and SHL, not SRA, because arithmetic right
6616/// shifts could produce 0 or -1 when shift >= BitWidth.
6617/// Note: We don't handle uge or ule. These don't appear because of
6618/// canonicalization.
6619static SDValue PerformSELECTShiftCombine(SDNode *N,
6620 TargetLowering::DAGCombinerInfo &DCI) {
6621 if (!DCI.isAfterLegalizeDAG())
6622 return SDValue();
6623
6624 using namespace SDPatternMatch;
6625 unsigned BitWidth = N->getValueType(ResNo: 0).getSizeInBits();
6626 SDValue ShiftAmt, ShiftOp;
6627
6628 // Match logical shifts where the shift amount in the guard matches the shift
6629 // amount in the operation.
6630 auto LogicalShift =
6631 m_AllOf(preds: m_Value(N&: ShiftOp),
6632 preds: m_AnyOf(preds: m_Srl(L: m_Value(), R: m_TruncOrSelf(Op: m_Deferred(V&: ShiftAmt))),
6633 preds: m_Shl(L: m_Value(), R: m_TruncOrSelf(Op: m_Deferred(V&: ShiftAmt)))));
6634
6635 // shift_amt > BitWidth-1 ? 0 : shift_op
6636 bool MatchedUGT =
6637 sd_match(N, P: m_Select(Cond: m_SetCC(LHS: m_Value(N&: ShiftAmt),
6638 RHS: m_SpecificInt(V: APInt(BitWidth, BitWidth - 1)),
6639 CC: m_SpecificCondCode(CC: ISD::SETUGT)),
6640 T: m_Zero(), F: LogicalShift));
6641 // shift_amt < BitWidth ? shift_op : 0
6642 bool MatchedULT =
6643 !MatchedUGT &&
6644 sd_match(N, P: m_Select(Cond: m_SetCC(LHS: m_Value(N&: ShiftAmt),
6645 RHS: m_SpecificInt(V: APInt(BitWidth, BitWidth)),
6646 CC: m_SpecificCondCode(CC: ISD::SETULT)),
6647 T: LogicalShift, F: m_Zero()));
6648
6649 if (!MatchedUGT && !MatchedULT)
6650 return SDValue();
6651
6652 // Return a clamp shift operation, which has the same semantics as PTX shift.
6653 unsigned ClampOpc = ShiftOp.getOpcode() == ISD::SRL ? NVPTXISD::SRL_CLAMP
6654 : NVPTXISD::SHL_CLAMP;
6655 return DCI.DAG.getNode(Opcode: ClampOpc, DL: SDLoc(N), VT: ShiftOp.getValueType(),
6656 N1: ShiftOp.getOperand(i: 0), N2: ShiftOp.getOperand(i: 1));
6657}
6658
6659static SDValue PerformVSELECTCombine(SDNode *N,
6660 TargetLowering::DAGCombinerInfo &DCI) {
6661 SDValue VA = N->getOperand(Num: 1);
6662 EVT VectorVT = VA.getValueType();
6663 if (VectorVT != MVT::v4i8)
6664 return SDValue();
6665
6666 // We need to split vselect into individual per-element operations Because we
6667 // use BFE/BFI instruction for byte extraction/insertion, we do end up with
6668 // 32-bit values, so we may as well do comparison as i32 to avoid conversions
6669 // to/from i16 normally used for i8 values.
6670 SmallVector<SDValue, 4> E;
6671 SDLoc DL(N);
6672 SDValue VCond = N->getOperand(Num: 0);
6673 SDValue VB = N->getOperand(Num: 2);
6674 for (int I = 0; I < 4; ++I) {
6675 SDValue C = DCI.DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i1, N1: VCond,
6676 N2: DCI.DAG.getConstant(Val: I, DL, VT: MVT::i32));
6677 SDValue EA = DCI.DAG.getAnyExtOrTrunc(
6678 Op: DCI.DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i8, N1: VA,
6679 N2: DCI.DAG.getConstant(Val: I, DL, VT: MVT::i32)),
6680 DL, VT: MVT::i32);
6681 SDValue EB = DCI.DAG.getAnyExtOrTrunc(
6682 Op: DCI.DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i8, N1: VB,
6683 N2: DCI.DAG.getConstant(Val: I, DL, VT: MVT::i32)),
6684 DL, VT: MVT::i32);
6685 E.push_back(Elt: DCI.DAG.getAnyExtOrTrunc(
6686 Op: DCI.DAG.getNode(Opcode: ISD::SELECT, DL, VT: MVT::i32, N1: C, N2: EA, N3: EB), DL, VT: MVT::i8));
6687 }
6688 return DCI.DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: MVT::v4i8, Ops: E);
6689}
6690
6691static SDValue
6692PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
6693 auto VT = N->getValueType(ResNo: 0);
6694 if (!DCI.isAfterLegalizeDAG() ||
6695 // only process v2*16 types
6696 !(NVPTX::isPackedVectorTy(VT) && VT.is32BitVector() &&
6697 VT.getVectorNumElements() == 2))
6698 return SDValue();
6699
6700 auto Op0 = N->getOperand(Num: 0);
6701 auto Op1 = N->getOperand(Num: 1);
6702
6703 // Start out by assuming we want to take the lower 2 bytes of each i32
6704 // operand.
6705 uint64_t Op0Bytes = 0x10;
6706 uint64_t Op1Bytes = 0x54;
6707
6708 std::pair<SDValue *, uint64_t *> OpData[2] = {{&Op0, &Op0Bytes},
6709 {&Op1, &Op1Bytes}};
6710
6711 // Check that each operand is an i16, truncated from an i32 operand. We'll
6712 // select individual bytes from those original operands. Optionally, fold in a
6713 // shift right of that original operand.
6714 for (auto &[Op, OpBytes] : OpData) {
6715 // Eat up any bitcast
6716 if (Op->getOpcode() == ISD::BITCAST)
6717 *Op = Op->getOperand(i: 0);
6718
6719 if (!(Op->getValueType() == MVT::i16 && Op->getOpcode() == ISD::TRUNCATE &&
6720 Op->getOperand(i: 0).getValueType() == MVT::i32))
6721 return SDValue();
6722
6723 // If the truncate has multiple uses, this optimization can increase
6724 // register pressure
6725 if (!Op->hasOneUse())
6726 return SDValue();
6727
6728 *Op = Op->getOperand(i: 0);
6729
6730 // Optionally, fold in a shift-right of the original operand and let permute
6731 // pick the two higher bytes of the original value directly.
6732 if (Op->getOpcode() == ISD::SRL && isa<ConstantSDNode>(Val: Op->getOperand(i: 1))) {
6733 if (cast<ConstantSDNode>(Val: Op->getOperand(i: 1))->getZExtValue() == 16) {
6734 // Shift the PRMT byte selector to pick upper bytes from each respective
6735 // value, instead of the lower ones: 0x10 -> 0x32, 0x54 -> 0x76
6736 assert((*OpBytes == 0x10 || *OpBytes == 0x54) &&
6737 "PRMT selector values out of range");
6738 *OpBytes += 0x22;
6739 *Op = Op->getOperand(i: 0);
6740 }
6741 }
6742 }
6743
6744 SDLoc DL(N);
6745 auto &DAG = DCI.DAG;
6746
6747 auto PRMT =
6748 getPRMT(A: DAG.getBitcast(VT: MVT::i32, V: Op0), B: DAG.getBitcast(VT: MVT::i32, V: Op1),
6749 Selector: (Op1Bytes << 8) | Op0Bytes, DL, DAG);
6750 return DAG.getBitcast(VT, V: PRMT);
6751}
6752
6753static SDValue combineADDRSPACECAST(SDNode *N,
6754 TargetLowering::DAGCombinerInfo &DCI) {
6755 auto *ASCN1 = cast<AddrSpaceCastSDNode>(Val: N);
6756
6757 if (auto *ASCN2 = dyn_cast<AddrSpaceCastSDNode>(Val: ASCN1->getOperand(Num: 0))) {
6758 assert(ASCN2->getDestAddressSpace() == ASCN1->getSrcAddressSpace());
6759
6760 // Fold asc[B -> A](asc[A -> B](x)) -> x
6761 if (ASCN1->getDestAddressSpace() == ASCN2->getSrcAddressSpace())
6762 return ASCN2->getOperand(Num: 0);
6763 }
6764
6765 return SDValue();
6766}
6767
6768// Given a constant selector value and a prmt mode, return the selector value
6769// normalized to the generic prmt mode. See the PTX ISA documentation for more
6770// details:
6771// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt
6772static APInt getPRMTSelector(const APInt &Selector, unsigned Mode) {
6773 assert(Selector.getBitWidth() == 32 && "PRMT must have i32 operands");
6774
6775 if (Mode == NVPTX::PTXPrmtMode::NONE)
6776 return Selector;
6777
6778 const unsigned V = Selector.trunc(width: 2).getZExtValue();
6779
6780 const auto GetSelector = [](unsigned S0, unsigned S1, unsigned S2,
6781 unsigned S3) {
6782 return APInt(32, S0 | (S1 << 4) | (S2 << 8) | (S3 << 12));
6783 };
6784
6785 switch (Mode) {
6786 case NVPTX::PTXPrmtMode::F4E:
6787 return GetSelector(V, V + 1, V + 2, V + 3);
6788 case NVPTX::PTXPrmtMode::B4E:
6789 return GetSelector(V, (V - 1) & 7, (V - 2) & 7, (V - 3) & 7);
6790 case NVPTX::PTXPrmtMode::RC8:
6791 return GetSelector(V, V, V, V);
6792 case NVPTX::PTXPrmtMode::ECL:
6793 return GetSelector(V, std::max(a: V, b: 1U), std::max(a: V, b: 2U), 3U);
6794 case NVPTX::PTXPrmtMode::ECR:
6795 return GetSelector(0, std::min(a: V, b: 1U), std::min(a: V, b: 2U), V);
6796 case NVPTX::PTXPrmtMode::RC16: {
6797 unsigned V1 = (V & 1) << 1;
6798 return GetSelector(V1, V1 + 1, V1, V1 + 1);
6799 }
6800 default:
6801 llvm_unreachable("Invalid PRMT mode");
6802 }
6803}
6804
6805static APInt computePRMT(APInt A, APInt B, APInt Selector, unsigned Mode) {
6806 assert(A.getBitWidth() == 32 && B.getBitWidth() == 32 &&
6807 Selector.getBitWidth() == 32 && "PRMT must have i32 operands");
6808 // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
6809 APInt BitField = B.concat(NewLSB: A);
6810 APInt SelectorVal = getPRMTSelector(Selector, Mode);
6811 APInt Result(32, 0);
6812 for (unsigned I : llvm::seq(Size: 4U)) {
6813 APInt Sel = SelectorVal.extractBits(numBits: 4, bitPosition: I * 4);
6814 unsigned Idx = Sel.getLoBits(numBits: 3).getZExtValue();
6815 unsigned Sign = Sel.getHiBits(numBits: 1).getZExtValue();
6816 APInt Byte = BitField.extractBits(numBits: 8, bitPosition: Idx * 8);
6817 if (Sign)
6818 Byte = Byte.ashr(ShiftAmt: 8);
6819 Result.insertBits(SubBits: Byte, bitPosition: I * 8);
6820 }
6821 return Result;
6822}
6823
6824static SDValue combinePRMT(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
6825 CodeGenOptLevel OptLevel) {
6826 if (OptLevel == CodeGenOptLevel::None)
6827 return SDValue();
6828
6829 // Constant fold PRMT
6830 if (isa<ConstantSDNode>(Val: N->getOperand(Num: 0)) &&
6831 isa<ConstantSDNode>(Val: N->getOperand(Num: 1)) &&
6832 isa<ConstantSDNode>(Val: N->getOperand(Num: 2)))
6833 return DCI.DAG.getConstant(Val: computePRMT(A: N->getConstantOperandAPInt(Num: 0),
6834 B: N->getConstantOperandAPInt(Num: 1),
6835 Selector: N->getConstantOperandAPInt(Num: 2),
6836 Mode: N->getConstantOperandVal(Num: 3)),
6837 DL: SDLoc(N), VT: N->getValueType(ResNo: 0));
6838 return SDValue();
6839}
6840
6841// During call lowering we wrap the return values in a ProxyReg node which
6842// depend on the chain value produced by the completed call. This ensures that
6843// the full call is emitted in cases where libcalls are used to legalize
6844// operations. To improve the functioning of other DAG combines we pull all
6845// operations we can through one of these nodes, ensuring that the ProxyReg
6846// directly wraps a load. That is:
6847//
6848// (ProxyReg (zext (load retval0))) => (zext (ProxyReg (load retval0)))
6849//
6850static SDValue sinkProxyReg(SDValue R, SDValue Chain,
6851 TargetLowering::DAGCombinerInfo &DCI) {
6852 switch (R.getOpcode()) {
6853 case ISD::TRUNCATE:
6854 case ISD::ANY_EXTEND:
6855 case ISD::SIGN_EXTEND:
6856 case ISD::ZERO_EXTEND:
6857 case ISD::BITCAST: {
6858 if (SDValue V = sinkProxyReg(R: R.getOperand(i: 0), Chain, DCI))
6859 return DCI.DAG.getNode(Opcode: R.getOpcode(), DL: SDLoc(R), VT: R.getValueType(), Operand: V);
6860 return SDValue();
6861 }
6862 case ISD::SHL:
6863 case ISD::SRL:
6864 case ISD::SRA:
6865 case ISD::OR: {
6866 if (SDValue A = sinkProxyReg(R: R.getOperand(i: 0), Chain, DCI))
6867 if (SDValue B = sinkProxyReg(R: R.getOperand(i: 1), Chain, DCI))
6868 return DCI.DAG.getNode(Opcode: R.getOpcode(), DL: SDLoc(R), VT: R.getValueType(), N1: A, N2: B);
6869 return SDValue();
6870 }
6871 case ISD::Constant:
6872 return R;
6873 case ISD::LOAD:
6874 case NVPTXISD::LoadV2:
6875 case NVPTXISD::LoadV4: {
6876 return DCI.DAG.getNode(Opcode: NVPTXISD::ProxyReg, DL: SDLoc(R), VT: R.getValueType(),
6877 Ops: {Chain, R});
6878 }
6879 case ISD::BUILD_VECTOR: {
6880 if (DCI.isBeforeLegalize())
6881 return SDValue();
6882
6883 SmallVector<SDValue, 16> Ops;
6884 for (auto &Op : R->ops()) {
6885 SDValue V = sinkProxyReg(R: Op, Chain, DCI);
6886 if (!V)
6887 return SDValue();
6888 Ops.push_back(Elt: V);
6889 }
6890 return DCI.DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL: SDLoc(R), VT: R.getValueType(), Ops);
6891 }
6892 case ISD::EXTRACT_VECTOR_ELT: {
6893 if (DCI.isBeforeLegalize())
6894 return SDValue();
6895
6896 if (SDValue V = sinkProxyReg(R: R.getOperand(i: 0), Chain, DCI))
6897 return DCI.DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: SDLoc(R),
6898 VT: R.getValueType(), N1: V, N2: R.getOperand(i: 1));
6899 return SDValue();
6900 }
6901 default:
6902 return SDValue();
6903 }
6904}
6905
6906static unsigned getF16SubOpc(Intrinsic::ID AddIntrinsicID) {
6907 switch (AddIntrinsicID) {
6908 default:
6909 break;
6910 case Intrinsic::nvvm_add_rn_sat_f16:
6911 case Intrinsic::nvvm_add_rn_sat_v2f16:
6912 return NVPTXISD::SUB_RN_SAT;
6913 case Intrinsic::nvvm_add_rn_ftz_sat_f16:
6914 case Intrinsic::nvvm_add_rn_ftz_sat_v2f16:
6915 return NVPTXISD::SUB_RN_FTZ_SAT;
6916 }
6917 llvm_unreachable("Invalid F16 add intrinsic");
6918}
6919
6920static SDValue combineF16AddWithNeg(SDNode *N, SelectionDAG &DAG,
6921 Intrinsic::ID AddIntrinsicID) {
6922 SDValue Op1 = N->getOperand(Num: 1);
6923 SDValue Op2 = N->getOperand(Num: 2);
6924
6925 SDValue SubOp1, SubOp2;
6926
6927 if (Op1.getOpcode() == ISD::FNEG) {
6928 SubOp1 = Op2;
6929 SubOp2 = Op1.getOperand(i: 0);
6930 } else if (Op2.getOpcode() == ISD::FNEG) {
6931 SubOp1 = Op1;
6932 SubOp2 = Op2.getOperand(i: 0);
6933 } else {
6934 return SDValue();
6935 }
6936
6937 SDLoc DL(N);
6938 return DAG.getNode(Opcode: getF16SubOpc(AddIntrinsicID), DL, VT: N->getValueType(ResNo: 0),
6939 N1: SubOp1, N2: SubOp2);
6940}
6941
6942static SDValue combineIntrinsicWOChain(SDNode *N,
6943 TargetLowering::DAGCombinerInfo &DCI,
6944 const NVPTXSubtarget &STI) {
6945 unsigned IID = N->getConstantOperandVal(Num: 0);
6946
6947 switch (IID) {
6948 default:
6949 break;
6950 case Intrinsic::nvvm_add_rn_sat_f16:
6951 case Intrinsic::nvvm_add_rn_ftz_sat_f16:
6952 case Intrinsic::nvvm_add_rn_sat_v2f16:
6953 case Intrinsic::nvvm_add_rn_ftz_sat_v2f16:
6954 return combineF16AddWithNeg(N, DAG&: DCI.DAG, AddIntrinsicID: IID);
6955 }
6956 return SDValue();
6957}
6958
6959static SDValue combineProxyReg(SDNode *N,
6960 TargetLowering::DAGCombinerInfo &DCI) {
6961
6962 SDValue Chain = N->getOperand(Num: 0);
6963 SDValue Reg = N->getOperand(Num: 1);
6964
6965 // If the ProxyReg is not wrapping a load, try to pull the operations through
6966 // the ProxyReg.
6967 if (Reg.getOpcode() != ISD::LOAD) {
6968 if (SDValue V = sinkProxyReg(R: Reg, Chain, DCI))
6969 return V;
6970 }
6971
6972 return SDValue();
6973}
6974
6975SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
6976 DAGCombinerInfo &DCI) const {
6977 CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
6978 switch (N->getOpcode()) {
6979 default:
6980 break;
6981 case ISD::ADD:
6982 return PerformADDCombine(N, DCI, OptLevel);
6983 case ISD::ADDRSPACECAST:
6984 return combineADDRSPACECAST(N, DCI);
6985 case ISD::SIGN_EXTEND:
6986 case ISD::ZERO_EXTEND:
6987 return combineMulWide(N, DCI, OptLevel);
6988 case ISD::BUILD_VECTOR:
6989 return PerformBUILD_VECTORCombine(N, DCI);
6990 case ISD::EXTRACT_VECTOR_ELT:
6991 return PerformEXTRACTCombine(N, DCI);
6992 case ISD::FADD:
6993 return PerformFADDCombine(N, DCI, OptLevel);
6994 case ISD::FMAXNUM:
6995 case ISD::FMINNUM:
6996 case ISD::FMAXIMUM:
6997 case ISD::FMINIMUM:
6998 case ISD::FMAXIMUMNUM:
6999 case ISD::FMINIMUMNUM:
7000 return PerformFMinMaxCombine(N, DCI, PTXVersion: STI.getPTXVersion(),
7001 SmVersion: STI.getSmVersion());
7002 case ISD::LOAD:
7003 case NVPTXISD::LoadV2:
7004 case NVPTXISD::LoadV4:
7005 return combineLOAD(N, DCI, STI);
7006 case ISD::MUL:
7007 return PerformMULCombine(N, DCI, OptLevel);
7008 case NVPTXISD::PRMT:
7009 return combinePRMT(N, DCI, OptLevel);
7010 case NVPTXISD::ProxyReg:
7011 return combineProxyReg(N, DCI);
7012 case ISD::SETCC:
7013 return PerformSETCCCombine(N, DCI, SmVersion: STI.getSmVersion());
7014 case ISD::SHL:
7015 return PerformSHLCombine(N, DCI, OptLevel);
7016 case ISD::SREM:
7017 case ISD::UREM:
7018 return PerformREMCombine(N, DCI, OptLevel);
7019 case ISD::STORE:
7020 case NVPTXISD::StoreV2:
7021 case NVPTXISD::StoreV4:
7022 return combineSTORE(N, DCI, STI);
7023 case ISD::SELECT:
7024 return PerformSELECTShiftCombine(N, DCI);
7025 case ISD::VSELECT:
7026 return PerformVSELECTCombine(N, DCI);
7027 case ISD::INTRINSIC_WO_CHAIN:
7028 return combineIntrinsicWOChain(N, DCI, STI);
7029 }
7030 return SDValue();
7031}
7032
7033static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG,
7034 SmallVectorImpl<SDValue> &Results) {
7035 // Handle bitcasting to v2i8 without hitting the default promotion
7036 // strategy which goes through stack memory.
7037 SDValue Op(Node, 0);
7038 EVT ToVT = Op->getValueType(ResNo: 0);
7039 if (ToVT != MVT::v2i8) {
7040 return;
7041 }
7042
7043 // Bitcast to i16 and unpack elements into a vector
7044 SDLoc DL(Node);
7045 SDValue AsInt = DAG.getBitcast(VT: MVT::i16, V: Op->getOperand(Num: 0));
7046 SDValue Vec0 = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MVT::i8, Operand: AsInt);
7047 SDValue Const8 = DAG.getConstant(Val: 8, DL, VT: MVT::i16);
7048 SDValue Vec1 =
7049 DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MVT::i8,
7050 Operand: DAG.getNode(Opcode: ISD::SRL, DL, VT: MVT::i16, Ops: {AsInt, Const8}));
7051 Results.push_back(
7052 Elt: DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: MVT::v2i8, Ops: {Vec0, Vec1}));
7053}
7054
7055static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
7056 SmallVectorImpl<SDValue> &Results) {
7057 SDValue Chain = N->getOperand(Num: 0);
7058 SDValue Intrin = N->getOperand(Num: 1);
7059 SDLoc DL(N);
7060
7061 // Get the intrinsic ID
7062 unsigned IntrinNo = Intrin.getNode()->getAsZExtVal();
7063 switch (IntrinNo) {
7064 default:
7065 return;
7066 case Intrinsic::nvvm_ldu_global_i:
7067 case Intrinsic::nvvm_ldu_global_f:
7068 case Intrinsic::nvvm_ldu_global_p: {
7069 EVT ResVT = N->getValueType(ResNo: 0);
7070
7071 if (ResVT.isVector()) {
7072 // Vector LDG/LDU
7073
7074 unsigned NumElts = ResVT.getVectorNumElements();
7075 EVT EltVT = ResVT.getVectorElementType();
7076
7077 // Since LDU/LDG are target nodes, we cannot rely on DAG type
7078 // legalization.
7079 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
7080 // loaded type to i16 and propagate the "real" type as the memory type.
7081 bool NeedTrunc = false;
7082 if (EltVT.getSizeInBits() < 16) {
7083 EltVT = MVT::i16;
7084 NeedTrunc = true;
7085 }
7086
7087 unsigned Opcode = 0;
7088 SDVTList LdResVTs;
7089
7090 switch (NumElts) {
7091 default:
7092 return;
7093 case 2:
7094 Opcode = NVPTXISD::LDUV2;
7095 LdResVTs = DAG.getVTList(VT1: EltVT, VT2: EltVT, VT3: MVT::Other);
7096 break;
7097 case 4: {
7098 Opcode = NVPTXISD::LDUV4;
7099 EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
7100 LdResVTs = DAG.getVTList(VTs: ListVTs);
7101 break;
7102 }
7103 }
7104
7105 SmallVector<SDValue, 8> OtherOps;
7106
7107 // Copy regular operands
7108
7109 OtherOps.push_back(Elt: Chain); // Chain
7110 // Skip operand 1 (intrinsic ID)
7111 // Others
7112 OtherOps.append(in_start: N->op_begin() + 2, in_end: N->op_end());
7113
7114 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
7115
7116 SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, dl: DL, VTList: LdResVTs, Ops: OtherOps,
7117 MemVT: MemSD->getMemoryVT(),
7118 MMO: MemSD->getMemOperand());
7119
7120 SmallVector<SDValue, 4> ScalarRes;
7121
7122 for (unsigned i = 0; i < NumElts; ++i) {
7123 SDValue Res = NewLD.getValue(R: i);
7124 if (NeedTrunc)
7125 Res =
7126 DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: ResVT.getVectorElementType(), Operand: Res);
7127 ScalarRes.push_back(Elt: Res);
7128 }
7129
7130 SDValue LoadChain = NewLD.getValue(R: NumElts);
7131
7132 SDValue BuildVec =
7133 DAG.getBuildVector(VT: ResVT, DL, Ops: ScalarRes);
7134
7135 Results.push_back(Elt: BuildVec);
7136 Results.push_back(Elt: LoadChain);
7137 } else {
7138 // i8 LDG/LDU
7139 assert(ResVT.isSimple() && ResVT.getSimpleVT().SimpleTy == MVT::i8 &&
7140 "Custom handling of non-i8 ldu/ldg?");
7141
7142 // Just copy all operands as-is
7143 SmallVector<SDValue, 4> Ops(N->ops());
7144
7145 // Force output to i16
7146 SDVTList LdResVTs = DAG.getVTList(VT1: MVT::i16, VT2: MVT::Other);
7147
7148 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
7149
7150 // We make sure the memory type is i8, which will be used during isel
7151 // to select the proper instruction.
7152 SDValue NewLD =
7153 DAG.getMemIntrinsicNode(Opcode: ISD::INTRINSIC_W_CHAIN, dl: DL, VTList: LdResVTs, Ops,
7154 MemVT: MVT::i8, MMO: MemSD->getMemOperand());
7155
7156 Results.push_back(Elt: DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MVT::i8,
7157 Operand: NewLD.getValue(R: 0)));
7158 Results.push_back(Elt: NewLD.getValue(R: 1));
7159 }
7160 return;
7161 }
7162
7163 case Intrinsic::nvvm_tcgen05_ld_16x64b_x4:
7164 case Intrinsic::nvvm_tcgen05_ld_16x64b_x8:
7165 case Intrinsic::nvvm_tcgen05_ld_16x64b_x16:
7166 case Intrinsic::nvvm_tcgen05_ld_16x64b_x32:
7167 case Intrinsic::nvvm_tcgen05_ld_16x64b_x64:
7168 case Intrinsic::nvvm_tcgen05_ld_16x64b_x128:
7169 case Intrinsic::nvvm_tcgen05_ld_32x32b_x4:
7170 case Intrinsic::nvvm_tcgen05_ld_32x32b_x8:
7171 case Intrinsic::nvvm_tcgen05_ld_32x32b_x16:
7172 case Intrinsic::nvvm_tcgen05_ld_32x32b_x32:
7173 case Intrinsic::nvvm_tcgen05_ld_32x32b_x64:
7174 case Intrinsic::nvvm_tcgen05_ld_32x32b_x128:
7175 case Intrinsic::nvvm_tcgen05_ld_16x128b_x2:
7176 case Intrinsic::nvvm_tcgen05_ld_16x128b_x4:
7177 case Intrinsic::nvvm_tcgen05_ld_16x128b_x8:
7178 case Intrinsic::nvvm_tcgen05_ld_16x128b_x16:
7179 case Intrinsic::nvvm_tcgen05_ld_16x128b_x32:
7180 case Intrinsic::nvvm_tcgen05_ld_16x128b_x64:
7181 case Intrinsic::nvvm_tcgen05_ld_16x256b_x1:
7182 case Intrinsic::nvvm_tcgen05_ld_16x256b_x2:
7183 case Intrinsic::nvvm_tcgen05_ld_16x256b_x4:
7184 case Intrinsic::nvvm_tcgen05_ld_16x256b_x8:
7185 case Intrinsic::nvvm_tcgen05_ld_16x256b_x16:
7186 case Intrinsic::nvvm_tcgen05_ld_16x256b_x32:
7187 if (auto Res = lowerTcgen05Ld(N, DAG)) {
7188 Results.push_back(Elt: Res->first);
7189 Results.push_back(Elt: Res->second);
7190 }
7191 return;
7192
7193 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x4:
7194 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x8:
7195 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x16:
7196 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x32:
7197 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x64:
7198 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x128:
7199 if (auto Res = lowerTcgen05Ld(N, DAG, /*HasOffset=*/true)) {
7200 Results.push_back(Elt: Res->first);
7201 Results.push_back(Elt: Res->second);
7202 }
7203 return;
7204
7205 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x8_i32:
7206 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x8_f32:
7207 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x64_i32:
7208 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x64_f32:
7209 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x4_i32:
7210 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x4_f32:
7211 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x32_i32:
7212 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x32_f32:
7213 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x16_i32:
7214 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x16_f32:
7215 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x128_i32:
7216 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x128_f32:
7217 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x8_i32:
7218 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x8_f32:
7219 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x64_i32:
7220 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x64_f32:
7221 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x4_i32:
7222 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x4_f32:
7223 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x32_i32:
7224 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x32_f32:
7225 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x16_i32:
7226 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x16_f32:
7227 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x128_i32:
7228 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x128_f32:
7229 if (auto Res = lowerTcgen05LdRed(N, DAG)) {
7230 Results.push_back(Elt: std::get<0>(t&: *Res));
7231 Results.push_back(Elt: std::get<1>(t&: *Res));
7232 Results.push_back(Elt: std::get<2>(t&: *Res));
7233 }
7234 return;
7235 }
7236}
7237
7238static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG,
7239 SmallVectorImpl<SDValue> &Results) {
7240 // Change the CopyFromReg to output 2 64-bit results instead of a 128-bit
7241 // result so that it can pass the legalization
7242 SDLoc DL(N);
7243 SDValue Chain = N->getOperand(Num: 0);
7244 SDValue Reg = N->getOperand(Num: 1);
7245 SDValue Glue = N->getOperand(Num: 2);
7246
7247 assert(Reg.getValueType() == MVT::i128 &&
7248 "Custom lowering for CopyFromReg with 128-bit reg only");
7249 SmallVector<EVT, 4> ResultsType = {MVT::i64, MVT::i64, N->getValueType(ResNo: 1),
7250 N->getValueType(ResNo: 2)};
7251 SmallVector<SDValue, 3> NewOps = {Chain, Reg, Glue};
7252
7253 SDValue NewValue = DAG.getNode(Opcode: ISD::CopyFromReg, DL, ResultTys: ResultsType, Ops: NewOps);
7254 SDValue Pair = DAG.getNode(Opcode: ISD::BUILD_PAIR, DL, VT: MVT::i128,
7255 Ops: {NewValue.getValue(R: 0), NewValue.getValue(R: 1)});
7256
7257 Results.push_back(Elt: Pair);
7258 Results.push_back(Elt: NewValue.getValue(R: 2));
7259 Results.push_back(Elt: NewValue.getValue(R: 3));
7260}
7261
7262static void replaceProxyReg(SDNode *N, SelectionDAG &DAG,
7263 const TargetLowering &TLI,
7264 SmallVectorImpl<SDValue> &Results) {
7265 SDValue Chain = N->getOperand(Num: 0);
7266 SDValue Reg = N->getOperand(Num: 1);
7267
7268 MVT VT = TLI.getRegisterType(Context&: *DAG.getContext(), VT: Reg.getValueType());
7269
7270 SDValue NewReg = DAG.getAnyExtOrTrunc(Op: Reg, DL: SDLoc(N), VT);
7271 SDValue NewProxy =
7272 DAG.getNode(Opcode: NVPTXISD::ProxyReg, DL: SDLoc(N), VT, Ops: {Chain, NewReg});
7273 SDValue Res = DAG.getAnyExtOrTrunc(Op: NewProxy, DL: SDLoc(N), VT: N->getValueType(ResNo: 0));
7274
7275 Results.push_back(Elt: Res);
7276}
7277
7278static void replaceAtomicSwap128(SDNode *N, SelectionDAG &DAG,
7279 const NVPTXSubtarget &STI,
7280 SmallVectorImpl<SDValue> &Results) {
7281 assert(N->getValueType(0) == MVT::i128 &&
7282 "Custom lowering for atomic128 only supports i128");
7283
7284 AtomicSDNode *AN = cast<AtomicSDNode>(Val: N);
7285 SDLoc dl(N);
7286
7287 if (!STI.hasAtomSwap128()) {
7288 DAG.getContext()->diagnose(DI: DiagnosticInfoUnsupported(
7289 DAG.getMachineFunction().getFunction(),
7290 "Support for b128 atomics introduced in PTX ISA version 8.3 and "
7291 "requires target sm_90.",
7292 dl.getDebugLoc()));
7293
7294 Results.push_back(Elt: DAG.getUNDEF(VT: MVT::i128));
7295 Results.push_back(Elt: AN->getOperand(Num: 0)); // Chain
7296 return;
7297 }
7298
7299 SmallVector<SDValue, 6> Ops;
7300 Ops.push_back(Elt: AN->getOperand(Num: 0)); // Chain
7301 Ops.push_back(Elt: AN->getOperand(Num: 1)); // Ptr
7302 for (const auto &Op : AN->ops().drop_front(N: 2)) {
7303 // Low part
7304 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::EXTRACT_ELEMENT, DL: dl, VT: MVT::i64, N1: Op,
7305 N2: DAG.getIntPtrConstant(Val: 0, DL: dl)));
7306 // High part
7307 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::EXTRACT_ELEMENT, DL: dl, VT: MVT::i64, N1: Op,
7308 N2: DAG.getIntPtrConstant(Val: 1, DL: dl)));
7309 }
7310 unsigned Opcode = N->getOpcode() == ISD::ATOMIC_SWAP
7311 ? NVPTXISD::ATOMIC_SWAP_B128
7312 : NVPTXISD::ATOMIC_CMP_SWAP_B128;
7313 SDVTList Tys = DAG.getVTList(VT1: MVT::i64, VT2: MVT::i64, VT3: MVT::Other);
7314 SDValue Result = DAG.getMemIntrinsicNode(Opcode, dl, VTList: Tys, Ops, MemVT: MVT::i128,
7315 MMO: AN->getMemOperand());
7316 Results.push_back(Elt: DAG.getNode(Opcode: ISD::BUILD_PAIR, DL: dl, VT: MVT::i128,
7317 Ops: {Result.getValue(R: 0), Result.getValue(R: 1)}));
7318 Results.push_back(Elt: Result.getValue(R: 2));
7319}
7320
7321void NVPTXTargetLowering::ReplaceNodeResults(
7322 SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
7323 switch (N->getOpcode()) {
7324 default:
7325 report_fatal_error(reason: "Unhandled custom legalization");
7326 case ISD::BITCAST:
7327 ReplaceBITCAST(Node: N, DAG, Results);
7328 return;
7329 case ISD::LOAD:
7330 case ISD::MLOAD:
7331 replaceLoadVector(N, DAG, Results, STI);
7332 return;
7333 case ISD::INTRINSIC_W_CHAIN:
7334 ReplaceINTRINSIC_W_CHAIN(N, DAG, Results);
7335 return;
7336 case ISD::CopyFromReg:
7337 ReplaceCopyFromReg_128(N, DAG, Results);
7338 return;
7339 case NVPTXISD::ProxyReg:
7340 replaceProxyReg(N, DAG, TLI: *this, Results);
7341 return;
7342 case ISD::ATOMIC_CMP_SWAP:
7343 case ISD::ATOMIC_SWAP:
7344 replaceAtomicSwap128(N, DAG, STI, Results);
7345 return;
7346 }
7347}
7348
7349NVPTXTargetLowering::AtomicExpansionKind
7350NVPTXTargetLowering::shouldExpandAtomicRMWInIR(const AtomicRMWInst *AI) const {
7351 Type *Ty = AI->getValOperand()->getType();
7352
7353 if (AI->isFloatingPointOperation()) {
7354 if (AI->getOperation() == AtomicRMWInst::BinOp::FAdd) {
7355 if (Ty->isHalfTy() && STI.getSmVersion() >= 70 &&
7356 STI.getPTXVersion() >= 63)
7357 return AtomicExpansionKind::None;
7358 if (Ty->isBFloatTy() && STI.getSmVersion() >= 90 &&
7359 STI.getPTXVersion() >= 78)
7360 return AtomicExpansionKind::None;
7361 if (Ty->isFloatTy())
7362 return AtomicExpansionKind::None;
7363 if (Ty->isDoubleTy() && STI.hasAtomAddF64())
7364 return AtomicExpansionKind::None;
7365 }
7366 return AtomicExpansionKind::CmpXChg;
7367 }
7368
7369 assert(Ty->isIntegerTy() && "Ty should be integer at this point");
7370 const unsigned BitWidth = cast<IntegerType>(Val: Ty)->getBitWidth();
7371
7372 switch (AI->getOperation()) {
7373 default:
7374 return AtomicExpansionKind::CmpXChg;
7375 case AtomicRMWInst::BinOp::Xchg:
7376 if (BitWidth == 128)
7377 return AtomicExpansionKind::None;
7378 [[fallthrough]];
7379 case AtomicRMWInst::BinOp::And:
7380 case AtomicRMWInst::BinOp::Or:
7381 case AtomicRMWInst::BinOp::Xor:
7382 switch (BitWidth) {
7383 case 8:
7384 case 16:
7385 return AtomicExpansionKind::CmpXChg;
7386 case 32:
7387 return AtomicExpansionKind::None;
7388 case 64:
7389 if (STI.hasAtomBitwise64())
7390 return AtomicExpansionKind::None;
7391 return AtomicExpansionKind::CmpXChg;
7392 case 128:
7393 return AtomicExpansionKind::CmpXChg;
7394 default:
7395 llvm_unreachable("unsupported width encountered");
7396 }
7397 case AtomicRMWInst::BinOp::Add:
7398 case AtomicRMWInst::BinOp::Sub:
7399 case AtomicRMWInst::BinOp::Max:
7400 case AtomicRMWInst::BinOp::Min:
7401 case AtomicRMWInst::BinOp::UMax:
7402 case AtomicRMWInst::BinOp::UMin:
7403 switch (BitWidth) {
7404 case 8:
7405 case 16:
7406 return AtomicExpansionKind::CmpXChg;
7407 case 32:
7408 return AtomicExpansionKind::None;
7409 case 64:
7410 if (STI.hasAtomMinMax64())
7411 return AtomicExpansionKind::None;
7412 return AtomicExpansionKind::CmpXChg;
7413 case 128:
7414 return AtomicExpansionKind::CmpXChg;
7415 default:
7416 llvm_unreachable("unsupported width encountered");
7417 }
7418 case AtomicRMWInst::BinOp::UIncWrap:
7419 case AtomicRMWInst::BinOp::UDecWrap:
7420 switch (BitWidth) {
7421 case 32:
7422 return AtomicExpansionKind::None;
7423 case 8:
7424 case 16:
7425 case 64:
7426 case 128:
7427 return AtomicExpansionKind::CmpXChg;
7428 default:
7429 llvm_unreachable("unsupported width encountered");
7430 }
7431 }
7432
7433 return AtomicExpansionKind::CmpXChg;
7434}
7435
7436bool NVPTXTargetLowering::shouldInsertFencesForAtomic(
7437 const Instruction *I) const {
7438 auto *CI = dyn_cast<AtomicCmpXchgInst>(Val: I);
7439 // When CAS bitwidth is not supported on the hardware, the CAS is emulated
7440 // using a retry loop that uses a higher-bitwidth monotonic CAS. We enforce
7441 // the memory order using explicit fences around the retry loop.
7442 // The memory order of natively supported CAS operations can be enforced
7443 // by lowering to an atom.cas with the right memory synchronizing effect.
7444 // However, atom.cas only supports relaxed, acquire, release and acq_rel.
7445 // So we also use explicit fences for enforcing memory order for
7446 // seq_cast CAS with natively-supported bitwidths.
7447 return CI &&
7448 (cast<IntegerType>(Val: CI->getCompareOperand()->getType())->getBitWidth() <
7449 STI.getMinCmpXchgSizeInBits() ||
7450 CI->getMergedOrdering() == AtomicOrdering::SequentiallyConsistent);
7451}
7452
7453AtomicOrdering NVPTXTargetLowering::atomicOperationOrderAfterFenceSplit(
7454 const Instruction *I) const {
7455 auto *CI = dyn_cast<AtomicCmpXchgInst>(Val: I);
7456 bool BitwidthSupportedAndIsSeqCst =
7457 CI && CI->getMergedOrdering() == AtomicOrdering::SequentiallyConsistent &&
7458 cast<IntegerType>(Val: CI->getCompareOperand()->getType())->getBitWidth() >=
7459 STI.getMinCmpXchgSizeInBits();
7460 return BitwidthSupportedAndIsSeqCst ? AtomicOrdering::Acquire
7461 : AtomicOrdering::Monotonic;
7462}
7463
7464Instruction *NVPTXTargetLowering::emitLeadingFence(IRBuilderBase &Builder,
7465 Instruction *Inst,
7466 AtomicOrdering Ord) const {
7467 if (!isa<AtomicCmpXchgInst>(Val: Inst))
7468 return TargetLoweringBase::emitLeadingFence(Builder, Inst, Ord);
7469
7470 // Specialize for cmpxchg
7471 // Emit a fence.sc leading fence for cmpxchg seq_cst which are not emulated
7472 SyncScope::ID SSID = cast<AtomicCmpXchgInst>(Val: Inst)->getSyncScopeID();
7473 if (isReleaseOrStronger(AO: Ord))
7474 return Builder.CreateFence(Ordering: Ord == AtomicOrdering::SequentiallyConsistent
7475 ? Ord
7476 : AtomicOrdering::Release,
7477 SSID);
7478
7479 return nullptr;
7480}
7481
7482Instruction *NVPTXTargetLowering::emitTrailingFence(IRBuilderBase &Builder,
7483 Instruction *Inst,
7484 AtomicOrdering Ord) const {
7485 // Specialize for cmpxchg
7486 if (!isa<AtomicCmpXchgInst>(Val: Inst))
7487 return TargetLoweringBase::emitTrailingFence(Builder, Inst, Ord);
7488
7489 auto *CI = cast<AtomicCmpXchgInst>(Val: Inst);
7490 auto CASWidth =
7491 cast<IntegerType>(Val: CI->getCompareOperand()->getType())->getBitWidth();
7492 SyncScope::ID SSID = CI->getSyncScopeID();
7493 // Do not emit a trailing fence for cmpxchg seq_cst which are not emulated
7494 if (isAcquireOrStronger(AO: Ord) &&
7495 (Ord != AtomicOrdering::SequentiallyConsistent ||
7496 CASWidth < STI.getMinCmpXchgSizeInBits()))
7497 return Builder.CreateFence(Ordering: AtomicOrdering::Acquire, SSID);
7498
7499 return nullptr;
7500}
7501
7502// Rather than default to SINT when both UINT and SINT are custom, we only
7503// change the opcode when UINT is not legal and SINT is. UINT is preferred when
7504// both are custom since unsigned CVT instructions can lead to slightly better
7505// SASS code with fewer instructions.
7506unsigned NVPTXTargetLowering::getPreferredFPToIntOpcode(unsigned Op, EVT FromVT,
7507 EVT ToVT) const {
7508 if (isOperationLegal(Op, VT: ToVT))
7509 return Op;
7510 switch (Op) {
7511 case ISD::FP_TO_UINT:
7512 if (isOperationLegal(Op: ISD::FP_TO_SINT, VT: ToVT))
7513 return ISD::FP_TO_SINT;
7514 break;
7515 case ISD::STRICT_FP_TO_UINT:
7516 if (isOperationLegal(Op: ISD::STRICT_FP_TO_SINT, VT: ToVT))
7517 return ISD::STRICT_FP_TO_SINT;
7518 break;
7519 case ISD::VP_FP_TO_UINT:
7520 if (isOperationLegal(Op: ISD::VP_FP_TO_SINT, VT: ToVT))
7521 return ISD::VP_FP_TO_SINT;
7522 break;
7523 default:
7524 break;
7525 }
7526 return Op;
7527}
7528
7529// Pin NVPTXTargetObjectFile's vtables to this file.
7530NVPTXTargetObjectFile::~NVPTXTargetObjectFile() = default;
7531
7532MCSection *NVPTXTargetObjectFile::SelectSectionForGlobal(
7533 const GlobalObject *GO, SectionKind Kind, const TargetMachine &TM) const {
7534 return getDataSection();
7535}
7536
7537static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
7538 const SelectionDAG &DAG, unsigned Depth) {
7539 SDValue A = Op.getOperand(i: 0);
7540 SDValue B = Op.getOperand(i: 1);
7541 ConstantSDNode *Selector = dyn_cast<ConstantSDNode>(Val: Op.getOperand(i: 2));
7542 unsigned Mode = Op.getConstantOperandVal(i: 3);
7543
7544 if (!Selector)
7545 return;
7546
7547 KnownBits AKnown = DAG.computeKnownBits(Op: A, Depth);
7548 KnownBits BKnown = DAG.computeKnownBits(Op: B, Depth);
7549
7550 // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
7551 assert(AKnown.getBitWidth() == 32 && BKnown.getBitWidth() == 32 &&
7552 "PRMT must have i32 operands");
7553 assert(Known.getBitWidth() == 32 && "PRMT must have i32 result");
7554 KnownBits BitField = BKnown.concat(Lo: AKnown);
7555
7556 APInt SelectorVal = getPRMTSelector(Selector: Selector->getAPIntValue(), Mode);
7557 for (unsigned I : llvm::seq(Size: 4)) {
7558 APInt Sel = SelectorVal.extractBits(numBits: 4, bitPosition: I * 4);
7559 unsigned Idx = Sel.getLoBits(numBits: 3).getZExtValue();
7560 unsigned Sign = Sel.getHiBits(numBits: 1).getZExtValue();
7561 KnownBits Byte = BitField.extractBits(NumBits: 8, BitPosition: Idx * 8);
7562 if (Sign)
7563 Byte = KnownBits::ashr(LHS: Byte, RHS: 8);
7564 Known.insertBits(SubBits: Byte, BitPosition: I * 8);
7565 }
7566}
7567
7568static void computeKnownBitsForLoadV(const SDValue Op, KnownBits &Known) {
7569 MemSDNode *LD = cast<MemSDNode>(Val: Op);
7570
7571 // We can't do anything without knowing the sign bit.
7572 auto ExtType = LD->getConstantOperandVal(Num: LD->getNumOperands() - 1);
7573 if (ExtType == ISD::SEXTLOAD)
7574 return;
7575
7576 // ExtLoading to vector types is weird and may not work well with known bits.
7577 auto DestVT = LD->getValueType(ResNo: 0);
7578 if (DestVT.isVector())
7579 return;
7580
7581 assert(Known.getBitWidth() == DestVT.getSizeInBits());
7582 auto ElementBitWidth = NVPTXDAGToDAGISel::getFromTypeWidthForLoad(Mem: LD);
7583 Known.Zero.setHighBits(Known.getBitWidth() - ElementBitWidth);
7584}
7585
7586void NVPTXTargetLowering::computeKnownBitsForTargetNode(
7587 const SDValue Op, KnownBits &Known, const APInt &DemandedElts,
7588 const SelectionDAG &DAG, unsigned Depth) const {
7589 Known.resetAll();
7590
7591 switch (Op.getOpcode()) {
7592 case NVPTXISD::PRMT:
7593 computeKnownBitsForPRMT(Op, Known, DAG, Depth);
7594 break;
7595 case NVPTXISD::LoadV2:
7596 case NVPTXISD::LoadV4:
7597 case NVPTXISD::LoadV8:
7598 computeKnownBitsForLoadV(Op, Known);
7599 break;
7600 default:
7601 break;
7602 }
7603}
7604
7605static std::pair<APInt, APInt> getPRMTDemandedBits(const APInt &SelectorVal,
7606 const APInt &DemandedBits) {
7607 APInt DemandedLHS = APInt(32, 0);
7608 APInt DemandedRHS = APInt(32, 0);
7609
7610 for (unsigned I : llvm::seq(Size: 4)) {
7611 if (DemandedBits.extractBits(numBits: 8, bitPosition: I * 8).isZero())
7612 continue;
7613
7614 APInt Sel = SelectorVal.extractBits(numBits: 4, bitPosition: I * 4);
7615 unsigned Idx = Sel.getLoBits(numBits: 3).getZExtValue();
7616 unsigned Sign = Sel.getHiBits(numBits: 1).getZExtValue();
7617
7618 APInt &Src = Idx < 4 ? DemandedLHS : DemandedRHS;
7619 unsigned ByteStart = (Idx % 4) * 8;
7620 if (Sign)
7621 Src.setBit(ByteStart + 7);
7622 else
7623 Src.setBits(loBit: ByteStart, hiBit: ByteStart + 8);
7624 }
7625
7626 return {DemandedLHS, DemandedRHS};
7627}
7628
7629// Replace undef with 0 as this is easier for other optimizations such as
7630// known bits.
7631static SDValue canonicalizePRMTInput(SDValue Op, SelectionDAG &DAG) {
7632 if (!Op)
7633 return SDValue();
7634 if (Op.isUndef())
7635 return DAG.getConstant(Val: 0, DL: SDLoc(), VT: MVT::i32);
7636 return Op;
7637}
7638
7639static SDValue simplifyDemandedBitsForPRMT(SDValue PRMT,
7640 const APInt &DemandedBits,
7641 SelectionDAG &DAG,
7642 const TargetLowering &TLI,
7643 unsigned Depth) {
7644 assert(PRMT.getOpcode() == NVPTXISD::PRMT);
7645 SDValue Op0 = PRMT.getOperand(i: 0);
7646 SDValue Op1 = PRMT.getOperand(i: 1);
7647 auto *SelectorConst = dyn_cast<ConstantSDNode>(Val: PRMT.getOperand(i: 2));
7648 if (!SelectorConst)
7649 return SDValue();
7650
7651 unsigned Mode = PRMT.getConstantOperandVal(i: 3);
7652 const APInt Selector = getPRMTSelector(Selector: SelectorConst->getAPIntValue(), Mode);
7653
7654 // Try to simplify the PRMT to one of the inputs if the used bytes are all
7655 // from the same input in the correct order.
7656 const unsigned LeadingBytes = DemandedBits.countLeadingZeros() / 8;
7657 const unsigned SelBits = (4 - LeadingBytes) * 4;
7658 if (Selector.getLoBits(numBits: SelBits) == APInt(32, 0x3210).getLoBits(numBits: SelBits))
7659 return Op0;
7660 if (Selector.getLoBits(numBits: SelBits) == APInt(32, 0x7654).getLoBits(numBits: SelBits))
7661 return Op1;
7662
7663 auto [DemandedLHS, DemandedRHS] = getPRMTDemandedBits(SelectorVal: Selector, DemandedBits);
7664
7665 // Attempt to avoid multi-use ops if we don't need anything from them.
7666 SDValue DemandedOp0 =
7667 TLI.SimplifyMultipleUseDemandedBits(Op: Op0, DemandedBits: DemandedLHS, DAG, Depth: Depth + 1);
7668 SDValue DemandedOp1 =
7669 TLI.SimplifyMultipleUseDemandedBits(Op: Op1, DemandedBits: DemandedRHS, DAG, Depth: Depth + 1);
7670
7671 DemandedOp0 = canonicalizePRMTInput(Op: DemandedOp0, DAG);
7672 DemandedOp1 = canonicalizePRMTInput(Op: DemandedOp1, DAG);
7673 if ((DemandedOp0 && DemandedOp0 != Op0) ||
7674 (DemandedOp1 && DemandedOp1 != Op1)) {
7675 Op0 = DemandedOp0 ? DemandedOp0 : Op0;
7676 Op1 = DemandedOp1 ? DemandedOp1 : Op1;
7677 return getPRMT(A: Op0, B: Op1, Selector: Selector.getZExtValue(), DL: SDLoc(PRMT), DAG);
7678 }
7679
7680 return SDValue();
7681}
7682
7683bool NVPTXTargetLowering::SimplifyDemandedBitsForTargetNode(
7684 SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
7685 KnownBits &Known, TargetLoweringOpt &TLO, unsigned Depth) const {
7686 Known.resetAll();
7687
7688 switch (Op.getOpcode()) {
7689 case NVPTXISD::PRMT:
7690 if (SDValue Result = simplifyDemandedBitsForPRMT(PRMT: Op, DemandedBits, DAG&: TLO.DAG,
7691 TLI: *this, Depth)) {
7692 TLO.CombineTo(O: Op, N: Result);
7693 return true;
7694 }
7695 break;
7696 default:
7697 break;
7698 }
7699
7700 computeKnownBitsForTargetNode(Op, Known, DemandedElts, DAG: TLO.DAG, Depth);
7701 return false;
7702}
7703