1//===-- NVPTXISelLowering.cpp - NVPTX DAG Lowering Implementation ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the interfaces that NVPTX uses to lower LLVM code into a
10// selection DAG.
11//
12//===----------------------------------------------------------------------===//
13
14#include "NVPTXISelLowering.h"
15#include "MCTargetDesc/NVPTXBaseInfo.h"
16#include "NVPTX.h"
17#include "NVPTXISelDAGToDAG.h"
18#include "NVPTXSelectionDAGInfo.h"
19#include "NVPTXSubtarget.h"
20#include "NVPTXTargetMachine.h"
21#include "NVPTXTargetObjectFile.h"
22#include "NVPTXUtilities.h"
23#include "llvm/ADT/APFloat.h"
24#include "llvm/ADT/APInt.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallVector.h"
27#include "llvm/ADT/StringRef.h"
28#include "llvm/CodeGen/Analysis.h"
29#include "llvm/CodeGen/ISDOpcodes.h"
30#include "llvm/CodeGen/MachineFunction.h"
31#include "llvm/CodeGen/MachineJumpTableInfo.h"
32#include "llvm/CodeGen/MachineMemOperand.h"
33#include "llvm/CodeGen/SDPatternMatch.h"
34#include "llvm/CodeGen/SelectionDAG.h"
35#include "llvm/CodeGen/SelectionDAGNodes.h"
36#include "llvm/CodeGen/TargetCallingConv.h"
37#include "llvm/CodeGen/TargetLowering.h"
38#include "llvm/CodeGen/ValueTypes.h"
39#include "llvm/CodeGenTypes/MachineValueType.h"
40#include "llvm/IR/Argument.h"
41#include "llvm/IR/Attributes.h"
42#include "llvm/IR/Constants.h"
43#include "llvm/IR/DataLayout.h"
44#include "llvm/IR/DerivedTypes.h"
45#include "llvm/IR/DiagnosticInfo.h"
46#include "llvm/IR/FPEnv.h"
47#include "llvm/IR/Function.h"
48#include "llvm/IR/GlobalValue.h"
49#include "llvm/IR/IRBuilder.h"
50#include "llvm/IR/Instruction.h"
51#include "llvm/IR/Instructions.h"
52#include "llvm/IR/IntrinsicsNVPTX.h"
53#include "llvm/IR/Module.h"
54#include "llvm/IR/Type.h"
55#include "llvm/IR/Value.h"
56#include "llvm/Support/Alignment.h"
57#include "llvm/Support/AtomicOrdering.h"
58#include "llvm/Support/Casting.h"
59#include "llvm/Support/CodeGen.h"
60#include "llvm/Support/CommandLine.h"
61#include "llvm/Support/ErrorHandling.h"
62#include "llvm/Support/KnownBits.h"
63#include "llvm/Support/NVPTXAddrSpace.h"
64#include "llvm/Support/raw_ostream.h"
65#include "llvm/Target/TargetMachine.h"
66#include "llvm/Target/TargetOptions.h"
67#include <algorithm>
68#include <cassert>
69#include <cmath>
70#include <cstdint>
71#include <iterator>
72#include <optional>
73#include <string>
74#include <tuple>
75#include <utility>
76#include <vector>
77
78#define DEBUG_TYPE "nvptx-lower"
79
80using namespace llvm;
81
82static cl::opt<bool> sched4reg(
83 "nvptx-sched4reg",
84 cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(Val: false));
85
86static cl::opt<unsigned> FMAContractLevelOpt(
87 "nvptx-fma-level", cl::Hidden,
88 cl::desc("NVPTX Specific: FMA contraction (0: don't do it"
89 " 1: do it 2: do it aggressively"),
90 cl::init(Val: 2));
91
92static cl::opt<NVPTX::DivPrecisionLevel> UsePrecDivF32(
93 "nvptx-prec-divf32", cl::Hidden,
94 cl::desc(
95 "NVPTX Specific: Override the precision of the lowering for f32 fdiv"),
96 cl::values(
97 clEnumValN(NVPTX::DivPrecisionLevel::Approx, "0", "Use div.approx"),
98 clEnumValN(NVPTX::DivPrecisionLevel::Full, "1", "Use div.full"),
99 clEnumValN(NVPTX::DivPrecisionLevel::IEEE754, "2",
100 "Use IEEE Compliant F32 div.rnd if available (default)"),
101 clEnumValN(NVPTX::DivPrecisionLevel::IEEE754_NoFTZ, "3",
102 "Use IEEE Compliant F32 div.rnd if available, no FTZ")),
103 cl::init(Val: NVPTX::DivPrecisionLevel::IEEE754));
104
105static cl::opt<bool> UsePrecSqrtF32(
106 "nvptx-prec-sqrtf32", cl::Hidden,
107 cl::desc("NVPTX Specific: 0 use sqrt.approx, 1 use sqrt.rn."),
108 cl::init(Val: true));
109
110/// Whereas CUDA's implementation (see libdevice) uses ex2.approx for exp2(), it
111/// does NOT use lg2.approx for log2, so this is disabled by default.
112static cl::opt<bool> UseApproxLog2F32(
113 "nvptx-approx-log2f32",
114 cl::desc("NVPTX Specific: whether to use lg2.approx for log2"),
115 cl::init(Val: false));
116
117static cl::opt<bool> ForceMinByValParamAlign(
118 "nvptx-force-min-byval-param-align", cl::Hidden,
119 cl::desc("NVPTX Specific: force 4-byte minimal alignment for byval"
120 " params of device functions."),
121 cl::init(Val: false));
122
123NVPTX::DivPrecisionLevel
124NVPTXTargetLowering::getDivF32Level(const MachineFunction &MF,
125 const SDNode &N) const {
126 // If nvptx-prec-div32=N is used on the command-line, always honor it
127 if (UsePrecDivF32.getNumOccurrences() > 0)
128 return UsePrecDivF32;
129
130 const SDNodeFlags Flags = N.getFlags();
131 if (Flags.hasApproximateFuncs())
132 return NVPTX::DivPrecisionLevel::Approx;
133
134 return NVPTX::DivPrecisionLevel::IEEE754;
135}
136
137bool NVPTXTargetLowering::usePrecSqrtF32(const SDNode *N) const {
138 // If nvptx-prec-sqrtf32 is used on the command-line, always honor it
139 if (UsePrecSqrtF32.getNumOccurrences() > 0)
140 return UsePrecSqrtF32;
141
142 if (N) {
143 const SDNodeFlags Flags = N->getFlags();
144 if (Flags.hasApproximateFuncs())
145 return false;
146 }
147
148 return true;
149}
150
151bool NVPTXTargetLowering::useF32FTZ(const MachineFunction &MF) const {
152 return MF.getDenormalMode(FPType: APFloat::IEEEsingle()).Output ==
153 DenormalMode::PreserveSign;
154}
155
156static bool IsPTXVectorType(MVT VT) {
157 switch (VT.SimpleTy) {
158 default:
159 return false;
160 case MVT::v2i1:
161 case MVT::v4i1:
162 case MVT::v2i8:
163 case MVT::v4i8:
164 case MVT::v8i8: // <2 x i8x4>
165 case MVT::v16i8: // <4 x i8x4>
166 case MVT::v2i16:
167 case MVT::v4i16:
168 case MVT::v8i16: // <4 x i16x2>
169 case MVT::v2i32:
170 case MVT::v4i32:
171 case MVT::v2i64:
172 case MVT::v2f16:
173 case MVT::v4f16:
174 case MVT::v8f16: // <4 x f16x2>
175 case MVT::v2bf16:
176 case MVT::v4bf16:
177 case MVT::v8bf16: // <4 x bf16x2>
178 case MVT::v2f32:
179 case MVT::v4f32:
180 case MVT::v2f64:
181 case MVT::v4i64:
182 case MVT::v4f64:
183 case MVT::v8i32:
184 case MVT::v8f32:
185 case MVT::v16f16: // <8 x f16x2>
186 case MVT::v16bf16: // <8 x bf16x2>
187 case MVT::v16i16: // <8 x i16x2>
188 case MVT::v32i8: // <8 x i8x4>
189 return true;
190 }
191}
192
193// When legalizing vector loads/stores, this function is called, which does two
194// things:
195// 1. Determines Whether the vector is something we want to custom lower,
196// std::nullopt is returned if we do not want to custom lower it.
197// 2. If we do want to handle it, returns two parameters:
198// - unsigned int NumElts - The number of elements in the final vector
199// - EVT EltVT - The type of the elements in the final vector
200static std::optional<std::pair<unsigned int, MVT>>
201getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI,
202 unsigned AddressSpace) {
203 const bool CanLowerTo256Bit = STI.has256BitVectorLoadStore(AS: AddressSpace);
204
205 if (CanLowerTo256Bit && VectorEVT.isScalarInteger() &&
206 VectorEVT.getSizeInBits() == 256)
207 return {{4, MVT::i64}};
208
209 if (!VectorEVT.isSimple())
210 return std::nullopt;
211 const MVT VectorVT = VectorEVT.getSimpleVT();
212
213 if (!VectorVT.isVector()) {
214 if (VectorVT == MVT::i128 || VectorVT == MVT::f128)
215 return {{2, MVT::i64}};
216 return std::nullopt;
217 }
218
219 const MVT EltVT = VectorVT.getVectorElementType();
220 const unsigned NumElts = VectorVT.getVectorNumElements();
221
222 // The size of the PTX virtual register that holds a packed type.
223 unsigned PackRegSize;
224
225 // We only handle "native" vector sizes for now, e.g. <4 x double> is not
226 // legal. We can (and should) split that into 2 stores of <2 x double> here
227 // but I'm leaving that as a TODO for now.
228 switch (VectorVT.SimpleTy) {
229 default:
230 return std::nullopt;
231
232 case MVT::v4i64:
233 case MVT::v4f64:
234 // This is a "native" vector type iff the address space is global and the
235 // target supports 256-bit loads/stores
236 if (!CanLowerTo256Bit)
237 return std::nullopt;
238 [[fallthrough]];
239 case MVT::v2i8:
240 case MVT::v2i64:
241 case MVT::v2f64:
242 // This is a "native" vector type
243 return std::pair(NumElts, EltVT);
244
245 case MVT::v16f16: // <8 x f16x2>
246 case MVT::v16bf16: // <8 x bf16x2>
247 case MVT::v16i16: // <8 x i16x2>
248 case MVT::v32i8: // <8 x i8x4>
249 // This can be upsized into a "native" vector type iff the address space is
250 // global and the target supports 256-bit loads/stores.
251 if (!CanLowerTo256Bit)
252 return std::nullopt;
253 [[fallthrough]];
254 case MVT::v2i16: // <1 x i16x2>
255 case MVT::v2f16: // <1 x f16x2>
256 case MVT::v2bf16: // <1 x bf16x2>
257 case MVT::v4i8: // <1 x i8x4>
258 case MVT::v4i16: // <2 x i16x2>
259 case MVT::v4f16: // <2 x f16x2>
260 case MVT::v4bf16: // <2 x bf16x2>
261 case MVT::v8i8: // <2 x i8x4>
262 case MVT::v8f16: // <4 x f16x2>
263 case MVT::v8bf16: // <4 x bf16x2>
264 case MVT::v8i16: // <4 x i16x2>
265 case MVT::v16i8: // <4 x i8x4>
266 PackRegSize = 32;
267 break;
268
269 case MVT::v8f32: // <4 x f32x2>
270 case MVT::v8i32: // <4 x i32x2>
271 // This is a "native" vector type iff the address space is global and the
272 // target supports 256-bit loads/stores
273 if (!CanLowerTo256Bit)
274 return std::nullopt;
275 [[fallthrough]];
276 case MVT::v2f32: // <1 x f32x2>
277 case MVT::v4f32: // <2 x f32x2>
278 case MVT::v2i32: // <1 x i32x2>
279 case MVT::v4i32: // <2 x i32x2>
280 if (!STI.hasF32x2Instructions())
281 return std::pair(NumElts, EltVT);
282 PackRegSize = 64;
283 break;
284 }
285
286 // If we reach here, then we can pack 2 or more elements into a single 32-bit
287 // or 64-bit PTX register and treat the vector as a new vector containing
288 // packed elements.
289
290 // Number of elements to pack in one word.
291 const unsigned NPerReg = PackRegSize / EltVT.getSizeInBits();
292
293 return std::pair(NumElts / NPerReg, MVT::getVectorVT(VT: EltVT, NumElements: NPerReg));
294}
295
296/// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
297/// legal-ish MVTs that compose it. Unlike ComputeValueVTs, this will legalize
298/// the types as required by the calling convention (with special handling for
299/// i8s).
300/// NOTE: This is a band-aid for code that expects ComputeValueVTs to return the
301/// same number of types as the Ins/Outs arrays in LowerFormalArguments,
302/// LowerCall, and LowerReturn.
303static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
304 LLVMContext &Ctx, CallingConv::ID CallConv,
305 Type *Ty, SmallVectorImpl<EVT> &ValueVTs,
306 SmallVectorImpl<uint64_t> &Offsets,
307 uint64_t StartingOffset = 0) {
308 SmallVector<EVT, 16> TempVTs;
309 SmallVector<uint64_t, 16> TempOffsets;
310 ComputeValueVTs(TLI, DL, Ty, ValueVTs&: TempVTs, /*MemVTs=*/nullptr, FixedOffsets: &TempOffsets,
311 StartingOffset);
312
313 for (const auto [VT, Off] : zip(t&: TempVTs, u&: TempOffsets)) {
314 MVT RegisterVT = TLI.getRegisterTypeForCallingConv(Context&: Ctx, CC: CallConv, VT);
315 unsigned NumRegs = TLI.getNumRegistersForCallingConv(Context&: Ctx, CC: CallConv, VT);
316
317 // Since we actually can load/store b8, we need to ensure that we'll use
318 // the original sized type for any i8s or i8 vectors.
319 if (VT.getScalarType() == MVT::i8) {
320 if (RegisterVT == MVT::i16)
321 RegisterVT = MVT::i8;
322 else if (RegisterVT == MVT::v2i16)
323 RegisterVT = MVT::v2i8;
324 else
325 assert(RegisterVT == MVT::v4i8 &&
326 "Expected v4i8, v2i16, or i16 for i8 RegisterVT");
327 }
328
329 // TODO: This is horribly incorrect for cases where the vector elements are
330 // not a multiple of bytes (ex i1) and legal or i8. However, this problem
331 // has existed for as long as NVPTX has and no one has complained, so we'll
332 // leave it for now.
333 for (unsigned I : seq(Size: NumRegs)) {
334 ValueVTs.push_back(Elt: RegisterVT);
335 Offsets.push_back(Elt: Off + I * RegisterVT.getStoreSize());
336 }
337 }
338}
339
340// We return an EVT that can hold N VTs
341// If the VT is a vector, the resulting EVT is a flat vector with the same
342// element type as VT's element type.
343static EVT getVectorizedVT(EVT VT, unsigned N, LLVMContext &C) {
344 if (N == 1)
345 return VT;
346
347 return VT.isVector() ? EVT::getVectorVT(Context&: C, VT: VT.getScalarType(),
348 NumElements: VT.getVectorNumElements() * N)
349 : EVT::getVectorVT(Context&: C, VT, NumElements: N);
350}
351
352static SDValue getExtractVectorizedValue(SDValue V, unsigned I, EVT VT,
353 const SDLoc &dl, SelectionDAG &DAG) {
354 if (V.getValueType() == VT) {
355 assert(I == 0 && "Index must be 0 for scalar value");
356 return V;
357 }
358
359 if (!VT.isVector())
360 return DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: dl, VT, N1: V,
361 N2: DAG.getVectorIdxConstant(Val: I, DL: dl));
362
363 return DAG.getNode(
364 Opcode: ISD::EXTRACT_SUBVECTOR, DL: dl, VT, N1: V,
365 N2: DAG.getVectorIdxConstant(Val: I * VT.getVectorNumElements(), DL: dl));
366}
367
368template <typename T>
369static inline SDValue getBuildVectorizedValue(unsigned N, const SDLoc &dl,
370 SelectionDAG &DAG, T GetElement) {
371 if (N == 1)
372 return GetElement(0);
373
374 SmallVector<SDValue, 8> Values;
375 for (const unsigned I : llvm::seq(Size: N)) {
376 SDValue Val = GetElement(I);
377 if (Val.getValueType().isVector())
378 DAG.ExtractVectorElements(Op: Val, Args&: Values);
379 else
380 Values.push_back(Elt: Val);
381 }
382
383 EVT VT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: Values[0].getValueType(),
384 NumElements: Values.size());
385 return DAG.getBuildVector(VT, DL: dl, Ops: Values);
386}
387
388/// PromoteScalarIntegerPTX
389/// Used to make sure the arguments/returns are suitable for passing
390/// and promote them to a larger size if they're not.
391///
392/// The promoted type is placed in \p PromoteVT if the function returns true.
393static EVT promoteScalarIntegerPTX(const EVT VT) {
394 if (VT.isScalarInteger()) {
395 switch (PowerOf2Ceil(A: VT.getFixedSizeInBits())) {
396 default:
397 llvm_unreachable(
398 "Promotion is not suitable for scalars of size larger than 64-bits");
399 case 1:
400 return MVT::i1;
401 case 2:
402 case 4:
403 case 8:
404 return MVT::i8;
405 case 16:
406 return MVT::i16;
407 case 32:
408 return MVT::i32;
409 case 64:
410 return MVT::i64;
411 }
412 }
413 return VT;
414}
415
416// Check whether we can merge loads/stores of some of the pieces of a
417// flattened function parameter or return value into a single vector
418// load/store.
419//
420// The flattened parameter is represented as a list of EVTs and
421// offsets, and the whole structure is aligned to ParamAlignment. This
422// function determines whether we can load/store pieces of the
423// parameter starting at index Idx using a single vectorized op of
424// size AccessSize. If so, it returns the number of param pieces
425// covered by the vector op. Otherwise, it returns 1.
426template <typename T>
427static unsigned canMergeParamLoadStoresStartingAt(
428 unsigned Idx, uint32_t AccessSize, const SmallVectorImpl<EVT> &ValueVTs,
429 const SmallVectorImpl<T> &Offsets, Align ParamAlignment) {
430
431 // Can't vectorize if param alignment is not sufficient.
432 if (ParamAlignment < AccessSize)
433 return 1;
434 // Can't vectorize if offset is not aligned.
435 if (Offsets[Idx] & (AccessSize - 1))
436 return 1;
437
438 EVT EltVT = ValueVTs[Idx];
439 unsigned EltSize = EltVT.getStoreSize();
440
441 // Element is too large to vectorize.
442 if (EltSize >= AccessSize)
443 return 1;
444
445 unsigned NumElts = AccessSize / EltSize;
446 // Can't vectorize if AccessBytes if not a multiple of EltSize.
447 if (AccessSize != EltSize * NumElts)
448 return 1;
449
450 // We don't have enough elements to vectorize.
451 if (Idx + NumElts > ValueVTs.size())
452 return 1;
453
454 // PTX ISA can only deal with 2- and 4-element vector ops.
455 if (NumElts != 4 && NumElts != 2)
456 return 1;
457
458 for (unsigned j = Idx + 1; j < Idx + NumElts; ++j) {
459 // Types do not match.
460 if (ValueVTs[j] != EltVT)
461 return 1;
462
463 // Elements are not contiguous.
464 if (Offsets[j] - Offsets[j - 1] != EltSize)
465 return 1;
466 }
467 // OK. We can vectorize ValueVTs[i..i+NumElts)
468 return NumElts;
469}
470
471// Computes whether and how we can vectorize the loads/stores of a
472// flattened function parameter or return value.
473//
474// The flattened parameter is represented as the list of ValueVTs and
475// Offsets, and is aligned to ParamAlignment bytes. We return a vector
476// of the same size as ValueVTs indicating how each piece should be
477// loaded/stored (i.e. as a scalar, or as part of a vector
478// load/store).
479template <typename T>
480static SmallVector<unsigned, 16>
481VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
482 const SmallVectorImpl<T> &Offsets, Align ParamAlignment,
483 bool IsVAArg = false) {
484 // Set vector size to match ValueVTs and mark all elements as
485 // scalars by default.
486
487 if (IsVAArg)
488 return SmallVector<unsigned>(ValueVTs.size(), 1);
489
490 SmallVector<unsigned, 16> VectorInfo;
491
492 const auto GetNumElts = [&](unsigned I) -> unsigned {
493 for (const unsigned AccessSize : {16, 8, 4, 2}) {
494 const unsigned NumElts = canMergeParamLoadStoresStartingAt(
495 I, AccessSize, ValueVTs, Offsets, ParamAlignment);
496 assert((NumElts == 1 || NumElts == 2 || NumElts == 4) &&
497 "Unexpected vectorization size");
498 if (NumElts != 1)
499 return NumElts;
500 }
501 return 1;
502 };
503
504 // Check what we can vectorize using 128/64/32-bit accesses.
505 for (unsigned I = 0, E = ValueVTs.size(); I != E;) {
506 const unsigned NumElts = GetNumElts(I);
507 VectorInfo.push_back(Elt: NumElts);
508 I += NumElts;
509 }
510 assert(std::accumulate(VectorInfo.begin(), VectorInfo.end(), 0u) ==
511 ValueVTs.size());
512 return VectorInfo;
513}
514
515// NVPTXTargetLowering Constructor.
516NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
517 const NVPTXSubtarget &STI)
518 : TargetLowering(TM, STI), nvTM(&TM), STI(STI), GlobalUniqueCallSite(0) {
519 // always lower memset, memcpy, and memmove intrinsics to load/store
520 // instructions, rather
521 // then generating calls to memset, mempcy or memmove.
522 MaxStoresPerMemset = MaxStoresPerMemsetOptSize = (unsigned)0xFFFFFFFF;
523 MaxStoresPerMemcpy = MaxStoresPerMemcpyOptSize = (unsigned) 0xFFFFFFFF;
524 MaxStoresPerMemmove = MaxStoresPerMemmoveOptSize = (unsigned) 0xFFFFFFFF;
525
526 setBooleanContents(ZeroOrNegativeOneBooleanContent);
527 setBooleanVectorContents(ZeroOrNegativeOneBooleanContent);
528
529 // Jump is Expensive. Don't create extra control flow for 'and', 'or'
530 // condition branches.
531 setJumpIsExpensive(true);
532
533 // Wide divides are _very_ slow. Try to reduce the width of the divide if
534 // possible.
535 addBypassSlowDiv(SlowBitWidth: 64, FastBitWidth: 32);
536
537 // By default, use the Source scheduling
538 if (sched4reg)
539 setSchedulingPreference(Sched::RegPressure);
540 else
541 setSchedulingPreference(Sched::Source);
542
543 auto setFP16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
544 LegalizeAction NoF16Action) {
545 bool IsOpSupported = STI.allowFP16Math();
546 switch (Op) {
547 // Several FP16 instructions are available on sm_80 only.
548 case ISD::FMINNUM:
549 case ISD::FMAXNUM:
550 case ISD::FMAXNUM_IEEE:
551 case ISD::FMINNUM_IEEE:
552 case ISD::FMAXIMUM:
553 case ISD::FMINIMUM:
554 case ISD::FMAXIMUMNUM:
555 case ISD::FMINIMUMNUM:
556 IsOpSupported &= STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
557 break;
558 case ISD::FEXP2:
559 IsOpSupported &= STI.getSmVersion() >= 75 && STI.getPTXVersion() >= 70;
560 break;
561 }
562 setOperationAction(Op, VT, Action: IsOpSupported ? Action : NoF16Action);
563 };
564
565 auto setBF16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
566 LegalizeAction NoBF16Action) {
567 bool IsOpSupported = STI.hasNativeBF16Support(Opcode: Op);
568 setOperationAction(
569 Op, VT, Action: IsOpSupported ? Action : NoBF16Action);
570 };
571
572 auto setI16x2OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
573 LegalizeAction NoI16x2Action) {
574 bool IsOpSupported = false;
575 // instructions are available on sm_90 only
576 switch (Op) {
577 case ISD::ADD:
578 case ISD::SMAX:
579 case ISD::SMIN:
580 case ISD::UMIN:
581 case ISD::UMAX:
582 IsOpSupported = STI.getSmVersion() >= 90 && STI.getPTXVersion() >= 80;
583 break;
584 }
585 setOperationAction(Op, VT, Action: IsOpSupported ? Action : NoI16x2Action);
586 };
587
588 addRegisterClass(VT: MVT::i1, RC: &NVPTX::B1RegClass);
589 addRegisterClass(VT: MVT::i16, RC: &NVPTX::B16RegClass);
590 addRegisterClass(VT: MVT::v2i16, RC: &NVPTX::B32RegClass);
591 addRegisterClass(VT: MVT::v4i8, RC: &NVPTX::B32RegClass);
592 addRegisterClass(VT: MVT::i32, RC: &NVPTX::B32RegClass);
593 addRegisterClass(VT: MVT::i64, RC: &NVPTX::B64RegClass);
594 addRegisterClass(VT: MVT::f32, RC: &NVPTX::B32RegClass);
595 addRegisterClass(VT: MVT::f64, RC: &NVPTX::B64RegClass);
596 addRegisterClass(VT: MVT::f16, RC: &NVPTX::B16RegClass);
597 addRegisterClass(VT: MVT::v2f16, RC: &NVPTX::B32RegClass);
598 addRegisterClass(VT: MVT::bf16, RC: &NVPTX::B16RegClass);
599 addRegisterClass(VT: MVT::v2bf16, RC: &NVPTX::B32RegClass);
600
601 if (STI.hasF32x2Instructions()) {
602 addRegisterClass(VT: MVT::v2f32, RC: &NVPTX::B64RegClass);
603 addRegisterClass(VT: MVT::v2i32, RC: &NVPTX::B64RegClass);
604 }
605
606 // Conversion to/from FP16/FP16x2 is always legal.
607 setOperationAction(Op: ISD::BUILD_VECTOR, VT: MVT::v2f16, Action: Custom);
608 setOperationAction(Op: ISD::EXTRACT_VECTOR_ELT, VT: MVT::v2f16, Action: Custom);
609 setOperationAction(Op: ISD::INSERT_VECTOR_ELT, VT: MVT::v2f16, Action: Expand);
610 setOperationAction(Op: ISD::VECTOR_SHUFFLE, VT: MVT::v2f16, Action: Expand);
611
612 setOperationAction(Op: ISD::READCYCLECOUNTER, VT: MVT::i64, Action: Legal);
613 if (STI.getSmVersion() >= 30 && STI.getPTXVersion() > 31)
614 setOperationAction(Op: ISD::READSTEADYCOUNTER, VT: MVT::i64, Action: Legal);
615
616 setFP16OperationAction(ISD::SETCC, MVT::f16, Legal, Promote);
617 setFP16OperationAction(ISD::SETCC, MVT::v2f16, Legal, Expand);
618
619 // Conversion to/from BFP16/BFP16x2 is always legal.
620 setOperationAction(Op: ISD::BUILD_VECTOR, VT: MVT::v2bf16, Action: Custom);
621 setOperationAction(Op: ISD::EXTRACT_VECTOR_ELT, VT: MVT::v2bf16, Action: Custom);
622 setOperationAction(Op: ISD::INSERT_VECTOR_ELT, VT: MVT::v2bf16, Action: Expand);
623 setOperationAction(Op: ISD::VECTOR_SHUFFLE, VT: MVT::v2bf16, Action: Expand);
624
625 setBF16OperationAction(ISD::SETCC, MVT::v2bf16, Legal, Expand);
626 setBF16OperationAction(ISD::SETCC, MVT::bf16, Legal, Promote);
627 if (getOperationAction(Op: ISD::SETCC, VT: MVT::bf16) == Promote)
628 AddPromotedToType(Opc: ISD::SETCC, OrigVT: MVT::bf16, DestVT: MVT::f32);
629
630 // Conversion to/from i16/i16x2 is always legal.
631 setOperationAction(Op: ISD::BUILD_VECTOR, VT: MVT::v2i16, Action: Custom);
632 setOperationAction(Op: ISD::EXTRACT_VECTOR_ELT, VT: MVT::v2i16, Action: Custom);
633 setOperationAction(Op: ISD::INSERT_VECTOR_ELT, VT: MVT::v2i16, Action: Expand);
634 setOperationAction(Op: ISD::VECTOR_SHUFFLE, VT: MVT::v2i16, Action: Expand);
635
636 setOperationAction(Op: ISD::BUILD_VECTOR, VT: MVT::v4i8, Action: Custom);
637 setOperationAction(Op: ISD::EXTRACT_VECTOR_ELT, VT: MVT::v4i8, Action: Custom);
638 setOperationAction(Op: ISD::INSERT_VECTOR_ELT, VT: MVT::v4i8, Action: Custom);
639 setOperationAction(Op: ISD::VECTOR_SHUFFLE, VT: MVT::v4i8, Action: Custom);
640
641 // No support for these operations with v2f32/v2i32
642 setOperationAction(Ops: ISD::INSERT_VECTOR_ELT, VTs: {MVT::v2f32, MVT::v2i32}, Action: Expand);
643 setOperationAction(Ops: ISD::VECTOR_SHUFFLE, VTs: {MVT::v2f32, MVT::v2i32}, Action: Expand);
644
645 setOperationAction(Op: ISD::TRUNCATE, VT: MVT::v2i16, Action: Expand);
646 setOperationAction(Ops: {ISD::ANY_EXTEND, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND},
647 VT: MVT::v2i32, Action: Expand);
648
649 // Need custom lowering in case the index is dynamic.
650 if (STI.hasF32x2Instructions())
651 setOperationAction(Ops: ISD::EXTRACT_VECTOR_ELT, VTs: {MVT::v2f32, MVT::v2i32},
652 Action: Custom);
653
654 // Custom conversions to/from v2i8.
655 setOperationAction(Op: ISD::BITCAST, VT: MVT::v2i8, Action: Custom);
656
657 // Only logical ops can be done on v4i8/v2i32 directly, others must be done
658 // elementwise.
659 setOperationAction(
660 Ops: {ISD::ABS, ISD::ADD, ISD::ADDC, ISD::ADDE,
661 ISD::BITREVERSE, ISD::CTLZ, ISD::CTPOP, ISD::CTTZ,
662 ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::FSHL, ISD::FSHR,
663 ISD::MUL, ISD::MULHS, ISD::MULHU, ISD::PARITY,
664 ISD::ROTL, ISD::ROTR, ISD::SADDO, ISD::SADDO_CARRY,
665 ISD::SADDSAT, ISD::SDIV, ISD::SDIVREM, ISD::SELECT_CC,
666 ISD::SETCC, ISD::SHL, ISD::SINT_TO_FP, ISD::SMAX,
667 ISD::SMIN, ISD::SMULO, ISD::SMUL_LOHI, ISD::SRA,
668 ISD::SREM, ISD::SRL, ISD::SSHLSAT, ISD::SSUBO,
669 ISD::SSUBO_CARRY, ISD::SSUBSAT, ISD::SUB, ISD::SUBC,
670 ISD::SUBE, ISD::UADDO, ISD::UADDO_CARRY, ISD::UADDSAT,
671 ISD::UDIV, ISD::UDIVREM, ISD::UINT_TO_FP, ISD::UMAX,
672 ISD::UMIN, ISD::UMULO, ISD::UMUL_LOHI, ISD::UREM,
673 ISD::USHLSAT, ISD::USUBO, ISD::USUBO_CARRY, ISD::VSELECT,
674 ISD::USUBSAT},
675 VTs: {MVT::v4i8, MVT::v2i32}, Action: Expand);
676
677 // Operations not directly supported by NVPTX.
678 for (MVT VT : {MVT::bf16, MVT::f16, MVT::v2bf16, MVT::v2f16, MVT::f32,
679 MVT::v2f32, MVT::f64, MVT::i1, MVT::i8, MVT::i16, MVT::v2i16,
680 MVT::v4i8, MVT::i32, MVT::v2i32, MVT::i64}) {
681 setOperationAction(Op: ISD::SELECT_CC, VT, Action: Expand);
682 setOperationAction(Op: ISD::BR_CC, VT, Action: Expand);
683 }
684
685 // We don't want ops like FMINIMUM or UMAX to be lowered to SETCC+VSELECT.
686 setOperationAction(Ops: ISD::VSELECT, VTs: {MVT::v2f32, MVT::v2i32}, Action: Expand);
687
688 // Some SIGN_EXTEND_INREG can be done using cvt instruction.
689 // For others we will expand to a SHL/SRA pair.
690 setOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: MVT::i64, Action: Legal);
691 setOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: MVT::i32, Action: Legal);
692 setOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: MVT::i16, Action: Legal);
693 setOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: MVT::i8 , Action: Legal);
694 setOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: MVT::i1, Action: Expand);
695 setOperationAction(Ops: ISD::SIGN_EXTEND_INREG, VTs: {MVT::v2i16, MVT::v2i32}, Action: Expand);
696
697 setOperationAction(Op: ISD::SHL_PARTS, VT: MVT::i32 , Action: Custom);
698 setOperationAction(Op: ISD::SRA_PARTS, VT: MVT::i32 , Action: Custom);
699 setOperationAction(Op: ISD::SRL_PARTS, VT: MVT::i32 , Action: Custom);
700 setOperationAction(Op: ISD::SHL_PARTS, VT: MVT::i64 , Action: Custom);
701 setOperationAction(Op: ISD::SRA_PARTS, VT: MVT::i64 , Action: Custom);
702 setOperationAction(Op: ISD::SRL_PARTS, VT: MVT::i64 , Action: Custom);
703
704 setOperationAction(Op: ISD::BITREVERSE, VT: MVT::i32, Action: Legal);
705 setOperationAction(Op: ISD::BITREVERSE, VT: MVT::i64, Action: Legal);
706
707 setOperationAction(Ops: {ISD::ROTL, ISD::ROTR},
708 VTs: {MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64},
709 Action: Expand);
710
711 if (STI.hasHWROT32()) {
712 setOperationAction(Ops: {ISD::FSHL, ISD::FSHR}, VT: MVT::i32, Action: Legal);
713 setOperationAction(Ops: {ISD::ROTL, ISD::ROTR, ISD::FSHL, ISD::FSHR}, VT: MVT::i64,
714 Action: Custom);
715 }
716
717 setOperationAction(Op: ISD::BR_JT, VT: MVT::Other, Action: STI.hasBrx() ? Legal : Expand);
718 setOperationAction(Op: ISD::BRIND, VT: MVT::Other, Action: Expand);
719
720 // We want to legalize constant related memmove and memcopy
721 // intrinsics.
722 setOperationAction(Op: ISD::INTRINSIC_W_CHAIN, VT: MVT::Other, Action: Custom);
723
724 // FP extload/truncstore is not legal in PTX. We need to expand all these.
725 for (auto FloatVTs :
726 {MVT::fp_valuetypes(), MVT::fp_fixedlen_vector_valuetypes()}) {
727 for (MVT ValVT : FloatVTs) {
728 for (MVT MemVT : FloatVTs) {
729 setLoadExtAction(ExtType: ISD::EXTLOAD, ValVT, MemVT, Action: Expand);
730 setTruncStoreAction(ValVT, MemVT, Action: Expand);
731 }
732 }
733 }
734
735 // To improve CodeGen we'll legalize any-extend loads to zext loads. This is
736 // how they'll be lowered in ISel anyway, and by doing this a little earlier
737 // we allow for more DAG combine opportunities.
738 for (auto IntVTs :
739 {MVT::integer_valuetypes(), MVT::integer_fixedlen_vector_valuetypes()})
740 for (MVT ValVT : IntVTs)
741 for (MVT MemVT : IntVTs)
742 if (isTypeLegal(VT: ValVT))
743 setLoadExtAction(ExtType: ISD::EXTLOAD, ValVT, MemVT, Action: Custom);
744
745 // PTX does not support load / store predicate registers
746 setOperationAction(Ops: {ISD::LOAD, ISD::STORE}, VT: MVT::i1, Action: Custom);
747 for (MVT VT : MVT::integer_valuetypes()) {
748 setLoadExtAction(ExtTypes: {ISD::SEXTLOAD, ISD::ZEXTLOAD, ISD::EXTLOAD}, ValVT: VT, MemVT: MVT::i1,
749 Action: Promote);
750 setTruncStoreAction(ValVT: VT, MemVT: MVT::i1, Action: Expand);
751 }
752
753 // Disable generations of extload/truncstore for v2i32/v2i16/v2i8. The generic
754 // expansion for these nodes when they are unaligned is incorrect if the
755 // type is a vector.
756 //
757 // TODO: Fix the generic expansion for these nodes found in
758 // TargetLowering::expandUnalignedLoad/Store.
759 setLoadExtAction(ExtTypes: {ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, ValVT: MVT::v2i16,
760 MemVT: MVT::v2i8, Action: Expand);
761 setLoadExtAction(ExtTypes: {ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, ValVT: MVT::v2i32,
762 MemVTs: {MVT::v2i8, MVT::v2i16}, Action: Expand);
763 setTruncStoreAction(ValVT: MVT::v2i16, MemVT: MVT::v2i8, Action: Expand);
764 setTruncStoreAction(ValVT: MVT::v2i32, MemVT: MVT::v2i16, Action: Expand);
765 setTruncStoreAction(ValVT: MVT::v2i32, MemVT: MVT::v2i8, Action: Expand);
766
767 // Register custom handling for illegal type loads/stores. We'll try to custom
768 // lower almost all illegal types and logic in the lowering will discard cases
769 // we can't handle.
770 setOperationAction(Ops: {ISD::LOAD, ISD::STORE}, VTs: {MVT::i128, MVT::i256, MVT::f128},
771 Action: Custom);
772 for (MVT VT : MVT::fixedlen_vector_valuetypes())
773 if (!isTypeLegal(VT) && VT.getStoreSizeInBits() <= 256)
774 setOperationAction(Ops: {ISD::STORE, ISD::LOAD, ISD::MSTORE, ISD::MLOAD}, VT,
775 Action: Custom);
776
777 // Custom legalization for LDU intrinsics.
778 // TODO: The logic to lower these is not very robust and we should rewrite it.
779 // Perhaps LDU should not be represented as an intrinsic at all.
780 setOperationAction(Op: ISD::INTRINSIC_W_CHAIN, VT: MVT::i8, Action: Custom);
781 for (MVT VT : MVT::fixedlen_vector_valuetypes())
782 if (IsPTXVectorType(VT))
783 setOperationAction(Op: ISD::INTRINSIC_W_CHAIN, VT, Action: Custom);
784
785 setCondCodeAction(CCs: {ISD::SETNE, ISD::SETEQ, ISD::SETUGE, ISD::SETULE,
786 ISD::SETUGT, ISD::SETULT, ISD::SETGT, ISD::SETLT,
787 ISD::SETGE, ISD::SETLE},
788 VT: MVT::i1, Action: Expand);
789
790 // This is legal in NVPTX
791 setOperationAction(Op: ISD::ConstantFP, VT: MVT::f64, Action: Legal);
792 setOperationAction(Op: ISD::ConstantFP, VT: MVT::f32, Action: Legal);
793 setOperationAction(Op: ISD::ConstantFP, VT: MVT::f16, Action: Legal);
794 setOperationAction(Op: ISD::ConstantFP, VT: MVT::bf16, Action: Legal);
795
796 setOperationAction(Ops: ISD::DYNAMIC_STACKALLOC, VTs: {MVT::i32, MVT::i64}, Action: Custom);
797 setOperationAction(Ops: {ISD::STACKRESTORE, ISD::STACKSAVE}, VT: MVT::Other, Action: Custom);
798
799 // TRAP can be lowered to PTX trap
800 setOperationAction(Op: ISD::TRAP, VT: MVT::Other, Action: Legal);
801 // DEBUGTRAP can be lowered to PTX brkpt
802 setOperationAction(Op: ISD::DEBUGTRAP, VT: MVT::Other, Action: Legal);
803
804 // Support varargs.
805 setOperationAction(Op: ISD::VASTART, VT: MVT::Other, Action: Custom);
806 setOperationAction(Op: ISD::VAARG, VT: MVT::Other, Action: Custom);
807 setOperationAction(Op: ISD::VACOPY, VT: MVT::Other, Action: Expand);
808 setOperationAction(Op: ISD::VAEND, VT: MVT::Other, Action: Expand);
809
810 setOperationAction(Ops: {ISD::ABS, ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX},
811 VTs: {MVT::i16, MVT::i32, MVT::i64}, Action: Legal);
812
813 setOperationAction(Ops: {ISD::CTPOP, ISD::CTLZ, ISD::CTLZ_ZERO_UNDEF}, VT: MVT::i16,
814 Action: Promote);
815 setOperationAction(Ops: {ISD::CTPOP, ISD::CTLZ}, VT: MVT::i32, Action: Legal);
816 setOperationAction(Ops: {ISD::CTPOP, ISD::CTLZ}, VT: MVT::i64, Action: Custom);
817
818 setI16x2OperationAction(ISD::ABS, MVT::v2i16, Legal, Custom);
819 setI16x2OperationAction(ISD::SMIN, MVT::v2i16, Legal, Custom);
820 setI16x2OperationAction(ISD::SMAX, MVT::v2i16, Legal, Custom);
821 setI16x2OperationAction(ISD::UMIN, MVT::v2i16, Legal, Custom);
822 setI16x2OperationAction(ISD::UMAX, MVT::v2i16, Legal, Custom);
823 setI16x2OperationAction(ISD::CTPOP, MVT::v2i16, Legal, Expand);
824 setI16x2OperationAction(ISD::CTLZ, MVT::v2i16, Legal, Expand);
825
826 setI16x2OperationAction(ISD::ADD, MVT::v2i16, Legal, Custom);
827 setI16x2OperationAction(ISD::SUB, MVT::v2i16, Legal, Custom);
828 setI16x2OperationAction(ISD::MUL, MVT::v2i16, Legal, Custom);
829 setI16x2OperationAction(ISD::SHL, MVT::v2i16, Legal, Custom);
830 setI16x2OperationAction(ISD::SREM, MVT::v2i16, Legal, Custom);
831 setI16x2OperationAction(ISD::UREM, MVT::v2i16, Legal, Custom);
832
833 // Other arithmetic and logic ops are unsupported.
834 setOperationAction(Ops: {ISD::SDIV, ISD::UDIV, ISD::SRA, ISD::SRL, ISD::MULHS,
835 ISD::MULHU, ISD::FP_TO_SINT, ISD::FP_TO_UINT,
836 ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::SETCC},
837 VTs: {MVT::v2i16, MVT::v2i32}, Action: Expand);
838
839 // v2i32 is not supported for any arithmetic operations
840 setOperationAction(Ops: {ISD::ABS, ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX,
841 ISD::CTPOP, ISD::CTLZ, ISD::ADD, ISD::SUB, ISD::MUL,
842 ISD::SHL, ISD::SRA, ISD::SRL, ISD::OR, ISD::AND, ISD::XOR,
843 ISD::SREM, ISD::UREM},
844 VT: MVT::v2i32, Action: Expand);
845
846 setOperationAction(Op: ISD::ADDC, VT: MVT::i32, Action: Legal);
847 setOperationAction(Op: ISD::ADDE, VT: MVT::i32, Action: Legal);
848 setOperationAction(Op: ISD::SUBC, VT: MVT::i32, Action: Legal);
849 setOperationAction(Op: ISD::SUBE, VT: MVT::i32, Action: Legal);
850 if (STI.getPTXVersion() >= 43) {
851 setOperationAction(Op: ISD::ADDC, VT: MVT::i64, Action: Legal);
852 setOperationAction(Op: ISD::ADDE, VT: MVT::i64, Action: Legal);
853 setOperationAction(Op: ISD::SUBC, VT: MVT::i64, Action: Legal);
854 setOperationAction(Op: ISD::SUBE, VT: MVT::i64, Action: Legal);
855 }
856
857 setOperationAction(Op: ISD::CTTZ, VT: MVT::i16, Action: Expand);
858 setOperationAction(Ops: ISD::CTTZ, VTs: {MVT::v2i16, MVT::v2i32}, Action: Expand);
859 setOperationAction(Op: ISD::CTTZ, VT: MVT::i32, Action: Expand);
860 setOperationAction(Op: ISD::CTTZ, VT: MVT::i64, Action: Expand);
861
862 // PTX does not directly support SELP of i1, so promote to i32 first
863 setOperationAction(Op: ISD::SELECT, VT: MVT::i1, Action: Custom);
864
865 // PTX cannot multiply two i64s in a single instruction.
866 setOperationAction(Op: ISD::SMUL_LOHI, VT: MVT::i64, Action: Expand);
867 setOperationAction(Op: ISD::UMUL_LOHI, VT: MVT::i64, Action: Expand);
868
869 // We have some custom DAG combine patterns for these nodes
870 setTargetDAGCombine({ISD::ADD,
871 ISD::AND,
872 ISD::EXTRACT_VECTOR_ELT,
873 ISD::FADD,
874 ISD::FMAXNUM,
875 ISD::FMINNUM,
876 ISD::FMAXIMUM,
877 ISD::FMINIMUM,
878 ISD::FMAXIMUMNUM,
879 ISD::FMINIMUMNUM,
880 ISD::MUL,
881 ISD::SELECT,
882 ISD::SHL,
883 ISD::SREM,
884 ISD::UREM,
885 ISD::VSELECT,
886 ISD::BUILD_VECTOR,
887 ISD::ADDRSPACECAST,
888 ISD::LOAD,
889 ISD::STORE,
890 ISD::ZERO_EXTEND,
891 ISD::SIGN_EXTEND,
892 ISD::INTRINSIC_WO_CHAIN});
893
894 // If the vector operands require register coalescing, scalarize instead
895 if (STI.hasF32x2Instructions())
896 setTargetDAGCombine({ISD::FMA, ISD::FMUL, ISD::FSUB});
897
898 // setcc for f16x2 and bf16x2 needs special handling to prevent
899 // legalizer's attempt to scalarize it due to v2i1 not being legal.
900 if (STI.allowFP16Math() || STI.hasBF16Math())
901 setTargetDAGCombine(ISD::SETCC);
902
903 // Vector reduction operations. These may be turned into shuffle or tree
904 // reductions depending on what instructions are available for each type.
905 for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
906 MVT EltVT = VT.getVectorElementType();
907 if (EltVT == MVT::f32 || EltVT == MVT::f64) {
908 setOperationAction(Ops: {ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
909 ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
910 VT, Action: Custom);
911 }
912 }
913
914 // Promote fp16 arithmetic if fp16 hardware isn't available or the
915 // user passed --nvptx-no-fp16-math. The flag is useful because,
916 // although sm_53+ GPUs have some sort of FP16 support in
917 // hardware, only sm_53 and sm_60 have full implementation. Others
918 // only have token amount of hardware and are likely to run faster
919 // by using fp32 units instead.
920 for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) {
921 setFP16OperationAction(Op, MVT::f16, Legal, Promote);
922 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
923 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
924 // bf16 must be promoted to f32.
925 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
926 if (getOperationAction(Op, VT: MVT::bf16) == Promote)
927 AddPromotedToType(Opc: Op, OrigVT: MVT::bf16, DestVT: MVT::f32);
928 setOperationAction(Op, VT: MVT::v2f32,
929 Action: STI.hasF32x2Instructions() ? Legal : Expand);
930 }
931
932 // On SM80, we select add/mul/sub as fma to avoid promotion to float
933 for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB}) {
934 for (const auto &VT : {MVT::bf16, MVT::v2bf16}) {
935 if (!STI.hasNativeBF16Support(Opcode: Op) && STI.hasNativeBF16Support(Opcode: ISD::FMA)) {
936 setOperationAction(Op, VT, Action: Custom);
937 }
938 }
939 }
940
941 // f16/f16x2 neg was introduced in PTX 60, SM_53.
942 const bool IsFP16FP16x2NegAvailable = STI.getSmVersion() >= 53 &&
943 STI.getPTXVersion() >= 60 &&
944 STI.allowFP16Math();
945 for (const auto &VT : {MVT::f16, MVT::v2f16})
946 setOperationAction(Op: ISD::FNEG, VT,
947 Action: IsFP16FP16x2NegAvailable ? Legal : Expand);
948
949 setBF16OperationAction(ISD::FNEG, MVT::bf16, Legal, Expand);
950 setBF16OperationAction(ISD::FNEG, MVT::v2bf16, Legal, Expand);
951 setOperationAction(Op: ISD::FNEG, VT: MVT::v2f32, Action: Expand);
952 // (would be) Library functions.
953
954 // These map to conversion instructions for scalar FP types.
955 for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
956 ISD::FROUNDEVEN, ISD::FTRUNC}) {
957 setOperationAction(Op, VT: MVT::f16, Action: Legal);
958 setOperationAction(Op, VT: MVT::f32, Action: Legal);
959 setOperationAction(Op, VT: MVT::f64, Action: Legal);
960 setOperationAction(Op, VT: MVT::v2f16, Action: Expand);
961 setOperationAction(Op, VT: MVT::v2bf16, Action: Expand);
962 setOperationAction(Op, VT: MVT::v2f32, Action: Expand);
963 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
964 if (getOperationAction(Op, VT: MVT::bf16) == Promote)
965 AddPromotedToType(Opc: Op, OrigVT: MVT::bf16, DestVT: MVT::f32);
966 }
967
968 if (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71) {
969 setOperationAction(Op: ISD::BF16_TO_FP, VT: MVT::f32, Action: Expand);
970 }
971 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
972 for (MVT VT : {MVT::bf16, MVT::f32, MVT::f64}) {
973 setOperationAction(Op: ISD::FP_EXTEND, VT, Action: Custom);
974 setOperationAction(Op: ISD::FP_ROUND, VT, Action: Custom);
975 }
976 }
977
978 // Expand v2f32 = fp_extend
979 setOperationAction(Op: ISD::FP_EXTEND, VT: MVT::v2f32, Action: Expand);
980 // Expand v2[b]f16 = fp_round v2f32
981 setOperationAction(Ops: ISD::FP_ROUND, VTs: {MVT::v2bf16, MVT::v2f16}, Action: Expand);
982
983 // sm_80 only has conversions between f32 and bf16. Custom lower all other
984 // bf16 conversions.
985 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
986 for (MVT VT : {MVT::i1, MVT::i16, MVT::i32, MVT::i64}) {
987 setOperationAction(
988 Ops: {ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT},
989 VT, Action: Custom);
990 }
991 setOperationAction(
992 Ops: {ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT},
993 VT: MVT::bf16, Action: Custom);
994 }
995
996 setOperationAction(Op: ISD::FROUND, VT: MVT::f16, Action: Promote);
997 setOperationAction(Op: ISD::FROUND, VT: MVT::v2f16, Action: Expand);
998 setOperationAction(Op: ISD::FROUND, VT: MVT::v2bf16, Action: Expand);
999 setOperationAction(Op: ISD::FROUND, VT: MVT::f32, Action: Custom);
1000 setOperationAction(Op: ISD::FROUND, VT: MVT::f64, Action: Custom);
1001 setOperationAction(Op: ISD::FROUND, VT: MVT::bf16, Action: Promote);
1002 AddPromotedToType(Opc: ISD::FROUND, OrigVT: MVT::bf16, DestVT: MVT::f32);
1003
1004 // 'Expand' implements FCOPYSIGN without calling an external library.
1005 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::f16, Action: Expand);
1006 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::v2f16, Action: Expand);
1007 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::bf16, Action: Expand);
1008 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::v2bf16, Action: Expand);
1009 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::f32, Action: Custom);
1010 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::f64, Action: Custom);
1011
1012 // These map to corresponding instructions for f32/f64. f16 must be
1013 // promoted to f32. v2f16 is expanded to f16, which is then promoted
1014 // to f32.
1015 for (const auto &Op :
1016 {ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS, ISD::FTANH}) {
1017 setOperationAction(Op, VT: MVT::f16, Action: Promote);
1018 setOperationAction(Op, VT: MVT::f32, Action: Legal);
1019 // only div/rem/sqrt are legal for f64
1020 if (Op == ISD::FDIV || Op == ISD::FREM || Op == ISD::FSQRT) {
1021 setOperationAction(Op, VT: MVT::f64, Action: Legal);
1022 }
1023 setOperationAction(Ops: Op, VTs: {MVT::v2f16, MVT::v2bf16, MVT::v2f32}, Action: Expand);
1024 setOperationAction(Op, VT: MVT::bf16, Action: Promote);
1025 AddPromotedToType(Opc: Op, OrigVT: MVT::bf16, DestVT: MVT::f32);
1026 }
1027 setOperationAction(Ops: ISD::FREM, VTs: {MVT::f32, MVT::f64}, Action: Custom);
1028
1029 setOperationAction(Ops: ISD::FABS, VTs: {MVT::f32, MVT::f64}, Action: Legal);
1030 setOperationAction(Op: ISD::FABS, VT: MVT::v2f32, Action: Expand);
1031 if (STI.getPTXVersion() >= 65) {
1032 setFP16OperationAction(ISD::FABS, MVT::f16, Legal, Promote);
1033 setFP16OperationAction(ISD::FABS, MVT::v2f16, Legal, Expand);
1034 } else {
1035 setOperationAction(Op: ISD::FABS, VT: MVT::f16, Action: Promote);
1036 setOperationAction(Op: ISD::FABS, VT: MVT::v2f16, Action: Expand);
1037 }
1038 setBF16OperationAction(ISD::FABS, MVT::v2bf16, Legal, Expand);
1039 setBF16OperationAction(ISD::FABS, MVT::bf16, Legal, Promote);
1040 if (getOperationAction(Op: ISD::FABS, VT: MVT::bf16) == Promote)
1041 AddPromotedToType(Opc: ISD::FABS, OrigVT: MVT::bf16, DestVT: MVT::f32);
1042
1043 for (const auto &Op :
1044 {ISD::FMINNUM, ISD::FMAXNUM, ISD::FMINIMUMNUM, ISD::FMAXIMUMNUM}) {
1045 setOperationAction(Op, VT: MVT::f32, Action: Legal);
1046 setOperationAction(Op, VT: MVT::f64, Action: Legal);
1047 setFP16OperationAction(Op, MVT::f16, Legal, Promote);
1048 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
1049 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
1050 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
1051 if (getOperationAction(Op, VT: MVT::bf16) == Promote)
1052 AddPromotedToType(Opc: Op, OrigVT: MVT::bf16, DestVT: MVT::f32);
1053 setOperationAction(Op, VT: MVT::v2f32, Action: Expand);
1054 }
1055 bool SupportsF32MinMaxNaN =
1056 STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
1057 for (const auto &Op : {ISD::FMINIMUM, ISD::FMAXIMUM}) {
1058 setOperationAction(Op, VT: MVT::f32, Action: SupportsF32MinMaxNaN ? Legal : Expand);
1059 setFP16OperationAction(Op, MVT::f16, Legal, Expand);
1060 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
1061 setBF16OperationAction(Op, MVT::bf16, Legal, Expand);
1062 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
1063 setOperationAction(Op, VT: MVT::v2f32, Action: Expand);
1064 }
1065
1066 // Custom lowering for inline asm with 128-bit operands
1067 setOperationAction(Op: ISD::CopyToReg, VT: MVT::i128, Action: Custom);
1068 setOperationAction(Op: ISD::CopyFromReg, VT: MVT::i128, Action: Custom);
1069
1070 // FEXP2 support:
1071 // - f32
1072 // - f16/f16x2 (sm_70+, PTX 7.0+)
1073 // - bf16/bf16x2 (sm_90+, PTX 7.8+)
1074 // When f16/bf16 types aren't supported, they are promoted/expanded to f32.
1075 setOperationAction(Op: ISD::FEXP2, VT: MVT::f32, Action: Legal);
1076 setOperationAction(Op: ISD::FEXP2, VT: MVT::v2f32, Action: Expand);
1077 setFP16OperationAction(ISD::FEXP2, MVT::f16, Legal, Promote);
1078 setFP16OperationAction(ISD::FEXP2, MVT::v2f16, Legal, Expand);
1079 setBF16OperationAction(ISD::FEXP2, MVT::bf16, Legal, Promote);
1080 setBF16OperationAction(ISD::FEXP2, MVT::v2bf16, Legal, Expand);
1081
1082 // FLOG2 supports f32 only
1083 // f16/bf16 types aren't supported, but they are promoted/expanded to f32.
1084 if (UseApproxLog2F32) {
1085 setOperationAction(Op: ISD::FLOG2, VT: MVT::f32, Action: Legal);
1086 setOperationPromotedToType(Opc: ISD::FLOG2, OrigVT: MVT::f16, DestVT: MVT::f32);
1087 setOperationPromotedToType(Opc: ISD::FLOG2, OrigVT: MVT::bf16, DestVT: MVT::f32);
1088 setOperationAction(Ops: ISD::FLOG2, VTs: {MVT::v2f16, MVT::v2bf16, MVT::v2f32},
1089 Action: Expand);
1090 }
1091
1092 setOperationAction(Ops: ISD::ADDRSPACECAST, VTs: {MVT::i32, MVT::i64}, Action: Custom);
1093
1094 setOperationAction(Ops: ISD::ATOMIC_LOAD_SUB, VTs: {MVT::i32, MVT::i64}, Action: Expand);
1095
1096 // atom.b128 is legal in PTX but since we don't represent i128 as a legal
1097 // type, we need to custom lower it.
1098 setOperationAction(Ops: {ISD::ATOMIC_CMP_SWAP, ISD::ATOMIC_SWAP}, VT: MVT::i128,
1099 Action: Custom);
1100
1101 // Now deduce the information based on the above mentioned
1102 // actions
1103 computeRegisterProperties(TRI: STI.getRegisterInfo());
1104
1105 // PTX support for 16-bit CAS is emulated. Only use 32+
1106 setMinCmpXchgSizeInBits(STI.getMinCmpXchgSizeInBits());
1107 setMaxAtomicSizeInBitsSupported(STI.hasAtomSwap128() ? 128 : 64);
1108 setMaxDivRemBitWidthSupported(64);
1109
1110 // Custom lowering for tcgen05.ld vector operands
1111 setOperationAction(Ops: ISD::INTRINSIC_W_CHAIN,
1112 VTs: {MVT::v2i32, MVT::v4i32, MVT::v8i32, MVT::v16i32,
1113 MVT::v32i32, MVT::v64i32, MVT::v128i32, MVT::v2f32,
1114 MVT::v4f32, MVT::v8f32, MVT::v16f32, MVT::v32f32,
1115 MVT::v64f32, MVT::v128f32},
1116 Action: Custom);
1117
1118 // Custom lowering for tcgen05.st vector operands
1119 setOperationAction(Ops: ISD::INTRINSIC_VOID,
1120 VTs: {MVT::v2i32, MVT::v4i32, MVT::v8i32, MVT::v16i32,
1121 MVT::v32i32, MVT::v64i32, MVT::v128i32, MVT::Other},
1122 Action: Custom);
1123
1124 // Enable custom lowering for the following:
1125 // * MVT::i128 - clusterlaunchcontrol
1126 // * MVT::i32 - prmt
1127 // * MVT::v4f32 - cvt_rs fp{4/6/8}x4 intrinsics
1128 // * MVT::Other - internal.addrspace.wrap
1129 setOperationAction(Ops: ISD::INTRINSIC_WO_CHAIN,
1130 VTs: {MVT::i32, MVT::i128, MVT::v4f32, MVT::Other}, Action: Custom);
1131
1132 // Custom lowering for bswap
1133 setOperationAction(Ops: ISD::BSWAP, VTs: {MVT::i16, MVT::i32, MVT::i64, MVT::v2i16},
1134 Action: Custom);
1135}
1136
1137TargetLoweringBase::LegalizeTypeAction
1138NVPTXTargetLowering::getPreferredVectorAction(MVT VT) const {
1139 if (!VT.isScalableVector() && VT.getVectorNumElements() != 1 &&
1140 VT.getScalarType() == MVT::i1)
1141 return TypeSplitVector;
1142 return TargetLoweringBase::getPreferredVectorAction(VT);
1143}
1144
1145SDValue NVPTXTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG,
1146 int Enabled, int &ExtraSteps,
1147 bool &UseOneConst,
1148 bool Reciprocal) const {
1149 if (!(Enabled == ReciprocalEstimate::Enabled ||
1150 (Enabled == ReciprocalEstimate::Unspecified && !usePrecSqrtF32())))
1151 return SDValue();
1152
1153 if (ExtraSteps == ReciprocalEstimate::Unspecified)
1154 ExtraSteps = 0;
1155
1156 SDLoc DL(Operand);
1157 EVT VT = Operand.getValueType();
1158 bool Ftz = useF32FTZ(MF: DAG.getMachineFunction());
1159
1160 auto MakeIntrinsicCall = [&](Intrinsic::ID IID) {
1161 return DAG.getNode(Opcode: ISD::INTRINSIC_WO_CHAIN, DL, VT,
1162 N1: DAG.getConstant(Val: IID, DL, VT: MVT::i32), N2: Operand);
1163 };
1164
1165 // The sqrt and rsqrt refinement processes assume we always start out with an
1166 // approximation of the rsqrt. Therefore, if we're going to do any refinement
1167 // (i.e. ExtraSteps > 0), we must return an rsqrt. But if we're *not* doing
1168 // any refinement, we must return a regular sqrt.
1169 if (Reciprocal || ExtraSteps > 0) {
1170 if (VT == MVT::f32)
1171 return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_rsqrt_approx_ftz_f
1172 : Intrinsic::nvvm_rsqrt_approx_f);
1173 else if (VT == MVT::f64)
1174 return MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d);
1175 else
1176 return SDValue();
1177 } else {
1178 if (VT == MVT::f32)
1179 return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_sqrt_approx_ftz_f
1180 : Intrinsic::nvvm_sqrt_approx_f);
1181 else {
1182 // There's no sqrt.approx.f64 instruction, so we emit
1183 // reciprocal(rsqrt(x)). This is faster than
1184 // select(x == 0, 0, x * rsqrt(x)). (In fact, it's faster than plain
1185 // x * rsqrt(x).)
1186 return DAG.getNode(
1187 Opcode: ISD::INTRINSIC_WO_CHAIN, DL, VT,
1188 N1: DAG.getConstant(Val: Intrinsic::nvvm_rcp_approx_ftz_d, DL, VT: MVT::i32),
1189 N2: MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d));
1190 }
1191 }
1192}
1193
1194std::string NVPTXTargetLowering::getPrototype(
1195 const DataLayout &DL, Type *RetTy, const ArgListTy &Args,
1196 const SmallVectorImpl<ISD::OutputArg> &Outs,
1197 std::optional<unsigned> FirstVAArg, const CallBase &CB,
1198 unsigned UniqueCallSite) const {
1199 auto PtrVT = getPointerTy(DL);
1200
1201 std::string Prototype;
1202 raw_string_ostream O(Prototype);
1203 O << "prototype_" << UniqueCallSite << " : .callprototype ";
1204
1205 if (RetTy->isVoidTy()) {
1206 O << "()";
1207 } else {
1208 O << "(";
1209 if (shouldPassAsArray(Ty: RetTy)) {
1210 const Align RetAlign = getArgumentAlignment(CB: &CB, Ty: RetTy, Idx: 0, DL);
1211 O << ".param .align " << RetAlign.value() << " .b8 _["
1212 << DL.getTypeAllocSize(Ty: RetTy) << "]";
1213 } else if (RetTy->isFloatingPointTy() || RetTy->isIntegerTy()) {
1214 unsigned size = 0;
1215 if (auto *ITy = dyn_cast<IntegerType>(Val: RetTy)) {
1216 size = ITy->getBitWidth();
1217 } else {
1218 assert(RetTy->isFloatingPointTy() &&
1219 "Floating point type expected here");
1220 size = RetTy->getPrimitiveSizeInBits();
1221 }
1222 // PTX ABI requires all scalar return values to be at least 32
1223 // bits in size. fp16 normally uses .b16 as its storage type in
1224 // PTX, so its size must be adjusted here, too.
1225 size = promoteScalarArgumentSize(size);
1226
1227 O << ".param .b" << size << " _";
1228 } else if (isa<PointerType>(Val: RetTy)) {
1229 O << ".param .b" << PtrVT.getSizeInBits() << " _";
1230 } else {
1231 llvm_unreachable("Unknown return type");
1232 }
1233 O << ") ";
1234 }
1235 O << "_ (";
1236
1237 bool first = true;
1238
1239 const unsigned NumArgs = FirstVAArg.value_or(u: Args.size());
1240 auto AllOuts = ArrayRef(Outs);
1241 for (const unsigned I : llvm::seq(Size: NumArgs)) {
1242 const auto ArgOuts =
1243 AllOuts.take_while(Pred: [I](auto O) { return O.OrigArgIndex == I; });
1244 AllOuts = AllOuts.drop_front(N: ArgOuts.size());
1245
1246 Type *Ty = Args[I].Ty;
1247 if (!first) {
1248 O << ", ";
1249 }
1250 first = false;
1251
1252 if (ArgOuts[0].Flags.isByVal()) {
1253 // Indirect calls need strict ABI alignment so we disable optimizations by
1254 // not providing a function to optimize.
1255 Type *ETy = Args[I].IndirectType;
1256 Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
1257 Align ParamByValAlign =
1258 getFunctionByValParamAlign(/*F=*/nullptr, ArgTy: ETy, InitialAlign, DL);
1259
1260 O << ".param .align " << ParamByValAlign.value() << " .b8 _["
1261 << ArgOuts[0].Flags.getByValSize() << "]";
1262 } else {
1263 if (shouldPassAsArray(Ty)) {
1264 Align ParamAlign =
1265 getArgumentAlignment(CB: &CB, Ty, Idx: I + AttributeList::FirstArgIndex, DL);
1266 O << ".param .align " << ParamAlign.value() << " .b8 _["
1267 << DL.getTypeAllocSize(Ty) << "]";
1268 continue;
1269 }
1270 // i8 types in IR will be i16 types in SDAG
1271 assert((getValueType(DL, Ty) == ArgOuts[0].VT ||
1272 (getValueType(DL, Ty) == MVT::i8 && ArgOuts[0].VT == MVT::i16)) &&
1273 "type mismatch between callee prototype and arguments");
1274 // scalar type
1275 unsigned sz = 0;
1276 if (auto *ITy = dyn_cast<IntegerType>(Val: Ty)) {
1277 sz = promoteScalarArgumentSize(size: ITy->getBitWidth());
1278 } else if (isa<PointerType>(Val: Ty)) {
1279 sz = PtrVT.getSizeInBits();
1280 } else {
1281 sz = Ty->getPrimitiveSizeInBits();
1282 }
1283 O << ".param .b" << sz << " _";
1284 }
1285 }
1286
1287 if (FirstVAArg)
1288 O << (first ? "" : ",") << " .param .align "
1289 << STI.getMaxRequiredAlignment() << " .b8 _[]";
1290 O << ")";
1291 if (shouldEmitPTXNoReturn(V: &CB, TM: *nvTM))
1292 O << " .noreturn";
1293 O << ";";
1294
1295 return Prototype;
1296}
1297
1298Align NVPTXTargetLowering::getFunctionArgumentAlignment(
1299 const Function *F, Type *Ty, unsigned Idx, const DataLayout &DL) const {
1300 return getAlign(F: *F, Index: Idx).value_or(u: getFunctionParamOptimizedAlign(F, ArgTy: Ty, DL));
1301}
1302
1303Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty,
1304 unsigned Idx,
1305 const DataLayout &DL) const {
1306 if (!CB) {
1307 // CallSite is zero, fallback to ABI type alignment
1308 return DL.getABITypeAlign(Ty);
1309 }
1310
1311 const Function *DirectCallee = CB->getCalledFunction();
1312
1313 if (!DirectCallee) {
1314 // We don't have a direct function symbol, but that may be because of
1315 // constant cast instructions in the call.
1316
1317 // With bitcast'd call targets, the instruction will be the call
1318 if (const auto *CI = dyn_cast<CallInst>(Val: CB)) {
1319 // Check if we have call alignment metadata
1320 if (MaybeAlign StackAlign = getAlign(*CI, Idx))
1321 return StackAlign.value();
1322 }
1323 DirectCallee = getMaybeBitcastedCallee(CB);
1324 }
1325
1326 // Check for function alignment information if we found that the
1327 // ultimate target is a Function
1328 if (DirectCallee)
1329 return getFunctionArgumentAlignment(F: DirectCallee, Ty, Idx, DL);
1330
1331 // Call is indirect, fall back to the ABI type alignment
1332 return DL.getABITypeAlign(Ty);
1333}
1334
1335static bool shouldConvertToIndirectCall(const CallBase *CB,
1336 const GlobalAddressSDNode *Func) {
1337 if (!Func)
1338 return false;
1339 if (auto *CalleeFunc = dyn_cast<Function>(Val: Func->getGlobal()))
1340 return CB->getFunctionType() != CalleeFunc->getFunctionType();
1341 return false;
1342}
1343
1344static MachinePointerInfo refinePtrAS(SDValue &Ptr, SelectionDAG &DAG,
1345 const DataLayout &DL,
1346 const TargetLowering &TL) {
1347 if (Ptr->getOpcode() == ISD::FrameIndex) {
1348 auto Ty = TL.getPointerTy(DL, AS: ADDRESS_SPACE_LOCAL);
1349 Ptr = DAG.getAddrSpaceCast(dl: SDLoc(), VT: Ty, Ptr, SrcAS: ADDRESS_SPACE_GENERIC,
1350 DestAS: ADDRESS_SPACE_LOCAL);
1351
1352 return MachinePointerInfo(ADDRESS_SPACE_LOCAL);
1353 }
1354
1355 // Peel of an addrspacecast to generic and load directly from the specific
1356 // address space.
1357 if (Ptr->getOpcode() == ISD::ADDRSPACECAST) {
1358 const auto *ASC = cast<AddrSpaceCastSDNode>(Val&: Ptr);
1359 if (ASC->getDestAddressSpace() == ADDRESS_SPACE_GENERIC) {
1360 Ptr = ASC->getOperand(Num: 0);
1361 return MachinePointerInfo(ASC->getSrcAddressSpace());
1362 }
1363 }
1364
1365 return MachinePointerInfo();
1366}
1367
1368static ISD::NodeType getExtOpcode(const ISD::ArgFlagsTy &Flags) {
1369 if (Flags.isSExt())
1370 return ISD::SIGN_EXTEND;
1371 if (Flags.isZExt())
1372 return ISD::ZERO_EXTEND;
1373 return ISD::ANY_EXTEND;
1374}
1375
1376static SDValue correctParamType(SDValue V, EVT ExpectedVT,
1377 ISD::ArgFlagsTy Flags, SelectionDAG &DAG,
1378 SDLoc dl) {
1379 const EVT ActualVT = V.getValueType();
1380 assert((ActualVT == ExpectedVT ||
1381 (ExpectedVT.isInteger() && ActualVT.isInteger())) &&
1382 "Non-integer argument type size mismatch");
1383 if (ExpectedVT.bitsGT(VT: ActualVT))
1384 return DAG.getNode(Opcode: getExtOpcode(Flags), DL: dl, VT: ExpectedVT, Operand: V);
1385 if (ExpectedVT.bitsLT(VT: ActualVT))
1386 return DAG.getNode(Opcode: ISD::TRUNCATE, DL: dl, VT: ExpectedVT, Operand: V);
1387
1388 return V;
1389}
1390
1391SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
1392 SmallVectorImpl<SDValue> &InVals) const {
1393
1394 if (CLI.IsVarArg && (STI.getPTXVersion() < 60 || STI.getSmVersion() < 30))
1395 report_fatal_error(
1396 reason: "Support for variadic functions (unsized array parameter) introduced "
1397 "in PTX ISA version 6.0 and requires target sm_30.");
1398
1399 SelectionDAG &DAG = CLI.DAG;
1400 SDLoc dl = CLI.DL;
1401 const SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
1402 SDValue Callee = CLI.Callee;
1403 ArgListTy &Args = CLI.getArgs();
1404 Type *RetTy = CLI.RetTy;
1405 const CallBase *CB = CLI.CB;
1406 const DataLayout &DL = DAG.getDataLayout();
1407 LLVMContext &Ctx = *DAG.getContext();
1408
1409 const auto GetI32 = [&](const unsigned I) {
1410 return DAG.getConstant(Val: I, DL: dl, VT: MVT::i32);
1411 };
1412
1413 const unsigned UniqueCallSite = GlobalUniqueCallSite++;
1414 const SDValue CallChain = CLI.Chain;
1415 const SDValue StartChain =
1416 DAG.getCALLSEQ_START(Chain: CallChain, InSize: UniqueCallSite, OutSize: 0, DL: dl);
1417 SDValue DeclareGlue = StartChain.getValue(R: 1);
1418
1419 SmallVector<SDValue, 16> CallPrereqs{StartChain};
1420
1421 const auto MakeDeclareScalarParam = [&](SDValue Symbol, unsigned Size) {
1422 // PTX ABI requires integral types to be at least 32 bits in size. FP16 is
1423 // loaded/stored using i16, so it's handled here as well.
1424 const unsigned SizeBits = promoteScalarArgumentSize(size: Size * 8);
1425 SDValue Declare =
1426 DAG.getNode(Opcode: NVPTXISD::DeclareScalarParam, DL: dl, ResultTys: {MVT::Other, MVT::Glue},
1427 Ops: {StartChain, Symbol, GetI32(SizeBits), DeclareGlue});
1428 CallPrereqs.push_back(Elt: Declare);
1429 DeclareGlue = Declare.getValue(R: 1);
1430 return Declare;
1431 };
1432
1433 const auto MakeDeclareArrayParam = [&](SDValue Symbol, Align Align,
1434 unsigned Size) {
1435 SDValue Declare = DAG.getNode(
1436 Opcode: NVPTXISD::DeclareArrayParam, DL: dl, ResultTys: {MVT::Other, MVT::Glue},
1437 Ops: {StartChain, Symbol, GetI32(Align.value()), GetI32(Size), DeclareGlue});
1438 CallPrereqs.push_back(Elt: Declare);
1439 DeclareGlue = Declare.getValue(R: 1);
1440 return Declare;
1441 };
1442
1443 // Variadic arguments.
1444 //
1445 // Normally, for each argument, we declare a param scalar or a param
1446 // byte array in the .param space, and store the argument value to that
1447 // param scalar or array starting at offset 0.
1448 //
1449 // In the case of the first variadic argument, we declare a vararg byte array
1450 // with size 0. The exact size of this array isn't known at this point, so
1451 // it'll be patched later. All the variadic arguments will be stored to this
1452 // array at a certain offset (which gets tracked by 'VAOffset'). The offset is
1453 // initially set to 0, so it can be used for non-variadic arguments (which use
1454 // 0 offset) to simplify the code.
1455 //
1456 // After all vararg is processed, 'VAOffset' holds the size of the
1457 // vararg byte array.
1458 assert((CLI.IsVarArg || CLI.Args.size() == CLI.NumFixedArgs) &&
1459 "Non-VarArg function with extra arguments");
1460
1461 const unsigned FirstVAArg = CLI.NumFixedArgs; // position of first variadic
1462 unsigned VAOffset = 0; // current offset in the param array
1463
1464 const SDValue VADeclareParam =
1465 CLI.Args.size() > FirstVAArg
1466 ? MakeDeclareArrayParam(getCallParamSymbol(DAG, I: FirstVAArg, T: MVT::i32),
1467 Align(STI.getMaxRequiredAlignment()), 0)
1468 : SDValue();
1469
1470 // Args.size() and Outs.size() need not match.
1471 // Outs.size() will be larger
1472 // * if there is an aggregate argument with multiple fields (each field
1473 // showing up separately in Outs)
1474 // * if there is a vector argument with more than typical vector-length
1475 // elements (generally if more than 4) where each vector element is
1476 // individually present in Outs.
1477 // So a different index should be used for indexing into Outs/OutVals.
1478 // See similar issue in LowerFormalArguments.
1479 auto AllOuts = ArrayRef(CLI.Outs);
1480 auto AllOutVals = ArrayRef(CLI.OutVals);
1481 assert(AllOuts.size() == AllOutVals.size() &&
1482 "Outs and OutVals must be the same size");
1483 // Declare the .params or .reg need to pass values
1484 // to the function
1485 for (const auto E : llvm::enumerate(First&: Args)) {
1486 const auto ArgI = E.index();
1487 const auto Arg = E.value();
1488 const auto ArgOuts =
1489 AllOuts.take_while(Pred: [&](auto O) { return O.OrigArgIndex == ArgI; });
1490 const auto ArgOutVals = AllOutVals.take_front(N: ArgOuts.size());
1491 AllOuts = AllOuts.drop_front(N: ArgOuts.size());
1492 AllOutVals = AllOutVals.drop_front(N: ArgOuts.size());
1493
1494 const bool IsVAArg = (ArgI >= FirstVAArg);
1495 const bool IsByVal = Arg.IsByVal;
1496
1497 const SDValue ParamSymbol =
1498 getCallParamSymbol(DAG, I: IsVAArg ? FirstVAArg : ArgI, T: MVT::i32);
1499
1500 assert((!IsByVal || Arg.IndirectType) &&
1501 "byval arg must have indirect type");
1502 Type *ETy = (IsByVal ? Arg.IndirectType : Arg.Ty);
1503
1504 const Align ArgAlign = [&]() {
1505 if (IsByVal) {
1506 // The ByValAlign in the Outs[OIdx].Flags is always set at this point,
1507 // so we don't need to worry whether it's naturally aligned or not.
1508 // See TargetLowering::LowerCallTo().
1509 const Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
1510 return getFunctionByValParamAlign(F: CB->getCalledFunction(), ArgTy: ETy,
1511 InitialAlign, DL);
1512 }
1513 return getArgumentAlignment(CB, Ty: Arg.Ty, Idx: ArgI + 1, DL);
1514 }();
1515
1516 const unsigned TySize = DL.getTypeAllocSize(Ty: ETy);
1517 assert((!IsByVal || TySize == ArgOuts[0].Flags.getByValSize()) &&
1518 "type size mismatch");
1519
1520 const SDValue ArgDeclare = [&]() {
1521 if (IsVAArg)
1522 return VADeclareParam;
1523
1524 if (IsByVal || shouldPassAsArray(Ty: Arg.Ty))
1525 return MakeDeclareArrayParam(ParamSymbol, ArgAlign, TySize);
1526
1527 assert(ArgOuts.size() == 1 && "We must pass only one value as non-array");
1528 assert((ArgOuts[0].VT.isInteger() || ArgOuts[0].VT.isFloatingPoint()) &&
1529 "Only int and float types are supported as non-array arguments");
1530
1531 return MakeDeclareScalarParam(ParamSymbol, TySize);
1532 }();
1533
1534 if (IsByVal) {
1535 assert(ArgOutVals.size() == 1 && "We must pass only one value as byval");
1536 SDValue SrcPtr = ArgOutVals[0];
1537 const auto PointerInfo = refinePtrAS(Ptr&: SrcPtr, DAG, DL, TL: *this);
1538 const Align BaseSrcAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
1539
1540 if (IsVAArg)
1541 VAOffset = alignTo(Size: VAOffset, A: ArgAlign);
1542
1543 SmallVector<EVT, 4> ValueVTs, MemVTs;
1544 SmallVector<TypeSize, 4> Offsets;
1545 ComputeValueVTs(TLI: *this, DL, Ty: ETy, ValueVTs, MemVTs: &MemVTs, Offsets: &Offsets);
1546
1547 unsigned J = 0;
1548 const auto VI = VectorizePTXValueVTs(ValueVTs: MemVTs, Offsets, ParamAlignment: ArgAlign, IsVAArg);
1549 for (const unsigned NumElts : VI) {
1550 EVT LoadVT = getVectorizedVT(VT: MemVTs[J], N: NumElts, C&: Ctx);
1551 Align SrcAlign = commonAlignment(A: BaseSrcAlign, Offset: Offsets[J]);
1552 SDValue SrcAddr = DAG.getObjectPtrOffset(SL: dl, Ptr: SrcPtr, Offset: Offsets[J]);
1553 SDValue SrcLoad =
1554 DAG.getLoad(VT: LoadVT, dl, Chain: CallChain, Ptr: SrcAddr, PtrInfo: PointerInfo, Alignment: SrcAlign);
1555
1556 TypeSize ParamOffset = Offsets[J].getWithIncrement(RHS: VAOffset);
1557 Align ParamAlign = commonAlignment(A: ArgAlign, Offset: ParamOffset);
1558 SDValue ParamAddr =
1559 DAG.getObjectPtrOffset(SL: dl, Ptr: ParamSymbol, Offset: ParamOffset);
1560 SDValue StoreParam =
1561 DAG.getStore(Chain: ArgDeclare, dl, Val: SrcLoad, Ptr: ParamAddr,
1562 PtrInfo: MachinePointerInfo(ADDRESS_SPACE_PARAM), Alignment: ParamAlign);
1563 CallPrereqs.push_back(Elt: StoreParam);
1564
1565 J += NumElts;
1566 }
1567 if (IsVAArg)
1568 VAOffset += TySize;
1569 } else {
1570 SmallVector<EVT, 16> VTs;
1571 SmallVector<uint64_t, 16> Offsets;
1572 ComputePTXValueVTs(TLI: *this, DL, Ctx, CallConv: CLI.CallConv, Ty: Arg.Ty, ValueVTs&: VTs, Offsets,
1573 StartingOffset: VAOffset);
1574 assert(VTs.size() == Offsets.size() && "Size mismatch");
1575 assert(VTs.size() == ArgOuts.size() && "Size mismatch");
1576
1577 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
1578 // than 32-bits are sign extended or zero extended, depending on
1579 // whether they are signed or unsigned types. This case applies
1580 // only to scalar parameters and not to aggregate values.
1581 const bool ExtendIntegerParam =
1582 Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty: Arg.Ty) < 32;
1583
1584 const auto GetStoredValue = [&](const unsigned I) {
1585 SDValue StVal = ArgOutVals[I];
1586 assert(promoteScalarIntegerPTX(StVal.getValueType()) ==
1587 StVal.getValueType() &&
1588 "OutVal type should always be legal");
1589
1590 const EVT VTI = promoteScalarIntegerPTX(VT: VTs[I]);
1591 const EVT StoreVT =
1592 ExtendIntegerParam ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
1593
1594 return correctParamType(V: StVal, ExpectedVT: StoreVT, Flags: ArgOuts[I].Flags, DAG, dl);
1595 };
1596
1597 unsigned J = 0;
1598 const auto VI = VectorizePTXValueVTs(ValueVTs: VTs, Offsets, ParamAlignment: ArgAlign, IsVAArg);
1599 for (const unsigned NumElts : VI) {
1600 const EVT EltVT = promoteScalarIntegerPTX(VT: VTs[J]);
1601
1602 unsigned Offset;
1603 if (IsVAArg) {
1604 // TODO: We may need to support vector types that can be passed
1605 // as scalars in variadic arguments.
1606 assert(NumElts == 1 &&
1607 "Vectorization should be disabled for vaargs.");
1608
1609 // Align each part of the variadic argument to their type.
1610 VAOffset = alignTo(Size: VAOffset, A: DAG.getEVTAlign(MemoryVT: EltVT));
1611 Offset = VAOffset;
1612
1613 const EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
1614 VAOffset += DL.getTypeAllocSize(Ty: TheStoreType.getTypeForEVT(Context&: Ctx));
1615 } else {
1616 assert(VAOffset == 0 && "VAOffset must be 0 for non-VA args");
1617 Offset = Offsets[J];
1618 }
1619
1620 SDValue Ptr =
1621 DAG.getObjectPtrOffset(SL: dl, Ptr: ParamSymbol, Offset: TypeSize::getFixed(ExactSize: Offset));
1622
1623 const MaybeAlign CurrentAlign = ExtendIntegerParam
1624 ? MaybeAlign(std::nullopt)
1625 : commonAlignment(A: ArgAlign, Offset);
1626
1627 SDValue Val =
1628 getBuildVectorizedValue(N: NumElts, dl, DAG, GetElement: [&](unsigned K) {
1629 return GetStoredValue(J + K);
1630 });
1631
1632 SDValue StoreParam =
1633 DAG.getStore(Chain: ArgDeclare, dl, Val, Ptr,
1634 PtrInfo: MachinePointerInfo(ADDRESS_SPACE_PARAM), Alignment: CurrentAlign);
1635 CallPrereqs.push_back(Elt: StoreParam);
1636
1637 J += NumElts;
1638 }
1639 }
1640 }
1641
1642 // Handle Result
1643 if (!Ins.empty()) {
1644 const SDValue RetSymbol = DAG.getExternalSymbol(Sym: "retval0", VT: MVT::i32);
1645 const unsigned ResultSize = DL.getTypeAllocSize(Ty: RetTy);
1646 if (shouldPassAsArray(Ty: RetTy)) {
1647 const Align RetAlign = getArgumentAlignment(CB, Ty: RetTy, Idx: 0, DL);
1648 MakeDeclareArrayParam(RetSymbol, RetAlign, ResultSize);
1649 } else {
1650 MakeDeclareScalarParam(RetSymbol, ResultSize);
1651 }
1652 }
1653
1654 // Set the size of the vararg param byte array if the callee is a variadic
1655 // function and the variadic part is not empty.
1656 if (VADeclareParam) {
1657 SDValue DeclareParamOps[] = {VADeclareParam.getOperand(i: 0),
1658 VADeclareParam.getOperand(i: 1),
1659 VADeclareParam.getOperand(i: 2), GetI32(VAOffset),
1660 VADeclareParam.getOperand(i: 4)};
1661 DAG.MorphNodeTo(N: VADeclareParam.getNode(), Opc: VADeclareParam.getOpcode(),
1662 VTs: VADeclareParam->getVTList(), Ops: DeclareParamOps);
1663 }
1664
1665 const auto *Func = dyn_cast<GlobalAddressSDNode>(Val: Callee.getNode());
1666 // If the type of the callsite does not match that of the function, convert
1667 // the callsite to an indirect call.
1668 const bool ConvertToIndirectCall = shouldConvertToIndirectCall(CB, Func);
1669
1670 // Both indirect calls and libcalls have nullptr Func. In order to distinguish
1671 // between them we must rely on the call site value which is valid for
1672 // indirect calls but is always null for libcalls.
1673 const bool IsIndirectCall = (!Func && CB) || ConvertToIndirectCall;
1674
1675 if (isa<ExternalSymbolSDNode>(Val: Callee)) {
1676 Function* CalleeFunc = nullptr;
1677
1678 // Try to find the callee in the current module.
1679 Callee = DAG.getSymbolFunctionGlobalAddress(Op: Callee, TargetFunction: &CalleeFunc);
1680 assert(CalleeFunc != nullptr && "Libcall callee must be set.");
1681
1682 // Set the "libcall callee" attribute to indicate that the function
1683 // must always have a declaration.
1684 CalleeFunc->addFnAttr(Kind: "nvptx-libcall-callee", Val: "true");
1685 }
1686
1687 if (IsIndirectCall) {
1688 // This is indirect function call case : PTX requires a prototype of the
1689 // form
1690 // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
1691 // to be emitted, and the label has to used as the last arg of call
1692 // instruction.
1693 // The prototype is embedded in a string and put as the operand for a
1694 // CallPrototype SDNode which will print out to the value of the string.
1695 const bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs);
1696 std::string Proto =
1697 getPrototype(DL, RetTy, Args, Outs: CLI.Outs,
1698 FirstVAArg: HasVAArgs ? std::optional(FirstVAArg) : std::nullopt, CB: *CB,
1699 UniqueCallSite);
1700 const char *ProtoStr = nvTM->getStrPool().save(S: Proto).data();
1701 const SDValue PrototypeDeclare = DAG.getNode(
1702 Opcode: NVPTXISD::CallPrototype, DL: dl, VT: MVT::Other,
1703 Ops: {StartChain, DAG.getTargetExternalSymbol(Sym: ProtoStr, VT: MVT::i32)});
1704 CallPrereqs.push_back(Elt: PrototypeDeclare);
1705 }
1706
1707 const unsigned Proto = IsIndirectCall ? UniqueCallSite : 0;
1708 const unsigned NumArgs =
1709 std::min<unsigned>(a: CLI.NumFixedArgs + 1, b: Args.size());
1710 /// CALL(Chain, IsConvergent, IsIndirectCall/IsUniform, NumReturns,
1711 /// NumParams, Callee, Proto)
1712 const SDValue CallToken = DAG.getTokenFactor(DL: dl, Vals&: CallPrereqs);
1713 const SDValue Call = DAG.getNode(
1714 Opcode: NVPTXISD::CALL, DL: dl, VT: MVT::Other,
1715 Ops: {CallToken, GetI32(CLI.IsConvergent), GetI32(IsIndirectCall),
1716 GetI32(Ins.empty() ? 0 : 1), GetI32(NumArgs), Callee, GetI32(Proto)});
1717
1718 SmallVector<SDValue, 16> LoadChains{Call};
1719 SmallVector<SDValue, 16> ProxyRegOps;
1720 if (!Ins.empty()) {
1721 SmallVector<EVT, 16> VTs;
1722 SmallVector<uint64_t, 16> Offsets;
1723 ComputePTXValueVTs(TLI: *this, DL, Ctx, CallConv: CLI.CallConv, Ty: RetTy, ValueVTs&: VTs, Offsets);
1724 assert(VTs.size() == Ins.size() && "Bad value decomposition");
1725
1726 const Align RetAlign = getArgumentAlignment(CB, Ty: RetTy, Idx: 0, DL);
1727 const SDValue RetSymbol = DAG.getExternalSymbol(Sym: "retval0", VT: MVT::i32);
1728
1729 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
1730 // 32-bits are sign extended or zero extended, depending on whether
1731 // they are signed or unsigned types.
1732 const bool ExtendIntegerRetVal =
1733 RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty: RetTy) < 32;
1734
1735 unsigned I = 0;
1736 const auto VI = VectorizePTXValueVTs(ValueVTs: VTs, Offsets, ParamAlignment: RetAlign);
1737 for (const unsigned NumElts : VI) {
1738 const MaybeAlign CurrentAlign =
1739 ExtendIntegerRetVal ? MaybeAlign(std::nullopt)
1740 : commonAlignment(A: RetAlign, Offset: Offsets[I]);
1741
1742 const EVT VTI = promoteScalarIntegerPTX(VT: VTs[I]);
1743 const EVT LoadVT =
1744 ExtendIntegerRetVal ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
1745 const EVT VecVT = getVectorizedVT(VT: LoadVT, N: NumElts, C&: Ctx);
1746 SDValue Ptr =
1747 DAG.getObjectPtrOffset(SL: dl, Ptr: RetSymbol, Offset: TypeSize::getFixed(ExactSize: Offsets[I]));
1748
1749 SDValue R =
1750 DAG.getLoad(VT: VecVT, dl, Chain: Call, Ptr,
1751 PtrInfo: MachinePointerInfo(ADDRESS_SPACE_PARAM), Alignment: CurrentAlign);
1752
1753 LoadChains.push_back(Elt: R.getValue(R: 1));
1754 for (const unsigned J : llvm::seq(Size: NumElts))
1755 ProxyRegOps.push_back(Elt: getExtractVectorizedValue(V: R, I: J, VT: LoadVT, dl, DAG));
1756 I += NumElts;
1757 }
1758 }
1759
1760 const SDValue EndToken = DAG.getTokenFactor(DL: dl, Vals&: LoadChains);
1761 const SDValue CallEnd = DAG.getCALLSEQ_END(Chain: EndToken, Size1: UniqueCallSite,
1762 Size2: UniqueCallSite + 1, Glue: SDValue(), DL: dl);
1763
1764 // Append ProxyReg instructions to the chain to make sure that `callseq_end`
1765 // will not get lost. Otherwise, during libcalls expansion, the nodes can become
1766 // dangling.
1767 for (const auto [I, Reg] : llvm::enumerate(First&: ProxyRegOps)) {
1768 SDValue Proxy =
1769 DAG.getNode(Opcode: NVPTXISD::ProxyReg, DL: dl, VT: Reg.getValueType(), Ops: {CallEnd, Reg});
1770 SDValue Ret = correctParamType(V: Proxy, ExpectedVT: Ins[I].VT, Flags: Ins[I].Flags, DAG, dl);
1771 InVals.push_back(Elt: Ret);
1772 }
1773
1774 // set IsTailCall to false for now, until we figure out how to express
1775 // tail call optimization in PTX
1776 CLI.IsTailCall = false;
1777 return CallEnd;
1778}
1779
1780SDValue NVPTXTargetLowering::LowerDYNAMIC_STACKALLOC(SDValue Op,
1781 SelectionDAG &DAG) const {
1782
1783 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
1784 const Function &Fn = DAG.getMachineFunction().getFunction();
1785
1786 DAG.getContext()->diagnose(DI: DiagnosticInfoUnsupported(
1787 Fn,
1788 "Support for dynamic alloca introduced in PTX ISA version 7.3 and "
1789 "requires target sm_52.",
1790 SDLoc(Op).getDebugLoc()));
1791 auto Ops = {DAG.getConstant(Val: 0, DL: SDLoc(), VT: Op.getValueType()),
1792 Op.getOperand(i: 0)};
1793 return DAG.getMergeValues(Ops, dl: SDLoc());
1794 }
1795
1796 SDLoc DL(Op.getNode());
1797 SDValue Chain = Op.getOperand(i: 0);
1798 SDValue Size = Op.getOperand(i: 1);
1799 uint64_t Align = Op.getConstantOperandVal(i: 2);
1800
1801 // The alignment on a ISD::DYNAMIC_STACKALLOC node may be 0 to indicate that
1802 // the default stack alignment should be used.
1803 if (Align == 0)
1804 Align = DAG.getSubtarget().getFrameLowering()->getStackAlign().value();
1805
1806 // The size for ptx alloca instruction is 64-bit for m64 and 32-bit for m32.
1807 const MVT LocalVT = getPointerTy(DL: DAG.getDataLayout(), AS: ADDRESS_SPACE_LOCAL);
1808
1809 SDValue Alloc =
1810 DAG.getNode(Opcode: NVPTXISD::DYNAMIC_STACKALLOC, DL, ResultTys: {LocalVT, MVT::Other},
1811 Ops: {Chain, DAG.getZExtOrTrunc(Op: Size, DL, VT: LocalVT),
1812 DAG.getTargetConstant(Val: Align, DL, VT: MVT::i32)});
1813
1814 SDValue ASC = DAG.getAddrSpaceCast(
1815 dl: DL, VT: Op.getValueType(), Ptr: Alloc, SrcAS: ADDRESS_SPACE_LOCAL, DestAS: ADDRESS_SPACE_GENERIC);
1816
1817 return DAG.getMergeValues(Ops: {ASC, SDValue(Alloc.getNode(), 1)}, dl: DL);
1818}
1819
1820SDValue NVPTXTargetLowering::LowerSTACKRESTORE(SDValue Op,
1821 SelectionDAG &DAG) const {
1822 SDLoc DL(Op.getNode());
1823 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
1824 const Function &Fn = DAG.getMachineFunction().getFunction();
1825
1826 DAG.getContext()->diagnose(DI: DiagnosticInfoUnsupported(
1827 Fn,
1828 "Support for stackrestore requires PTX ISA version >= 7.3 and target "
1829 ">= sm_52.",
1830 DL.getDebugLoc()));
1831 return Op.getOperand(i: 0);
1832 }
1833
1834 const MVT LocalVT = getPointerTy(DL: DAG.getDataLayout(), AS: ADDRESS_SPACE_LOCAL);
1835 SDValue Chain = Op.getOperand(i: 0);
1836 SDValue Ptr = Op.getOperand(i: 1);
1837 SDValue ASC = DAG.getAddrSpaceCast(dl: DL, VT: LocalVT, Ptr, SrcAS: ADDRESS_SPACE_GENERIC,
1838 DestAS: ADDRESS_SPACE_LOCAL);
1839 return DAG.getNode(Opcode: NVPTXISD::STACKRESTORE, DL, VT: MVT::Other, Ops: {Chain, ASC});
1840}
1841
1842SDValue NVPTXTargetLowering::LowerSTACKSAVE(SDValue Op,
1843 SelectionDAG &DAG) const {
1844 SDLoc DL(Op.getNode());
1845 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
1846 const Function &Fn = DAG.getMachineFunction().getFunction();
1847
1848 DAG.getContext()->diagnose(DI: DiagnosticInfoUnsupported(
1849 Fn,
1850 "Support for stacksave requires PTX ISA version >= 7.3 and target >= "
1851 "sm_52.",
1852 DL.getDebugLoc()));
1853 auto Ops = {DAG.getConstant(Val: 0, DL, VT: Op.getValueType()), Op.getOperand(i: 0)};
1854 return DAG.getMergeValues(Ops, dl: DL);
1855 }
1856
1857 const MVT LocalVT = getPointerTy(DL: DAG.getDataLayout(), AS: ADDRESS_SPACE_LOCAL);
1858 SDValue Chain = Op.getOperand(i: 0);
1859 SDValue SS =
1860 DAG.getNode(Opcode: NVPTXISD::STACKSAVE, DL, ResultTys: {LocalVT, MVT::Other}, Ops: Chain);
1861 SDValue ASC = DAG.getAddrSpaceCast(
1862 dl: DL, VT: Op.getValueType(), Ptr: SS, SrcAS: ADDRESS_SPACE_LOCAL, DestAS: ADDRESS_SPACE_GENERIC);
1863 return DAG.getMergeValues(Ops: {ASC, SDValue(SS.getNode(), 1)}, dl: DL);
1864}
1865
1866// By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
1867// (see LegalizeDAG.cpp). This is slow and uses local memory.
1868// We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
1869SDValue
1870NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
1871 SDNode *Node = Op.getNode();
1872 SDLoc dl(Node);
1873 SmallVector<SDValue, 8> Ops;
1874 unsigned NumOperands = Node->getNumOperands();
1875 for (unsigned i = 0; i < NumOperands; ++i) {
1876 SDValue SubOp = Node->getOperand(Num: i);
1877 EVT VVT = SubOp.getNode()->getValueType(ResNo: 0);
1878 EVT EltVT = VVT.getVectorElementType();
1879 unsigned NumSubElem = VVT.getVectorNumElements();
1880 for (unsigned j = 0; j < NumSubElem; ++j) {
1881 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: dl, VT: EltVT, N1: SubOp,
1882 N2: DAG.getIntPtrConstant(Val: j, DL: dl)));
1883 }
1884 }
1885 return DAG.getBuildVector(VT: Node->getValueType(ResNo: 0), DL: dl, Ops);
1886}
1887
1888static SDValue getPRMT(SDValue A, SDValue B, SDValue Selector, SDLoc DL,
1889 SelectionDAG &DAG,
1890 unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
1891 assert(A.getValueType() == MVT::i32 && B.getValueType() == MVT::i32 &&
1892 Selector.getValueType() == MVT::i32 && "PRMT must have i32 operands");
1893 return DAG.getNode(Opcode: NVPTXISD::PRMT, DL, VT: MVT::i32,
1894 Ops: {A, B, Selector, DAG.getConstant(Val: Mode, DL, VT: MVT::i32)});
1895}
1896
1897static SDValue getPRMT(SDValue A, SDValue B, uint64_t Selector, SDLoc DL,
1898 SelectionDAG &DAG,
1899 unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
1900 return getPRMT(A, B, Selector: DAG.getConstant(Val: Selector, DL, VT: MVT::i32), DL, DAG, Mode);
1901}
1902
1903/// Reduces the elements using the scalar operations provided. The operations
1904/// are sorted descending in number of inputs they take. The flags on the
1905/// original reduction operation will be propagated to each scalar operation.
1906/// Nearby elements are grouped in tree reduction, unlike the shuffle reduction
1907/// used in ExpandReductions and SelectionDAG.
1908static SDValue buildTreeReduction(
1909 const SmallVector<SDValue> &Elements, EVT EltTy,
1910 ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
1911 const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
1912 // Build the reduction tree at each level, starting with all the elements.
1913 SmallVector<SDValue> Level = Elements;
1914
1915 unsigned OpIdx = 0;
1916 while (Level.size() > 1) {
1917 // Try to reduce this level using the current operator.
1918 const auto [Op, NumInputs] = Ops[OpIdx];
1919
1920 // Build the next level by partially reducing all elements.
1921 SmallVector<SDValue> ReducedLevel;
1922 unsigned I = 0, E = Level.size();
1923 for (; I + NumInputs <= E; I += NumInputs) {
1924 // Reduce elements in groups of [NumInputs], as much as possible.
1925 ReducedLevel.push_back(Elt: DAG.getNode(
1926 Opcode: Op, DL, VT: EltTy, Ops: ArrayRef<SDValue>(Level).slice(N: I, M: NumInputs), Flags));
1927 }
1928
1929 if (I < E) {
1930 // Handle leftover elements.
1931
1932 if (ReducedLevel.empty()) {
1933 // We didn't reduce anything at this level. We need to pick a smaller
1934 // operator.
1935 ++OpIdx;
1936 assert(OpIdx < Ops.size() && "no smaller operators for reduction");
1937 continue;
1938 }
1939
1940 // We reduced some things but there's still more left, meaning the
1941 // operator's number of inputs doesn't evenly divide this level size. Move
1942 // these elements to the next level.
1943 for (; I < E; ++I)
1944 ReducedLevel.push_back(Elt: Level[I]);
1945 }
1946
1947 // Process the next level.
1948 Level = ReducedLevel;
1949 }
1950
1951 return *Level.begin();
1952}
1953
1954// Get scalar reduction opcode
1955static ISD::NodeType getScalarOpcodeForReduction(unsigned ReductionOpcode) {
1956 switch (ReductionOpcode) {
1957 case ISD::VECREDUCE_FMAX:
1958 return ISD::FMAXNUM;
1959 case ISD::VECREDUCE_FMIN:
1960 return ISD::FMINNUM;
1961 case ISD::VECREDUCE_FMAXIMUM:
1962 return ISD::FMAXIMUM;
1963 case ISD::VECREDUCE_FMINIMUM:
1964 return ISD::FMINIMUM;
1965 default:
1966 llvm_unreachable("unhandled reduction opcode");
1967 }
1968}
1969
1970/// Get 3-input scalar reduction opcode
1971static std::optional<unsigned>
1972getScalar3OpcodeForReduction(unsigned ReductionOpcode) {
1973 switch (ReductionOpcode) {
1974 case ISD::VECREDUCE_FMAX:
1975 return NVPTXISD::FMAXNUM3;
1976 case ISD::VECREDUCE_FMIN:
1977 return NVPTXISD::FMINNUM3;
1978 case ISD::VECREDUCE_FMAXIMUM:
1979 return NVPTXISD::FMAXIMUM3;
1980 case ISD::VECREDUCE_FMINIMUM:
1981 return NVPTXISD::FMINIMUM3;
1982 default:
1983 return std::nullopt;
1984 }
1985}
1986
1987/// Lower reductions to either a sequence of operations or a tree if
1988/// reassociations are allowed. This method will use larger operations like
1989/// max3/min3 when the target supports them.
1990SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
1991 SelectionDAG &DAG) const {
1992 SDLoc DL(Op);
1993 const SDNodeFlags Flags = Op->getFlags();
1994 SDValue Vector = Op.getOperand(i: 0);
1995
1996 const unsigned Opcode = Op->getOpcode();
1997 const EVT EltTy = Vector.getValueType().getVectorElementType();
1998
1999 // Whether we can use 3-input min/max when expanding the reduction.
2000 const bool CanUseMinMax3 =
2001 EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
2002 STI.getPTXVersion() >= 88 &&
2003 (Opcode == ISD::VECREDUCE_FMAX || Opcode == ISD::VECREDUCE_FMIN ||
2004 Opcode == ISD::VECREDUCE_FMAXIMUM || Opcode == ISD::VECREDUCE_FMINIMUM);
2005
2006 // A list of SDNode opcodes with equivalent semantics, sorted descending by
2007 // number of inputs they take.
2008 SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> ScalarOps;
2009
2010 if (auto Opcode3Elem = getScalar3OpcodeForReduction(ReductionOpcode: Opcode);
2011 CanUseMinMax3 && Opcode3Elem)
2012 ScalarOps.push_back(Elt: {*Opcode3Elem, 3});
2013 ScalarOps.push_back(Elt: {getScalarOpcodeForReduction(ReductionOpcode: Opcode), 2});
2014
2015 SmallVector<SDValue> Elements;
2016 DAG.ExtractVectorElements(Op: Vector, Args&: Elements);
2017
2018 return buildTreeReduction(Elements, EltTy, Ops: ScalarOps, DL, Flags, DAG);
2019}
2020
2021SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
2022 // Handle bitcasting from v2i8 without hitting the default promotion
2023 // strategy which goes through stack memory.
2024 EVT FromVT = Op->getOperand(Num: 0)->getValueType(ResNo: 0);
2025 if (FromVT != MVT::v2i8) {
2026 return Op;
2027 }
2028
2029 // Pack vector elements into i16 and bitcast to final type
2030 SDLoc DL(Op);
2031 SDValue Vec0 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i8,
2032 N1: Op->getOperand(Num: 0), N2: DAG.getIntPtrConstant(Val: 0, DL));
2033 SDValue Vec1 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i8,
2034 N1: Op->getOperand(Num: 0), N2: DAG.getIntPtrConstant(Val: 1, DL));
2035 SDValue Extend0 = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: MVT::i16, Operand: Vec0);
2036 SDValue Extend1 = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: MVT::i16, Operand: Vec1);
2037 SDValue Const8 = DAG.getConstant(Val: 8, DL, VT: MVT::i16);
2038 SDValue AsInt = DAG.getNode(
2039 Opcode: ISD::OR, DL, VT: MVT::i16,
2040 Ops: {Extend0, DAG.getNode(Opcode: ISD::SHL, DL, VT: MVT::i16, Ops: {Extend1, Const8})});
2041 EVT ToVT = Op->getValueType(ResNo: 0);
2042 return DAG.getBitcast(VT: ToVT, V: AsInt);
2043}
2044
2045// We can init constant f16x2/v2i16/v4i8 with a single .b32 move. Normally it
2046// would get lowered as two constant loads and vector-packing move.
2047// Instead we want just a constant move:
2048// mov.b32 %r2, 0x40003C00
2049SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
2050 SelectionDAG &DAG) const {
2051 EVT VT = Op->getValueType(ResNo: 0);
2052 if (!(NVPTX::isPackedVectorTy(VT) && VT.is32BitVector()))
2053 return Op;
2054 SDLoc DL(Op);
2055
2056 if (!llvm::all_of(Range: Op->ops(), P: [](SDValue Operand) {
2057 return Operand->isUndef() || isa<ConstantSDNode>(Val: Operand) ||
2058 isa<ConstantFPSDNode>(Val: Operand);
2059 })) {
2060 if (VT != MVT::v4i8)
2061 return Op;
2062 // Lower non-const v4i8 vector as byte-wise constructed i32, which allows us
2063 // to optimize calculation of constant parts.
2064 auto GetPRMT = [&](const SDValue Left, const SDValue Right, bool Cast,
2065 uint64_t SelectionValue) -> SDValue {
2066 SDValue L = Left;
2067 SDValue R = Right;
2068 if (Cast) {
2069 L = DAG.getAnyExtOrTrunc(Op: L, DL, VT: MVT::i32);
2070 R = DAG.getAnyExtOrTrunc(Op: R, DL, VT: MVT::i32);
2071 }
2072 return getPRMT(A: L, B: R, Selector: SelectionValue, DL, DAG);
2073 };
2074 auto PRMT__10 = GetPRMT(Op->getOperand(Num: 0), Op->getOperand(Num: 1), true, 0x3340);
2075 auto PRMT__32 = GetPRMT(Op->getOperand(Num: 2), Op->getOperand(Num: 3), true, 0x3340);
2076 auto PRMT3210 = GetPRMT(PRMT__10, PRMT__32, false, 0x5410);
2077 return DAG.getBitcast(VT, V: PRMT3210);
2078 }
2079
2080 // Get value or the Nth operand as an APInt(32). Undef values treated as 0.
2081 auto GetOperand = [](SDValue Op, int N) -> APInt {
2082 const SDValue &Operand = Op->getOperand(Num: N);
2083 EVT VT = Op->getValueType(ResNo: 0);
2084 if (Operand->isUndef())
2085 return APInt(32, 0);
2086 APInt Value;
2087 if (VT == MVT::v2f16 || VT == MVT::v2bf16)
2088 Value = cast<ConstantFPSDNode>(Val: Operand)->getValueAPF().bitcastToAPInt();
2089 else if (VT == MVT::v2i16 || VT == MVT::v4i8)
2090 Value = Operand->getAsAPIntVal();
2091 else
2092 llvm_unreachable("Unsupported type");
2093 // i8 values are carried around as i16, so we need to zero out upper bits,
2094 // so they do not get in the way of combining individual byte values
2095 if (VT == MVT::v4i8)
2096 Value = Value.trunc(width: 8);
2097 return Value.zext(width: 32);
2098 };
2099
2100 // Construct a 32-bit constant by shifting into place smaller values
2101 // (elements of the vector type VT).
2102 // For example, if VT has 2 elements, then N == 2:
2103 // ShiftAmount = 32 / N = 16
2104 // Value |= Op0 (b16) << 0
2105 // Value |= Op1 (b16) << 16
2106 // If N == 4:
2107 // ShiftAmount = 32 / N = 8
2108 // Value |= Op0 (b8) << 0
2109 // Value |= Op1 (b8) << 8
2110 // Value |= Op2 (b8) << 16
2111 // Value |= Op3 (b8) << 24
2112 // ...etc
2113 APInt Value(32, 0);
2114 const unsigned NumElements = VT.getVectorNumElements();
2115 assert(32 % NumElements == 0 && "must evenly divide bit length");
2116 const unsigned ShiftAmount = 32 / NumElements;
2117 for (unsigned ElementNo : seq(Size: NumElements))
2118 Value |= GetOperand(Op, ElementNo).shl(shiftAmt: ElementNo * ShiftAmount);
2119 SDValue Const = DAG.getConstant(Val: Value, DL, VT: MVT::i32);
2120 return DAG.getNode(Opcode: ISD::BITCAST, DL, VT: Op->getValueType(ResNo: 0), Operand: Const);
2121}
2122
2123SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
2124 SelectionDAG &DAG) const {
2125 SDValue Index = Op->getOperand(Num: 1);
2126 SDValue Vector = Op->getOperand(Num: 0);
2127 SDLoc DL(Op);
2128 EVT VectorVT = Vector.getValueType();
2129
2130 if (VectorVT == MVT::v4i8) {
2131 SDValue Selector = DAG.getNode(Opcode: ISD::OR, DL, VT: MVT::i32,
2132 N1: DAG.getZExtOrTrunc(Op: Index, DL, VT: MVT::i32),
2133 N2: DAG.getConstant(Val: 0x7770, DL, VT: MVT::i32));
2134 SDValue PRMT = getPRMT(A: DAG.getBitcast(VT: MVT::i32, V: Vector),
2135 B: DAG.getConstant(Val: 0, DL, VT: MVT::i32), Selector, DL, DAG);
2136 SDValue Ext = DAG.getAnyExtOrTrunc(Op: PRMT, DL, VT: Op->getValueType(ResNo: 0));
2137 SDNodeFlags Flags;
2138 Flags.setNoSignedWrap(Ext.getScalarValueSizeInBits() > 8);
2139 Flags.setNoUnsignedWrap(Ext.getScalarValueSizeInBits() >= 8);
2140 Ext->setFlags(Flags);
2141 return Ext;
2142 }
2143
2144 // Constant index will be matched by tablegen.
2145 if (isa<ConstantSDNode>(Val: Index.getNode()))
2146 return Op;
2147
2148 // Extract individual elements and select one of them.
2149 assert(NVPTX::isPackedVectorTy(VectorVT) &&
2150 VectorVT.getVectorNumElements() == 2 && "Unexpected vector type.");
2151 EVT EltVT = VectorVT.getVectorElementType();
2152
2153 SDLoc dl(Op.getNode());
2154 SDValue E0 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: dl, VT: EltVT, N1: Vector,
2155 N2: DAG.getIntPtrConstant(Val: 0, DL: dl));
2156 SDValue E1 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: dl, VT: EltVT, N1: Vector,
2157 N2: DAG.getIntPtrConstant(Val: 1, DL: dl));
2158 return DAG.getSelectCC(DL: dl, LHS: Index, RHS: DAG.getIntPtrConstant(Val: 0, DL: dl), True: E0, False: E1,
2159 Cond: ISD::CondCode::SETEQ);
2160}
2161
2162SDValue NVPTXTargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op,
2163 SelectionDAG &DAG) const {
2164 SDValue Vector = Op->getOperand(Num: 0);
2165 EVT VectorVT = Vector.getValueType();
2166
2167 if (VectorVT != MVT::v4i8)
2168 return Op;
2169 SDLoc DL(Op);
2170 SDValue Value = Op->getOperand(Num: 1);
2171 if (Value->isUndef())
2172 return Vector;
2173
2174 SDValue Index = Op->getOperand(Num: 2);
2175
2176 SDValue BFI =
2177 DAG.getNode(Opcode: NVPTXISD::BFI, DL, VT: MVT::i32,
2178 Ops: {DAG.getZExtOrTrunc(Op: Value, DL, VT: MVT::i32), Vector,
2179 DAG.getNode(Opcode: ISD::MUL, DL, VT: MVT::i32,
2180 N1: DAG.getZExtOrTrunc(Op: Index, DL, VT: MVT::i32),
2181 N2: DAG.getConstant(Val: 8, DL, VT: MVT::i32)),
2182 DAG.getConstant(Val: 8, DL, VT: MVT::i32)});
2183 return DAG.getNode(Opcode: ISD::BITCAST, DL, VT: Op->getValueType(ResNo: 0), Operand: BFI);
2184}
2185
2186SDValue NVPTXTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
2187 SelectionDAG &DAG) const {
2188 SDValue V1 = Op.getOperand(i: 0);
2189 EVT VectorVT = V1.getValueType();
2190 if (VectorVT != MVT::v4i8 || Op.getValueType() != MVT::v4i8)
2191 return Op;
2192
2193 // Lower shuffle to PRMT instruction.
2194 const ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Val: Op.getNode());
2195 SDValue V2 = Op.getOperand(i: 1);
2196 uint32_t Selector = 0;
2197 for (auto I : llvm::enumerate(First: SVN->getMask())) {
2198 if (I.value() != -1) // -1 is a placeholder for undef.
2199 Selector |= (I.value() << (I.index() * 4));
2200 }
2201
2202 SDLoc DL(Op);
2203 SDValue PRMT = getPRMT(A: DAG.getBitcast(VT: MVT::i32, V: V1),
2204 B: DAG.getBitcast(VT: MVT::i32, V: V2), Selector, DL, DAG);
2205 return DAG.getBitcast(VT: Op.getValueType(), V: PRMT);
2206}
2207/// LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which
2208/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
2209/// amount, or
2210/// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift
2211/// amount.
2212SDValue NVPTXTargetLowering::LowerShiftRightParts(SDValue Op,
2213 SelectionDAG &DAG) const {
2214 assert(Op.getNumOperands() == 3 && "Not a double-shift!");
2215 assert(Op.getOpcode() == ISD::SRA_PARTS || Op.getOpcode() == ISD::SRL_PARTS);
2216
2217 EVT VT = Op.getValueType();
2218 unsigned VTBits = VT.getSizeInBits();
2219 SDLoc dl(Op);
2220 SDValue ShOpLo = Op.getOperand(i: 0);
2221 SDValue ShOpHi = Op.getOperand(i: 1);
2222 SDValue ShAmt = Op.getOperand(i: 2);
2223 unsigned Opc = (Op.getOpcode() == ISD::SRA_PARTS) ? ISD::SRA : ISD::SRL;
2224
2225 if (VTBits == 32 && STI.getSmVersion() >= 35) {
2226 // For 32bit and sm35, we can use the funnel shift 'shf' instruction.
2227 // {dHi, dLo} = {aHi, aLo} >> Amt
2228 // dHi = aHi >> Amt
2229 // dLo = shf.r.clamp aLo, aHi, Amt
2230
2231 SDValue Hi = DAG.getNode(Opcode: Opc, DL: dl, VT, N1: ShOpHi, N2: ShAmt);
2232 SDValue Lo =
2233 DAG.getNode(Opcode: NVPTXISD::FSHR_CLAMP, DL: dl, VT, N1: ShOpHi, N2: ShOpLo, N3: ShAmt);
2234
2235 SDValue Ops[2] = { Lo, Hi };
2236 return DAG.getMergeValues(Ops, dl);
2237 }
2238 else {
2239 // {dHi, dLo} = {aHi, aLo} >> Amt
2240 // - if (Amt>=size) then
2241 // dLo = aHi >> (Amt-size)
2242 // dHi = aHi >> Amt (this is either all 0 or all 1)
2243 // else
2244 // dLo = (aLo >>logic Amt) | (aHi << (size-Amt))
2245 // dHi = aHi >> Amt
2246
2247 SDValue RevShAmt = DAG.getNode(Opcode: ISD::SUB, DL: dl, VT: MVT::i32,
2248 N1: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32),
2249 N2: ShAmt);
2250 SDValue Tmp1 = DAG.getNode(Opcode: ISD::SRL, DL: dl, VT, N1: ShOpLo, N2: ShAmt);
2251 SDValue ExtraShAmt = DAG.getNode(Opcode: ISD::SUB, DL: dl, VT: MVT::i32, N1: ShAmt,
2252 N2: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32));
2253 SDValue Tmp2 = DAG.getNode(Opcode: ISD::SHL, DL: dl, VT, N1: ShOpHi, N2: RevShAmt);
2254 SDValue FalseVal = DAG.getNode(Opcode: ISD::OR, DL: dl, VT, N1: Tmp1, N2: Tmp2);
2255 SDValue TrueVal = DAG.getNode(Opcode: Opc, DL: dl, VT, N1: ShOpHi, N2: ExtraShAmt);
2256
2257 SDValue Cmp = DAG.getSetCC(DL: dl, VT: MVT::i1, LHS: ShAmt,
2258 RHS: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32),
2259 Cond: ISD::SETGE);
2260 SDValue Hi = DAG.getNode(Opcode: Opc, DL: dl, VT, N1: ShOpHi, N2: ShAmt);
2261 SDValue Lo = DAG.getNode(Opcode: ISD::SELECT, DL: dl, VT, N1: Cmp, N2: TrueVal, N3: FalseVal);
2262
2263 SDValue Ops[2] = { Lo, Hi };
2264 return DAG.getMergeValues(Ops, dl);
2265 }
2266}
2267
2268/// LowerShiftLeftParts - Lower SHL_PARTS, which
2269/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
2270/// amount, or
2271/// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift
2272/// amount.
2273SDValue NVPTXTargetLowering::LowerShiftLeftParts(SDValue Op,
2274 SelectionDAG &DAG) const {
2275 assert(Op.getNumOperands() == 3 && "Not a double-shift!");
2276 assert(Op.getOpcode() == ISD::SHL_PARTS);
2277
2278 EVT VT = Op.getValueType();
2279 unsigned VTBits = VT.getSizeInBits();
2280 SDLoc dl(Op);
2281 SDValue ShOpLo = Op.getOperand(i: 0);
2282 SDValue ShOpHi = Op.getOperand(i: 1);
2283 SDValue ShAmt = Op.getOperand(i: 2);
2284
2285 if (VTBits == 32 && STI.getSmVersion() >= 35) {
2286 // For 32bit and sm35, we can use the funnel shift 'shf' instruction.
2287 // {dHi, dLo} = {aHi, aLo} << Amt
2288 // dHi = shf.l.clamp aLo, aHi, Amt
2289 // dLo = aLo << Amt
2290
2291 SDValue Hi =
2292 DAG.getNode(Opcode: NVPTXISD::FSHL_CLAMP, DL: dl, VT, N1: ShOpHi, N2: ShOpLo, N3: ShAmt);
2293 SDValue Lo = DAG.getNode(Opcode: ISD::SHL, DL: dl, VT, N1: ShOpLo, N2: ShAmt);
2294
2295 SDValue Ops[2] = { Lo, Hi };
2296 return DAG.getMergeValues(Ops, dl);
2297 }
2298 else {
2299 // {dHi, dLo} = {aHi, aLo} << Amt
2300 // - if (Amt>=size) then
2301 // dLo = aLo << Amt (all 0)
2302 // dLo = aLo << (Amt-size)
2303 // else
2304 // dLo = aLo << Amt
2305 // dHi = (aHi << Amt) | (aLo >> (size-Amt))
2306
2307 SDValue RevShAmt = DAG.getNode(Opcode: ISD::SUB, DL: dl, VT: MVT::i32,
2308 N1: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32),
2309 N2: ShAmt);
2310 SDValue Tmp1 = DAG.getNode(Opcode: ISD::SHL, DL: dl, VT, N1: ShOpHi, N2: ShAmt);
2311 SDValue ExtraShAmt = DAG.getNode(Opcode: ISD::SUB, DL: dl, VT: MVT::i32, N1: ShAmt,
2312 N2: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32));
2313 SDValue Tmp2 = DAG.getNode(Opcode: ISD::SRL, DL: dl, VT, N1: ShOpLo, N2: RevShAmt);
2314 SDValue FalseVal = DAG.getNode(Opcode: ISD::OR, DL: dl, VT, N1: Tmp1, N2: Tmp2);
2315 SDValue TrueVal = DAG.getNode(Opcode: ISD::SHL, DL: dl, VT, N1: ShOpLo, N2: ExtraShAmt);
2316
2317 SDValue Cmp = DAG.getSetCC(DL: dl, VT: MVT::i1, LHS: ShAmt,
2318 RHS: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32),
2319 Cond: ISD::SETGE);
2320 SDValue Lo = DAG.getNode(Opcode: ISD::SHL, DL: dl, VT, N1: ShOpLo, N2: ShAmt);
2321 SDValue Hi = DAG.getNode(Opcode: ISD::SELECT, DL: dl, VT, N1: Cmp, N2: TrueVal, N3: FalseVal);
2322
2323 SDValue Ops[2] = { Lo, Hi };
2324 return DAG.getMergeValues(Ops, dl);
2325 }
2326}
2327
2328/// If the types match, convert the generic copysign to the NVPTXISD version,
2329/// otherwise bail ensuring that mismatched cases are properly expaned.
2330SDValue NVPTXTargetLowering::LowerFCOPYSIGN(SDValue Op,
2331 SelectionDAG &DAG) const {
2332 EVT VT = Op.getValueType();
2333 SDLoc DL(Op);
2334
2335 SDValue In1 = Op.getOperand(i: 0);
2336 SDValue In2 = Op.getOperand(i: 1);
2337 EVT SrcVT = In2.getValueType();
2338
2339 if (!SrcVT.bitsEq(VT))
2340 return SDValue();
2341
2342 return DAG.getNode(Opcode: NVPTXISD::FCOPYSIGN, DL, VT, N1: In1, N2: In2);
2343}
2344
2345SDValue NVPTXTargetLowering::LowerFROUND(SDValue Op, SelectionDAG &DAG) const {
2346 EVT VT = Op.getValueType();
2347
2348 if (VT == MVT::f32)
2349 return LowerFROUND32(Op, DAG);
2350
2351 if (VT == MVT::f64)
2352 return LowerFROUND64(Op, DAG);
2353
2354 llvm_unreachable("unhandled type");
2355}
2356
2357// This is the the rounding method used in CUDA libdevice in C like code:
2358// float roundf(float A)
2359// {
2360// float RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f));
2361// RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
2362// return abs(A) < 0.5 ? (float)(int)A : RoundedA;
2363// }
2364SDValue NVPTXTargetLowering::LowerFROUND32(SDValue Op,
2365 SelectionDAG &DAG) const {
2366 SDLoc SL(Op);
2367 SDValue A = Op.getOperand(i: 0);
2368 EVT VT = Op.getValueType();
2369
2370 SDValue AbsA = DAG.getNode(Opcode: ISD::FABS, DL: SL, VT, Operand: A);
2371
2372 // RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f))
2373 SDValue Bitcast = DAG.getNode(Opcode: ISD::BITCAST, DL: SL, VT: MVT::i32, Operand: A);
2374 const unsigned SignBitMask = 0x80000000;
2375 SDValue Sign = DAG.getNode(Opcode: ISD::AND, DL: SL, VT: MVT::i32, N1: Bitcast,
2376 N2: DAG.getConstant(Val: SignBitMask, DL: SL, VT: MVT::i32));
2377 const unsigned PointFiveInBits = 0x3F000000;
2378 SDValue PointFiveWithSignRaw =
2379 DAG.getNode(Opcode: ISD::OR, DL: SL, VT: MVT::i32, N1: Sign,
2380 N2: DAG.getConstant(Val: PointFiveInBits, DL: SL, VT: MVT::i32));
2381 SDValue PointFiveWithSign =
2382 DAG.getNode(Opcode: ISD::BITCAST, DL: SL, VT, Operand: PointFiveWithSignRaw);
2383 SDValue AdjustedA = DAG.getNode(Opcode: ISD::FADD, DL: SL, VT, N1: A, N2: PointFiveWithSign);
2384 SDValue RoundedA = DAG.getNode(Opcode: ISD::FTRUNC, DL: SL, VT, Operand: AdjustedA);
2385
2386 // RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
2387 EVT SetCCVT = getSetCCResultType(DL: DAG.getDataLayout(), Ctx&: *DAG.getContext(), VT);
2388 SDValue IsLarge =
2389 DAG.getSetCC(DL: SL, VT: SetCCVT, LHS: AbsA, RHS: DAG.getConstantFP(Val: pow(x: 2.0, y: 23.0), DL: SL, VT),
2390 Cond: ISD::SETOGT);
2391 RoundedA = DAG.getNode(Opcode: ISD::SELECT, DL: SL, VT, N1: IsLarge, N2: A, N3: RoundedA);
2392
2393 // return abs(A) < 0.5 ? (float)(int)A : RoundedA;
2394 SDValue IsSmall =DAG.getSetCC(DL: SL, VT: SetCCVT, LHS: AbsA,
2395 RHS: DAG.getConstantFP(Val: 0.5, DL: SL, VT), Cond: ISD::SETOLT);
2396 SDValue RoundedAForSmallA = DAG.getNode(Opcode: ISD::FTRUNC, DL: SL, VT, Operand: A);
2397 return DAG.getNode(Opcode: ISD::SELECT, DL: SL, VT, N1: IsSmall, N2: RoundedAForSmallA, N3: RoundedA);
2398}
2399
2400// The implementation of round(double) is similar to that of round(float) in
2401// that they both separate the value range into three regions and use a method
2402// specific to the region to round the values. However, round(double) first
2403// calculates the round of the absolute value and then adds the sign back while
2404// round(float) directly rounds the value with sign.
2405SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
2406 SelectionDAG &DAG) const {
2407 SDLoc SL(Op);
2408 SDValue A = Op.getOperand(i: 0);
2409 EVT VT = Op.getValueType();
2410
2411 SDValue AbsA = DAG.getNode(Opcode: ISD::FABS, DL: SL, VT, Operand: A);
2412
2413 // double RoundedA = (double) (int) (abs(A) + 0.5f);
2414 SDValue AdjustedA = DAG.getNode(Opcode: ISD::FADD, DL: SL, VT, N1: AbsA,
2415 N2: DAG.getConstantFP(Val: 0.5, DL: SL, VT));
2416 SDValue RoundedA = DAG.getNode(Opcode: ISD::FTRUNC, DL: SL, VT, Operand: AdjustedA);
2417
2418 // RoundedA = abs(A) < 0.5 ? (double)0 : RoundedA;
2419 EVT SetCCVT = getSetCCResultType(DL: DAG.getDataLayout(), Ctx&: *DAG.getContext(), VT);
2420 SDValue IsSmall =DAG.getSetCC(DL: SL, VT: SetCCVT, LHS: AbsA,
2421 RHS: DAG.getConstantFP(Val: 0.5, DL: SL, VT), Cond: ISD::SETOLT);
2422 RoundedA = DAG.getNode(Opcode: ISD::SELECT, DL: SL, VT, N1: IsSmall,
2423 N2: DAG.getConstantFP(Val: 0, DL: SL, VT),
2424 N3: RoundedA);
2425
2426 // Add sign to rounded_A
2427 RoundedA = DAG.getNode(Opcode: ISD::FCOPYSIGN, DL: SL, VT, N1: RoundedA, N2: A);
2428 DAG.getNode(Opcode: ISD::FTRUNC, DL: SL, VT, Operand: A);
2429
2430 // RoundedA = abs(A) > 0x1.0p52 ? A : RoundedA;
2431 SDValue IsLarge =
2432 DAG.getSetCC(DL: SL, VT: SetCCVT, LHS: AbsA, RHS: DAG.getConstantFP(Val: pow(x: 2.0, y: 52.0), DL: SL, VT),
2433 Cond: ISD::SETOGT);
2434 return DAG.getNode(Opcode: ISD::SELECT, DL: SL, VT, N1: IsLarge, N2: A, N3: RoundedA);
2435}
2436
2437static SDValue PromoteBinOpToF32(SDNode *N, SelectionDAG &DAG) {
2438 EVT VT = N->getValueType(ResNo: 0);
2439 EVT NVT = MVT::f32;
2440 if (VT.isVector()) {
2441 NVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: NVT, EC: VT.getVectorElementCount());
2442 }
2443 SDLoc DL(N);
2444 SDValue Tmp0 = DAG.getFPExtendOrRound(Op: N->getOperand(Num: 0), DL, VT: NVT);
2445 SDValue Tmp1 = DAG.getFPExtendOrRound(Op: N->getOperand(Num: 1), DL, VT: NVT);
2446 SDValue Res = DAG.getNode(Opcode: N->getOpcode(), DL, VT: NVT, N1: Tmp0, N2: Tmp1, Flags: N->getFlags());
2447 return DAG.getFPExtendOrRound(Op: Res, DL, VT);
2448}
2449
2450SDValue NVPTXTargetLowering::PromoteBinOpIfF32FTZ(SDValue Op,
2451 SelectionDAG &DAG) const {
2452 if (useF32FTZ(MF: DAG.getMachineFunction())) {
2453 return PromoteBinOpToF32(N: Op.getNode(), DAG);
2454 }
2455 return Op;
2456}
2457
2458SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
2459 SelectionDAG &DAG) const {
2460 assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
2461
2462 if (Op.getValueType() == MVT::bf16) {
2463 SDLoc Loc(Op);
2464 return DAG.getNode(
2465 Opcode: ISD::FP_ROUND, DL: Loc, VT: MVT::bf16,
2466 N1: DAG.getNode(Opcode: Op.getOpcode(), DL: Loc, VT: MVT::f32, Operand: Op.getOperand(i: 0)),
2467 N2: DAG.getIntPtrConstant(Val: 0, DL: Loc, /*isTarget=*/true));
2468 }
2469
2470 // Everything else is considered legal.
2471 return Op;
2472}
2473
2474SDValue NVPTXTargetLowering::LowerFP_TO_INT(SDValue Op,
2475 SelectionDAG &DAG) const {
2476 assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
2477
2478 if (Op.getOperand(i: 0).getValueType() == MVT::bf16) {
2479 SDLoc Loc(Op);
2480 return DAG.getNode(
2481 Opcode: Op.getOpcode(), DL: Loc, VT: Op.getValueType(),
2482 Operand: DAG.getNode(Opcode: ISD::FP_EXTEND, DL: Loc, VT: MVT::f32, Operand: Op.getOperand(i: 0)));
2483 }
2484
2485 // Everything else is considered legal.
2486 return Op;
2487}
2488
2489SDValue NVPTXTargetLowering::LowerFP_ROUND(SDValue Op,
2490 SelectionDAG &DAG) const {
2491 EVT NarrowVT = Op.getValueType();
2492 SDValue Wide = Op.getOperand(i: 0);
2493 EVT WideVT = Wide.getValueType();
2494 if (NarrowVT.getScalarType() == MVT::bf16) {
2495 const TargetLowering *TLI = STI.getTargetLowering();
2496 if (STI.getSmVersion() < 80 || STI.getPTXVersion() < 70) {
2497 return TLI->expandFP_ROUND(Node: Op.getNode(), DAG);
2498 }
2499 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
2500 // This combination was the first to support f32 -> bf16.
2501 if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70) {
2502 if (WideVT.getScalarType() == MVT::f32) {
2503 return Op;
2504 }
2505 if (WideVT.getScalarType() == MVT::f64) {
2506 SDLoc Loc(Op);
2507 // Round-inexact-to-odd f64 to f32, then do the final rounding using
2508 // the hardware f32 -> bf16 instruction.
2509 SDValue rod = TLI->expandRoundInexactToOdd(
2510 ResultVT: WideVT.changeElementType(Context&: *DAG.getContext(), EltVT: MVT::f32), Op: Wide, DL: Loc,
2511 DAG);
2512 return DAG.getFPExtendOrRound(Op: rod, DL: Loc, VT: NarrowVT);
2513 }
2514 }
2515 return TLI->expandFP_ROUND(Node: Op.getNode(), DAG);
2516 }
2517 }
2518
2519 // Everything else is considered legal.
2520 return Op;
2521}
2522
2523SDValue NVPTXTargetLowering::LowerFP_EXTEND(SDValue Op,
2524 SelectionDAG &DAG) const {
2525 SDValue Narrow = Op.getOperand(i: 0);
2526 EVT NarrowVT = Narrow.getValueType();
2527 EVT WideVT = Op.getValueType();
2528 if (NarrowVT.getScalarType() == MVT::bf16) {
2529 if (WideVT.getScalarType() == MVT::f32 &&
2530 (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71)) {
2531 SDLoc Loc(Op);
2532 return DAG.getNode(Opcode: ISD::BF16_TO_FP, DL: Loc, VT: WideVT, Operand: Narrow);
2533 }
2534 if (WideVT.getScalarType() == MVT::f64 &&
2535 (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78)) {
2536 EVT F32 = NarrowVT.changeElementType(Context&: *DAG.getContext(), EltVT: MVT::f32);
2537 SDLoc Loc(Op);
2538 if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 71) {
2539 Op = DAG.getNode(Opcode: ISD::FP_EXTEND, DL: Loc, VT: F32, Operand: Narrow);
2540 } else {
2541 Op = DAG.getNode(Opcode: ISD::BF16_TO_FP, DL: Loc, VT: F32, Operand: Narrow);
2542 }
2543 return DAG.getNode(Opcode: ISD::FP_EXTEND, DL: Loc, VT: WideVT, Operand: Op);
2544 }
2545 }
2546
2547 // Everything else is considered legal.
2548 return Op;
2549}
2550
2551static SDValue LowerVectorArith(SDValue Op, SelectionDAG &DAG) {
2552 SDLoc DL(Op);
2553 if (Op.getValueType() != MVT::v2i16)
2554 return Op;
2555 EVT EltVT = Op.getValueType().getVectorElementType();
2556 SmallVector<SDValue> VecElements;
2557 for (int I = 0, E = Op.getValueType().getVectorNumElements(); I < E; I++) {
2558 SmallVector<SDValue> ScalarArgs;
2559 llvm::transform(Range: Op->ops(), d_first: std::back_inserter(x&: ScalarArgs),
2560 F: [&](const SDUse &O) {
2561 return DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT,
2562 N1: O.get(), N2: DAG.getIntPtrConstant(Val: I, DL));
2563 });
2564 VecElements.push_back(Elt: DAG.getNode(Opcode: Op.getOpcode(), DL, VT: EltVT, Ops: ScalarArgs));
2565 }
2566 SDValue V =
2567 DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: Op.getValueType(), Ops: VecElements);
2568 return V;
2569}
2570
2571static SDValue lowerTcgen05St(SDValue Op, SelectionDAG &DAG) {
2572 SDNode *N = Op.getNode();
2573 SDLoc DL(N);
2574 SmallVector<SDValue, 32> Ops;
2575
2576 // split the vector argument
2577 for (size_t I = 0; I < N->getNumOperands(); I++) {
2578 SDValue Val = N->getOperand(Num: I);
2579 EVT ValVT = Val.getValueType();
2580 if (ValVT.isVector()) {
2581 EVT EltVT = ValVT.getVectorElementType();
2582 for (unsigned J = 0, NElts = ValVT.getVectorNumElements(); J < NElts; J++)
2583 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: Val,
2584 N2: DAG.getIntPtrConstant(Val: J, DL)));
2585 } else
2586 Ops.push_back(Elt: Val);
2587 }
2588
2589 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
2590 SDValue Tcgen05StNode =
2591 DAG.getMemIntrinsicNode(Opcode: ISD::INTRINSIC_VOID, dl: DL, VTList: N->getVTList(), Ops,
2592 MemVT: MemSD->getMemoryVT(), MMO: MemSD->getMemOperand());
2593
2594 return Tcgen05StNode;
2595}
2596
2597static SDValue lowerBSWAP(SDValue Op, SelectionDAG &DAG) {
2598 SDLoc DL(Op);
2599 SDValue Src = Op.getOperand(i: 0);
2600 EVT VT = Op.getValueType();
2601
2602 switch (VT.getSimpleVT().SimpleTy) {
2603 case MVT::i16: {
2604 SDValue Extended = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT: MVT::i32, Operand: Src);
2605 SDValue Swapped =
2606 getPRMT(A: Extended, B: DAG.getConstant(Val: 0, DL, VT: MVT::i32), Selector: 0x7701, DL, DAG);
2607 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MVT::i16, Operand: Swapped);
2608 }
2609 case MVT::i32: {
2610 return getPRMT(A: Src, B: DAG.getConstant(Val: 0, DL, VT: MVT::i32), Selector: 0x0123, DL, DAG);
2611 }
2612 case MVT::v2i16: {
2613 SDValue Converted = DAG.getBitcast(VT: MVT::i32, V: Src);
2614 SDValue Swapped =
2615 getPRMT(A: Converted, B: DAG.getConstant(Val: 0, DL, VT: MVT::i32), Selector: 0x2301, DL, DAG);
2616 return DAG.getNode(Opcode: ISD::BITCAST, DL, VT: MVT::v2i16, Operand: Swapped);
2617 }
2618 case MVT::i64: {
2619 SDValue UnpackSrc =
2620 DAG.getNode(Opcode: NVPTXISD::UNPACK_VECTOR, DL, ResultTys: {MVT::i32, MVT::i32}, Ops: Src);
2621 SDValue SwappedLow =
2622 getPRMT(A: UnpackSrc.getValue(R: 0), B: DAG.getConstant(Val: 0, DL, VT: MVT::i32), Selector: 0x0123,
2623 DL, DAG);
2624 SDValue SwappedHigh =
2625 getPRMT(A: UnpackSrc.getValue(R: 1), B: DAG.getConstant(Val: 0, DL, VT: MVT::i32), Selector: 0x0123,
2626 DL, DAG);
2627 return DAG.getNode(Opcode: NVPTXISD::BUILD_VECTOR, DL, VT: MVT::i64,
2628 Ops: {SwappedHigh, SwappedLow});
2629 }
2630 default:
2631 llvm_unreachable("unsupported type for bswap");
2632 }
2633}
2634
2635static unsigned getTcgen05MMADisableOutputLane(unsigned IID) {
2636 switch (IID) {
2637 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1:
2638 return NVPTXISD::TCGEN05_MMA_SHARED_DISABLE_OUTPUT_LANE_CG1;
2639 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2:
2640 return NVPTXISD::TCGEN05_MMA_SHARED_DISABLE_OUTPUT_LANE_CG2;
2641 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1:
2642 return NVPTXISD::TCGEN05_MMA_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG1;
2643 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2:
2644 return NVPTXISD::TCGEN05_MMA_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG2;
2645 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1:
2646 return NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG1;
2647 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2:
2648 return NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG2;
2649 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1:
2650 return NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1;
2651 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2:
2652 return NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2;
2653 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift:
2654 return NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG1_ASHIFT;
2655 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift:
2656 return NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG2_ASHIFT;
2657 case Intrinsic::
2658 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift:
2659 return NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1_ASHIFT;
2660 case Intrinsic::
2661 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift:
2662 return NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2_ASHIFT;
2663 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1:
2664 return NVPTXISD::TCGEN05_MMA_SP_SHARED_DISABLE_OUTPUT_LANE_CG1;
2665 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2:
2666 return NVPTXISD::TCGEN05_MMA_SP_SHARED_DISABLE_OUTPUT_LANE_CG2;
2667 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1:
2668 return NVPTXISD::TCGEN05_MMA_SP_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG1;
2669 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2:
2670 return NVPTXISD::TCGEN05_MMA_SP_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG2;
2671 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1:
2672 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG1;
2673 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2:
2674 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG2;
2675 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift:
2676 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG1_ASHIFT;
2677 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift:
2678 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG2_ASHIFT;
2679 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1:
2680 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1;
2681 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2:
2682 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2;
2683 case Intrinsic::
2684 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift:
2685 return NVPTXISD::
2686 TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1_ASHIFT;
2687 case Intrinsic::
2688 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift:
2689 return NVPTXISD::
2690 TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2_ASHIFT;
2691 };
2692 llvm_unreachable("unhandled tcgen05.mma.disable_output_lane intrinsic");
2693}
2694
2695static SDValue LowerTcgen05MMADisableOutputLane(SDValue Op, SelectionDAG &DAG) {
2696 SDNode *N = Op.getNode();
2697 SDLoc DL(N);
2698 unsigned IID = cast<ConstantSDNode>(Val: N->getOperand(Num: 1))->getZExtValue();
2699
2700 SmallVector<SDValue, 16> Ops;
2701 // split the vector argument
2702 for (size_t I = 0; I < N->getNumOperands(); I++) {
2703 if (I == 1)
2704 continue; // skip IID
2705 SDValue Val = N->getOperand(Num: I);
2706 EVT ValVT = Val.getValueType();
2707 if (ValVT.isVector()) {
2708 EVT EltVT = ValVT.getVectorElementType();
2709 for (unsigned J = 0, NElts = ValVT.getVectorNumElements(); J < NElts; J++)
2710 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: Val,
2711 N2: DAG.getIntPtrConstant(Val: J, DL)));
2712 } else
2713 Ops.push_back(Elt: Val);
2714 }
2715
2716 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
2717 SDValue Tcgen05MMANode = DAG.getMemIntrinsicNode(
2718 Opcode: getTcgen05MMADisableOutputLane(IID), dl: DL, VTList: N->getVTList(), Ops,
2719 MemVT: MemSD->getMemoryVT(), MMO: MemSD->getMemOperand());
2720
2721 return Tcgen05MMANode;
2722}
2723
2724// Lower vector return type of tcgen05.ld intrinsics
2725static std::optional<std::pair<SDValue, SDValue>>
2726lowerTcgen05Ld(SDNode *N, SelectionDAG &DAG, bool HasOffset = false) {
2727 SDLoc DL(N);
2728 EVT ResVT = N->getValueType(ResNo: 0);
2729 if (!ResVT.isVector())
2730 return {}; // already legalized.
2731
2732 const unsigned NumElts = ResVT.getVectorNumElements();
2733
2734 // Create the return type of the instructions
2735 SmallVector<EVT, 5> ListVTs;
2736 for (unsigned i = 0; i < NumElts; ++i)
2737 ListVTs.push_back(Elt: MVT::i32);
2738
2739 ListVTs.push_back(Elt: N->getValueType(ResNo: 1)); // Chain
2740
2741 SDVTList ResVTs = DAG.getVTList(VTs: ListVTs);
2742
2743 SmallVector<SDValue, 8> Ops{N->getOperand(Num: 0), N->getOperand(Num: 1),
2744 N->getOperand(Num: 2)};
2745
2746 if (HasOffset) {
2747 Ops.push_back(Elt: N->getOperand(Num: 3)); // offset
2748 Ops.push_back(Elt: N->getOperand(Num: 4)); // Pack flag
2749 } else
2750 Ops.push_back(Elt: N->getOperand(Num: 3)); // Pack flag
2751
2752 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
2753 SDValue NewNode =
2754 DAG.getMemIntrinsicNode(Opcode: ISD::INTRINSIC_W_CHAIN, dl: DL, VTList: ResVTs, Ops,
2755 MemVT: MemSD->getMemoryVT(), MMO: MemSD->getMemOperand());
2756
2757 // split the vector result
2758 SmallVector<SDValue, 4> ScalarRes;
2759 for (unsigned i = 0; i < NumElts; ++i) {
2760 SDValue Res = NewNode.getValue(R: i);
2761 ScalarRes.push_back(Elt: Res);
2762 }
2763
2764 SDValue Chain = NewNode.getValue(R: NumElts);
2765 SDValue BuildVector = DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: ResVT, Ops: ScalarRes);
2766 return {{BuildVector, Chain}};
2767}
2768
2769static SDValue reportInvalidTensormapReplaceUsage(SDValue Op, SelectionDAG &DAG,
2770 unsigned Val) {
2771 SDNode *N = Op.getNode();
2772 SDLoc DL(N);
2773
2774 const Function &Fn = DAG.getMachineFunction().getFunction();
2775
2776 unsigned AS = 0;
2777 if (auto *MemN = dyn_cast<MemIntrinsicSDNode>(Val: N))
2778 AS = MemN->getAddressSpace();
2779 Type *PtrTy = PointerType::get(C&: *DAG.getContext(), AddressSpace: AS);
2780 Module *M = DAG.getMachineFunction().getFunction().getParent();
2781
2782 DAG.getContext()->diagnose(DI: DiagnosticInfoUnsupported(
2783 Fn,
2784 "Intrinsic " +
2785 Intrinsic::getName(Id: N->getConstantOperandVal(Num: 1), Tys: {PtrTy}, M) +
2786 " with value " + Twine(Val) +
2787 " is not supported on the given target.",
2788 DL.getDebugLoc()));
2789 return Op.getOperand(i: 0);
2790}
2791
2792static SDValue lowerTensormapReplaceElemtype(SDValue Op, SelectionDAG &DAG) {
2793 SDNode *N = Op.getNode();
2794 SDLoc DL(N);
2795
2796 // immediate argument representing elemtype
2797 unsigned Val = N->getConstantOperandVal(Num: 3);
2798
2799 if (!DAG.getSubtarget<NVPTXSubtarget>().hasTensormapReplaceElemtypeSupport(
2800 value: Val))
2801 return reportInvalidTensormapReplaceUsage(Op, DAG, Val);
2802
2803 return Op;
2804}
2805
2806static SDValue lowerTensormapReplaceSwizzleMode(SDValue Op, SelectionDAG &DAG) {
2807 SDNode *N = Op.getNode();
2808 SDLoc DL(N);
2809
2810 // immediate argument representing swizzle mode
2811 unsigned Val = N->getConstantOperandVal(Num: 3);
2812
2813 if (!DAG.getSubtarget<NVPTXSubtarget>().hasTensormapReplaceSwizzleModeSupport(
2814 value: Val))
2815 return reportInvalidTensormapReplaceUsage(Op, DAG, Val);
2816
2817 return Op;
2818}
2819
2820static SDValue lowerIntrinsicVoid(SDValue Op, SelectionDAG &DAG) {
2821 SDNode *N = Op.getNode();
2822 SDValue Intrin = N->getOperand(Num: 1);
2823
2824 // Get the intrinsic ID
2825 unsigned IntrinNo = cast<ConstantSDNode>(Val: Intrin.getNode())->getZExtValue();
2826 switch (IntrinNo) {
2827 default:
2828 break;
2829 case Intrinsic::nvvm_tcgen05_st_16x64b_x1:
2830 case Intrinsic::nvvm_tcgen05_st_16x64b_x2:
2831 case Intrinsic::nvvm_tcgen05_st_16x64b_x4:
2832 case Intrinsic::nvvm_tcgen05_st_16x64b_x8:
2833 case Intrinsic::nvvm_tcgen05_st_16x64b_x16:
2834 case Intrinsic::nvvm_tcgen05_st_16x64b_x32:
2835 case Intrinsic::nvvm_tcgen05_st_16x64b_x128:
2836 case Intrinsic::nvvm_tcgen05_st_16x128b_x1:
2837 case Intrinsic::nvvm_tcgen05_st_16x128b_x2:
2838 case Intrinsic::nvvm_tcgen05_st_16x128b_x4:
2839 case Intrinsic::nvvm_tcgen05_st_16x128b_x8:
2840 case Intrinsic::nvvm_tcgen05_st_16x128b_x16:
2841 case Intrinsic::nvvm_tcgen05_st_16x128b_x32:
2842 case Intrinsic::nvvm_tcgen05_st_16x128b_x64:
2843 case Intrinsic::nvvm_tcgen05_st_16x256b_x1:
2844 case Intrinsic::nvvm_tcgen05_st_16x256b_x2:
2845 case Intrinsic::nvvm_tcgen05_st_16x256b_x4:
2846 case Intrinsic::nvvm_tcgen05_st_16x256b_x8:
2847 case Intrinsic::nvvm_tcgen05_st_16x256b_x16:
2848 case Intrinsic::nvvm_tcgen05_st_16x256b_x32:
2849 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x1:
2850 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x2:
2851 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x4:
2852 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x8:
2853 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x16:
2854 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x32:
2855 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x64:
2856 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x128:
2857 case Intrinsic::nvvm_tcgen05_st_32x32b_x1:
2858 case Intrinsic::nvvm_tcgen05_st_32x32b_x2:
2859 case Intrinsic::nvvm_tcgen05_st_32x32b_x4:
2860 case Intrinsic::nvvm_tcgen05_st_32x32b_x8:
2861 case Intrinsic::nvvm_tcgen05_st_32x32b_x16:
2862 case Intrinsic::nvvm_tcgen05_st_32x32b_x32:
2863 case Intrinsic::nvvm_tcgen05_st_16x64b_x64:
2864 case Intrinsic::nvvm_tcgen05_st_32x32b_x64:
2865 case Intrinsic::nvvm_tcgen05_st_32x32b_x128:
2866 return lowerTcgen05St(Op, DAG);
2867 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1:
2868 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2:
2869 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1:
2870 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2:
2871 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1:
2872 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2:
2873 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1:
2874 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2:
2875 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1:
2876 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2:
2877 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1:
2878 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2:
2879 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1:
2880 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2:
2881 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1:
2882 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2:
2883 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift:
2884 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift:
2885 case Intrinsic::
2886 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift:
2887 case Intrinsic::
2888 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift:
2889 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift:
2890 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift:
2891 case Intrinsic::
2892 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift:
2893 case Intrinsic::
2894 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift:
2895 return LowerTcgen05MMADisableOutputLane(Op, DAG);
2896 case Intrinsic::nvvm_tensormap_replace_elemtype:
2897 return lowerTensormapReplaceElemtype(Op, DAG);
2898 case Intrinsic::nvvm_tensormap_replace_swizzle_mode:
2899 return lowerTensormapReplaceSwizzleMode(Op, DAG);
2900 }
2901 return Op;
2902}
2903
2904static SDValue LowerClusterLaunchControlQueryCancel(SDValue Op,
2905 SelectionDAG &DAG) {
2906
2907 SDNode *N = Op.getNode();
2908 if (N->getOperand(Num: 1).getValueType() != MVT::i128) {
2909 // return, if the operand is already lowered
2910 return SDValue();
2911 }
2912
2913 unsigned IID =
2914 cast<ConstantSDNode>(Val: N->getOperand(Num: 0).getNode())->getZExtValue();
2915 auto Opcode = [&]() {
2916 switch (IID) {
2917 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled:
2918 return NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_IS_CANCELED;
2919 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x:
2920 return NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_X;
2921 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y:
2922 return NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Y;
2923 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z:
2924 return NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Z;
2925 default:
2926 llvm_unreachable("unsupported/unhandled intrinsic");
2927 }
2928 }();
2929
2930 SDLoc DL(N);
2931 SDValue TryCancelResponse = N->getOperand(Num: 1);
2932 SDValue Cast = DAG.getNode(Opcode: ISD::BITCAST, DL, VT: MVT::v2i64, Operand: TryCancelResponse);
2933 SDValue TryCancelResponse0 =
2934 DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i64, N1: Cast,
2935 N2: DAG.getIntPtrConstant(Val: 0, DL));
2936 SDValue TryCancelResponse1 =
2937 DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i64, N1: Cast,
2938 N2: DAG.getIntPtrConstant(Val: 1, DL));
2939
2940 return DAG.getNode(Opcode, DL, VTList: N->getVTList(),
2941 Ops: {TryCancelResponse0, TryCancelResponse1});
2942}
2943
2944static SDValue lowerCvtRSIntrinsics(SDValue Op, SelectionDAG &DAG) {
2945 SDNode *N = Op.getNode();
2946 SDLoc DL(N);
2947 SDValue F32Vec = N->getOperand(Num: 1);
2948 SDValue RBits = N->getOperand(Num: 2);
2949
2950 unsigned IntrinsicID = N->getConstantOperandVal(Num: 0);
2951
2952 // Extract the 4 float elements from the vector
2953 SmallVector<SDValue, 6> Ops;
2954 for (unsigned i = 0; i < 4; ++i)
2955 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::f32, N1: F32Vec,
2956 N2: DAG.getIntPtrConstant(Val: i, DL)));
2957
2958 using NVPTX::PTXCvtMode::CvtMode;
2959
2960 auto [OpCode, RetTy, CvtModeFlag] =
2961 [&]() -> std::tuple<unsigned, MVT::SimpleValueType, uint32_t> {
2962 switch (IntrinsicID) {
2963 case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite:
2964 return {NVPTXISD::CVT_E4M3X4_F32X4_RS_SF, MVT::v4i8,
2965 CvtMode::RS | CvtMode::RELU_FLAG};
2966 case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite:
2967 return {NVPTXISD::CVT_E4M3X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
2968 case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite:
2969 return {NVPTXISD::CVT_E5M2X4_F32X4_RS_SF, MVT::v4i8,
2970 CvtMode::RS | CvtMode::RELU_FLAG};
2971 case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite:
2972 return {NVPTXISD::CVT_E5M2X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
2973 case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite:
2974 return {NVPTXISD::CVT_E2M3X4_F32X4_RS_SF, MVT::v4i8,
2975 CvtMode::RS | CvtMode::RELU_FLAG};
2976 case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite:
2977 return {NVPTXISD::CVT_E2M3X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
2978 case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite:
2979 return {NVPTXISD::CVT_E3M2X4_F32X4_RS_SF, MVT::v4i8,
2980 CvtMode::RS | CvtMode::RELU_FLAG};
2981 case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite:
2982 return {NVPTXISD::CVT_E3M2X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
2983 case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite:
2984 return {NVPTXISD::CVT_E2M1X4_F32X4_RS_SF, MVT::i16,
2985 CvtMode::RS | CvtMode::RELU_FLAG};
2986 case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite:
2987 return {NVPTXISD::CVT_E2M1X4_F32X4_RS_SF, MVT::i16, CvtMode::RS};
2988 default:
2989 llvm_unreachable("unsupported/unhandled intrinsic");
2990 }
2991 }();
2992
2993 Ops.push_back(Elt: RBits);
2994 Ops.push_back(Elt: DAG.getConstant(Val: CvtModeFlag, DL, VT: MVT::i32));
2995
2996 return DAG.getNode(Opcode: OpCode, DL, VT: RetTy, Ops);
2997}
2998
2999static SDValue lowerPrmtIntrinsic(SDValue Op, SelectionDAG &DAG) {
3000 const unsigned Mode = [&]() {
3001 switch (Op->getConstantOperandVal(Num: 0)) {
3002 case Intrinsic::nvvm_prmt:
3003 return NVPTX::PTXPrmtMode::NONE;
3004 case Intrinsic::nvvm_prmt_b4e:
3005 return NVPTX::PTXPrmtMode::B4E;
3006 case Intrinsic::nvvm_prmt_ecl:
3007 return NVPTX::PTXPrmtMode::ECL;
3008 case Intrinsic::nvvm_prmt_ecr:
3009 return NVPTX::PTXPrmtMode::ECR;
3010 case Intrinsic::nvvm_prmt_f4e:
3011 return NVPTX::PTXPrmtMode::F4E;
3012 case Intrinsic::nvvm_prmt_rc16:
3013 return NVPTX::PTXPrmtMode::RC16;
3014 case Intrinsic::nvvm_prmt_rc8:
3015 return NVPTX::PTXPrmtMode::RC8;
3016 default:
3017 llvm_unreachable("unsupported/unhandled intrinsic");
3018 }
3019 }();
3020 SDLoc DL(Op);
3021 SDValue A = Op->getOperand(Num: 1);
3022 SDValue B = Op.getNumOperands() == 4 ? Op.getOperand(i: 2)
3023 : DAG.getConstant(Val: 0, DL, VT: MVT::i32);
3024 SDValue Selector = (Op->op_end() - 1)->get();
3025 return getPRMT(A, B, Selector, DL, DAG, Mode);
3026}
3027
3028#define TCGEN05_LD_RED_INTR(SHAPE, NUM, TYPE) \
3029 Intrinsic::nvvm_tcgen05_ld_red_##SHAPE##_x##NUM##_##TYPE
3030
3031#define TCGEN05_LD_RED_INST(SHAPE, NUM, TYPE) \
3032 NVPTXISD::TCGEN05_LD_RED_##SHAPE##_X##NUM##_##TYPE
3033
3034static unsigned getTcgen05LdRedID(Intrinsic::ID IID) {
3035 switch (IID) {
3036 case TCGEN05_LD_RED_INTR(32x32b, 2, f32):
3037 return TCGEN05_LD_RED_INST(32x32b, 2, F32);
3038 case TCGEN05_LD_RED_INTR(32x32b, 4, f32):
3039 return TCGEN05_LD_RED_INST(32x32b, 4, F32);
3040 case TCGEN05_LD_RED_INTR(32x32b, 8, f32):
3041 return TCGEN05_LD_RED_INST(32x32b, 8, F32);
3042 case TCGEN05_LD_RED_INTR(32x32b, 16, f32):
3043 return TCGEN05_LD_RED_INST(32x32b, 16, F32);
3044 case TCGEN05_LD_RED_INTR(32x32b, 32, f32):
3045 return TCGEN05_LD_RED_INST(32x32b, 32, F32);
3046 case TCGEN05_LD_RED_INTR(32x32b, 64, f32):
3047 return TCGEN05_LD_RED_INST(32x32b, 64, F32);
3048 case TCGEN05_LD_RED_INTR(32x32b, 128, f32):
3049 return TCGEN05_LD_RED_INST(32x32b, 128, F32);
3050 case TCGEN05_LD_RED_INTR(16x32bx2, 2, f32):
3051 return TCGEN05_LD_RED_INST(16x32bx2, 2, F32);
3052 case TCGEN05_LD_RED_INTR(16x32bx2, 4, f32):
3053 return TCGEN05_LD_RED_INST(16x32bx2, 4, F32);
3054 case TCGEN05_LD_RED_INTR(16x32bx2, 8, f32):
3055 return TCGEN05_LD_RED_INST(16x32bx2, 8, F32);
3056 case TCGEN05_LD_RED_INTR(16x32bx2, 16, f32):
3057 return TCGEN05_LD_RED_INST(16x32bx2, 16, F32);
3058 case TCGEN05_LD_RED_INTR(16x32bx2, 32, f32):
3059 return TCGEN05_LD_RED_INST(16x32bx2, 32, F32);
3060 case TCGEN05_LD_RED_INTR(16x32bx2, 64, f32):
3061 return TCGEN05_LD_RED_INST(16x32bx2, 64, F32);
3062 case TCGEN05_LD_RED_INTR(16x32bx2, 128, f32):
3063 return TCGEN05_LD_RED_INST(16x32bx2, 128, F32);
3064 case TCGEN05_LD_RED_INTR(32x32b, 2, i32):
3065 return TCGEN05_LD_RED_INST(32x32b, 2, I32);
3066 case TCGEN05_LD_RED_INTR(32x32b, 4, i32):
3067 return TCGEN05_LD_RED_INST(32x32b, 4, I32);
3068 case TCGEN05_LD_RED_INTR(32x32b, 8, i32):
3069 return TCGEN05_LD_RED_INST(32x32b, 8, I32);
3070 case TCGEN05_LD_RED_INTR(32x32b, 16, i32):
3071 return TCGEN05_LD_RED_INST(32x32b, 16, I32);
3072 case TCGEN05_LD_RED_INTR(32x32b, 32, i32):
3073 return TCGEN05_LD_RED_INST(32x32b, 32, I32);
3074 case TCGEN05_LD_RED_INTR(32x32b, 64, i32):
3075 return TCGEN05_LD_RED_INST(32x32b, 64, I32);
3076 case TCGEN05_LD_RED_INTR(32x32b, 128, i32):
3077 return TCGEN05_LD_RED_INST(32x32b, 128, I32);
3078 case TCGEN05_LD_RED_INTR(16x32bx2, 2, i32):
3079 return TCGEN05_LD_RED_INST(16x32bx2, 2, I32);
3080 case TCGEN05_LD_RED_INTR(16x32bx2, 4, i32):
3081 return TCGEN05_LD_RED_INST(16x32bx2, 4, I32);
3082 case TCGEN05_LD_RED_INTR(16x32bx2, 8, i32):
3083 return TCGEN05_LD_RED_INST(16x32bx2, 8, I32);
3084 case TCGEN05_LD_RED_INTR(16x32bx2, 16, i32):
3085 return TCGEN05_LD_RED_INST(16x32bx2, 16, I32);
3086 case TCGEN05_LD_RED_INTR(16x32bx2, 32, i32):
3087 return TCGEN05_LD_RED_INST(16x32bx2, 32, I32);
3088 case TCGEN05_LD_RED_INTR(16x32bx2, 64, i32):
3089 return TCGEN05_LD_RED_INST(16x32bx2, 64, I32);
3090 case TCGEN05_LD_RED_INTR(16x32bx2, 128, i32):
3091 return TCGEN05_LD_RED_INST(16x32bx2, 128, I32);
3092 default:
3093 llvm_unreachable("Invalid tcgen05.ld.red intrinsic ID");
3094 }
3095}
3096
3097// Lower vector return type of tcgen05.ld intrinsics
3098static std::optional<std::tuple<SDValue, SDValue, SDValue>>
3099lowerTcgen05LdRed(SDNode *N, SelectionDAG &DAG) {
3100 SDLoc DL(N);
3101 EVT ResVT = N->getValueType(ResNo: 0);
3102 if (!ResVT.isVector())
3103 return {}; // already legalized.
3104
3105 const unsigned NumElts = ResVT.getVectorNumElements();
3106
3107 // Create the return type of the instructions
3108 // +1 represents the reduction value
3109 SmallVector<EVT, 132> ListVTs{
3110 NumElts + 1,
3111 ResVT.getVectorElementType().isFloatingPoint() ? MVT::f32 : MVT::i32};
3112
3113 ListVTs.push_back(Elt: MVT::Other); // Chain
3114
3115 SDVTList ResVTs = DAG.getVTList(VTs: ListVTs);
3116
3117 // Prepare the Operands
3118 SmallVector<SDValue, 8> Ops{N->getOperand(Num: 0)}; // Chain
3119
3120 // skip IID at index 1
3121 for (unsigned i = 2; i < N->getNumOperands(); i++)
3122 Ops.push_back(Elt: N->getOperand(Num: i));
3123
3124 unsigned IID = cast<ConstantSDNode>(Val: N->getOperand(Num: 1))->getZExtValue();
3125 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
3126 SDValue NewNode =
3127 DAG.getMemIntrinsicNode(Opcode: getTcgen05LdRedID(IID), dl: DL, VTList: ResVTs, Ops,
3128 MemVT: MemSD->getMemoryVT(), MMO: MemSD->getMemOperand());
3129
3130 // Split vector result
3131 SmallVector<SDValue, 132> ScalarRes;
3132 for (unsigned i = 0; i < NumElts; ++i) {
3133 SDValue Res = NewNode.getValue(R: i);
3134 ScalarRes.push_back(Elt: Res);
3135 }
3136
3137 SDValue BuildVector = DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: ResVT, Ops: ScalarRes);
3138 SDValue RedResult = NewNode.getValue(R: NumElts);
3139 SDValue Chain = NewNode.getValue(R: NumElts + 1);
3140 return {{BuildVector, RedResult, Chain}};
3141}
3142
3143static SDValue lowerIntrinsicWChain(SDValue Op, SelectionDAG &DAG) {
3144 switch (Op->getConstantOperandVal(Num: 1)) {
3145 default:
3146 return Op;
3147
3148 // These tcgen05 intrinsics return a v2i32, which is legal, so we have to
3149 // lower them through LowerOperation() instead of ReplaceNodeResults().
3150 case Intrinsic::nvvm_tcgen05_ld_16x64b_x2:
3151 case Intrinsic::nvvm_tcgen05_ld_16x128b_x1:
3152 case Intrinsic::nvvm_tcgen05_ld_32x32b_x2:
3153 if (auto Res = lowerTcgen05Ld(N: Op.getNode(), DAG))
3154 return DAG.getMergeValues(Ops: {Res->first, Res->second}, dl: SDLoc(Op));
3155 return SDValue();
3156
3157 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x2:
3158 if (auto Res = lowerTcgen05Ld(N: Op.getNode(), DAG, /*HasOffset=*/true))
3159 return DAG.getMergeValues(Ops: {Res->first, Res->second}, dl: SDLoc(Op));
3160 return SDValue();
3161
3162 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x2_f32:
3163 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x2_i32:
3164 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x2_f32:
3165 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x2_i32:
3166 if (auto Res = lowerTcgen05LdRed(N: Op.getNode(), DAG))
3167 return DAG.getMergeValues(
3168 Ops: {std::get<0>(t&: *Res), std::get<1>(t&: *Res), std::get<2>(t&: *Res)}, dl: SDLoc(Op));
3169 return SDValue();
3170 }
3171}
3172
3173static SDValue lowerIntrinsicWOChain(SDValue Op, SelectionDAG &DAG) {
3174 switch (Op->getConstantOperandVal(Num: 0)) {
3175 default:
3176 return Op;
3177 case Intrinsic::nvvm_prmt:
3178 case Intrinsic::nvvm_prmt_b4e:
3179 case Intrinsic::nvvm_prmt_ecl:
3180 case Intrinsic::nvvm_prmt_ecr:
3181 case Intrinsic::nvvm_prmt_f4e:
3182 case Intrinsic::nvvm_prmt_rc16:
3183 case Intrinsic::nvvm_prmt_rc8:
3184 return lowerPrmtIntrinsic(Op, DAG);
3185 case Intrinsic::nvvm_internal_addrspace_wrap:
3186 return Op.getOperand(i: 1);
3187 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled:
3188 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x:
3189 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y:
3190 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z:
3191 return LowerClusterLaunchControlQueryCancel(Op, DAG);
3192 case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite:
3193 case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite:
3194 case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite:
3195 case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite:
3196 case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite:
3197 case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite:
3198 case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite:
3199 case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite:
3200 case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite:
3201 case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite:
3202 return lowerCvtRSIntrinsics(Op, DAG);
3203 }
3204}
3205
3206// In PTX 64-bit CTLZ and CTPOP are supported, but they return a 32-bit value.
3207// Lower these into a node returning the correct type which is zero-extended
3208// back to the correct size.
3209static SDValue lowerCTLZCTPOP(SDValue Op, SelectionDAG &DAG) {
3210 SDValue V = Op->getOperand(Num: 0);
3211 assert(V.getValueType() == MVT::i64 &&
3212 "Unexpected CTLZ/CTPOP type to legalize");
3213
3214 SDLoc DL(Op);
3215 SDValue CT = DAG.getNode(Opcode: Op->getOpcode(), DL, VT: MVT::i32, Operand: V);
3216 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: MVT::i64, Operand: CT, Flags: SDNodeFlags::NonNeg);
3217}
3218
3219static SDValue expandFSH64(SDValue A, SDValue B, SDValue ShiftAmount, SDLoc DL,
3220 unsigned Opcode, SelectionDAG &DAG) {
3221 assert(A.getValueType() == MVT::i64 && B.getValueType() == MVT::i64);
3222
3223 const auto *AmtConst = dyn_cast<ConstantSDNode>(Val&: ShiftAmount);
3224 if (!AmtConst)
3225 return SDValue();
3226 const auto Amt = AmtConst->getZExtValue() & 63;
3227
3228 SDValue UnpackA =
3229 DAG.getNode(Opcode: NVPTXISD::UNPACK_VECTOR, DL, ResultTys: {MVT::i32, MVT::i32}, Ops: A);
3230 SDValue UnpackB =
3231 DAG.getNode(Opcode: NVPTXISD::UNPACK_VECTOR, DL, ResultTys: {MVT::i32, MVT::i32}, Ops: B);
3232
3233 // Arch is Little endiain: 0 = low bits, 1 = high bits
3234 SDValue ALo = UnpackA.getValue(R: 0);
3235 SDValue AHi = UnpackA.getValue(R: 1);
3236 SDValue BLo = UnpackB.getValue(R: 0);
3237 SDValue BHi = UnpackB.getValue(R: 1);
3238
3239 // The bitfeild consists of { AHi : ALo : BHi : BLo }
3240 //
3241 // * FSHL, Amt < 32 - The window will contain { AHi : ALo : BHi }
3242 // * FSHL, Amt >= 32 - The window will contain { ALo : BHi : BLo }
3243 // * FSHR, Amt < 32 - The window will contain { ALo : BHi : BLo }
3244 // * FSHR, Amt >= 32 - The window will contain { AHi : ALo : BHi }
3245 //
3246 // Note that Amt = 0 and Amt = 32 are special cases where 32-bit funnel shifts
3247 // are not needed at all. Amt = 0 is a no-op producing either A or B depending
3248 // on the direction. Amt = 32 can be implemented by a packing and unpacking
3249 // move to select and arrange the 32bit values. For simplicity, these cases
3250 // are not handled here explicitly and instead we rely on DAGCombiner to
3251 // remove the no-op funnel shifts we insert.
3252 auto [High, Mid, Low] = ((Opcode == ISD::FSHL) == (Amt < 32))
3253 ? std::make_tuple(args&: AHi, args&: ALo, args&: BHi)
3254 : std::make_tuple(args&: ALo, args&: BHi, args&: BLo);
3255
3256 SDValue NewAmt = DAG.getConstant(Val: Amt & 31, DL, VT: MVT::i32);
3257 SDValue RHi = DAG.getNode(Opcode, DL, VT: MVT::i32, Ops: {High, Mid, NewAmt});
3258 SDValue RLo = DAG.getNode(Opcode, DL, VT: MVT::i32, Ops: {Mid, Low, NewAmt});
3259
3260 return DAG.getNode(Opcode: NVPTXISD::BUILD_VECTOR, DL, VT: MVT::i64, Ops: {RLo, RHi});
3261}
3262
3263static SDValue lowerFSH(SDValue Op, SelectionDAG &DAG) {
3264 return expandFSH64(A: Op->getOperand(Num: 0), B: Op->getOperand(Num: 1), ShiftAmount: Op->getOperand(Num: 2),
3265 DL: SDLoc(Op), Opcode: Op->getOpcode(), DAG);
3266}
3267
3268static SDValue lowerROT(SDValue Op, SelectionDAG &DAG) {
3269 unsigned Opcode = Op->getOpcode() == ISD::ROTL ? ISD::FSHL : ISD::FSHR;
3270 return expandFSH64(A: Op->getOperand(Num: 0), B: Op->getOperand(Num: 0), ShiftAmount: Op->getOperand(Num: 1),
3271 DL: SDLoc(Op), Opcode, DAG);
3272}
3273
3274static SDValue lowerFREM(SDValue Op, SelectionDAG &DAG) {
3275 // Lower (frem x, y) into (sub x, (mul (ftrunc (div x, y)) y)),
3276 // i.e. "poor man's fmod()". When y is infinite, x is returned. This matches
3277 // the semantics of LLVM's frem.
3278 SDLoc DL(Op);
3279 SDValue X = Op->getOperand(Num: 0);
3280 SDValue Y = Op->getOperand(Num: 1);
3281 EVT Ty = Op.getValueType();
3282 SDNodeFlags Flags = Op->getFlags();
3283
3284 SDValue Div = DAG.getNode(Opcode: ISD::FDIV, DL, VT: Ty, N1: X, N2: Y, Flags);
3285 SDValue Trunc = DAG.getNode(Opcode: ISD::FTRUNC, DL, VT: Ty, Operand: Div, Flags);
3286 SDValue Mul = DAG.getNode(Opcode: ISD::FMUL, DL, VT: Ty, N1: Trunc, N2: Y,
3287 Flags: Flags | SDNodeFlags::AllowContract);
3288 SDValue Sub = DAG.getNode(Opcode: ISD::FSUB, DL, VT: Ty, N1: X, N2: Mul,
3289 Flags: Flags | SDNodeFlags::AllowContract);
3290
3291 if (Flags.hasNoInfs())
3292 return Sub;
3293
3294 // If Y is infinite, return X
3295 SDValue AbsY = DAG.getNode(Opcode: ISD::FABS, DL, VT: Ty, Operand: Y);
3296 SDValue Inf =
3297 DAG.getConstantFP(Val: APFloat::getInf(Sem: Ty.getFltSemantics()), DL, VT: Ty);
3298 SDValue IsInf = DAG.getSetCC(DL, VT: MVT::i1, LHS: AbsY, RHS: Inf, Cond: ISD::SETEQ);
3299 return DAG.getSelect(DL, VT: Ty, Cond: IsInf, LHS: X, RHS: Sub);
3300}
3301
3302static SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG) {
3303 assert(Op.getValueType() == MVT::i1 && "Custom lowering enabled only for i1");
3304
3305 SDValue Cond = Op->getOperand(Num: 0);
3306 SDValue TrueVal = Op->getOperand(Num: 1);
3307 SDValue FalseVal = Op->getOperand(Num: 2);
3308 SDLoc DL(Op);
3309
3310 // If both operands are truncated, we push the select through the truncates.
3311 if (TrueVal.getOpcode() == ISD::TRUNCATE &&
3312 FalseVal.getOpcode() == ISD::TRUNCATE) {
3313 TrueVal = TrueVal.getOperand(i: 0);
3314 FalseVal = FalseVal.getOperand(i: 0);
3315
3316 EVT VT = TrueVal.getSimpleValueType().bitsLE(VT: FalseVal.getSimpleValueType())
3317 ? TrueVal.getValueType()
3318 : FalseVal.getValueType();
3319 TrueVal = DAG.getAnyExtOrTrunc(Op: TrueVal, DL, VT);
3320 FalseVal = DAG.getAnyExtOrTrunc(Op: FalseVal, DL, VT);
3321 SDValue Select = DAG.getSelect(DL, VT, Cond, LHS: TrueVal, RHS: FalseVal);
3322 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MVT::i1, Operand: Select);
3323 }
3324
3325 // Otherwise, expand the select into a series of logical operations. These
3326 // often can be folded into other operations either by us or ptxas.
3327 TrueVal = DAG.getFreeze(V: TrueVal);
3328 FalseVal = DAG.getFreeze(V: FalseVal);
3329 SDValue And1 = DAG.getNode(Opcode: ISD::AND, DL, VT: MVT::i1, N1: Cond, N2: TrueVal);
3330 SDValue NotCond = DAG.getNOT(DL, Val: Cond, VT: MVT::i1);
3331 SDValue And2 = DAG.getNode(Opcode: ISD::AND, DL, VT: MVT::i1, N1: NotCond, N2: FalseVal);
3332 SDValue Or = DAG.getNode(Opcode: ISD::OR, DL, VT: MVT::i1, N1: And1, N2: And2);
3333 return Or;
3334}
3335
3336static SDValue lowerMSTORE(SDValue Op, SelectionDAG &DAG) {
3337 SDNode *N = Op.getNode();
3338
3339 SDValue Chain = N->getOperand(Num: 0);
3340 SDValue Val = N->getOperand(Num: 1);
3341 SDValue BasePtr = N->getOperand(Num: 2);
3342 SDValue Offset = N->getOperand(Num: 3);
3343 SDValue Mask = N->getOperand(Num: 4);
3344
3345 SDLoc DL(N);
3346 EVT ValVT = Val.getValueType();
3347 MemSDNode *MemSD = cast<MemSDNode>(Val: N);
3348 assert(ValVT.isVector() && "Masked vector store must have vector type");
3349 assert(MemSD->getAlign() >= DAG.getEVTAlign(ValVT) &&
3350 "Unexpected alignment for masked store");
3351
3352 unsigned Opcode = 0;
3353 switch (ValVT.getSimpleVT().SimpleTy) {
3354 default:
3355 llvm_unreachable("Unexpected masked vector store type");
3356 case MVT::v4i64:
3357 case MVT::v4f64: {
3358 Opcode = NVPTXISD::StoreV4;
3359 break;
3360 }
3361 case MVT::v8i32:
3362 case MVT::v8f32: {
3363 Opcode = NVPTXISD::StoreV8;
3364 break;
3365 }
3366 }
3367
3368 SmallVector<SDValue, 8> Ops;
3369
3370 // Construct the new SDNode. First operand is the chain.
3371 Ops.push_back(Elt: Chain);
3372
3373 // The next N operands are the values to store. Encode the mask into the
3374 // values using the sentinel register 0 to represent a masked-off element.
3375 assert(Mask.getValueType().isVector() &&
3376 Mask.getValueType().getVectorElementType() == MVT::i1 &&
3377 "Mask must be a vector of i1");
3378 assert(Mask.getOpcode() == ISD::BUILD_VECTOR &&
3379 "Mask expected to be a BUILD_VECTOR");
3380 assert(Mask.getValueType().getVectorNumElements() ==
3381 ValVT.getVectorNumElements() &&
3382 "Mask size must be the same as the vector size");
3383 for (auto [I, Op] : enumerate(First: Mask->ops())) {
3384 // Mask elements must be constants.
3385 if (Op.getNode()->getAsZExtVal() == 0) {
3386 // Append a sentinel register 0 to the Ops vector to represent a masked
3387 // off element, this will be handled in tablegen
3388 Ops.push_back(Elt: DAG.getRegister(Reg: MCRegister::NoRegister,
3389 VT: ValVT.getVectorElementType()));
3390 } else {
3391 // Extract the element from the vector to store
3392 SDValue ExtVal =
3393 DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: ValVT.getVectorElementType(),
3394 N1: Val, N2: DAG.getIntPtrConstant(Val: I, DL));
3395 Ops.push_back(Elt: ExtVal);
3396 }
3397 }
3398
3399 // Next, the pointer operand.
3400 Ops.push_back(Elt: BasePtr);
3401
3402 // Finally, the offset operand. We expect this to always be undef, and it will
3403 // be ignored in lowering, but to mirror the handling of the other vector
3404 // store instructions we include it in the new SDNode.
3405 assert(Offset.getOpcode() == ISD::UNDEF &&
3406 "Offset operand expected to be undef");
3407 Ops.push_back(Elt: Offset);
3408
3409 SDValue NewSt =
3410 DAG.getMemIntrinsicNode(Opcode, dl: DL, VTList: DAG.getVTList(VT: MVT::Other), Ops,
3411 MemVT: MemSD->getMemoryVT(), MMO: MemSD->getMemOperand());
3412
3413 return NewSt;
3414}
3415
3416SDValue
3417NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
3418 switch (Op.getOpcode()) {
3419 case ISD::RETURNADDR:
3420 return SDValue();
3421 case ISD::FRAMEADDR:
3422 return SDValue();
3423 case ISD::ADDRSPACECAST:
3424 return LowerADDRSPACECAST(Op, DAG);
3425 case ISD::INTRINSIC_W_CHAIN:
3426 return lowerIntrinsicWChain(Op, DAG);
3427 case ISD::INTRINSIC_WO_CHAIN:
3428 return lowerIntrinsicWOChain(Op, DAG);
3429 case ISD::INTRINSIC_VOID:
3430 return lowerIntrinsicVoid(Op, DAG);
3431 case ISD::BUILD_VECTOR:
3432 return LowerBUILD_VECTOR(Op, DAG);
3433 case ISD::BITCAST:
3434 return LowerBITCAST(Op, DAG);
3435 case ISD::EXTRACT_SUBVECTOR:
3436 return Op;
3437 case ISD::EXTRACT_VECTOR_ELT:
3438 return LowerEXTRACT_VECTOR_ELT(Op, DAG);
3439 case ISD::INSERT_VECTOR_ELT:
3440 return LowerINSERT_VECTOR_ELT(Op, DAG);
3441 case ISD::VECTOR_SHUFFLE:
3442 return LowerVECTOR_SHUFFLE(Op, DAG);
3443 case ISD::CONCAT_VECTORS:
3444 return LowerCONCAT_VECTORS(Op, DAG);
3445 case ISD::VECREDUCE_FMAX:
3446 case ISD::VECREDUCE_FMIN:
3447 case ISD::VECREDUCE_FMAXIMUM:
3448 case ISD::VECREDUCE_FMINIMUM:
3449 return LowerVECREDUCE(Op, DAG);
3450 case ISD::STORE:
3451 return LowerSTORE(Op, DAG);
3452 case ISD::MSTORE: {
3453 assert(STI.has256BitVectorLoadStore(
3454 cast<MemSDNode>(Op.getNode())->getAddressSpace()) &&
3455 "Masked store vector not supported on subtarget.");
3456 return lowerMSTORE(Op, DAG);
3457 }
3458 case ISD::LOAD:
3459 return LowerLOAD(Op, DAG);
3460 case ISD::MLOAD:
3461 return LowerMLOAD(Op, DAG);
3462 case ISD::SHL_PARTS:
3463 return LowerShiftLeftParts(Op, DAG);
3464 case ISD::SRA_PARTS:
3465 case ISD::SRL_PARTS:
3466 return LowerShiftRightParts(Op, DAG);
3467 case ISD::SELECT:
3468 return lowerSELECT(Op, DAG);
3469 case ISD::FROUND:
3470 return LowerFROUND(Op, DAG);
3471 case ISD::FCOPYSIGN:
3472 return LowerFCOPYSIGN(Op, DAG);
3473 case ISD::SINT_TO_FP:
3474 case ISD::UINT_TO_FP:
3475 return LowerINT_TO_FP(Op, DAG);
3476 case ISD::FP_TO_SINT:
3477 case ISD::FP_TO_UINT:
3478 return LowerFP_TO_INT(Op, DAG);
3479 case ISD::FP_ROUND:
3480 return LowerFP_ROUND(Op, DAG);
3481 case ISD::FP_EXTEND:
3482 return LowerFP_EXTEND(Op, DAG);
3483 case ISD::VAARG:
3484 return LowerVAARG(Op, DAG);
3485 case ISD::VASTART:
3486 return LowerVASTART(Op, DAG);
3487 case ISD::FSHL:
3488 case ISD::FSHR:
3489 return lowerFSH(Op, DAG);
3490 case ISD::ROTL:
3491 case ISD::ROTR:
3492 return lowerROT(Op, DAG);
3493 case ISD::ABS:
3494 case ISD::SMIN:
3495 case ISD::SMAX:
3496 case ISD::UMIN:
3497 case ISD::UMAX:
3498 case ISD::ADD:
3499 case ISD::SUB:
3500 case ISD::MUL:
3501 case ISD::SHL:
3502 case ISD::SREM:
3503 case ISD::UREM:
3504 return LowerVectorArith(Op, DAG);
3505 case ISD::DYNAMIC_STACKALLOC:
3506 return LowerDYNAMIC_STACKALLOC(Op, DAG);
3507 case ISD::STACKRESTORE:
3508 return LowerSTACKRESTORE(Op, DAG);
3509 case ISD::STACKSAVE:
3510 return LowerSTACKSAVE(Op, DAG);
3511 case ISD::CopyToReg:
3512 return LowerCopyToReg_128(Op, DAG);
3513 case ISD::FADD:
3514 case ISD::FSUB:
3515 case ISD::FMUL:
3516 // Used only for bf16 on SM80, where we select fma for non-ftz operation
3517 return PromoteBinOpIfF32FTZ(Op, DAG);
3518 case ISD::CTPOP:
3519 case ISD::CTLZ:
3520 return lowerCTLZCTPOP(Op, DAG);
3521 case ISD::FREM:
3522 return lowerFREM(Op, DAG);
3523 case ISD::BSWAP:
3524 return lowerBSWAP(Op, DAG);
3525 default:
3526 llvm_unreachable("Custom lowering not defined for operation");
3527 }
3528}
3529
3530// This will prevent AsmPrinter from trying to print the jump tables itself.
3531unsigned NVPTXTargetLowering::getJumpTableEncoding() const {
3532 return MachineJumpTableInfo::EK_Inline;
3533}
3534
3535SDValue NVPTXTargetLowering::LowerADDRSPACECAST(SDValue Op,
3536 SelectionDAG &DAG) const {
3537 AddrSpaceCastSDNode *N = cast<AddrSpaceCastSDNode>(Val: Op.getNode());
3538 unsigned SrcAS = N->getSrcAddressSpace();
3539 unsigned DestAS = N->getDestAddressSpace();
3540 if (SrcAS != llvm::ADDRESS_SPACE_GENERIC &&
3541 DestAS != llvm::ADDRESS_SPACE_GENERIC) {
3542 // Shared and SharedCluster can be converted to each other through generic
3543 // space
3544 if ((SrcAS == llvm::ADDRESS_SPACE_SHARED &&
3545 DestAS == llvm::ADDRESS_SPACE_SHARED_CLUSTER) ||
3546 (SrcAS == llvm::ADDRESS_SPACE_SHARED_CLUSTER &&
3547 DestAS == llvm::ADDRESS_SPACE_SHARED)) {
3548 SDLoc DL(Op.getNode());
3549 const MVT GenerictVT =
3550 getPointerTy(DL: DAG.getDataLayout(), AS: ADDRESS_SPACE_GENERIC);
3551 SDValue GenericConversion = DAG.getAddrSpaceCast(
3552 dl: DL, VT: GenerictVT, Ptr: Op.getOperand(i: 0), SrcAS, DestAS: ADDRESS_SPACE_GENERIC);
3553 SDValue SharedClusterConversion =
3554 DAG.getAddrSpaceCast(dl: DL, VT: Op.getValueType(), Ptr: GenericConversion,
3555 SrcAS: ADDRESS_SPACE_GENERIC, DestAS);
3556 return SharedClusterConversion;
3557 }
3558
3559 return DAG.getUNDEF(VT: Op.getValueType());
3560 }
3561
3562 return Op;
3563}
3564
3565// This function is almost a copy of SelectionDAG::expandVAArg().
3566// The only diff is that this one produces loads from local address space.
3567SDValue NVPTXTargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {
3568 const TargetLowering *TLI = STI.getTargetLowering();
3569 SDLoc DL(Op);
3570
3571 SDNode *Node = Op.getNode();
3572 const Value *V = cast<SrcValueSDNode>(Val: Node->getOperand(Num: 2))->getValue();
3573 EVT VT = Node->getValueType(ResNo: 0);
3574 auto *Ty = VT.getTypeForEVT(Context&: *DAG.getContext());
3575 SDValue Tmp1 = Node->getOperand(Num: 0);
3576 SDValue Tmp2 = Node->getOperand(Num: 1);
3577 const MaybeAlign MA(Node->getConstantOperandVal(Num: 3));
3578
3579 SDValue VAListLoad = DAG.getLoad(VT: TLI->getPointerTy(DL: DAG.getDataLayout()), dl: DL,
3580 Chain: Tmp1, Ptr: Tmp2, PtrInfo: MachinePointerInfo(V));
3581 SDValue VAList = VAListLoad;
3582
3583 if (MA && *MA > TLI->getMinStackArgumentAlignment()) {
3584 VAList = DAG.getNode(
3585 Opcode: ISD::ADD, DL, VT: VAList.getValueType(), N1: VAList,
3586 N2: DAG.getConstant(Val: MA->value() - 1, DL, VT: VAList.getValueType()));
3587
3588 VAList = DAG.getNode(Opcode: ISD::AND, DL, VT: VAList.getValueType(), N1: VAList,
3589 N2: DAG.getSignedConstant(Val: -(int64_t)MA->value(), DL,
3590 VT: VAList.getValueType()));
3591 }
3592
3593 // Increment the pointer, VAList, to the next vaarg
3594 Tmp1 = DAG.getNode(Opcode: ISD::ADD, DL, VT: VAList.getValueType(), N1: VAList,
3595 N2: DAG.getConstant(Val: DAG.getDataLayout().getTypeAllocSize(Ty),
3596 DL, VT: VAList.getValueType()));
3597
3598 // Store the incremented VAList to the legalized pointer
3599 Tmp1 = DAG.getStore(Chain: VAListLoad.getValue(R: 1), dl: DL, Val: Tmp1, Ptr: Tmp2,
3600 PtrInfo: MachinePointerInfo(V));
3601
3602 const Value *SrcV = Constant::getNullValue(
3603 Ty: PointerType::get(C&: *DAG.getContext(), AddressSpace: ADDRESS_SPACE_LOCAL));
3604
3605 // Load the actual argument out of the pointer VAList
3606 return DAG.getLoad(VT, dl: DL, Chain: Tmp1, Ptr: VAList, PtrInfo: MachinePointerInfo(SrcV));
3607}
3608
3609SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
3610 const TargetLowering *TLI = STI.getTargetLowering();
3611 SDLoc DL(Op);
3612 EVT PtrVT = TLI->getPointerTy(DL: DAG.getDataLayout());
3613
3614 // Store the address of unsized array <function>_vararg[] in the ap object.
3615 SDValue VAReg = getParamSymbol(DAG, /* vararg */ I: -1, T: PtrVT);
3616
3617 const Value *SV = cast<SrcValueSDNode>(Val: Op.getOperand(i: 2))->getValue();
3618 return DAG.getStore(Chain: Op.getOperand(i: 0), dl: DL, Val: VAReg, Ptr: Op.getOperand(i: 1),
3619 PtrInfo: MachinePointerInfo(SV));
3620}
3621
3622static std::pair<MemSDNode *, uint32_t>
3623convertMLOADToLoadWithUsedBytesMask(MemSDNode *N, SelectionDAG &DAG,
3624 const NVPTXSubtarget &STI) {
3625 SDValue Chain = N->getOperand(Num: 0);
3626 SDValue BasePtr = N->getOperand(Num: 1);
3627 SDValue Mask = N->getOperand(Num: 3);
3628 [[maybe_unused]] SDValue Passthru = N->getOperand(Num: 4);
3629
3630 SDLoc DL(N);
3631 EVT ResVT = N->getValueType(ResNo: 0);
3632 assert(ResVT.isVector() && "Masked vector load must have vector type");
3633 // While we only expect poison passthru vectors as an input to the backend,
3634 // when the legalization framework splits a poison vector in half, it creates
3635 // two undef vectors, so we can technically expect those too.
3636 assert((Passthru.getOpcode() == ISD::POISON ||
3637 Passthru.getOpcode() == ISD::UNDEF) &&
3638 "Passthru operand expected to be poison or undef");
3639
3640 // Extract the mask and convert it to a uint32_t representing the used bytes
3641 // of the entire vector load
3642 uint32_t UsedBytesMask = 0;
3643 uint32_t ElementSizeInBits = ResVT.getVectorElementType().getSizeInBits();
3644 assert(ElementSizeInBits % 8 == 0 && "Unexpected element size");
3645 uint32_t ElementSizeInBytes = ElementSizeInBits / 8;
3646 uint32_t ElementMask = (1u << ElementSizeInBytes) - 1u;
3647
3648 for (SDValue Op : reverse(C: Mask->ops())) {
3649 // We technically only want to do this shift for every
3650 // iteration *but* the first, but in the first iteration UsedBytesMask is 0,
3651 // so this shift is a no-op.
3652 UsedBytesMask <<= ElementSizeInBytes;
3653
3654 // Mask elements must be constants.
3655 if (Op->getAsZExtVal() != 0)
3656 UsedBytesMask |= ElementMask;
3657 }
3658
3659 assert(UsedBytesMask != 0 && UsedBytesMask != UINT32_MAX &&
3660 "Unexpected masked load with elements masked all on or all off");
3661
3662 // Create a new load sd node to be handled normally by ReplaceLoadVector.
3663 MemSDNode *NewLD = cast<MemSDNode>(
3664 Val: DAG.getLoad(VT: ResVT, dl: DL, Chain, Ptr: BasePtr, MMO: N->getMemOperand()).getNode());
3665
3666 // If our subtarget does not support the used bytes mask pragma, "drop" the
3667 // mask by setting it to UINT32_MAX
3668 if (!STI.hasUsedBytesMaskPragma())
3669 UsedBytesMask = UINT32_MAX;
3670
3671 return {NewLD, UsedBytesMask};
3672}
3673
3674/// replaceLoadVector - Convert vector loads into multi-output scalar loads.
3675static std::optional<std::pair<SDValue, SDValue>>
3676replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
3677 MemSDNode *LD = cast<MemSDNode>(Val: N);
3678 const EVT ResVT = LD->getValueType(ResNo: 0);
3679 const EVT MemVT = LD->getMemoryVT();
3680
3681 // If we're doing sign/zero extension as part of the load, avoid lowering to
3682 // a LoadV node. TODO: consider relaxing this restriction.
3683 if (ResVT != MemVT)
3684 return std::nullopt;
3685
3686 const auto NumEltsAndEltVT =
3687 getVectorLoweringShape(VectorEVT: ResVT, STI, AddressSpace: LD->getAddressSpace());
3688 if (!NumEltsAndEltVT)
3689 return std::nullopt;
3690 const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
3691
3692 Align Alignment = LD->getAlign();
3693 const auto &TD = DAG.getDataLayout();
3694 Align PrefAlign = TD.getPrefTypeAlign(Ty: MemVT.getTypeForEVT(Context&: *DAG.getContext()));
3695 if (Alignment < PrefAlign) {
3696 // This load is not sufficiently aligned, so bail out and let this vector
3697 // load be scalarized. Note that we may still be able to emit smaller
3698 // vector loads. For example, if we are loading a <4 x float> with an
3699 // alignment of 8, this check will fail but the legalizer will try again
3700 // with 2 x <2 x float>, which will succeed with an alignment of 8.
3701 return std::nullopt;
3702 }
3703
3704 // If we have a masked load, convert it to a normal load now
3705 std::optional<uint32_t> UsedBytesMask = std::nullopt;
3706 if (LD->getOpcode() == ISD::MLOAD)
3707 std::tie(args&: LD, args&: UsedBytesMask) =
3708 convertMLOADToLoadWithUsedBytesMask(N: LD, DAG, STI);
3709
3710 // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
3711 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
3712 // loaded type to i16 and propagate the "real" type as the memory type.
3713 const MVT LoadEltVT = (EltVT.getSizeInBits() < 16) ? MVT::i16 : EltVT;
3714
3715 unsigned Opcode;
3716 switch (NumElts) {
3717 default:
3718 return std::nullopt;
3719 case 2:
3720 Opcode = NVPTXISD::LoadV2;
3721 break;
3722 case 4:
3723 Opcode = NVPTXISD::LoadV4;
3724 break;
3725 case 8:
3726 Opcode = NVPTXISD::LoadV8;
3727 break;
3728 }
3729 auto ListVTs = SmallVector<EVT, 9>(NumElts, LoadEltVT);
3730 ListVTs.push_back(Elt: MVT::Other);
3731 SDVTList LdResVTs = DAG.getVTList(VTs: ListVTs);
3732
3733 SDLoc DL(LD);
3734
3735 // Copy regular operands
3736 SmallVector<SDValue, 8> OtherOps(LD->ops());
3737
3738 OtherOps.push_back(
3739 Elt: DAG.getConstant(Val: UsedBytesMask.value_or(UINT32_MAX), DL, VT: MVT::i32));
3740
3741 // The select routine does not have access to the LoadSDNode instance, so
3742 // pass along the extension information
3743 OtherOps.push_back(
3744 Elt: DAG.getIntPtrConstant(Val: cast<LoadSDNode>(Val: LD)->getExtensionType(), DL));
3745
3746 SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, dl: DL, VTList: LdResVTs, Ops: OtherOps, MemVT,
3747 MMO: LD->getMemOperand());
3748
3749 SmallVector<SDValue> ScalarRes;
3750 if (EltVT.isVector()) {
3751 assert(EVT(EltVT.getVectorElementType()) == ResVT.getVectorElementType());
3752 assert(NumElts * EltVT.getVectorNumElements() ==
3753 ResVT.getVectorNumElements());
3754 // Generate EXTRACT_VECTOR_ELTs to split v2[i,f,bf]16/v4i8 subvectors back
3755 // into individual elements.
3756 for (const unsigned I : llvm::seq(Size: NumElts)) {
3757 SDValue SubVector = NewLD.getValue(R: I);
3758 DAG.ExtractVectorElements(Op: SubVector, Args&: ScalarRes);
3759 }
3760 } else {
3761 for (const unsigned I : llvm::seq(Size: NumElts)) {
3762 SDValue Res = NewLD.getValue(R: I);
3763 if (LoadEltVT != EltVT)
3764 Res = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: EltVT, Operand: Res);
3765 ScalarRes.push_back(Elt: Res);
3766 }
3767 }
3768
3769 SDValue LoadChain = NewLD.getValue(R: NumElts);
3770
3771 const MVT BuildVecVT =
3772 MVT::getVectorVT(VT: EltVT.getScalarType(), NumElements: ScalarRes.size());
3773 SDValue BuildVec = DAG.getBuildVector(VT: BuildVecVT, DL, Ops: ScalarRes);
3774 SDValue LoadValue = DAG.getBitcast(VT: ResVT, V: BuildVec);
3775
3776 return {{LoadValue, LoadChain}};
3777}
3778
3779static void replaceLoadVector(SDNode *N, SelectionDAG &DAG,
3780 SmallVectorImpl<SDValue> &Results,
3781 const NVPTXSubtarget &STI) {
3782 if (auto Res = replaceLoadVector(N, DAG, STI))
3783 Results.append(IL: {Res->first, Res->second});
3784}
3785
3786static SDValue lowerLoadVector(SDNode *N, SelectionDAG &DAG,
3787 const NVPTXSubtarget &STI) {
3788 if (auto Res = replaceLoadVector(N, DAG, STI))
3789 return DAG.getMergeValues(Ops: {Res->first, Res->second}, dl: SDLoc(N));
3790 return SDValue();
3791}
3792
3793// v = ld i1* addr
3794// =>
3795// v1 = ld i8* addr (-> i16)
3796// v = trunc i16 to i1
3797static SDValue lowerLOADi1(LoadSDNode *LD, SelectionDAG &DAG) {
3798 SDLoc dl(LD);
3799 assert(LD->getExtensionType() == ISD::NON_EXTLOAD);
3800 assert(LD->getValueType(0) == MVT::i1 && "Custom lowering for i1 load only");
3801 SDValue newLD = DAG.getExtLoad(ExtType: ISD::ZEXTLOAD, dl, VT: MVT::i16, Chain: LD->getChain(),
3802 Ptr: LD->getBasePtr(), PtrInfo: LD->getPointerInfo(),
3803 MemVT: MVT::i8, Alignment: LD->getAlign(),
3804 MMOFlags: LD->getMemOperand()->getFlags());
3805 SDValue result = DAG.getNode(Opcode: ISD::TRUNCATE, DL: dl, VT: MVT::i1, Operand: newLD);
3806 // The legalizer (the caller) is expecting two values from the legalized
3807 // load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
3808 // in LegalizeDAG.cpp which also uses MergeValues.
3809 return DAG.getMergeValues(Ops: {result, LD->getChain()}, dl);
3810}
3811
3812SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
3813 LoadSDNode *LD = cast<LoadSDNode>(Val&: Op);
3814
3815 if (Op.getValueType() == MVT::i1)
3816 return lowerLOADi1(LD, DAG);
3817
3818 // To improve CodeGen we'll legalize any-extend loads to zext loads. This is
3819 // how they'll be lowered in ISel anyway, and by doing this a little earlier
3820 // we allow for more DAG combine opportunities.
3821 if (LD->getExtensionType() == ISD::EXTLOAD) {
3822 assert(LD->getValueType(0).isInteger() && LD->getMemoryVT().isInteger() &&
3823 "Unexpected fpext-load");
3824 return DAG.getExtLoad(ExtType: ISD::ZEXTLOAD, dl: SDLoc(Op), VT: Op.getValueType(),
3825 Chain: LD->getChain(), Ptr: LD->getBasePtr(), MemVT: LD->getMemoryVT(),
3826 MMO: LD->getMemOperand());
3827 }
3828
3829 llvm_unreachable("Unexpected custom lowering for load");
3830}
3831
3832SDValue NVPTXTargetLowering::LowerMLOAD(SDValue Op, SelectionDAG &DAG) const {
3833 // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
3834 // masked loads of these types and have to handle them here.
3835 // v2f32 also needs to be handled here if the subtarget has f32x2
3836 // instructions, making it legal.
3837 //
3838 // Note: misaligned masked loads should never reach this point
3839 // because the override of isLegalMaskedLoad in NVPTXTargetTransformInfo.cpp
3840 // will validate alignment. Therefore, we do not need to special case handle
3841 // them here.
3842 EVT VT = Op.getValueType();
3843 if (NVPTX::isPackedVectorTy(VT)) {
3844 auto Result = convertMLOADToLoadWithUsedBytesMask(
3845 N: cast<MemSDNode>(Val: Op.getNode()), DAG, STI);
3846 MemSDNode *LD = std::get<0>(in&: Result);
3847 uint32_t UsedBytesMask = std::get<1>(in&: Result);
3848
3849 SDLoc DL(LD);
3850
3851 // Copy regular operands
3852 SmallVector<SDValue, 8> OtherOps(LD->ops());
3853
3854 OtherOps.push_back(Elt: DAG.getConstant(Val: UsedBytesMask, DL, VT: MVT::i32));
3855
3856 // We currently are not lowering extending loads, but pass the extension
3857 // type anyway as later handling expects it.
3858 OtherOps.push_back(
3859 Elt: DAG.getIntPtrConstant(Val: cast<LoadSDNode>(Val: LD)->getExtensionType(), DL));
3860 SDValue NewLD =
3861 DAG.getMemIntrinsicNode(Opcode: NVPTXISD::MLoad, dl: DL, VTList: LD->getVTList(), Ops: OtherOps,
3862 MemVT: LD->getMemoryVT(), MMO: LD->getMemOperand());
3863 return NewLD;
3864 }
3865 return SDValue();
3866}
3867
3868static SDValue lowerSTOREVector(SDValue Op, SelectionDAG &DAG,
3869 const NVPTXSubtarget &STI) {
3870 MemSDNode *N = cast<MemSDNode>(Val: Op.getNode());
3871 SDValue Val = N->getOperand(Num: 1);
3872 SDLoc DL(N);
3873 const EVT ValVT = Val.getValueType();
3874 const EVT MemVT = N->getMemoryVT();
3875
3876 // If we're truncating as part of the store, avoid lowering to a StoreV node.
3877 // TODO: consider relaxing this restriction.
3878 if (ValVT != MemVT)
3879 return SDValue();
3880
3881 const auto NumEltsAndEltVT =
3882 getVectorLoweringShape(VectorEVT: ValVT, STI, AddressSpace: N->getAddressSpace());
3883 if (!NumEltsAndEltVT)
3884 return SDValue();
3885 const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
3886
3887 const DataLayout &TD = DAG.getDataLayout();
3888
3889 Align Alignment = N->getAlign();
3890 Align PrefAlign = TD.getPrefTypeAlign(Ty: ValVT.getTypeForEVT(Context&: *DAG.getContext()));
3891 if (Alignment < PrefAlign) {
3892 // This store is not sufficiently aligned, so bail out and let this vector
3893 // store be scalarized. Note that we may still be able to emit smaller
3894 // vector stores. For example, if we are storing a <4 x float> with an
3895 // alignment of 8, this check will fail but the legalizer will try again
3896 // with 2 x <2 x float>, which will succeed with an alignment of 8.
3897 return SDValue();
3898 }
3899
3900 unsigned Opcode;
3901 switch (NumElts) {
3902 default:
3903 return SDValue();
3904 case 2:
3905 Opcode = NVPTXISD::StoreV2;
3906 break;
3907 case 4:
3908 Opcode = NVPTXISD::StoreV4;
3909 break;
3910 case 8:
3911 Opcode = NVPTXISD::StoreV8;
3912 break;
3913 }
3914
3915 SmallVector<SDValue, 8> Ops;
3916
3917 // First is the chain
3918 Ops.push_back(Elt: N->getOperand(Num: 0));
3919
3920 // Then the split values
3921 if (EltVT.isVector()) {
3922 assert(EVT(EltVT.getVectorElementType()) == ValVT.getVectorElementType());
3923 assert(NumElts * EltVT.getVectorNumElements() ==
3924 ValVT.getVectorNumElements());
3925 // Combine individual elements into v2[i,f,bf]16/v4i8 subvectors to be
3926 // stored as b32s
3927 const unsigned NumEltsPerSubVector = EltVT.getVectorNumElements();
3928 for (const unsigned I : llvm::seq(Size: NumElts)) {
3929 SmallVector<SDValue, 4> SubVectorElts;
3930 DAG.ExtractVectorElements(Op: Val, Args&: SubVectorElts, Start: I * NumEltsPerSubVector,
3931 Count: NumEltsPerSubVector);
3932 Ops.push_back(Elt: DAG.getBuildVector(VT: EltVT, DL, Ops: SubVectorElts));
3933 }
3934 } else {
3935 SDValue V = DAG.getBitcast(VT: MVT::getVectorVT(VT: EltVT, NumElements: NumElts), V: Val);
3936 for (const unsigned I : llvm::seq(Size: NumElts)) {
3937 SDValue ExtVal = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: V,
3938 N2: DAG.getIntPtrConstant(Val: I, DL));
3939
3940 // Since StoreV2 is a target node, we cannot rely on DAG type
3941 // legalization. Therefore, we must ensure the type is legal. For i1 and
3942 // i8, we set the stored type to i16 and propagate the "real" type as the
3943 // memory type.
3944 if (EltVT.getSizeInBits() < 16)
3945 ExtVal = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT: MVT::i16, Operand: ExtVal);
3946 Ops.push_back(Elt: ExtVal);
3947 }
3948 }
3949
3950 // Then any remaining arguments
3951 Ops.append(in_start: N->op_begin() + 2, in_end: N->op_end());
3952
3953 SDValue NewSt =
3954 DAG.getMemIntrinsicNode(Opcode, dl: DL, VTList: DAG.getVTList(VT: MVT::Other), Ops,
3955 MemVT: N->getMemoryVT(), MMO: N->getMemOperand());
3956
3957 // return DCI.CombineTo(N, NewSt, true);
3958 return NewSt;
3959}
3960
3961SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
3962 StoreSDNode *Store = cast<StoreSDNode>(Val&: Op);
3963 EVT VT = Store->getMemoryVT();
3964
3965 if (VT == MVT::i1)
3966 return LowerSTOREi1(Op, DAG);
3967
3968 // Lower store of any other vector type, including v2f32 as we want to break
3969 // it apart since this is not a widely-supported type.
3970 return lowerSTOREVector(Op, DAG, STI);
3971}
3972
3973// st i1 v, addr
3974// =>
3975// v1 = zxt v to i16
3976// st.u8 i16, addr
3977SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
3978 SDNode *Node = Op.getNode();
3979 SDLoc dl(Node);
3980 StoreSDNode *ST = cast<StoreSDNode>(Val: Node);
3981 SDValue Tmp1 = ST->getChain();
3982 SDValue Tmp2 = ST->getBasePtr();
3983 SDValue Tmp3 = ST->getValue();
3984 assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only");
3985 Tmp3 = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL: dl, VT: MVT::i16, Operand: Tmp3);
3986 SDValue Result =
3987 DAG.getTruncStore(Chain: Tmp1, dl, Val: Tmp3, Ptr: Tmp2, PtrInfo: ST->getPointerInfo(), SVT: MVT::i8,
3988 Alignment: ST->getAlign(), MMOFlags: ST->getMemOperand()->getFlags());
3989 return Result;
3990}
3991
3992SDValue NVPTXTargetLowering::LowerCopyToReg_128(SDValue Op,
3993 SelectionDAG &DAG) const {
3994 // Change the CopyToReg to take in two 64-bit operands instead of a 128-bit
3995 // operand so that it can pass the legalization.
3996
3997 assert(Op.getOperand(1).getValueType() == MVT::i128 &&
3998 "Custom lowering for 128-bit CopyToReg only");
3999
4000 SDNode *Node = Op.getNode();
4001 SDLoc DL(Node);
4002
4003 SDValue Cast = DAG.getBitcast(VT: MVT::v2i64, V: Op->getOperand(Num: 2));
4004 SDValue Lo = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i64, N1: Cast,
4005 N2: DAG.getIntPtrConstant(Val: 0, DL));
4006 SDValue Hi = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i64, N1: Cast,
4007 N2: DAG.getIntPtrConstant(Val: 1, DL));
4008
4009 SmallVector<SDValue, 5> NewOps(Op->getNumOperands() + 1);
4010 SmallVector<EVT, 3> ResultsType(Node->values());
4011
4012 NewOps[0] = Op->getOperand(Num: 0); // Chain
4013 NewOps[1] = Op->getOperand(Num: 1); // Dst Reg
4014 NewOps[2] = Lo; // Lower 64-bit
4015 NewOps[3] = Hi; // Higher 64-bit
4016 if (Op.getNumOperands() == 4)
4017 NewOps[4] = Op->getOperand(Num: 3); // Glue if exists
4018
4019 return DAG.getNode(Opcode: ISD::CopyToReg, DL, ResultTys: ResultsType, Ops: NewOps);
4020}
4021
4022unsigned NVPTXTargetLowering::getNumRegisters(
4023 LLVMContext &Context, EVT VT,
4024 std::optional<MVT> RegisterVT = std::nullopt) const {
4025 if (VT == MVT::i128 && RegisterVT == MVT::i128)
4026 return 1;
4027 return TargetLoweringBase::getNumRegisters(Context, VT, RegisterVT);
4028}
4029
4030bool NVPTXTargetLowering::splitValueIntoRegisterParts(
4031 SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
4032 unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
4033 if (Val.getValueType() == MVT::i128 && NumParts == 1) {
4034 Parts[0] = Val;
4035 return true;
4036 }
4037 return false;
4038}
4039
4040// This creates target external symbol for a function parameter.
4041// Name of the symbol is composed from its index and the function name.
4042// Negative index corresponds to special parameter (unsized array) used for
4043// passing variable arguments.
4044SDValue NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int I,
4045 EVT T) const {
4046 StringRef SavedStr = nvTM->getStrPool().save(
4047 S: getParamName(F: &DAG.getMachineFunction().getFunction(), Idx: I));
4048 return DAG.getExternalSymbol(Sym: SavedStr.data(), VT: T);
4049}
4050
4051SDValue NVPTXTargetLowering::getCallParamSymbol(SelectionDAG &DAG, int I,
4052 EVT T) const {
4053 const StringRef SavedStr = nvTM->getStrPool().save(S: "param" + Twine(I));
4054 return DAG.getExternalSymbol(Sym: SavedStr.data(), VT: T);
4055}
4056
4057SDValue NVPTXTargetLowering::LowerFormalArguments(
4058 SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
4059 const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &dl,
4060 SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
4061 const DataLayout &DL = DAG.getDataLayout();
4062 LLVMContext &Ctx = *DAG.getContext();
4063 auto PtrVT = getPointerTy(DL: DAG.getDataLayout());
4064
4065 const Function &F = DAG.getMachineFunction().getFunction();
4066
4067 SDValue Root = DAG.getRoot();
4068 SmallVector<SDValue, 16> OutChains;
4069
4070 // argTypes.size() (or theArgs.size()) and Ins.size() need not match.
4071 // Ins.size() will be larger
4072 // * if there is an aggregate argument with multiple fields (each field
4073 // showing up separately in Ins)
4074 // * if there is a vector argument with more than typical vector-length
4075 // elements (generally if more than 4) where each vector element is
4076 // individually present in Ins.
4077 // So a different index should be used for indexing into Ins.
4078 // See similar issue in LowerCall.
4079
4080 auto AllIns = ArrayRef(Ins);
4081 for (const auto &Arg : F.args()) {
4082 const auto ArgIns = AllIns.take_while(
4083 Pred: [&](auto I) { return I.OrigArgIndex == Arg.getArgNo(); });
4084 AllIns = AllIns.drop_front(N: ArgIns.size());
4085
4086 Type *Ty = Arg.getType();
4087
4088 if (ArgIns.empty())
4089 report_fatal_error(reason: "Empty parameter types are not supported");
4090
4091 if (Arg.use_empty()) {
4092 // argument is dead
4093 for (const auto &In : ArgIns) {
4094 assert(!In.Used && "Arg.use_empty() is true but Arg is used?");
4095 InVals.push_back(Elt: DAG.getUNDEF(VT: In.VT));
4096 }
4097 continue;
4098 }
4099
4100 SDValue ArgSymbol = getParamSymbol(DAG, I: Arg.getArgNo(), T: PtrVT);
4101
4102 // In the following cases, assign a node order of "i+1"
4103 // to newly created nodes. The SDNodes for params have to
4104 // appear in the same order as their order of appearance
4105 // in the original function. "i+1" holds that order.
4106 if (Arg.hasByValAttr()) {
4107 // Param has ByVal attribute
4108 // Return MoveParam(param symbol).
4109 // Ideally, the param symbol can be returned directly,
4110 // but when SDNode builder decides to use it in a CopyToReg(),
4111 // machine instruction fails because TargetExternalSymbol
4112 // (not lowered) is target dependent, and CopyToReg assumes
4113 // the source is lowered.
4114 assert(ArgIns.size() == 1 && "ByVal argument must be a pointer");
4115 const auto &ByvalIn = ArgIns[0];
4116 assert(getValueType(DL, Ty) == ByvalIn.VT &&
4117 "Ins type did not match function type");
4118 assert(ByvalIn.VT == PtrVT && "ByVal argument must be a pointer");
4119
4120 SDValue P;
4121 if (isKernelFunction(F)) {
4122 assert(isParamGridConstant(Arg) && "ByVal argument must be lowered to "
4123 "grid_constant by NVPTXLowerArgs");
4124 P = ArgSymbol;
4125 P.getNode()->setIROrder(Arg.getArgNo() + 1);
4126 } else {
4127 P = DAG.getNode(Opcode: NVPTXISD::MoveParam, DL: dl, VT: ByvalIn.VT, Operand: ArgSymbol);
4128 P.getNode()->setIROrder(Arg.getArgNo() + 1);
4129 P = DAG.getAddrSpaceCast(dl, VT: ByvalIn.VT, Ptr: P, SrcAS: ADDRESS_SPACE_LOCAL,
4130 DestAS: ADDRESS_SPACE_GENERIC);
4131 }
4132 InVals.push_back(Elt: P);
4133 } else {
4134 SmallVector<EVT, 16> VTs;
4135 SmallVector<uint64_t, 16> Offsets;
4136 ComputePTXValueVTs(TLI: *this, DL, Ctx, CallConv, Ty, ValueVTs&: VTs, Offsets);
4137 assert(VTs.size() == ArgIns.size() && "Size mismatch");
4138 assert(VTs.size() == Offsets.size() && "Size mismatch");
4139
4140 const Align ArgAlign = getFunctionArgumentAlignment(
4141 F: &F, Ty, Idx: Arg.getArgNo() + AttributeList::FirstArgIndex, DL);
4142
4143 unsigned I = 0;
4144 const auto VI = VectorizePTXValueVTs(ValueVTs: VTs, Offsets, ParamAlignment: ArgAlign);
4145 for (const unsigned NumElts : VI) {
4146 // i1 is loaded/stored as i8
4147 const EVT LoadVT = VTs[I] == MVT::i1 ? MVT::i8 : VTs[I];
4148 const EVT VecVT = getVectorizedVT(VT: LoadVT, N: NumElts, C&: Ctx);
4149
4150 SDValue VecAddr = DAG.getObjectPtrOffset(
4151 SL: dl, Ptr: ArgSymbol, Offset: TypeSize::getFixed(ExactSize: Offsets[I]));
4152
4153 const Align PartAlign = commonAlignment(A: ArgAlign, Offset: Offsets[I]);
4154 SDValue P =
4155 DAG.getLoad(VT: VecVT, dl, Chain: Root, Ptr: VecAddr,
4156 PtrInfo: MachinePointerInfo(ADDRESS_SPACE_PARAM), Alignment: PartAlign,
4157 MMOFlags: MachineMemOperand::MODereferenceable |
4158 MachineMemOperand::MOInvariant);
4159 P.getNode()->setIROrder(Arg.getArgNo() + 1);
4160 for (const unsigned J : llvm::seq(Size: NumElts)) {
4161 SDValue Elt = getExtractVectorizedValue(V: P, I: J, VT: LoadVT, dl, DAG);
4162
4163 Elt = correctParamType(V: Elt, ExpectedVT: ArgIns[I + J].VT, Flags: ArgIns[I + J].Flags,
4164 DAG, dl);
4165 InVals.push_back(Elt);
4166 }
4167 I += NumElts;
4168 }
4169 }
4170 }
4171
4172 if (!OutChains.empty())
4173 DAG.setRoot(DAG.getTokenFactor(DL: dl, Vals&: OutChains));
4174
4175 return Chain;
4176}
4177
4178SDValue
4179NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
4180 bool isVarArg,
4181 const SmallVectorImpl<ISD::OutputArg> &Outs,
4182 const SmallVectorImpl<SDValue> &OutVals,
4183 const SDLoc &dl, SelectionDAG &DAG) const {
4184 const Function &F = DAG.getMachineFunction().getFunction();
4185 Type *RetTy = F.getReturnType();
4186
4187 if (RetTy->isVoidTy()) {
4188 assert(OutVals.empty() && Outs.empty() && "Return value expected for void");
4189 return DAG.getNode(Opcode: NVPTXISD::RET_GLUE, DL: dl, VT: MVT::Other, Operand: Chain);
4190 }
4191
4192 const DataLayout &DL = DAG.getDataLayout();
4193 LLVMContext &Ctx = *DAG.getContext();
4194
4195 const SDValue RetSymbol = DAG.getExternalSymbol(Sym: "func_retval0", VT: MVT::i32);
4196 const auto RetAlign = getFunctionParamOptimizedAlign(F: &F, ArgTy: RetTy, DL);
4197
4198 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
4199 // 32-bits are sign extended or zero extended, depending on whether
4200 // they are signed or unsigned types.
4201 const bool ExtendIntegerRetVal =
4202 RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty: RetTy) < 32;
4203
4204 SmallVector<EVT, 16> VTs;
4205 SmallVector<uint64_t, 16> Offsets;
4206 ComputePTXValueVTs(TLI: *this, DL, Ctx, CallConv, Ty: RetTy, ValueVTs&: VTs, Offsets);
4207 assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
4208
4209 const auto GetRetVal = [&](unsigned I) -> SDValue {
4210 SDValue RetVal = OutVals[I];
4211 assert(promoteScalarIntegerPTX(RetVal.getValueType()) ==
4212 RetVal.getValueType() &&
4213 "OutVal type should always be legal");
4214
4215 const EVT VTI = promoteScalarIntegerPTX(VT: VTs[I]);
4216 const EVT StoreVT =
4217 ExtendIntegerRetVal ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
4218 return correctParamType(V: RetVal, ExpectedVT: StoreVT, Flags: Outs[I].Flags, DAG, dl);
4219 };
4220
4221 unsigned I = 0;
4222 const auto VI = VectorizePTXValueVTs(ValueVTs: VTs, Offsets, ParamAlignment: RetAlign);
4223 for (const unsigned NumElts : VI) {
4224 const MaybeAlign CurrentAlign = ExtendIntegerRetVal
4225 ? MaybeAlign(std::nullopt)
4226 : commonAlignment(A: RetAlign, Offset: Offsets[I]);
4227
4228 SDValue Val = getBuildVectorizedValue(
4229 N: NumElts, dl, DAG, GetElement: [&](unsigned K) { return GetRetVal(I + K); });
4230
4231 SDValue Ptr =
4232 DAG.getObjectPtrOffset(SL: dl, Ptr: RetSymbol, Offset: TypeSize::getFixed(ExactSize: Offsets[I]));
4233
4234 Chain = DAG.getStore(Chain, dl, Val, Ptr,
4235 PtrInfo: MachinePointerInfo(ADDRESS_SPACE_PARAM), Alignment: CurrentAlign);
4236
4237 I += NumElts;
4238 }
4239
4240 return DAG.getNode(Opcode: NVPTXISD::RET_GLUE, DL: dl, VT: MVT::Other, Operand: Chain);
4241}
4242
4243void NVPTXTargetLowering::LowerAsmOperandForConstraint(
4244 SDValue Op, StringRef Constraint, std::vector<SDValue> &Ops,
4245 SelectionDAG &DAG) const {
4246 if (Constraint.size() > 1)
4247 return;
4248 TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
4249}
4250
4251// llvm.ptx.memcpy.const and llvm.ptx.memmove.const need to be modeled as
4252// TgtMemIntrinsic
4253// because we need the information that is only available in the "Value" type
4254// of destination
4255// pointer. In particular, the address space information.
4256void NVPTXTargetLowering::getTgtMemIntrinsic(
4257 SmallVectorImpl<IntrinsicInfo> &Infos, const CallBase &I,
4258 MachineFunction &MF, unsigned Intrinsic) const {
4259 IntrinsicInfo Info;
4260 switch (Intrinsic) {
4261 default:
4262 return;
4263 case Intrinsic::nvvm_match_all_sync_i32p:
4264 case Intrinsic::nvvm_match_all_sync_i64p:
4265 Info.opc = ISD::INTRINSIC_W_CHAIN;
4266 // memVT is bogus. These intrinsics have IntrInaccessibleMemOnly attribute
4267 // in order to model data exchange with other threads, but perform no real
4268 // memory accesses.
4269 Info.memVT = MVT::i1;
4270
4271 // Our result depends on both our and other thread's arguments.
4272 Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOStore;
4273 Infos.push_back(Elt: Info);
4274 return;
4275 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col:
4276 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row:
4277 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col_stride:
4278 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row_stride:
4279 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col:
4280 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row:
4281 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col_stride:
4282 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride:
4283 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col:
4284 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row:
4285 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col_stride:
4286 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row_stride:
4287 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col:
4288 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row:
4289 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col_stride:
4290 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row_stride:
4291 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col:
4292 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row:
4293 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col_stride:
4294 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row_stride:
4295 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col:
4296 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row:
4297 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col_stride:
4298 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row_stride: {
4299 Info.opc = ISD::INTRINSIC_W_CHAIN;
4300 Info.memVT = MVT::v8f16;
4301 Info.ptrVal = I.getArgOperand(i: 0);
4302 Info.offset = 0;
4303 Info.flags = MachineMemOperand::MOLoad;
4304 Info.align = Align(16);
4305 Infos.push_back(Elt: Info);
4306 return;
4307 }
4308 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col:
4309 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col_stride:
4310 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col_stride:
4311 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col:
4312 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row:
4313 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row_stride:
4314 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row_stride:
4315 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row:
4316 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col:
4317 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col_stride:
4318 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row:
4319 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row_stride:
4320 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col:
4321 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col_stride:
4322 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col_stride:
4323 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col:
4324 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row:
4325 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row_stride:
4326 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row_stride:
4327 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row:
4328 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col:
4329 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col_stride:
4330 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row:
4331 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row_stride: {
4332 Info.opc = ISD::INTRINSIC_W_CHAIN;
4333 Info.memVT = MVT::v2i32;
4334 Info.ptrVal = I.getArgOperand(i: 0);
4335 Info.offset = 0;
4336 Info.flags = MachineMemOperand::MOLoad;
4337 Info.align = Align(8);
4338 Infos.push_back(Elt: Info);
4339 return;
4340 }
4341
4342 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col:
4343 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col_stride:
4344 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col_stride:
4345 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col:
4346 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row:
4347 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row_stride:
4348 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row_stride:
4349 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row:
4350 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col:
4351 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col_stride:
4352 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row:
4353 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row_stride:
4354 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col:
4355 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col_stride:
4356 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row:
4357 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row_stride:
4358
4359 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col:
4360 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col_stride:
4361 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col_stride:
4362 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col:
4363 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row:
4364 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row_stride:
4365 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row_stride:
4366 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row:
4367 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col:
4368 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col_stride:
4369 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row:
4370 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row_stride:
4371 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col:
4372 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col_stride:
4373 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row:
4374 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride:
4375 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16:
4376 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16:
4377 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8:
4378 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64:
4379 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32:
4380 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64:
4381 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32: {
4382 Info.opc = ISD::INTRINSIC_W_CHAIN;
4383 Info.memVT = MVT::v4i32;
4384 Info.ptrVal = I.getArgOperand(i: 0);
4385 Info.offset = 0;
4386 Info.flags = MachineMemOperand::MOLoad;
4387 Info.align = Align(16);
4388 Infos.push_back(Elt: Info);
4389 return;
4390 }
4391
4392 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col:
4393 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col_stride:
4394 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col_stride:
4395 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col:
4396 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row:
4397 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row_stride:
4398 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row_stride:
4399 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row:
4400
4401 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col:
4402 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col_stride:
4403 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col_stride:
4404 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col:
4405 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row:
4406 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row_stride:
4407 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row_stride:
4408 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row:
4409 case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row:
4410 case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row_stride:
4411 case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col:
4412 case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col_stride:
4413 case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row:
4414 case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row_stride:
4415 case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row_stride:
4416 case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row:
4417 case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col:
4418 case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col_stride:
4419 case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col_stride:
4420 case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col:
4421 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16:
4422 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16:
4423 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64:
4424 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32: {
4425 Info.opc = ISD::INTRINSIC_W_CHAIN;
4426 Info.memVT = MVT::i32;
4427 Info.ptrVal = I.getArgOperand(i: 0);
4428 Info.offset = 0;
4429 Info.flags = MachineMemOperand::MOLoad;
4430 Info.align = Align(4);
4431 Infos.push_back(Elt: Info);
4432 return;
4433 }
4434
4435 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col:
4436 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row:
4437 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col_stride:
4438 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride:
4439 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col:
4440 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row:
4441 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col_stride:
4442 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row_stride:
4443 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col:
4444 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row:
4445 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col_stride:
4446 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row_stride: {
4447 Info.opc = ISD::INTRINSIC_W_CHAIN;
4448 Info.memVT = MVT::v4f16;
4449 Info.ptrVal = I.getArgOperand(i: 0);
4450 Info.offset = 0;
4451 Info.flags = MachineMemOperand::MOLoad;
4452 Info.align = Align(16);
4453 Infos.push_back(Elt: Info);
4454 return;
4455 }
4456
4457 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col:
4458 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row:
4459 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col_stride:
4460 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride:
4461 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col:
4462 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row:
4463 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col_stride:
4464 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row_stride:
4465 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col:
4466 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row:
4467 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col_stride:
4468 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride:
4469 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col:
4470 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row:
4471 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col_stride:
4472 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row_stride: {
4473 Info.opc = ISD::INTRINSIC_W_CHAIN;
4474 Info.memVT = MVT::v8f32;
4475 Info.ptrVal = I.getArgOperand(i: 0);
4476 Info.offset = 0;
4477 Info.flags = MachineMemOperand::MOLoad;
4478 Info.align = Align(16);
4479 Infos.push_back(Elt: Info);
4480 return;
4481 }
4482
4483 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col:
4484 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col_stride:
4485 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row:
4486 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row_stride:
4487
4488 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col:
4489 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col_stride:
4490 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row:
4491 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row_stride:
4492
4493 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col:
4494 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col_stride:
4495 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row:
4496 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row_stride:
4497 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col:
4498 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col_stride:
4499 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row:
4500 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row_stride:
4501 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col:
4502 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col_stride:
4503 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row:
4504 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row_stride: {
4505 Info.opc = ISD::INTRINSIC_W_CHAIN;
4506 Info.memVT = MVT::v8i32;
4507 Info.ptrVal = I.getArgOperand(i: 0);
4508 Info.offset = 0;
4509 Info.flags = MachineMemOperand::MOLoad;
4510 Info.align = Align(16);
4511 Infos.push_back(Elt: Info);
4512 return;
4513 }
4514
4515 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col:
4516 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col_stride:
4517 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row:
4518 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row_stride:
4519 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col:
4520 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col_stride:
4521 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row:
4522 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride:
4523 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16:
4524 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16:
4525 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8:
4526 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64:
4527 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32:
4528 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64:
4529 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32: {
4530 Info.opc = ISD::INTRINSIC_W_CHAIN;
4531 Info.memVT = MVT::v2i32;
4532 Info.ptrVal = I.getArgOperand(i: 0);
4533 Info.offset = 0;
4534 Info.flags = MachineMemOperand::MOLoad;
4535 Info.align = Align(8);
4536 Infos.push_back(Elt: Info);
4537 return;
4538 }
4539
4540 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col:
4541 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col_stride:
4542 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row:
4543 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row_stride:
4544
4545 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col:
4546 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col_stride:
4547 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row:
4548 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row_stride: {
4549 Info.opc = ISD::INTRINSIC_W_CHAIN;
4550 Info.memVT = MVT::f64;
4551 Info.ptrVal = I.getArgOperand(i: 0);
4552 Info.offset = 0;
4553 Info.flags = MachineMemOperand::MOLoad;
4554 Info.align = Align(8);
4555 Infos.push_back(Elt: Info);
4556 return;
4557 }
4558
4559 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col:
4560 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col_stride:
4561 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row:
4562 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row_stride: {
4563 Info.opc = ISD::INTRINSIC_W_CHAIN;
4564 Info.memVT = MVT::v2f64;
4565 Info.ptrVal = I.getArgOperand(i: 0);
4566 Info.offset = 0;
4567 Info.flags = MachineMemOperand::MOLoad;
4568 Info.align = Align(16);
4569 Infos.push_back(Elt: Info);
4570 return;
4571 }
4572
4573 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col:
4574 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row:
4575 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride:
4576 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride:
4577 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col:
4578 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row:
4579 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col_stride:
4580 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row_stride:
4581 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col:
4582 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row:
4583 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col_stride:
4584 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row_stride: {
4585 Info.opc = ISD::INTRINSIC_VOID;
4586 Info.memVT = MVT::v4f16;
4587 Info.ptrVal = I.getArgOperand(i: 0);
4588 Info.offset = 0;
4589 Info.flags = MachineMemOperand::MOStore;
4590 Info.align = Align(16);
4591 Infos.push_back(Elt: Info);
4592 return;
4593 }
4594
4595 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col:
4596 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row:
4597 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col_stride:
4598 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride:
4599 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col:
4600 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row:
4601 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col_stride:
4602 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row_stride:
4603 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col:
4604 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row:
4605 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col_stride:
4606 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride:
4607 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col:
4608 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row:
4609 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col_stride:
4610 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row_stride: {
4611 Info.opc = ISD::INTRINSIC_VOID;
4612 Info.memVT = MVT::v8f32;
4613 Info.ptrVal = I.getArgOperand(i: 0);
4614 Info.offset = 0;
4615 Info.flags = MachineMemOperand::MOStore;
4616 Info.align = Align(16);
4617 Infos.push_back(Elt: Info);
4618 return;
4619 }
4620
4621 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col:
4622 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col_stride:
4623 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row:
4624 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row_stride:
4625 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col:
4626 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col_stride:
4627 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row:
4628 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row_stride:
4629 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col:
4630 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col_stride:
4631 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row:
4632 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row_stride: {
4633 Info.opc = ISD::INTRINSIC_VOID;
4634 Info.memVT = MVT::v8i32;
4635 Info.ptrVal = I.getArgOperand(i: 0);
4636 Info.offset = 0;
4637 Info.flags = MachineMemOperand::MOStore;
4638 Info.align = Align(16);
4639 Infos.push_back(Elt: Info);
4640 return;
4641 }
4642
4643 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col:
4644 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col_stride:
4645 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row:
4646 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row_stride:
4647 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col:
4648 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride:
4649 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row:
4650 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride:
4651 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16:
4652 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16:
4653 case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8: {
4654 Info.opc = ISD::INTRINSIC_VOID;
4655 Info.memVT = MVT::v2i32;
4656 Info.ptrVal = I.getArgOperand(i: 0);
4657 Info.offset = 0;
4658 Info.flags = MachineMemOperand::MOStore;
4659 Info.align = Align(8);
4660 Infos.push_back(Elt: Info);
4661 return;
4662 }
4663
4664 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col:
4665 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col_stride:
4666 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row:
4667 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row_stride: {
4668 Info.opc = ISD::INTRINSIC_VOID;
4669 Info.memVT = MVT::v2f64;
4670 Info.ptrVal = I.getArgOperand(i: 0);
4671 Info.offset = 0;
4672 Info.flags = MachineMemOperand::MOStore;
4673 Info.align = Align(16);
4674 Infos.push_back(Elt: Info);
4675 return;
4676 }
4677
4678 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16:
4679 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16:
4680 case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8: {
4681 Info.opc = ISD::INTRINSIC_VOID;
4682 Info.memVT = MVT::i32;
4683 Info.ptrVal = I.getArgOperand(i: 0);
4684 Info.offset = 0;
4685 Info.flags = MachineMemOperand::MOStore;
4686 Info.align = Align(4);
4687 Infos.push_back(Elt: Info);
4688 return;
4689 }
4690
4691 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16:
4692 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16:
4693 case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8: {
4694 Info.opc = ISD::INTRINSIC_VOID;
4695 Info.memVT = MVT::v4i32;
4696 Info.ptrVal = I.getArgOperand(i: 0);
4697 Info.offset = 0;
4698 Info.flags = MachineMemOperand::MOStore;
4699 Info.align = Align(16);
4700 Infos.push_back(Elt: Info);
4701 return;
4702 }
4703
4704 case Intrinsic::nvvm_atomic_add_gen_f_cta:
4705 case Intrinsic::nvvm_atomic_add_gen_f_sys:
4706 case Intrinsic::nvvm_atomic_add_gen_i_cta:
4707 case Intrinsic::nvvm_atomic_add_gen_i_sys:
4708 case Intrinsic::nvvm_atomic_and_gen_i_cta:
4709 case Intrinsic::nvvm_atomic_and_gen_i_sys:
4710 case Intrinsic::nvvm_atomic_cas_gen_i_cta:
4711 case Intrinsic::nvvm_atomic_cas_gen_i_sys:
4712 case Intrinsic::nvvm_atomic_dec_gen_i_cta:
4713 case Intrinsic::nvvm_atomic_dec_gen_i_sys:
4714 case Intrinsic::nvvm_atomic_inc_gen_i_cta:
4715 case Intrinsic::nvvm_atomic_inc_gen_i_sys:
4716 case Intrinsic::nvvm_atomic_max_gen_i_cta:
4717 case Intrinsic::nvvm_atomic_max_gen_i_sys:
4718 case Intrinsic::nvvm_atomic_min_gen_i_cta:
4719 case Intrinsic::nvvm_atomic_min_gen_i_sys:
4720 case Intrinsic::nvvm_atomic_or_gen_i_cta:
4721 case Intrinsic::nvvm_atomic_or_gen_i_sys:
4722 case Intrinsic::nvvm_atomic_exch_gen_i_cta:
4723 case Intrinsic::nvvm_atomic_exch_gen_i_sys:
4724 case Intrinsic::nvvm_atomic_xor_gen_i_cta:
4725 case Intrinsic::nvvm_atomic_xor_gen_i_sys: {
4726 auto &DL = I.getDataLayout();
4727 Info.opc = ISD::INTRINSIC_W_CHAIN;
4728 Info.memVT = getValueType(DL, Ty: I.getType());
4729 Info.ptrVal = I.getArgOperand(i: 0);
4730 Info.offset = 0;
4731 Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOStore;
4732 Info.align.reset();
4733 Infos.push_back(Elt: Info);
4734 return;
4735 }
4736
4737 case Intrinsic::nvvm_prefetch_tensormap: {
4738 auto &DL = I.getDataLayout();
4739 Info.opc = ISD::INTRINSIC_VOID;
4740 Info.memVT = getPointerTy(DL);
4741 Info.ptrVal = I.getArgOperand(i: 0);
4742 Info.offset = 0;
4743 Info.flags =
4744 MachineMemOperand::MOLoad | MachineMemOperand::MODereferenceable;
4745 Info.align.reset();
4746 Infos.push_back(Elt: Info);
4747 return;
4748 }
4749
4750 case Intrinsic::nvvm_tensormap_replace_global_address:
4751 case Intrinsic::nvvm_tensormap_replace_global_stride: {
4752 Info.opc = ISD::INTRINSIC_VOID;
4753 Info.memVT = MVT::i64;
4754 Info.ptrVal = I.getArgOperand(i: 0);
4755 Info.offset = 0;
4756 Info.flags = MachineMemOperand::MOStore;
4757 Info.align.reset();
4758 Infos.push_back(Elt: Info);
4759 return;
4760 }
4761
4762 case Intrinsic::nvvm_tensormap_replace_rank:
4763 case Intrinsic::nvvm_tensormap_replace_box_dim:
4764 case Intrinsic::nvvm_tensormap_replace_global_dim:
4765 case Intrinsic::nvvm_tensormap_replace_element_stride:
4766 case Intrinsic::nvvm_tensormap_replace_elemtype:
4767 case Intrinsic::nvvm_tensormap_replace_interleave_layout:
4768 case Intrinsic::nvvm_tensormap_replace_swizzle_mode:
4769 case Intrinsic::nvvm_tensormap_replace_swizzle_atomicity:
4770 case Intrinsic::nvvm_tensormap_replace_fill_mode: {
4771 Info.opc = ISD::INTRINSIC_VOID;
4772 Info.memVT = MVT::i32;
4773 Info.ptrVal = I.getArgOperand(i: 0);
4774 Info.offset = 0;
4775 Info.flags = MachineMemOperand::MOStore;
4776 Info.align.reset();
4777 Infos.push_back(Elt: Info);
4778 return;
4779 }
4780
4781 case Intrinsic::nvvm_ldu_global_i:
4782 case Intrinsic::nvvm_ldu_global_f:
4783 case Intrinsic::nvvm_ldu_global_p: {
4784 Info.opc = ISD::INTRINSIC_W_CHAIN;
4785 Info.memVT = getValueType(DL: I.getDataLayout(), Ty: I.getType());
4786 Info.ptrVal = I.getArgOperand(i: 0);
4787 Info.offset = 0;
4788 Info.flags = MachineMemOperand::MOLoad;
4789 Info.align = cast<ConstantInt>(Val: I.getArgOperand(i: 1))->getMaybeAlignValue();
4790
4791 Infos.push_back(Elt: Info);
4792 return;
4793 }
4794 case Intrinsic::nvvm_tex_1d_v4f32_s32:
4795 case Intrinsic::nvvm_tex_1d_v4f32_f32:
4796 case Intrinsic::nvvm_tex_1d_level_v4f32_f32:
4797 case Intrinsic::nvvm_tex_1d_grad_v4f32_f32:
4798 case Intrinsic::nvvm_tex_1d_array_v4f32_s32:
4799 case Intrinsic::nvvm_tex_1d_array_v4f32_f32:
4800 case Intrinsic::nvvm_tex_1d_array_level_v4f32_f32:
4801 case Intrinsic::nvvm_tex_1d_array_grad_v4f32_f32:
4802 case Intrinsic::nvvm_tex_2d_v4f32_s32:
4803 case Intrinsic::nvvm_tex_2d_v4f32_f32:
4804 case Intrinsic::nvvm_tex_2d_level_v4f32_f32:
4805 case Intrinsic::nvvm_tex_2d_grad_v4f32_f32:
4806 case Intrinsic::nvvm_tex_2d_array_v4f32_s32:
4807 case Intrinsic::nvvm_tex_2d_array_v4f32_f32:
4808 case Intrinsic::nvvm_tex_2d_array_level_v4f32_f32:
4809 case Intrinsic::nvvm_tex_2d_array_grad_v4f32_f32:
4810 case Intrinsic::nvvm_tex_3d_v4f32_s32:
4811 case Intrinsic::nvvm_tex_3d_v4f32_f32:
4812 case Intrinsic::nvvm_tex_3d_level_v4f32_f32:
4813 case Intrinsic::nvvm_tex_3d_grad_v4f32_f32:
4814 case Intrinsic::nvvm_tex_cube_v4f32_f32:
4815 case Intrinsic::nvvm_tex_cube_level_v4f32_f32:
4816 case Intrinsic::nvvm_tex_cube_array_v4f32_f32:
4817 case Intrinsic::nvvm_tex_cube_array_level_v4f32_f32:
4818 case Intrinsic::nvvm_tld4_r_2d_v4f32_f32:
4819 case Intrinsic::nvvm_tld4_g_2d_v4f32_f32:
4820 case Intrinsic::nvvm_tld4_b_2d_v4f32_f32:
4821 case Intrinsic::nvvm_tld4_a_2d_v4f32_f32:
4822 case Intrinsic::nvvm_tex_unified_1d_v4f32_s32:
4823 case Intrinsic::nvvm_tex_unified_1d_v4f32_f32:
4824 case Intrinsic::nvvm_tex_unified_1d_level_v4f32_f32:
4825 case Intrinsic::nvvm_tex_unified_1d_grad_v4f32_f32:
4826 case Intrinsic::nvvm_tex_unified_1d_array_v4f32_s32:
4827 case Intrinsic::nvvm_tex_unified_1d_array_v4f32_f32:
4828 case Intrinsic::nvvm_tex_unified_1d_array_level_v4f32_f32:
4829 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4f32_f32:
4830 case Intrinsic::nvvm_tex_unified_2d_v4f32_s32:
4831 case Intrinsic::nvvm_tex_unified_2d_v4f32_f32:
4832 case Intrinsic::nvvm_tex_unified_2d_level_v4f32_f32:
4833 case Intrinsic::nvvm_tex_unified_2d_grad_v4f32_f32:
4834 case Intrinsic::nvvm_tex_unified_2d_array_v4f32_s32:
4835 case Intrinsic::nvvm_tex_unified_2d_array_v4f32_f32:
4836 case Intrinsic::nvvm_tex_unified_2d_array_level_v4f32_f32:
4837 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4f32_f32:
4838 case Intrinsic::nvvm_tex_unified_3d_v4f32_s32:
4839 case Intrinsic::nvvm_tex_unified_3d_v4f32_f32:
4840 case Intrinsic::nvvm_tex_unified_3d_level_v4f32_f32:
4841 case Intrinsic::nvvm_tex_unified_3d_grad_v4f32_f32:
4842 case Intrinsic::nvvm_tex_unified_cube_v4f32_f32:
4843 case Intrinsic::nvvm_tex_unified_cube_level_v4f32_f32:
4844 case Intrinsic::nvvm_tex_unified_cube_array_v4f32_f32:
4845 case Intrinsic::nvvm_tex_unified_cube_array_level_v4f32_f32:
4846 case Intrinsic::nvvm_tex_unified_cube_grad_v4f32_f32:
4847 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4f32_f32:
4848 case Intrinsic::nvvm_tld4_unified_r_2d_v4f32_f32:
4849 case Intrinsic::nvvm_tld4_unified_g_2d_v4f32_f32:
4850 case Intrinsic::nvvm_tld4_unified_b_2d_v4f32_f32:
4851 case Intrinsic::nvvm_tld4_unified_a_2d_v4f32_f32:
4852 Info.opc = ISD::INTRINSIC_W_CHAIN;
4853 Info.memVT = MVT::v4f32;
4854 Info.ptrVal = nullptr;
4855 Info.offset = 0;
4856 Info.flags = MachineMemOperand::MOLoad;
4857 Info.align = Align(16);
4858 Infos.push_back(Elt: Info);
4859 return;
4860
4861 case Intrinsic::nvvm_tex_1d_v4s32_s32:
4862 case Intrinsic::nvvm_tex_1d_v4s32_f32:
4863 case Intrinsic::nvvm_tex_1d_level_v4s32_f32:
4864 case Intrinsic::nvvm_tex_1d_grad_v4s32_f32:
4865 case Intrinsic::nvvm_tex_1d_array_v4s32_s32:
4866 case Intrinsic::nvvm_tex_1d_array_v4s32_f32:
4867 case Intrinsic::nvvm_tex_1d_array_level_v4s32_f32:
4868 case Intrinsic::nvvm_tex_1d_array_grad_v4s32_f32:
4869 case Intrinsic::nvvm_tex_2d_v4s32_s32:
4870 case Intrinsic::nvvm_tex_2d_v4s32_f32:
4871 case Intrinsic::nvvm_tex_2d_level_v4s32_f32:
4872 case Intrinsic::nvvm_tex_2d_grad_v4s32_f32:
4873 case Intrinsic::nvvm_tex_2d_array_v4s32_s32:
4874 case Intrinsic::nvvm_tex_2d_array_v4s32_f32:
4875 case Intrinsic::nvvm_tex_2d_array_level_v4s32_f32:
4876 case Intrinsic::nvvm_tex_2d_array_grad_v4s32_f32:
4877 case Intrinsic::nvvm_tex_3d_v4s32_s32:
4878 case Intrinsic::nvvm_tex_3d_v4s32_f32:
4879 case Intrinsic::nvvm_tex_3d_level_v4s32_f32:
4880 case Intrinsic::nvvm_tex_3d_grad_v4s32_f32:
4881 case Intrinsic::nvvm_tex_cube_v4s32_f32:
4882 case Intrinsic::nvvm_tex_cube_level_v4s32_f32:
4883 case Intrinsic::nvvm_tex_cube_array_v4s32_f32:
4884 case Intrinsic::nvvm_tex_cube_array_level_v4s32_f32:
4885 case Intrinsic::nvvm_tex_cube_v4u32_f32:
4886 case Intrinsic::nvvm_tex_cube_level_v4u32_f32:
4887 case Intrinsic::nvvm_tex_cube_array_v4u32_f32:
4888 case Intrinsic::nvvm_tex_cube_array_level_v4u32_f32:
4889 case Intrinsic::nvvm_tex_1d_v4u32_s32:
4890 case Intrinsic::nvvm_tex_1d_v4u32_f32:
4891 case Intrinsic::nvvm_tex_1d_level_v4u32_f32:
4892 case Intrinsic::nvvm_tex_1d_grad_v4u32_f32:
4893 case Intrinsic::nvvm_tex_1d_array_v4u32_s32:
4894 case Intrinsic::nvvm_tex_1d_array_v4u32_f32:
4895 case Intrinsic::nvvm_tex_1d_array_level_v4u32_f32:
4896 case Intrinsic::nvvm_tex_1d_array_grad_v4u32_f32:
4897 case Intrinsic::nvvm_tex_2d_v4u32_s32:
4898 case Intrinsic::nvvm_tex_2d_v4u32_f32:
4899 case Intrinsic::nvvm_tex_2d_level_v4u32_f32:
4900 case Intrinsic::nvvm_tex_2d_grad_v4u32_f32:
4901 case Intrinsic::nvvm_tex_2d_array_v4u32_s32:
4902 case Intrinsic::nvvm_tex_2d_array_v4u32_f32:
4903 case Intrinsic::nvvm_tex_2d_array_level_v4u32_f32:
4904 case Intrinsic::nvvm_tex_2d_array_grad_v4u32_f32:
4905 case Intrinsic::nvvm_tex_3d_v4u32_s32:
4906 case Intrinsic::nvvm_tex_3d_v4u32_f32:
4907 case Intrinsic::nvvm_tex_3d_level_v4u32_f32:
4908 case Intrinsic::nvvm_tex_3d_grad_v4u32_f32:
4909 case Intrinsic::nvvm_tld4_r_2d_v4s32_f32:
4910 case Intrinsic::nvvm_tld4_g_2d_v4s32_f32:
4911 case Intrinsic::nvvm_tld4_b_2d_v4s32_f32:
4912 case Intrinsic::nvvm_tld4_a_2d_v4s32_f32:
4913 case Intrinsic::nvvm_tld4_r_2d_v4u32_f32:
4914 case Intrinsic::nvvm_tld4_g_2d_v4u32_f32:
4915 case Intrinsic::nvvm_tld4_b_2d_v4u32_f32:
4916 case Intrinsic::nvvm_tld4_a_2d_v4u32_f32:
4917 case Intrinsic::nvvm_tex_unified_1d_v4s32_s32:
4918 case Intrinsic::nvvm_tex_unified_1d_v4s32_f32:
4919 case Intrinsic::nvvm_tex_unified_1d_level_v4s32_f32:
4920 case Intrinsic::nvvm_tex_unified_1d_grad_v4s32_f32:
4921 case Intrinsic::nvvm_tex_unified_1d_array_v4s32_s32:
4922 case Intrinsic::nvvm_tex_unified_1d_array_v4s32_f32:
4923 case Intrinsic::nvvm_tex_unified_1d_array_level_v4s32_f32:
4924 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4s32_f32:
4925 case Intrinsic::nvvm_tex_unified_2d_v4s32_s32:
4926 case Intrinsic::nvvm_tex_unified_2d_v4s32_f32:
4927 case Intrinsic::nvvm_tex_unified_2d_level_v4s32_f32:
4928 case Intrinsic::nvvm_tex_unified_2d_grad_v4s32_f32:
4929 case Intrinsic::nvvm_tex_unified_2d_array_v4s32_s32:
4930 case Intrinsic::nvvm_tex_unified_2d_array_v4s32_f32:
4931 case Intrinsic::nvvm_tex_unified_2d_array_level_v4s32_f32:
4932 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4s32_f32:
4933 case Intrinsic::nvvm_tex_unified_3d_v4s32_s32:
4934 case Intrinsic::nvvm_tex_unified_3d_v4s32_f32:
4935 case Intrinsic::nvvm_tex_unified_3d_level_v4s32_f32:
4936 case Intrinsic::nvvm_tex_unified_3d_grad_v4s32_f32:
4937 case Intrinsic::nvvm_tex_unified_1d_v4u32_s32:
4938 case Intrinsic::nvvm_tex_unified_1d_v4u32_f32:
4939 case Intrinsic::nvvm_tex_unified_1d_level_v4u32_f32:
4940 case Intrinsic::nvvm_tex_unified_1d_grad_v4u32_f32:
4941 case Intrinsic::nvvm_tex_unified_1d_array_v4u32_s32:
4942 case Intrinsic::nvvm_tex_unified_1d_array_v4u32_f32:
4943 case Intrinsic::nvvm_tex_unified_1d_array_level_v4u32_f32:
4944 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4u32_f32:
4945 case Intrinsic::nvvm_tex_unified_2d_v4u32_s32:
4946 case Intrinsic::nvvm_tex_unified_2d_v4u32_f32:
4947 case Intrinsic::nvvm_tex_unified_2d_level_v4u32_f32:
4948 case Intrinsic::nvvm_tex_unified_2d_grad_v4u32_f32:
4949 case Intrinsic::nvvm_tex_unified_2d_array_v4u32_s32:
4950 case Intrinsic::nvvm_tex_unified_2d_array_v4u32_f32:
4951 case Intrinsic::nvvm_tex_unified_2d_array_level_v4u32_f32:
4952 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4u32_f32:
4953 case Intrinsic::nvvm_tex_unified_3d_v4u32_s32:
4954 case Intrinsic::nvvm_tex_unified_3d_v4u32_f32:
4955 case Intrinsic::nvvm_tex_unified_3d_level_v4u32_f32:
4956 case Intrinsic::nvvm_tex_unified_3d_grad_v4u32_f32:
4957 case Intrinsic::nvvm_tex_unified_cube_v4s32_f32:
4958 case Intrinsic::nvvm_tex_unified_cube_level_v4s32_f32:
4959 case Intrinsic::nvvm_tex_unified_cube_array_v4s32_f32:
4960 case Intrinsic::nvvm_tex_unified_cube_array_level_v4s32_f32:
4961 case Intrinsic::nvvm_tex_unified_cube_v4u32_f32:
4962 case Intrinsic::nvvm_tex_unified_cube_level_v4u32_f32:
4963 case Intrinsic::nvvm_tex_unified_cube_array_v4u32_f32:
4964 case Intrinsic::nvvm_tex_unified_cube_array_level_v4u32_f32:
4965 case Intrinsic::nvvm_tex_unified_cube_grad_v4s32_f32:
4966 case Intrinsic::nvvm_tex_unified_cube_grad_v4u32_f32:
4967 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4s32_f32:
4968 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4u32_f32:
4969 case Intrinsic::nvvm_tld4_unified_r_2d_v4s32_f32:
4970 case Intrinsic::nvvm_tld4_unified_g_2d_v4s32_f32:
4971 case Intrinsic::nvvm_tld4_unified_b_2d_v4s32_f32:
4972 case Intrinsic::nvvm_tld4_unified_a_2d_v4s32_f32:
4973 case Intrinsic::nvvm_tld4_unified_r_2d_v4u32_f32:
4974 case Intrinsic::nvvm_tld4_unified_g_2d_v4u32_f32:
4975 case Intrinsic::nvvm_tld4_unified_b_2d_v4u32_f32:
4976 case Intrinsic::nvvm_tld4_unified_a_2d_v4u32_f32:
4977 Info.opc = ISD::INTRINSIC_W_CHAIN;
4978 Info.memVT = MVT::v4i32;
4979 Info.ptrVal = nullptr;
4980 Info.offset = 0;
4981 Info.flags = MachineMemOperand::MOLoad;
4982 Info.align = Align(16);
4983 Infos.push_back(Elt: Info);
4984 return;
4985
4986 case Intrinsic::nvvm_suld_1d_i8_clamp:
4987 case Intrinsic::nvvm_suld_1d_v2i8_clamp:
4988 case Intrinsic::nvvm_suld_1d_v4i8_clamp:
4989 case Intrinsic::nvvm_suld_1d_array_i8_clamp:
4990 case Intrinsic::nvvm_suld_1d_array_v2i8_clamp:
4991 case Intrinsic::nvvm_suld_1d_array_v4i8_clamp:
4992 case Intrinsic::nvvm_suld_2d_i8_clamp:
4993 case Intrinsic::nvvm_suld_2d_v2i8_clamp:
4994 case Intrinsic::nvvm_suld_2d_v4i8_clamp:
4995 case Intrinsic::nvvm_suld_2d_array_i8_clamp:
4996 case Intrinsic::nvvm_suld_2d_array_v2i8_clamp:
4997 case Intrinsic::nvvm_suld_2d_array_v4i8_clamp:
4998 case Intrinsic::nvvm_suld_3d_i8_clamp:
4999 case Intrinsic::nvvm_suld_3d_v2i8_clamp:
5000 case Intrinsic::nvvm_suld_3d_v4i8_clamp:
5001 case Intrinsic::nvvm_suld_1d_i8_trap:
5002 case Intrinsic::nvvm_suld_1d_v2i8_trap:
5003 case Intrinsic::nvvm_suld_1d_v4i8_trap:
5004 case Intrinsic::nvvm_suld_1d_array_i8_trap:
5005 case Intrinsic::nvvm_suld_1d_array_v2i8_trap:
5006 case Intrinsic::nvvm_suld_1d_array_v4i8_trap:
5007 case Intrinsic::nvvm_suld_2d_i8_trap:
5008 case Intrinsic::nvvm_suld_2d_v2i8_trap:
5009 case Intrinsic::nvvm_suld_2d_v4i8_trap:
5010 case Intrinsic::nvvm_suld_2d_array_i8_trap:
5011 case Intrinsic::nvvm_suld_2d_array_v2i8_trap:
5012 case Intrinsic::nvvm_suld_2d_array_v4i8_trap:
5013 case Intrinsic::nvvm_suld_3d_i8_trap:
5014 case Intrinsic::nvvm_suld_3d_v2i8_trap:
5015 case Intrinsic::nvvm_suld_3d_v4i8_trap:
5016 case Intrinsic::nvvm_suld_1d_i8_zero:
5017 case Intrinsic::nvvm_suld_1d_v2i8_zero:
5018 case Intrinsic::nvvm_suld_1d_v4i8_zero:
5019 case Intrinsic::nvvm_suld_1d_array_i8_zero:
5020 case Intrinsic::nvvm_suld_1d_array_v2i8_zero:
5021 case Intrinsic::nvvm_suld_1d_array_v4i8_zero:
5022 case Intrinsic::nvvm_suld_2d_i8_zero:
5023 case Intrinsic::nvvm_suld_2d_v2i8_zero:
5024 case Intrinsic::nvvm_suld_2d_v4i8_zero:
5025 case Intrinsic::nvvm_suld_2d_array_i8_zero:
5026 case Intrinsic::nvvm_suld_2d_array_v2i8_zero:
5027 case Intrinsic::nvvm_suld_2d_array_v4i8_zero:
5028 case Intrinsic::nvvm_suld_3d_i8_zero:
5029 case Intrinsic::nvvm_suld_3d_v2i8_zero:
5030 case Intrinsic::nvvm_suld_3d_v4i8_zero:
5031 Info.opc = ISD::INTRINSIC_W_CHAIN;
5032 Info.memVT = MVT::i8;
5033 Info.ptrVal = nullptr;
5034 Info.offset = 0;
5035 Info.flags = MachineMemOperand::MOLoad;
5036 Info.align = Align(16);
5037 Infos.push_back(Elt: Info);
5038 return;
5039
5040 case Intrinsic::nvvm_suld_1d_i16_clamp:
5041 case Intrinsic::nvvm_suld_1d_v2i16_clamp:
5042 case Intrinsic::nvvm_suld_1d_v4i16_clamp:
5043 case Intrinsic::nvvm_suld_1d_array_i16_clamp:
5044 case Intrinsic::nvvm_suld_1d_array_v2i16_clamp:
5045 case Intrinsic::nvvm_suld_1d_array_v4i16_clamp:
5046 case Intrinsic::nvvm_suld_2d_i16_clamp:
5047 case Intrinsic::nvvm_suld_2d_v2i16_clamp:
5048 case Intrinsic::nvvm_suld_2d_v4i16_clamp:
5049 case Intrinsic::nvvm_suld_2d_array_i16_clamp:
5050 case Intrinsic::nvvm_suld_2d_array_v2i16_clamp:
5051 case Intrinsic::nvvm_suld_2d_array_v4i16_clamp:
5052 case Intrinsic::nvvm_suld_3d_i16_clamp:
5053 case Intrinsic::nvvm_suld_3d_v2i16_clamp:
5054 case Intrinsic::nvvm_suld_3d_v4i16_clamp:
5055 case Intrinsic::nvvm_suld_1d_i16_trap:
5056 case Intrinsic::nvvm_suld_1d_v2i16_trap:
5057 case Intrinsic::nvvm_suld_1d_v4i16_trap:
5058 case Intrinsic::nvvm_suld_1d_array_i16_trap:
5059 case Intrinsic::nvvm_suld_1d_array_v2i16_trap:
5060 case Intrinsic::nvvm_suld_1d_array_v4i16_trap:
5061 case Intrinsic::nvvm_suld_2d_i16_trap:
5062 case Intrinsic::nvvm_suld_2d_v2i16_trap:
5063 case Intrinsic::nvvm_suld_2d_v4i16_trap:
5064 case Intrinsic::nvvm_suld_2d_array_i16_trap:
5065 case Intrinsic::nvvm_suld_2d_array_v2i16_trap:
5066 case Intrinsic::nvvm_suld_2d_array_v4i16_trap:
5067 case Intrinsic::nvvm_suld_3d_i16_trap:
5068 case Intrinsic::nvvm_suld_3d_v2i16_trap:
5069 case Intrinsic::nvvm_suld_3d_v4i16_trap:
5070 case Intrinsic::nvvm_suld_1d_i16_zero:
5071 case Intrinsic::nvvm_suld_1d_v2i16_zero:
5072 case Intrinsic::nvvm_suld_1d_v4i16_zero:
5073 case Intrinsic::nvvm_suld_1d_array_i16_zero:
5074 case Intrinsic::nvvm_suld_1d_array_v2i16_zero:
5075 case Intrinsic::nvvm_suld_1d_array_v4i16_zero:
5076 case Intrinsic::nvvm_suld_2d_i16_zero:
5077 case Intrinsic::nvvm_suld_2d_v2i16_zero:
5078 case Intrinsic::nvvm_suld_2d_v4i16_zero:
5079 case Intrinsic::nvvm_suld_2d_array_i16_zero:
5080 case Intrinsic::nvvm_suld_2d_array_v2i16_zero:
5081 case Intrinsic::nvvm_suld_2d_array_v4i16_zero:
5082 case Intrinsic::nvvm_suld_3d_i16_zero:
5083 case Intrinsic::nvvm_suld_3d_v2i16_zero:
5084 case Intrinsic::nvvm_suld_3d_v4i16_zero:
5085 Info.opc = ISD::INTRINSIC_W_CHAIN;
5086 Info.memVT = MVT::i16;
5087 Info.ptrVal = nullptr;
5088 Info.offset = 0;
5089 Info.flags = MachineMemOperand::MOLoad;
5090 Info.align = Align(16);
5091 Infos.push_back(Elt: Info);
5092 return;
5093
5094 case Intrinsic::nvvm_suld_1d_i32_clamp:
5095 case Intrinsic::nvvm_suld_1d_v2i32_clamp:
5096 case Intrinsic::nvvm_suld_1d_v4i32_clamp:
5097 case Intrinsic::nvvm_suld_1d_array_i32_clamp:
5098 case Intrinsic::nvvm_suld_1d_array_v2i32_clamp:
5099 case Intrinsic::nvvm_suld_1d_array_v4i32_clamp:
5100 case Intrinsic::nvvm_suld_2d_i32_clamp:
5101 case Intrinsic::nvvm_suld_2d_v2i32_clamp:
5102 case Intrinsic::nvvm_suld_2d_v4i32_clamp:
5103 case Intrinsic::nvvm_suld_2d_array_i32_clamp:
5104 case Intrinsic::nvvm_suld_2d_array_v2i32_clamp:
5105 case Intrinsic::nvvm_suld_2d_array_v4i32_clamp:
5106 case Intrinsic::nvvm_suld_3d_i32_clamp:
5107 case Intrinsic::nvvm_suld_3d_v2i32_clamp:
5108 case Intrinsic::nvvm_suld_3d_v4i32_clamp:
5109 case Intrinsic::nvvm_suld_1d_i32_trap:
5110 case Intrinsic::nvvm_suld_1d_v2i32_trap:
5111 case Intrinsic::nvvm_suld_1d_v4i32_trap:
5112 case Intrinsic::nvvm_suld_1d_array_i32_trap:
5113 case Intrinsic::nvvm_suld_1d_array_v2i32_trap:
5114 case Intrinsic::nvvm_suld_1d_array_v4i32_trap:
5115 case Intrinsic::nvvm_suld_2d_i32_trap:
5116 case Intrinsic::nvvm_suld_2d_v2i32_trap:
5117 case Intrinsic::nvvm_suld_2d_v4i32_trap:
5118 case Intrinsic::nvvm_suld_2d_array_i32_trap:
5119 case Intrinsic::nvvm_suld_2d_array_v2i32_trap:
5120 case Intrinsic::nvvm_suld_2d_array_v4i32_trap:
5121 case Intrinsic::nvvm_suld_3d_i32_trap:
5122 case Intrinsic::nvvm_suld_3d_v2i32_trap:
5123 case Intrinsic::nvvm_suld_3d_v4i32_trap:
5124 case Intrinsic::nvvm_suld_1d_i32_zero:
5125 case Intrinsic::nvvm_suld_1d_v2i32_zero:
5126 case Intrinsic::nvvm_suld_1d_v4i32_zero:
5127 case Intrinsic::nvvm_suld_1d_array_i32_zero:
5128 case Intrinsic::nvvm_suld_1d_array_v2i32_zero:
5129 case Intrinsic::nvvm_suld_1d_array_v4i32_zero:
5130 case Intrinsic::nvvm_suld_2d_i32_zero:
5131 case Intrinsic::nvvm_suld_2d_v2i32_zero:
5132 case Intrinsic::nvvm_suld_2d_v4i32_zero:
5133 case Intrinsic::nvvm_suld_2d_array_i32_zero:
5134 case Intrinsic::nvvm_suld_2d_array_v2i32_zero:
5135 case Intrinsic::nvvm_suld_2d_array_v4i32_zero:
5136 case Intrinsic::nvvm_suld_3d_i32_zero:
5137 case Intrinsic::nvvm_suld_3d_v2i32_zero:
5138 case Intrinsic::nvvm_suld_3d_v4i32_zero:
5139 Info.opc = ISD::INTRINSIC_W_CHAIN;
5140 Info.memVT = MVT::i32;
5141 Info.ptrVal = nullptr;
5142 Info.offset = 0;
5143 Info.flags = MachineMemOperand::MOLoad;
5144 Info.align = Align(16);
5145 Infos.push_back(Elt: Info);
5146 return;
5147
5148 case Intrinsic::nvvm_suld_1d_i64_clamp:
5149 case Intrinsic::nvvm_suld_1d_v2i64_clamp:
5150 case Intrinsic::nvvm_suld_1d_array_i64_clamp:
5151 case Intrinsic::nvvm_suld_1d_array_v2i64_clamp:
5152 case Intrinsic::nvvm_suld_2d_i64_clamp:
5153 case Intrinsic::nvvm_suld_2d_v2i64_clamp:
5154 case Intrinsic::nvvm_suld_2d_array_i64_clamp:
5155 case Intrinsic::nvvm_suld_2d_array_v2i64_clamp:
5156 case Intrinsic::nvvm_suld_3d_i64_clamp:
5157 case Intrinsic::nvvm_suld_3d_v2i64_clamp:
5158 case Intrinsic::nvvm_suld_1d_i64_trap:
5159 case Intrinsic::nvvm_suld_1d_v2i64_trap:
5160 case Intrinsic::nvvm_suld_1d_array_i64_trap:
5161 case Intrinsic::nvvm_suld_1d_array_v2i64_trap:
5162 case Intrinsic::nvvm_suld_2d_i64_trap:
5163 case Intrinsic::nvvm_suld_2d_v2i64_trap:
5164 case Intrinsic::nvvm_suld_2d_array_i64_trap:
5165 case Intrinsic::nvvm_suld_2d_array_v2i64_trap:
5166 case Intrinsic::nvvm_suld_3d_i64_trap:
5167 case Intrinsic::nvvm_suld_3d_v2i64_trap:
5168 case Intrinsic::nvvm_suld_1d_i64_zero:
5169 case Intrinsic::nvvm_suld_1d_v2i64_zero:
5170 case Intrinsic::nvvm_suld_1d_array_i64_zero:
5171 case Intrinsic::nvvm_suld_1d_array_v2i64_zero:
5172 case Intrinsic::nvvm_suld_2d_i64_zero:
5173 case Intrinsic::nvvm_suld_2d_v2i64_zero:
5174 case Intrinsic::nvvm_suld_2d_array_i64_zero:
5175 case Intrinsic::nvvm_suld_2d_array_v2i64_zero:
5176 case Intrinsic::nvvm_suld_3d_i64_zero:
5177 case Intrinsic::nvvm_suld_3d_v2i64_zero:
5178 Info.opc = ISD::INTRINSIC_W_CHAIN;
5179 Info.memVT = MVT::i64;
5180 Info.ptrVal = nullptr;
5181 Info.offset = 0;
5182 Info.flags = MachineMemOperand::MOLoad;
5183 Info.align = Align(16);
5184 Infos.push_back(Elt: Info);
5185 return;
5186
5187 case Intrinsic::nvvm_tcgen05_ld_16x64b_x1:
5188 case Intrinsic::nvvm_tcgen05_ld_32x32b_x1:
5189 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x1: {
5190 Info.opc = ISD::INTRINSIC_W_CHAIN;
5191 Info.memVT = MVT::v1i32;
5192 Info.ptrVal = I.getArgOperand(i: 0);
5193 Info.offset = 0;
5194 Info.flags = MachineMemOperand::MOLoad;
5195 Info.align.reset();
5196 Infos.push_back(Elt: Info);
5197 return;
5198 }
5199
5200 case Intrinsic::nvvm_tcgen05_ld_16x64b_x2:
5201 case Intrinsic::nvvm_tcgen05_ld_16x128b_x1:
5202 case Intrinsic::nvvm_tcgen05_ld_32x32b_x2:
5203 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x2:
5204 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x2_i32:
5205 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x2_i32: {
5206 Info.opc = ISD::INTRINSIC_W_CHAIN;
5207 Info.memVT = MVT::v2i32;
5208 Info.ptrVal = I.getArgOperand(i: 0);
5209 Info.offset = 0;
5210 Info.flags = MachineMemOperand::MOLoad;
5211 Info.align.reset();
5212 Infos.push_back(Elt: Info);
5213 return;
5214 }
5215
5216 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x2_f32:
5217 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x2_f32: {
5218 Info.opc = ISD::INTRINSIC_W_CHAIN;
5219 Info.memVT = MVT::v2f32;
5220 Info.ptrVal = I.getArgOperand(i: 0);
5221 Info.offset = 0;
5222 Info.flags = MachineMemOperand::MOLoad;
5223 Info.align.reset();
5224 Infos.push_back(Elt: Info);
5225 return;
5226 }
5227
5228 case Intrinsic::nvvm_tcgen05_ld_16x64b_x4:
5229 case Intrinsic::nvvm_tcgen05_ld_16x128b_x2:
5230 case Intrinsic::nvvm_tcgen05_ld_32x32b_x4:
5231 case Intrinsic::nvvm_tcgen05_ld_16x256b_x1:
5232 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x4:
5233 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x4_i32:
5234 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x4_i32: {
5235 Info.opc = ISD::INTRINSIC_W_CHAIN;
5236 Info.memVT = MVT::v4i32;
5237 Info.ptrVal = I.getArgOperand(i: 0);
5238 Info.offset = 0;
5239 Info.flags = MachineMemOperand::MOLoad;
5240 Info.align.reset();
5241 Infos.push_back(Elt: Info);
5242 return;
5243 }
5244
5245 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x4_f32:
5246 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x4_f32: {
5247 Info.opc = ISD::INTRINSIC_W_CHAIN;
5248 Info.memVT = MVT::v4f32;
5249 Info.ptrVal = I.getArgOperand(i: 0);
5250 Info.offset = 0;
5251 Info.flags = MachineMemOperand::MOLoad;
5252 Info.align.reset();
5253 Infos.push_back(Elt: Info);
5254 return;
5255 }
5256
5257 case Intrinsic::nvvm_tcgen05_ld_16x64b_x8:
5258 case Intrinsic::nvvm_tcgen05_ld_16x128b_x4:
5259 case Intrinsic::nvvm_tcgen05_ld_16x256b_x2:
5260 case Intrinsic::nvvm_tcgen05_ld_32x32b_x8:
5261 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x8:
5262 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x8_i32:
5263 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x8_i32: {
5264 Info.opc = ISD::INTRINSIC_W_CHAIN;
5265 Info.memVT = MVT::v8i32;
5266 Info.ptrVal = I.getArgOperand(i: 0);
5267 Info.offset = 0;
5268 Info.flags = MachineMemOperand::MOLoad;
5269 Info.align.reset();
5270 Infos.push_back(Elt: Info);
5271 return;
5272 }
5273
5274 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x8_f32:
5275 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x8_f32: {
5276 Info.opc = ISD::INTRINSIC_W_CHAIN;
5277 Info.memVT = MVT::v8f32;
5278 Info.ptrVal = I.getArgOperand(i: 0);
5279 Info.offset = 0;
5280 Info.flags = MachineMemOperand::MOLoad;
5281 Info.align.reset();
5282 Infos.push_back(Elt: Info);
5283 return;
5284 }
5285
5286 case Intrinsic::nvvm_tcgen05_ld_16x64b_x16:
5287 case Intrinsic::nvvm_tcgen05_ld_16x128b_x8:
5288 case Intrinsic::nvvm_tcgen05_ld_16x256b_x4:
5289 case Intrinsic::nvvm_tcgen05_ld_32x32b_x16:
5290 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x16:
5291 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x16_i32:
5292 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x16_i32: {
5293 Info.opc = ISD::INTRINSIC_W_CHAIN;
5294 Info.memVT = MVT::v16i32;
5295 Info.ptrVal = I.getArgOperand(i: 0);
5296 Info.offset = 0;
5297 Info.flags = MachineMemOperand::MOLoad;
5298 Info.align.reset();
5299 Infos.push_back(Elt: Info);
5300 return;
5301 }
5302
5303 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x16_f32:
5304 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x16_f32: {
5305 Info.opc = ISD::INTRINSIC_W_CHAIN;
5306 Info.memVT = MVT::v16f32;
5307 Info.ptrVal = I.getArgOperand(i: 0);
5308 Info.offset = 0;
5309 Info.flags = MachineMemOperand::MOLoad;
5310 Info.align.reset();
5311 Infos.push_back(Elt: Info);
5312 return;
5313 }
5314
5315 case Intrinsic::nvvm_tcgen05_ld_16x64b_x32:
5316 case Intrinsic::nvvm_tcgen05_ld_16x128b_x16:
5317 case Intrinsic::nvvm_tcgen05_ld_16x256b_x8:
5318 case Intrinsic::nvvm_tcgen05_ld_32x32b_x32:
5319 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x32:
5320 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x32_i32:
5321 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x32_i32: {
5322 Info.opc = ISD::INTRINSIC_W_CHAIN;
5323 Info.memVT = MVT::v32i32;
5324 Info.ptrVal = I.getArgOperand(i: 0);
5325 Info.offset = 0;
5326 Info.flags = MachineMemOperand::MOLoad;
5327 Info.align.reset();
5328 Infos.push_back(Elt: Info);
5329 return;
5330 }
5331
5332 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x32_f32:
5333 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x32_f32: {
5334 Info.opc = ISD::INTRINSIC_W_CHAIN;
5335 Info.memVT = MVT::v32f32;
5336 Info.ptrVal = I.getArgOperand(i: 0);
5337 Info.offset = 0;
5338 Info.flags = MachineMemOperand::MOLoad;
5339 Info.align.reset();
5340 Infos.push_back(Elt: Info);
5341 return;
5342 }
5343
5344 case Intrinsic::nvvm_tcgen05_ld_16x64b_x64:
5345 case Intrinsic::nvvm_tcgen05_ld_16x128b_x32:
5346 case Intrinsic::nvvm_tcgen05_ld_16x256b_x16:
5347 case Intrinsic::nvvm_tcgen05_ld_32x32b_x64:
5348 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x64:
5349 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x64_i32:
5350 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x64_i32: {
5351 Info.opc = ISD::INTRINSIC_W_CHAIN;
5352 Info.memVT = MVT::v64i32;
5353 Info.ptrVal = I.getArgOperand(i: 0);
5354 Info.offset = 0;
5355 Info.flags = MachineMemOperand::MOLoad;
5356 Info.align.reset();
5357 Infos.push_back(Elt: Info);
5358 return;
5359 }
5360
5361 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x64_f32:
5362 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x64_f32: {
5363 Info.opc = ISD::INTRINSIC_W_CHAIN;
5364 Info.memVT = MVT::v64f32;
5365 Info.ptrVal = I.getArgOperand(i: 0);
5366 Info.offset = 0;
5367 Info.flags = MachineMemOperand::MOLoad;
5368 Info.align.reset();
5369 Infos.push_back(Elt: Info);
5370 return;
5371 }
5372
5373 case Intrinsic::nvvm_tcgen05_ld_16x64b_x128:
5374 case Intrinsic::nvvm_tcgen05_ld_16x128b_x64:
5375 case Intrinsic::nvvm_tcgen05_ld_16x256b_x32:
5376 case Intrinsic::nvvm_tcgen05_ld_32x32b_x128:
5377 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x128:
5378 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x128_i32:
5379 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x128_i32: {
5380 Info.opc = ISD::INTRINSIC_W_CHAIN;
5381 Info.memVT = MVT::v128i32;
5382 Info.ptrVal = I.getArgOperand(i: 0);
5383 Info.offset = 0;
5384 Info.flags = MachineMemOperand::MOLoad;
5385 Info.align.reset();
5386 Infos.push_back(Elt: Info);
5387 return;
5388 }
5389
5390 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x128_f32:
5391 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x128_f32: {
5392 Info.opc = ISD::INTRINSIC_W_CHAIN;
5393 Info.memVT = MVT::v128f32;
5394 Info.ptrVal = I.getArgOperand(i: 0);
5395 Info.offset = 0;
5396 Info.flags = MachineMemOperand::MOLoad;
5397 Info.align.reset();
5398 Infos.push_back(Elt: Info);
5399 return;
5400 }
5401
5402 case Intrinsic::nvvm_tcgen05_st_16x64b_x1:
5403 case Intrinsic::nvvm_tcgen05_st_32x32b_x1:
5404 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x1: {
5405 Info.opc = ISD::INTRINSIC_VOID;
5406 Info.memVT = MVT::i32;
5407 Info.ptrVal = I.getArgOperand(i: 0);
5408 Info.offset = 0;
5409 Info.flags = MachineMemOperand::MOStore;
5410 Info.align.reset();
5411 Infos.push_back(Elt: Info);
5412 return;
5413 }
5414
5415 case Intrinsic::nvvm_tcgen05_st_16x64b_x2:
5416 case Intrinsic::nvvm_tcgen05_st_16x128b_x1:
5417 case Intrinsic::nvvm_tcgen05_st_32x32b_x2:
5418 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x2: {
5419 Info.opc = ISD::INTRINSIC_VOID;
5420 Info.memVT = MVT::v2i32;
5421 Info.ptrVal = I.getArgOperand(i: 0);
5422 Info.offset = 0;
5423 Info.flags = MachineMemOperand::MOStore;
5424 Info.align.reset();
5425 Infos.push_back(Elt: Info);
5426 return;
5427 }
5428
5429 case Intrinsic::nvvm_tcgen05_st_16x64b_x4:
5430 case Intrinsic::nvvm_tcgen05_st_16x128b_x2:
5431 case Intrinsic::nvvm_tcgen05_st_16x256b_x1:
5432 case Intrinsic::nvvm_tcgen05_st_32x32b_x4:
5433 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x4: {
5434 Info.opc = ISD::INTRINSIC_VOID;
5435 Info.memVT = MVT::v4i32;
5436 Info.ptrVal = I.getArgOperand(i: 0);
5437 Info.offset = 0;
5438 Info.flags = MachineMemOperand::MOStore;
5439 Info.align.reset();
5440 Infos.push_back(Elt: Info);
5441 return;
5442 }
5443
5444 case Intrinsic::nvvm_tcgen05_st_16x64b_x8:
5445 case Intrinsic::nvvm_tcgen05_st_16x128b_x4:
5446 case Intrinsic::nvvm_tcgen05_st_16x256b_x2:
5447 case Intrinsic::nvvm_tcgen05_st_32x32b_x8:
5448 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x8: {
5449 Info.opc = ISD::INTRINSIC_VOID;
5450 Info.memVT = MVT::v8i32;
5451 Info.ptrVal = I.getArgOperand(i: 0);
5452 Info.offset = 0;
5453 Info.flags = MachineMemOperand::MOStore;
5454 Info.align.reset();
5455 Infos.push_back(Elt: Info);
5456 return;
5457 }
5458
5459 case Intrinsic::nvvm_tcgen05_st_16x64b_x16:
5460 case Intrinsic::nvvm_tcgen05_st_16x128b_x8:
5461 case Intrinsic::nvvm_tcgen05_st_16x256b_x4:
5462 case Intrinsic::nvvm_tcgen05_st_32x32b_x16:
5463 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x16: {
5464 Info.opc = ISD::INTRINSIC_VOID;
5465 Info.memVT = MVT::v16i32;
5466 Info.ptrVal = I.getArgOperand(i: 0);
5467 Info.offset = 0;
5468 Info.flags = MachineMemOperand::MOStore;
5469 Info.align.reset();
5470 Infos.push_back(Elt: Info);
5471 return;
5472 }
5473
5474 case Intrinsic::nvvm_tcgen05_st_16x64b_x32:
5475 case Intrinsic::nvvm_tcgen05_st_16x128b_x16:
5476 case Intrinsic::nvvm_tcgen05_st_16x256b_x8:
5477 case Intrinsic::nvvm_tcgen05_st_32x32b_x32:
5478 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x32: {
5479 Info.opc = ISD::INTRINSIC_VOID;
5480 Info.memVT = MVT::v32i32;
5481 Info.ptrVal = I.getArgOperand(i: 0);
5482 Info.offset = 0;
5483 Info.flags = MachineMemOperand::MOStore;
5484 Info.align.reset();
5485 Infos.push_back(Elt: Info);
5486 return;
5487 }
5488
5489 case Intrinsic::nvvm_tcgen05_st_16x64b_x64:
5490 case Intrinsic::nvvm_tcgen05_st_16x128b_x32:
5491 case Intrinsic::nvvm_tcgen05_st_16x256b_x16:
5492 case Intrinsic::nvvm_tcgen05_st_32x32b_x64:
5493 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x64: {
5494 Info.opc = ISD::INTRINSIC_VOID;
5495 Info.memVT = MVT::v64i32;
5496 Info.ptrVal = I.getArgOperand(i: 0);
5497 Info.offset = 0;
5498 Info.flags = MachineMemOperand::MOStore;
5499 Info.align.reset();
5500 Infos.push_back(Elt: Info);
5501 return;
5502 }
5503
5504 case Intrinsic::nvvm_tcgen05_st_16x64b_x128:
5505 case Intrinsic::nvvm_tcgen05_st_16x128b_x64:
5506 case Intrinsic::nvvm_tcgen05_st_16x256b_x32:
5507 case Intrinsic::nvvm_tcgen05_st_32x32b_x128:
5508 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x128: {
5509 Info.opc = ISD::INTRINSIC_VOID;
5510 Info.memVT = MVT::v128i32;
5511 Info.ptrVal = I.getArgOperand(i: 0);
5512 Info.offset = 0;
5513 Info.flags = MachineMemOperand::MOStore;
5514 Info.align.reset();
5515 Infos.push_back(Elt: Info);
5516 return;
5517 }
5518 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1:
5519 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1:
5520 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1:
5521 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1:
5522 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1:
5523 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1:
5524 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift:
5525 case Intrinsic::
5526 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift:
5527 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1:
5528 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1:
5529 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift:
5530 case Intrinsic::
5531 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift: {
5532 // We are reading and writing back to TMem
5533 Info.opc = ISD::INTRINSIC_VOID;
5534 Info.memVT = MVT::v4i32;
5535 Info.ptrVal = I.getArgOperand(i: 0);
5536 Info.offset = 0;
5537 Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOStore;
5538 Info.align = Align(16);
5539 Infos.push_back(Elt: Info);
5540 return;
5541 }
5542
5543 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2:
5544 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2:
5545 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2:
5546 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2:
5547 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2:
5548 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2:
5549 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2:
5550 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2:
5551 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift:
5552 case Intrinsic::
5553 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift:
5554 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift:
5555 case Intrinsic::
5556 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift: {
5557 // We are reading and writing back to TMem
5558 Info.opc = ISD::INTRINSIC_VOID;
5559 Info.memVT = MVT::v8i32;
5560 Info.ptrVal = I.getArgOperand(i: 0);
5561 Info.offset = 0;
5562 Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOStore;
5563 Info.align = Align(16);
5564 Infos.push_back(Elt: Info);
5565 return;
5566 }
5567 }
5568}
5569
5570/// getFunctionParamOptimizedAlign - since function arguments are passed via
5571/// .param space, we may want to increase their alignment in a way that
5572/// ensures that we can effectively vectorize their loads & stores. We can
5573/// increase alignment only if the function has internal or has private
5574/// linkage as for other linkage types callers may already rely on default
5575/// alignment. To allow using 128-bit vectorized loads/stores, this function
5576/// ensures that alignment is 16 or greater.
5577Align NVPTXTargetLowering::getFunctionParamOptimizedAlign(
5578 const Function *F, Type *ArgTy, const DataLayout &DL) const {
5579 // Capping the alignment to 128 bytes as that is the maximum alignment
5580 // supported by PTX.
5581 const Align ABITypeAlign = std::min(a: Align(128), b: DL.getABITypeAlign(Ty: ArgTy));
5582
5583 // If a function has linkage different from internal or private, we
5584 // must use default ABI alignment as external users rely on it. Same
5585 // for a function that may be called from a function pointer.
5586 if (!F || !F->hasLocalLinkage() ||
5587 F->hasAddressTaken(/*Users=*/nullptr,
5588 /*IgnoreCallbackUses=*/false,
5589 /*IgnoreAssumeLikeCalls=*/true,
5590 /*IgnoreLLVMUsed=*/IngoreLLVMUsed: true))
5591 return ABITypeAlign;
5592
5593 assert(!isKernelFunction(*F) && "Expect kernels to have non-local linkage");
5594 return std::max(a: Align(16), b: ABITypeAlign);
5595}
5596
5597/// Helper for computing alignment of a device function byval parameter.
5598Align NVPTXTargetLowering::getFunctionByValParamAlign(
5599 const Function *F, Type *ArgTy, Align InitialAlign,
5600 const DataLayout &DL) const {
5601 Align ArgAlign = InitialAlign;
5602 // Try to increase alignment to enhance vectorization options.
5603 if (F)
5604 ArgAlign = std::max(a: ArgAlign, b: getFunctionParamOptimizedAlign(F, ArgTy, DL));
5605
5606 // Old ptx versions have a bug. When PTX code takes address of
5607 // byval parameter with alignment < 4, ptxas generates code to
5608 // spill argument into memory. Alas on sm_50+ ptxas generates
5609 // SASS code that fails with misaligned access. To work around
5610 // the problem, make sure that we align byval parameters by at
5611 // least 4. This bug seems to be fixed at least starting from
5612 // ptxas > 9.0.
5613 // TODO: remove this after verifying the bug is not reproduced
5614 // on non-deprecated ptxas versions.
5615 if (ForceMinByValParamAlign)
5616 ArgAlign = std::max(a: ArgAlign, b: Align(4));
5617
5618 return ArgAlign;
5619}
5620
5621// Helper for getting a function parameter name. Name is composed from
5622// its index and the function name. Negative index corresponds to special
5623// parameter (unsized array) used for passing variable arguments.
5624std::string NVPTXTargetLowering::getParamName(const Function *F,
5625 int Idx) const {
5626 std::string ParamName;
5627 raw_string_ostream ParamStr(ParamName);
5628
5629 ParamStr << getTargetMachine().getSymbol(GV: F)->getName();
5630 if (Idx < 0)
5631 ParamStr << "_vararg";
5632 else
5633 ParamStr << "_param_" << Idx;
5634
5635 return ParamName;
5636}
5637
5638/// isLegalAddressingMode - Return true if the addressing mode represented
5639/// by AM is legal for this target, for a load/store of the specified type.
5640/// Used to guide target specific optimizations, like loop strength reduction
5641/// (LoopStrengthReduce.cpp) and memory optimization for address mode
5642/// (CodeGenPrepare.cpp)
5643bool NVPTXTargetLowering::isLegalAddressingMode(const DataLayout &DL,
5644 const AddrMode &AM, Type *Ty,
5645 unsigned AS, Instruction *I) const {
5646 // AddrMode - This represents an addressing mode of:
5647 // BaseGV + BaseOffs + BaseReg + Scale*ScaleReg
5648 //
5649 // The legal address modes are
5650 // - [avar]
5651 // - [areg]
5652 // - [areg+immoff]
5653 // - [immAddr]
5654
5655 // immoff must fit in a signed 32-bit int
5656 if (!APInt(64, AM.BaseOffs).isSignedIntN(N: 32))
5657 return false;
5658
5659 if (AM.BaseGV)
5660 return !AM.BaseOffs && !AM.HasBaseReg && !AM.Scale;
5661
5662 switch (AM.Scale) {
5663 case 0: // "r", "r+i" or "i" is allowed
5664 break;
5665 case 1:
5666 if (AM.HasBaseReg) // "r+r+i" or "r+r" is not allowed.
5667 return false;
5668 // Otherwise we have r+i.
5669 break;
5670 default:
5671 // No scale > 1 is allowed
5672 return false;
5673 }
5674 return true;
5675}
5676
5677//===----------------------------------------------------------------------===//
5678// NVPTX Inline Assembly Support
5679//===----------------------------------------------------------------------===//
5680
5681/// getConstraintType - Given a constraint letter, return the type of
5682/// constraint it is for this target.
5683NVPTXTargetLowering::ConstraintType
5684NVPTXTargetLowering::getConstraintType(StringRef Constraint) const {
5685 if (Constraint.size() == 1) {
5686 switch (Constraint[0]) {
5687 default:
5688 break;
5689 case 'b':
5690 case 'r':
5691 case 'h':
5692 case 'c':
5693 case 'l':
5694 case 'f':
5695 case 'd':
5696 case 'q':
5697 case '0':
5698 case 'N':
5699 return C_RegisterClass;
5700 }
5701 }
5702 return TargetLowering::getConstraintType(Constraint);
5703}
5704
5705std::pair<unsigned, const TargetRegisterClass *>
5706NVPTXTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
5707 StringRef Constraint,
5708 MVT VT) const {
5709 if (Constraint.size() == 1) {
5710 switch (Constraint[0]) {
5711 case 'b':
5712 return std::make_pair(x: 0U, y: &NVPTX::B1RegClass);
5713 case 'c':
5714 case 'h':
5715 return std::make_pair(x: 0U, y: &NVPTX::B16RegClass);
5716 case 'r':
5717 case 'f':
5718 return std::make_pair(x: 0U, y: &NVPTX::B32RegClass);
5719 case 'l':
5720 case 'N':
5721 case 'd':
5722 return std::make_pair(x: 0U, y: &NVPTX::B64RegClass);
5723 case 'q': {
5724 if (STI.getSmVersion() < 70)
5725 report_fatal_error(reason: "Inline asm with 128 bit operands is only "
5726 "supported for sm_70 and higher!");
5727 return std::make_pair(x: 0U, y: &NVPTX::B128RegClass);
5728 }
5729 }
5730 }
5731 return TargetLowering::getRegForInlineAsmConstraint(TRI, Constraint, VT);
5732}
5733
5734//===----------------------------------------------------------------------===//
5735// NVPTX DAG Combining
5736//===----------------------------------------------------------------------===//
5737
5738bool NVPTXTargetLowering::allowFMA(MachineFunction &MF,
5739 CodeGenOptLevel OptLevel) const {
5740 // Always honor command-line argument
5741 if (FMAContractLevelOpt.getNumOccurrences() > 0)
5742 return FMAContractLevelOpt > 0;
5743
5744 // Do not contract if we're not optimizing the code.
5745 if (OptLevel == CodeGenOptLevel::None)
5746 return false;
5747
5748 // Honor TargetOptions flags that explicitly say fusion is okay.
5749 if (MF.getTarget().Options.AllowFPOpFusion == FPOpFusion::Fast)
5750 return true;
5751
5752 return false;
5753}
5754
5755static bool isConstZero(const SDValue &Operand) {
5756 const auto *Const = dyn_cast<ConstantSDNode>(Val: Operand);
5757 return Const && Const->getZExtValue() == 0;
5758}
5759
5760/// PerformADDCombineWithOperands - Try DAG combinations for an ADD with
5761/// operands N0 and N1. This is a helper for PerformADDCombine that is
5762/// called with the default operands, and if that fails, with commuted
5763/// operands.
5764static SDValue
5765PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
5766 TargetLowering::DAGCombinerInfo &DCI) {
5767 EVT VT = N0.getValueType();
5768
5769 // Since integer multiply-add costs the same as integer multiply
5770 // but is more costly than integer add, do the fusion only when
5771 // the mul is only used in the add.
5772 // TODO: this may not be true for later architectures, consider relaxing this
5773 if (!N0.getNode()->hasOneUse())
5774 return SDValue();
5775
5776 // fold (add (select cond, 0, (mul a, b)), c)
5777 // -> (select cond, c, (add (mul a, b), c))
5778 //
5779 if (N0.getOpcode() == ISD::SELECT) {
5780 unsigned ZeroOpNum;
5781 if (isConstZero(Operand: N0->getOperand(Num: 1)))
5782 ZeroOpNum = 1;
5783 else if (isConstZero(Operand: N0->getOperand(Num: 2)))
5784 ZeroOpNum = 2;
5785 else
5786 return SDValue();
5787
5788 SDValue M = N0->getOperand(Num: (ZeroOpNum == 1) ? 2 : 1);
5789 if (M->getOpcode() != ISD::MUL || !M.getNode()->hasOneUse())
5790 return SDValue();
5791
5792 SDLoc DL(N);
5793 SDValue Mul =
5794 DCI.DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: M->getOperand(Num: 0), N2: M->getOperand(Num: 1));
5795 SDValue MAD = DCI.DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Mul, N2: N1);
5796 return DCI.DAG.getSelect(DL: SDLoc(N), VT, Cond: N0->getOperand(Num: 0),
5797 LHS: ((ZeroOpNum == 1) ? N1 : MAD),
5798 RHS: ((ZeroOpNum == 1) ? MAD : N1));
5799 }
5800
5801 return SDValue();
5802}
5803
5804static SDValue
5805PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
5806 TargetLowering::DAGCombinerInfo &DCI,
5807 CodeGenOptLevel OptLevel) {
5808 EVT VT = N0.getValueType();
5809 if (N0.getOpcode() == ISD::FMUL) {
5810 const auto *TLI = static_cast<const NVPTXTargetLowering *>(
5811 &DCI.DAG.getTargetLoweringInfo());
5812 if (!(TLI->allowFMA(MF&: DCI.DAG.getMachineFunction(), OptLevel) ||
5813 (N->getFlags().hasAllowContract() &&
5814 N0->getFlags().hasAllowContract())))
5815 return SDValue();
5816
5817 // For floating point:
5818 // Do the fusion only when the mul has less than 5 uses and all
5819 // are add.
5820 // The heuristic is that if a use is not an add, then that use
5821 // cannot be fused into fma, therefore mul is still needed anyway.
5822 // If there are more than 4 uses, even if they are all add, fusing
5823 // them will increase register pressue.
5824 //
5825 int numUses = 0;
5826 int nonAddCount = 0;
5827 for (const SDNode *User : N0.getNode()->users()) {
5828 numUses++;
5829 if (User->getOpcode() != ISD::FADD)
5830 ++nonAddCount;
5831 if (numUses >= 5)
5832 return SDValue();
5833 }
5834 if (nonAddCount) {
5835 int orderNo = N->getIROrder();
5836 int orderNo2 = N0.getNode()->getIROrder();
5837 // simple heuristics here for considering potential register
5838 // pressure, the logics here is that the differnce are used
5839 // to measure the distance between def and use, the longer distance
5840 // more likely cause register pressure.
5841 if (orderNo - orderNo2 < 500)
5842 return SDValue();
5843
5844 // Now, check if at least one of the FMUL's operands is live beyond the
5845 // node N, which guarantees that the FMA will not increase register
5846 // pressure at node N.
5847 bool opIsLive = false;
5848 const SDNode *left = N0.getOperand(i: 0).getNode();
5849 const SDNode *right = N0.getOperand(i: 1).getNode();
5850
5851 if (isa<ConstantSDNode>(Val: left) || isa<ConstantSDNode>(Val: right))
5852 opIsLive = true;
5853
5854 if (!opIsLive)
5855 for (const SDNode *User : left->users()) {
5856 int orderNo3 = User->getIROrder();
5857 if (orderNo3 > orderNo) {
5858 opIsLive = true;
5859 break;
5860 }
5861 }
5862
5863 if (!opIsLive)
5864 for (const SDNode *User : right->users()) {
5865 int orderNo3 = User->getIROrder();
5866 if (orderNo3 > orderNo) {
5867 opIsLive = true;
5868 break;
5869 }
5870 }
5871
5872 if (!opIsLive)
5873 return SDValue();
5874 }
5875
5876 return DCI.DAG.getNode(Opcode: ISD::FMA, DL: SDLoc(N), VT, N1: N0.getOperand(i: 0),
5877 N2: N0.getOperand(i: 1), N3: N1);
5878 }
5879
5880 return SDValue();
5881}
5882
5883/// Fold unpacking movs into a load by increasing the number of return values.
5884///
5885/// ex:
5886/// L: v2f16,ch = load <p>
5887/// a: f16 = extractelt L:0, 0
5888/// b: f16 = extractelt L:0, 1
5889/// use(a, b)
5890///
5891/// ...is turned into...
5892///
5893/// L: f16,f16,ch = LoadV2 <p>
5894/// use(L:0, L:1)
5895static SDValue
5896combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
5897 // Don't run this optimization before the legalizer
5898 if (!DCI.isAfterLegalizeDAG())
5899 return SDValue();
5900
5901 EVT ElementVT = N->getValueType(ResNo: 0);
5902 // Avoid non-packed types and v4i8
5903 if (!NVPTX::isPackedVectorTy(VT: ElementVT) || ElementVT == MVT::v4i8)
5904 return SDValue();
5905
5906 // Check whether all outputs are either used by an extractelt or are
5907 // glue/chain nodes
5908 if (!all_of(Range: N->uses(), P: [&](SDUse &U) {
5909 // Skip glue, chain nodes
5910 if (U.getValueType() == MVT::Glue || U.getValueType() == MVT::Other)
5911 return true;
5912 if (U.getUser()->getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
5913 if (N->getOpcode() != ISD::LOAD)
5914 return true;
5915 // Since this is an ISD::LOAD, check all extractelts are used. If
5916 // any are not used, we don't want to defeat another optimization that
5917 // will narrow the load.
5918 //
5919 // For example:
5920 //
5921 // L: v2f16,ch = load <p>
5922 // e0: f16 = extractelt L:0, 0
5923 // e1: f16 = extractelt L:0, 1 <-- unused
5924 // store e0
5925 //
5926 // Can be optimized by DAGCombiner to:
5927 //
5928 // L: f16,ch = load <p>
5929 // store L:0
5930 return !U.getUser()->use_empty();
5931 }
5932
5933 // Otherwise, this use prevents us from splitting a value.
5934 return false;
5935 }))
5936 return SDValue();
5937
5938 auto *LD = cast<MemSDNode>(Val: N);
5939 SDLoc DL(LD);
5940
5941 // the new opcode after we double the number of operands
5942 unsigned Opcode;
5943 SmallVector<SDValue> Operands(LD->ops());
5944 unsigned OldNumOutputs; // non-glue, non-chain outputs
5945 switch (LD->getOpcode()) {
5946 case ISD::LOAD:
5947 OldNumOutputs = 1;
5948 // Any packed type is legal, so the legalizer will not have lowered
5949 // ISD::LOAD -> NVPTXISD::Load (unless it's under-aligned). We have to do it
5950 // here.
5951 Opcode = NVPTXISD::LoadV2;
5952 // append a "full" used bytes mask operand right before the extension type
5953 // operand, signifying that all bytes are used.
5954 Operands.push_back(Elt: DCI.DAG.getConstant(UINT32_MAX, DL, VT: MVT::i32));
5955 Operands.push_back(Elt: DCI.DAG.getIntPtrConstant(
5956 Val: cast<LoadSDNode>(Val: LD)->getExtensionType(), DL));
5957 break;
5958 case NVPTXISD::LoadV2:
5959 OldNumOutputs = 2;
5960 Opcode = NVPTXISD::LoadV4;
5961 break;
5962 case NVPTXISD::LoadV4:
5963 // V8 is only supported for f32/i32. Don't forget, we're not changing the
5964 // load size here. This is already a 256-bit load.
5965 if (ElementVT != MVT::v2f32 && ElementVT != MVT::v2i32)
5966 return SDValue();
5967 OldNumOutputs = 4;
5968 Opcode = NVPTXISD::LoadV8;
5969 break;
5970 case NVPTXISD::LoadV8:
5971 // PTX doesn't support the next doubling of outputs
5972 return SDValue();
5973 }
5974
5975 // the non-glue, non-chain outputs in the new load
5976 const unsigned NewNumOutputs = OldNumOutputs * 2;
5977 SmallVector<EVT> NewVTs(NewNumOutputs, ElementVT.getVectorElementType());
5978 // add remaining chain and glue values
5979 NewVTs.append(in_start: LD->value_begin() + OldNumOutputs, in_end: LD->value_end());
5980
5981 // Create the new load
5982 SDValue NewLoad = DCI.DAG.getMemIntrinsicNode(
5983 Opcode, dl: DL, VTList: DCI.DAG.getVTList(VTs: NewVTs), Ops: Operands, MemVT: LD->getMemoryVT(),
5984 MMO: LD->getMemOperand());
5985
5986 // Now we use a combination of BUILD_VECTORs and a MERGE_VALUES node to keep
5987 // the outputs the same. These nodes will be optimized away in later
5988 // DAGCombiner iterations.
5989 SmallVector<SDValue> Results;
5990 for (unsigned I : seq(Size: OldNumOutputs))
5991 Results.push_back(Elt: DCI.DAG.getBuildVector(
5992 VT: ElementVT, DL, Ops: {NewLoad.getValue(R: I * 2), NewLoad.getValue(R: I * 2 + 1)}));
5993 // Add remaining chain and glue nodes
5994 for (unsigned I : seq(Size: NewLoad->getNumValues() - NewNumOutputs))
5995 Results.push_back(Elt: NewLoad.getValue(R: NewNumOutputs + I));
5996
5997 return DCI.DAG.getMergeValues(Ops: Results, dl: DL);
5998}
5999
6000/// Fold packing movs into a store.
6001///
6002/// ex:
6003/// v1: v2f16 = BUILD_VECTOR a:f16, b:f16
6004/// v2: v2f16 = BUILD_VECTOR c:f16, d:f16
6005/// StoreV2 v1, v2
6006///
6007/// ...is turned into...
6008///
6009/// StoreV4 a, b, c, d
6010static SDValue combinePackingMovIntoStore(SDNode *N,
6011 TargetLowering::DAGCombinerInfo &DCI,
6012 unsigned Front, unsigned Back) {
6013 // We want to run this as late as possible since other optimizations may
6014 // eliminate the BUILD_VECTORs.
6015 if (!DCI.isAfterLegalizeDAG())
6016 return SDValue();
6017
6018 // Get the type of the operands being stored.
6019 EVT ElementVT = N->getOperand(Num: Front).getValueType();
6020
6021 // Avoid non-packed types and v4i8
6022 if (!NVPTX::isPackedVectorTy(VT: ElementVT) || ElementVT == MVT::v4i8)
6023 return SDValue();
6024
6025 auto *ST = cast<MemSDNode>(Val: N);
6026
6027 // The new opcode after we double the number of operands.
6028 unsigned Opcode;
6029 switch (N->getOpcode()) {
6030 case ISD::STORE:
6031 // Any packed type is legal, so the legalizer will not have lowered
6032 // ISD::STORE -> NVPTXISD::Store (unless it's under-aligned). We have to do
6033 // it here.
6034 Opcode = NVPTXISD::StoreV2;
6035 break;
6036 case NVPTXISD::StoreV2:
6037 Opcode = NVPTXISD::StoreV4;
6038 break;
6039 case NVPTXISD::StoreV4:
6040 // V8 is only supported for f32/i32. Don't forget, we're not changing the
6041 // store size here. This is already a 256-bit store.
6042 if (ElementVT != MVT::v2f32 && ElementVT != MVT::v2i32)
6043 return SDValue();
6044 Opcode = NVPTXISD::StoreV8;
6045 break;
6046 case NVPTXISD::StoreV8:
6047 // PTX doesn't support the next doubling of operands
6048 return SDValue();
6049 default:
6050 llvm_unreachable("Unhandled store opcode");
6051 }
6052
6053 // Scan the operands and if they're all BUILD_VECTORs, we'll have gathered
6054 // their elements.
6055 SmallVector<SDValue, 4> Operands(N->ops().take_front(N: Front));
6056 for (SDValue BV : N->ops().drop_front(N: Front).drop_back(N: Back)) {
6057 if (BV.getOpcode() != ISD::BUILD_VECTOR)
6058 return SDValue();
6059
6060 // If the operand has multiple uses, this optimization can increase register
6061 // pressure.
6062 if (!BV.hasOneUse())
6063 return SDValue();
6064
6065 // DAGCombiner visits nodes bottom-up. Check the BUILD_VECTOR operands for
6066 // any signs they may be folded by some other pattern or rule.
6067 for (SDValue Op : BV->ops()) {
6068 // Peek through bitcasts
6069 if (Op.getOpcode() == ISD::BITCAST)
6070 Op = Op.getOperand(i: 0);
6071
6072 // This may be folded into a PRMT.
6073 if (Op.getValueType() == MVT::i16 && Op.getOpcode() == ISD::TRUNCATE &&
6074 Op->getOperand(Num: 0).getValueType() == MVT::i32)
6075 return SDValue();
6076
6077 // This may be folded into cvt.bf16x2
6078 if (Op.getOpcode() == ISD::FP_ROUND)
6079 return SDValue();
6080 }
6081 Operands.append(IL: {BV.getOperand(i: 0), BV.getOperand(i: 1)});
6082 }
6083 Operands.append(in_start: N->op_end() - Back, in_end: N->op_end());
6084
6085 // Now we replace the store
6086 return DCI.DAG.getMemIntrinsicNode(Opcode, dl: SDLoc(N), VTList: N->getVTList(), Ops: Operands,
6087 MemVT: ST->getMemoryVT(), MMO: ST->getMemOperand());
6088}
6089
6090static SDValue combineSTORE(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
6091 const NVPTXSubtarget &STI) {
6092
6093 if (DCI.isBeforeLegalize() && N->getOpcode() == ISD::STORE) {
6094 // Here is our chance to custom lower a store with a non-simple type.
6095 // Unfortunately, we can't do this in the legalizer because there is no
6096 // way to setOperationAction for an non-simple type.
6097 StoreSDNode *ST = cast<StoreSDNode>(Val: N);
6098 if (!ST->getValue().getValueType().isSimple())
6099 return lowerSTOREVector(Op: SDValue(ST, 0), DAG&: DCI.DAG, STI);
6100 }
6101
6102 return combinePackingMovIntoStore(N, DCI, Front: 1, Back: 2);
6103}
6104
6105static SDValue combineLOAD(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
6106 const NVPTXSubtarget &STI) {
6107 if (DCI.isBeforeLegalize() && N->getOpcode() == ISD::LOAD) {
6108 // Here is our chance to custom lower a load with a non-simple type.
6109 // Unfortunately, we can't do this in the legalizer because there is no
6110 // way to setOperationAction for an non-simple type.
6111 if (!N->getValueType(ResNo: 0).isSimple())
6112 return lowerLoadVector(N, DAG&: DCI.DAG, STI);
6113 }
6114
6115 return combineUnpackingMovIntoLoad(N, DCI);
6116}
6117
6118/// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
6119///
6120static SDValue PerformADDCombine(SDNode *N,
6121 TargetLowering::DAGCombinerInfo &DCI,
6122 CodeGenOptLevel OptLevel) {
6123 if (OptLevel == CodeGenOptLevel::None)
6124 return SDValue();
6125
6126 SDValue N0 = N->getOperand(Num: 0);
6127 SDValue N1 = N->getOperand(Num: 1);
6128
6129 // Skip non-integer, non-scalar case
6130 EVT VT = N0.getValueType();
6131 if (VT.isVector() || VT != MVT::i32)
6132 return SDValue();
6133
6134 // First try with the default operand order.
6135 if (SDValue Result = PerformADDCombineWithOperands(N, N0, N1, DCI))
6136 return Result;
6137
6138 // If that didn't work, try again with the operands commuted.
6139 return PerformADDCombineWithOperands(N, N0: N1, N1: N0, DCI);
6140}
6141
6142/// Check if a v2f32 BUILD_VECTOR provably packs values from non-adjacent
6143/// register pairs (non-coalescable).
6144static bool isNonCoalescableBuildVector(const SDValue &BV) {
6145 if (BV.getOpcode() != ISD::BUILD_VECTOR || BV.getValueType() != MVT::v2f32)
6146 return false;
6147
6148 SDValue Elt0 = BV.getOperand(i: 0);
6149 SDValue Elt1 = BV.getOperand(i: 1);
6150
6151 bool IsExt0 = Elt0.getOpcode() == ISD::EXTRACT_VECTOR_ELT;
6152 bool IsExt1 = Elt1.getOpcode() == ISD::EXTRACT_VECTOR_ELT;
6153
6154 // If neither element is an EXTRACT_VECTOR_ELT they are free-standing
6155 // scalars and the register allocator can still place them side-by-side.
6156 if (!IsExt0 && !IsExt1)
6157 return false;
6158
6159 // If exactly one element is an EXTRACT_VECTOR_ELT, the other is a scalar
6160 // that cannot generally occupy the adjacent register slot.
6161 if (IsExt0 != IsExt1)
6162 return true;
6163
6164 // At this point both sources are extracting from vectors. If they are from
6165 // different vectors, then the BUILD_VECTOR is non-coalescable.
6166 SDValue Src0 = Elt0.getOperand(i: 0);
6167 SDValue Src1 = Elt1.getOperand(i: 0);
6168 if (Src0 != Src1)
6169 return true;
6170
6171 auto *Idx0 = dyn_cast<ConstantSDNode>(Val: Elt0.getOperand(i: 1));
6172 auto *Idx1 = dyn_cast<ConstantSDNode>(Val: Elt1.getOperand(i: 1));
6173 // If both indices are dynamic they will be lowered to
6174 // loads and the vector will be spilled to local memory. The register
6175 // allocator can easily place the results in adjacent registers.
6176 if (!Idx0 && !Idx1)
6177 return false;
6178
6179 // If one index is dynamic and the other is constant, the value from the
6180 // constant load will result in an additional register to pair with the result
6181 // from the dynamic load. We consider this non-coalescable.
6182 if ((Idx0 && !Idx1) || (!Idx0 && Idx1))
6183 return true;
6184
6185 // Both are constant, adjacent pairs are coalescable
6186 return std::abs(i: Idx0->getSExtValue() - Idx1->getSExtValue()) != 1;
6187}
6188
6189/// Scalarize a v2f32 arithmetic node (FADD, FMUL, FSUB, FMA) when at least
6190/// one operand is a BUILD_VECTOR that repacks values from non-adjacent register
6191/// pairs. Without this combine the BUILD_VECTOR forces allocation of a
6192/// temporary 64-bit register, increasing register pressure.
6193///
6194/// Example - before:
6195/// t0: v2f32,v2f32,ch = LoadV2 ...
6196/// t1: f32 = extract_vector_elt t0, 0
6197/// t2: f32 = extract_vector_elt t0:1, 0
6198/// t3: v2f32 = BUILD_VECTOR t1, t2 ;; non-coalescable repack
6199/// t4: v2f32 = fma t_a, t3, t_c
6200///
6201/// After:
6202/// t0: v2f32,v2f32,ch = LoadV2 ...
6203/// t1: f32 = extract_vector_elt t0, 0
6204/// t2: f32 = extract_vector_elt t0:1, 0
6205/// a0: f32 = extract_vector_elt t_a, 0
6206/// a1: f32 = extract_vector_elt t_a, 1
6207/// c0: f32 = extract_vector_elt t_c, 0
6208/// c1: f32 = extract_vector_elt t_c, 1
6209/// r0: f32 = fma a0, t1, c0
6210/// r1: f32 = fma a1, t2, c1
6211/// t4: v2f32 = BUILD_VECTOR r0, r1
6212static SDValue PerformScalarizeV2F32Op(SDNode *N,
6213 TargetLowering::DAGCombinerInfo &DCI) {
6214 EVT VT = N->getValueType(ResNo: 0);
6215 if (VT != MVT::v2f32)
6216 return SDValue();
6217
6218 // Only scalarize when at least one operand is a BUILD_VECTOR whose elements
6219 // are guaranteed to reside in different register pairs.
6220 if (none_of(Range: N->ops(), P: isNonCoalescableBuildVector))
6221 return SDValue();
6222
6223 SelectionDAG &DAG = DCI.DAG;
6224 SDLoc DL(N);
6225 EVT EltVT = VT.getVectorElementType();
6226 unsigned Opc = N->getOpcode();
6227
6228 // For each operand, get the scalar element at the given index: if the operand
6229 // is a BUILD_VECTOR, grab the element directly; otherwise, emit an
6230 // EXTRACT_VECTOR_ELT.
6231 auto GetElement = [&](SDValue Op, unsigned Index) -> SDValue {
6232 if (Op.getOpcode() == ISD::BUILD_VECTOR)
6233 return Op.getOperand(i: Index);
6234 return DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: Op,
6235 N2: DAG.getVectorIdxConstant(Val: Index, DL));
6236 };
6237
6238 // Build scalar operand lists for element 0 and element 1.
6239 SmallVector<SDValue, 3> Ops0, Ops1;
6240 for (const SDValue &Op : N->ops()) {
6241 Ops0.push_back(Elt: GetElement(Op, 0));
6242 Ops1.push_back(Elt: GetElement(Op, 1));
6243 }
6244
6245 SDValue Res0 = DAG.getNode(Opcode: Opc, DL, VT: EltVT, Ops: Ops0, Flags: N->getFlags());
6246 SDValue Res1 = DAG.getNode(Opcode: Opc, DL, VT: EltVT, Ops: Ops1, Flags: N->getFlags());
6247
6248 return DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT, N1: Res0, N2: Res1);
6249}
6250
6251/// PerformFADDCombine - Target-specific dag combine xforms for ISD::FADD.
6252///
6253static SDValue PerformFADDCombine(SDNode *N,
6254 TargetLowering::DAGCombinerInfo &DCI,
6255 CodeGenOptLevel OptLevel) {
6256 SDValue N0 = N->getOperand(Num: 0);
6257 SDValue N1 = N->getOperand(Num: 1);
6258
6259 if (SDValue Result = PerformScalarizeV2F32Op(N, DCI))
6260 return Result;
6261
6262 EVT VT = N0.getValueType();
6263 if (VT.isVector() || !(VT == MVT::f32 || VT == MVT::f64))
6264 return SDValue();
6265
6266 // First try with the default operand order.
6267 if (SDValue Result = PerformFADDCombineWithOperands(N, N0, N1, DCI, OptLevel))
6268 return Result;
6269
6270 // If that didn't work, try again with the operands commuted.
6271 return PerformFADDCombineWithOperands(N, N0: N1, N1: N0, DCI, OptLevel);
6272}
6273
6274/// Get 3-input version of a 2-input min/max opcode
6275static unsigned getMinMax3Opcode(unsigned MinMax2Opcode) {
6276 switch (MinMax2Opcode) {
6277 case ISD::FMAXNUM:
6278 case ISD::FMAXIMUMNUM:
6279 return NVPTXISD::FMAXNUM3;
6280 case ISD::FMINNUM:
6281 case ISD::FMINIMUMNUM:
6282 return NVPTXISD::FMINNUM3;
6283 case ISD::FMAXIMUM:
6284 return NVPTXISD::FMAXIMUM3;
6285 case ISD::FMINIMUM:
6286 return NVPTXISD::FMINIMUM3;
6287 default:
6288 llvm_unreachable("Invalid 2-input min/max opcode");
6289 }
6290}
6291
6292/// PerformFMinMaxCombine - Combine (fmaxnum (fmaxnum a, b), c) into
6293/// (fmaxnum3 a, b, c). Also covers other llvm min/max intrinsics.
6294static SDValue PerformFMinMaxCombine(SDNode *N,
6295 TargetLowering::DAGCombinerInfo &DCI,
6296 unsigned PTXVersion, unsigned SmVersion) {
6297
6298 // 3-input min/max requires PTX 8.8+ and SM_100+, and only supports f32s
6299 EVT VT = N->getValueType(ResNo: 0);
6300 if (VT != MVT::f32 || PTXVersion < 88 || SmVersion < 100)
6301 return SDValue();
6302
6303 SDValue Op0 = N->getOperand(Num: 0);
6304 SDValue Op1 = N->getOperand(Num: 1);
6305 unsigned MinMaxOp2 = N->getOpcode();
6306 unsigned MinMaxOp3 = getMinMax3Opcode(MinMax2Opcode: MinMaxOp2);
6307
6308 if (Op0.getOpcode() == MinMaxOp2 && Op0.hasOneUse()) {
6309 // (maxnum (maxnum a, b), c) -> (maxnum3 a, b, c)
6310 SDValue A = Op0.getOperand(i: 0);
6311 SDValue B = Op0.getOperand(i: 1);
6312 SDValue C = Op1;
6313 return DCI.DAG.getNode(Opcode: MinMaxOp3, DL: SDLoc(N), VT, N1: A, N2: B, N3: C, Flags: N->getFlags());
6314 } else if (Op1.getOpcode() == MinMaxOp2 && Op1.hasOneUse()) {
6315 // (maxnum a, (maxnum b, c)) -> (maxnum3 a, b, c)
6316 SDValue A = Op0;
6317 SDValue B = Op1.getOperand(i: 0);
6318 SDValue C = Op1.getOperand(i: 1);
6319 return DCI.DAG.getNode(Opcode: MinMaxOp3, DL: SDLoc(N), VT, N1: A, N2: B, N3: C, Flags: N->getFlags());
6320 }
6321 return SDValue();
6322}
6323
6324static SDValue PerformREMCombine(SDNode *N,
6325 TargetLowering::DAGCombinerInfo &DCI,
6326 CodeGenOptLevel OptLevel) {
6327 assert(N->getOpcode() == ISD::SREM || N->getOpcode() == ISD::UREM);
6328
6329 // Don't do anything at less than -O2.
6330 if (OptLevel < CodeGenOptLevel::Default)
6331 return SDValue();
6332
6333 SelectionDAG &DAG = DCI.DAG;
6334 SDLoc DL(N);
6335 EVT VT = N->getValueType(ResNo: 0);
6336 bool IsSigned = N->getOpcode() == ISD::SREM;
6337 unsigned DivOpc = IsSigned ? ISD::SDIV : ISD::UDIV;
6338
6339 const SDValue &Num = N->getOperand(Num: 0);
6340 const SDValue &Den = N->getOperand(Num: 1);
6341
6342 for (const SDNode *U : Num->users()) {
6343 if (U->getOpcode() == DivOpc && U->getOperand(Num: 0) == Num &&
6344 U->getOperand(Num: 1) == Den) {
6345 // Num % Den -> Num - (Num / Den) * Den
6346 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Num,
6347 N2: DAG.getNode(Opcode: ISD::MUL, DL, VT,
6348 N1: DAG.getNode(Opcode: DivOpc, DL, VT, N1: Num, N2: Den),
6349 N2: Den));
6350 }
6351 }
6352 return SDValue();
6353}
6354
6355// (sign_extend|zero_extend (mul|shl) x, y) -> (mul.wide x, y)
6356static SDValue combineMulWide(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
6357 CodeGenOptLevel OptLevel) {
6358 if (OptLevel == CodeGenOptLevel::None)
6359 return SDValue();
6360
6361 SDValue Op = N->getOperand(Num: 0);
6362 if (!Op.hasOneUse())
6363 return SDValue();
6364 EVT ToVT = N->getValueType(ResNo: 0);
6365 EVT FromVT = Op.getValueType();
6366 if (!((ToVT == MVT::i32 && FromVT == MVT::i16) ||
6367 (ToVT == MVT::i64 && FromVT == MVT::i32)))
6368 return SDValue();
6369 if (!(Op.getOpcode() == ISD::MUL ||
6370 (Op.getOpcode() == ISD::SHL && isa<ConstantSDNode>(Val: Op.getOperand(i: 1)))))
6371 return SDValue();
6372
6373 SDLoc DL(N);
6374 unsigned ExtOpcode = N->getOpcode();
6375 unsigned Opcode = 0;
6376 if (ExtOpcode == ISD::SIGN_EXTEND && Op->getFlags().hasNoSignedWrap())
6377 Opcode = NVPTXISD::MUL_WIDE_SIGNED;
6378 else if (ExtOpcode == ISD::ZERO_EXTEND && Op->getFlags().hasNoUnsignedWrap())
6379 Opcode = NVPTXISD::MUL_WIDE_UNSIGNED;
6380 else
6381 return SDValue();
6382 SDValue RHS = Op.getOperand(i: 1);
6383 if (Op.getOpcode() == ISD::SHL) {
6384 const auto ShiftAmt = Op.getConstantOperandVal(i: 1);
6385 const auto MulVal = APInt(FromVT.getSizeInBits(), 1) << ShiftAmt;
6386 RHS = DCI.DAG.getConstant(Val: MulVal, DL, VT: FromVT);
6387 }
6388 return DCI.DAG.getNode(Opcode, DL, VT: ToVT, N1: Op.getOperand(i: 0), N2: RHS);
6389}
6390
6391enum OperandSignedness {
6392 Signed = 0,
6393 Unsigned,
6394 Unknown
6395};
6396
6397/// IsMulWideOperandDemotable - Checks if the provided DAG node is an operand
6398/// that can be demoted to \p OptSize bits without loss of information. The
6399/// signedness of the operand, if determinable, is placed in \p S.
6400static bool IsMulWideOperandDemotable(SDValue Op,
6401 unsigned OptSize,
6402 OperandSignedness &S) {
6403 S = Unknown;
6404
6405 if (Op.getOpcode() == ISD::SIGN_EXTEND ||
6406 Op.getOpcode() == ISD::SIGN_EXTEND_INREG) {
6407 EVT OrigVT = Op.getOperand(i: 0).getValueType();
6408 if (OrigVT.getFixedSizeInBits() <= OptSize) {
6409 S = Signed;
6410 return true;
6411 }
6412 } else if (Op.getOpcode() == ISD::ZERO_EXTEND) {
6413 EVT OrigVT = Op.getOperand(i: 0).getValueType();
6414 if (OrigVT.getFixedSizeInBits() <= OptSize) {
6415 S = Unsigned;
6416 return true;
6417 }
6418 }
6419
6420 return false;
6421}
6422
6423/// AreMulWideOperandsDemotable - Checks if the given LHS and RHS operands can
6424/// be demoted to \p OptSize bits without loss of information. If the operands
6425/// contain a constant, it should appear as the RHS operand. The signedness of
6426/// the operands is placed in \p IsSigned.
6427static bool AreMulWideOperandsDemotable(SDValue LHS, SDValue RHS,
6428 unsigned OptSize,
6429 bool &IsSigned) {
6430 OperandSignedness LHSSign;
6431
6432 // The LHS operand must be a demotable op
6433 if (!IsMulWideOperandDemotable(Op: LHS, OptSize, S&: LHSSign))
6434 return false;
6435
6436 // We should have been able to determine the signedness from the LHS
6437 if (LHSSign == Unknown)
6438 return false;
6439
6440 IsSigned = (LHSSign == Signed);
6441
6442 // The RHS can be a demotable op or a constant
6443 if (ConstantSDNode *CI = dyn_cast<ConstantSDNode>(Val&: RHS)) {
6444 const APInt &Val = CI->getAPIntValue();
6445 if (LHSSign == Unsigned) {
6446 return Val.isIntN(N: OptSize);
6447 } else {
6448 return Val.isSignedIntN(N: OptSize);
6449 }
6450 } else {
6451 OperandSignedness RHSSign;
6452 if (!IsMulWideOperandDemotable(Op: RHS, OptSize, S&: RHSSign))
6453 return false;
6454
6455 return LHSSign == RHSSign;
6456 }
6457}
6458
6459/// TryMULWIDECombine - Attempt to replace a multiply of M bits with a multiply
6460/// of M/2 bits that produces an M-bit result (i.e. mul.wide). This transform
6461/// works on both multiply DAG nodes and SHL DAG nodes with a constant shift
6462/// amount.
6463static SDValue TryMULWIDECombine(SDNode *N,
6464 TargetLowering::DAGCombinerInfo &DCI) {
6465 EVT MulType = N->getValueType(ResNo: 0);
6466 if (MulType != MVT::i32 && MulType != MVT::i64) {
6467 return SDValue();
6468 }
6469
6470 SDLoc DL(N);
6471 unsigned OptSize = MulType.getSizeInBits() >> 1;
6472 SDValue LHS = N->getOperand(Num: 0);
6473 SDValue RHS = N->getOperand(Num: 1);
6474
6475 // Canonicalize the multiply so the constant (if any) is on the right
6476 if (N->getOpcode() == ISD::MUL) {
6477 if (isa<ConstantSDNode>(Val: LHS)) {
6478 std::swap(a&: LHS, b&: RHS);
6479 }
6480 }
6481
6482 // If we have a SHL, determine the actual multiply amount
6483 if (N->getOpcode() == ISD::SHL) {
6484 ConstantSDNode *ShlRHS = dyn_cast<ConstantSDNode>(Val&: RHS);
6485 if (!ShlRHS) {
6486 return SDValue();
6487 }
6488
6489 APInt ShiftAmt = ShlRHS->getAPIntValue();
6490 unsigned BitWidth = MulType.getSizeInBits();
6491 if (ShiftAmt.sge(RHS: 0) && ShiftAmt.slt(RHS: BitWidth)) {
6492 APInt MulVal = APInt(BitWidth, 1) << ShiftAmt;
6493 RHS = DCI.DAG.getConstant(Val: MulVal, DL, VT: MulType);
6494 } else {
6495 return SDValue();
6496 }
6497 }
6498
6499 bool Signed;
6500 // Verify that our operands are demotable
6501 if (!AreMulWideOperandsDemotable(LHS, RHS, OptSize, IsSigned&: Signed)) {
6502 return SDValue();
6503 }
6504
6505 EVT DemotedVT;
6506 if (MulType == MVT::i32) {
6507 DemotedVT = MVT::i16;
6508 } else {
6509 DemotedVT = MVT::i32;
6510 }
6511
6512 // Truncate the operands to the correct size. Note that these are just for
6513 // type consistency and will (likely) be eliminated in later phases.
6514 SDValue TruncLHS =
6515 DCI.DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: DemotedVT, Operand: LHS);
6516 SDValue TruncRHS =
6517 DCI.DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: DemotedVT, Operand: RHS);
6518
6519 unsigned Opc;
6520 if (Signed) {
6521 Opc = NVPTXISD::MUL_WIDE_SIGNED;
6522 } else {
6523 Opc = NVPTXISD::MUL_WIDE_UNSIGNED;
6524 }
6525
6526 return DCI.DAG.getNode(Opcode: Opc, DL, VT: MulType, N1: TruncLHS, N2: TruncRHS);
6527}
6528
6529static bool isConstOne(const SDValue &Operand) {
6530 const auto *Const = dyn_cast<ConstantSDNode>(Val: Operand);
6531 return Const && Const->getZExtValue() == 1;
6532}
6533
6534static SDValue matchMADConstOnePattern(SDValue Add) {
6535 if (Add->getOpcode() != ISD::ADD)
6536 return SDValue();
6537
6538 if (isConstOne(Operand: Add->getOperand(Num: 0)))
6539 return Add->getOperand(Num: 1);
6540
6541 if (isConstOne(Operand: Add->getOperand(Num: 1)))
6542 return Add->getOperand(Num: 0);
6543
6544 return SDValue();
6545}
6546
6547static SDValue combineMADConstOne(SDValue X, SDValue Add, EVT VT, SDLoc DL,
6548 TargetLowering::DAGCombinerInfo &DCI) {
6549
6550 if (SDValue Y = matchMADConstOnePattern(Add)) {
6551 SDValue Mul = DCI.DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: X, N2: Y);
6552 return DCI.DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Mul, N2: X);
6553 }
6554
6555 return SDValue();
6556}
6557
6558static SDValue combineMulSelectConstOne(SDValue X, SDValue Select, EVT VT,
6559 SDLoc DL,
6560 TargetLowering::DAGCombinerInfo &DCI) {
6561 if (Select->getOpcode() != ISD::SELECT)
6562 return SDValue();
6563
6564 SDValue Cond = Select->getOperand(Num: 0);
6565
6566 unsigned ConstOpNo;
6567 if (isConstOne(Operand: Select->getOperand(Num: 1)))
6568 ConstOpNo = 1;
6569 else if (isConstOne(Operand: Select->getOperand(Num: 2)))
6570 ConstOpNo = 2;
6571 else
6572 return SDValue();
6573
6574 SDValue Y = Select->getOperand(Num: (ConstOpNo == 1) ? 2 : 1);
6575
6576 // Do not combine if the resulting sequence is not obviously profitable.
6577 if (!matchMADConstOnePattern(Add: Y))
6578 return SDValue();
6579
6580 SDValue NewMul = DCI.DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: X, N2: Y);
6581
6582 return DCI.DAG.getNode(Opcode: ISD::SELECT, DL, VT, N1: Cond,
6583 N2: (ConstOpNo == 1) ? X : NewMul,
6584 N3: (ConstOpNo == 1) ? NewMul : X);
6585}
6586
6587static SDValue
6588PerformMULCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
6589 TargetLowering::DAGCombinerInfo &DCI) {
6590
6591 EVT VT = N0.getValueType();
6592 if (VT.isVector())
6593 return SDValue();
6594
6595 if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
6596 return SDValue();
6597
6598 SDLoc DL(N);
6599
6600 // (mul x, (add y, 1)) -> (add (mul x, y), x)
6601 if (SDValue Res = combineMADConstOne(X: N0, Add: N1, VT, DL, DCI))
6602 return Res;
6603 if (SDValue Res = combineMADConstOne(X: N1, Add: N0, VT, DL, DCI))
6604 return Res;
6605
6606 // (mul x, (select y, 1)) -> (select (mul x, y), x)
6607 if (SDValue Res = combineMulSelectConstOne(X: N0, Select: N1, VT, DL, DCI))
6608 return Res;
6609 if (SDValue Res = combineMulSelectConstOne(X: N1, Select: N0, VT, DL, DCI))
6610 return Res;
6611
6612 return SDValue();
6613}
6614
6615/// PerformMULCombine - Runs PTX-specific DAG combine patterns on MUL nodes.
6616static SDValue PerformMULCombine(SDNode *N,
6617 TargetLowering::DAGCombinerInfo &DCI,
6618 CodeGenOptLevel OptLevel) {
6619 if (OptLevel == CodeGenOptLevel::None)
6620 return SDValue();
6621
6622 if (SDValue Ret = TryMULWIDECombine(N, DCI))
6623 return Ret;
6624
6625 SDValue N0 = N->getOperand(Num: 0);
6626 SDValue N1 = N->getOperand(Num: 1);
6627 return PerformMULCombineWithOperands(N, N0, N1, DCI);
6628}
6629
6630/// PerformSHLCombine - Runs PTX-specific DAG combine patterns on SHL nodes.
6631static SDValue PerformSHLCombine(SDNode *N,
6632 TargetLowering::DAGCombinerInfo &DCI,
6633 CodeGenOptLevel OptLevel) {
6634 if (OptLevel > CodeGenOptLevel::None) {
6635 // Try mul.wide combining at OptLevel > 0
6636 if (SDValue Ret = TryMULWIDECombine(N, DCI))
6637 return Ret;
6638 }
6639
6640 return SDValue();
6641}
6642
6643static SDValue PerformSETCCCombine(SDNode *N,
6644 TargetLowering::DAGCombinerInfo &DCI,
6645 unsigned int SmVersion) {
6646 EVT CCType = N->getValueType(ResNo: 0);
6647 SDValue A = N->getOperand(Num: 0);
6648 SDValue B = N->getOperand(Num: 1);
6649
6650 EVT AType = A.getValueType();
6651 if (!(CCType == MVT::v2i1 && (AType == MVT::v2f16 || AType == MVT::v2bf16)))
6652 return SDValue();
6653
6654 if (A.getValueType() == MVT::v2bf16 && SmVersion < 90)
6655 return SDValue();
6656
6657 SDLoc DL(N);
6658 // setp.f16x2 returns two scalar predicates, which we need to
6659 // convert back to v2i1. The returned result will be scalarized by
6660 // the legalizer, but the comparison will remain a single vector
6661 // instruction.
6662 SDValue CCNode = DCI.DAG.getNode(
6663 Opcode: A.getValueType() == MVT::v2f16 ? NVPTXISD::SETP_F16X2
6664 : NVPTXISD::SETP_BF16X2,
6665 DL, VTList: DCI.DAG.getVTList(VT1: MVT::i1, VT2: MVT::i1), Ops: {A, B, N->getOperand(Num: 2)});
6666 return DCI.DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: CCType, N1: CCNode.getValue(R: 0),
6667 N2: CCNode.getValue(R: 1));
6668}
6669
6670static SDValue PerformEXTRACTCombine(SDNode *N,
6671 TargetLowering::DAGCombinerInfo &DCI) {
6672 SDValue Vector = N->getOperand(Num: 0);
6673 if (Vector->getOpcode() == ISD::FREEZE)
6674 Vector = Vector->getOperand(Num: 0);
6675 SDLoc DL(N);
6676 EVT VectorVT = Vector.getValueType();
6677 if (Vector->getOpcode() == ISD::LOAD && VectorVT.isSimple() &&
6678 IsPTXVectorType(VT: VectorVT.getSimpleVT()))
6679 return SDValue(); // Native vector loads already combine nicely w/
6680 // extract_vector_elt.
6681 // Don't mess with singletons or packed types (v2*32, v2*16, v4i8 and v8i8),
6682 // we already handle them OK.
6683 if (VectorVT.getVectorNumElements() == 1 ||
6684 NVPTX::isPackedVectorTy(VT: VectorVT) || VectorVT == MVT::v8i8)
6685 return SDValue();
6686
6687 // Don't mess with undef values as sra may be simplified to 0, not undef.
6688 if (Vector->isUndef() || ISD::allOperandsUndef(N: Vector.getNode()))
6689 return SDValue();
6690
6691 uint64_t VectorBits = VectorVT.getSizeInBits();
6692 // We only handle the types we can extract in-register.
6693 if (!(VectorBits == 16 || VectorBits == 32 || VectorBits == 64))
6694 return SDValue();
6695
6696 ConstantSDNode *Index = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
6697 // Index == 0 is handled by generic DAG combiner.
6698 if (!Index || Index->getZExtValue() == 0)
6699 return SDValue();
6700
6701 MVT IVT = MVT::getIntegerVT(BitWidth: VectorBits);
6702 EVT EltVT = VectorVT.getVectorElementType();
6703 EVT EltIVT = EltVT.changeTypeToInteger();
6704 uint64_t EltBits = EltVT.getScalarSizeInBits();
6705
6706 SDValue Result = DCI.DAG.getNode(
6707 Opcode: ISD::TRUNCATE, DL, VT: EltIVT,
6708 Operand: DCI.DAG.getNode(
6709 Opcode: ISD::SRA, DL, VT: IVT, N1: DCI.DAG.getNode(Opcode: ISD::BITCAST, DL, VT: IVT, Operand: Vector),
6710 N2: DCI.DAG.getConstant(Val: Index->getZExtValue() * EltBits, DL, VT: IVT)));
6711
6712 // If element has non-integer type, bitcast it back to the expected type.
6713 if (EltVT != EltIVT)
6714 Result = DCI.DAG.getNode(Opcode: ISD::BITCAST, DL, VT: EltVT, Operand: Result);
6715 // Past legalizer, we may need to extent i8 -> i16 to match the register type.
6716 if (EltVT != N->getValueType(ResNo: 0))
6717 Result = DCI.DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT: N->getValueType(ResNo: 0), Operand: Result);
6718
6719 return Result;
6720}
6721
6722/// Transform patterns like:
6723/// (select (ugt shift_amt, BitWidth-1), 0, (srl/shl x, shift_amt))
6724/// (select (ult shift_amt, BitWidth), (srl/shl x, shift_amt), 0)
6725/// Into:
6726/// (NVPTXISD::SRL_CLAMP x, shift_amt) or (NVPTXISD::SHL_CLAMP x, shift_amt)
6727///
6728/// These patterns arise from C/C++ code like `shift >= 32 ? 0 : x >> shift`
6729/// which guards against undefined behavior. PTX shr/shl instructions clamp
6730/// shift amounts >= BitWidth to produce 0 for logical shifts, making the
6731/// guard redundant.
6732///
6733/// Note: We only handle SRL and SHL, not SRA, because arithmetic right
6734/// shifts could produce 0 or -1 when shift >= BitWidth.
6735/// Note: We don't handle uge or ule. These don't appear because of
6736/// canonicalization.
6737static SDValue PerformSELECTShiftCombine(SDNode *N,
6738 TargetLowering::DAGCombinerInfo &DCI) {
6739 if (!DCI.isAfterLegalizeDAG())
6740 return SDValue();
6741
6742 using namespace SDPatternMatch;
6743 unsigned BitWidth = N->getValueType(ResNo: 0).getSizeInBits();
6744 SDValue ShiftAmt, ShiftOp;
6745
6746 // Match logical shifts where the shift amount in the guard matches the shift
6747 // amount in the operation.
6748 auto LogicalShift =
6749 m_AllOf(preds: m_Value(N&: ShiftOp),
6750 preds: m_AnyOf(preds: m_Srl(L: m_Value(), R: m_TruncOrSelf(Op: m_Deferred(V&: ShiftAmt))),
6751 preds: m_Shl(L: m_Value(), R: m_TruncOrSelf(Op: m_Deferred(V&: ShiftAmt)))));
6752
6753 // shift_amt > BitWidth-1 ? 0 : shift_op
6754 bool MatchedUGT =
6755 sd_match(N, P: m_Select(Cond: m_SetCC(LHS: m_Value(N&: ShiftAmt),
6756 RHS: m_SpecificInt(V: APInt(BitWidth, BitWidth - 1)),
6757 CC: m_SpecificCondCode(CC: ISD::SETUGT)),
6758 T: m_Zero(), F: LogicalShift));
6759 // shift_amt < BitWidth ? shift_op : 0
6760 bool MatchedULT =
6761 !MatchedUGT &&
6762 sd_match(N, P: m_Select(Cond: m_SetCC(LHS: m_Value(N&: ShiftAmt),
6763 RHS: m_SpecificInt(V: APInt(BitWidth, BitWidth)),
6764 CC: m_SpecificCondCode(CC: ISD::SETULT)),
6765 T: LogicalShift, F: m_Zero()));
6766
6767 if (!MatchedUGT && !MatchedULT)
6768 return SDValue();
6769
6770 // Return a clamp shift operation, which has the same semantics as PTX shift.
6771 unsigned ClampOpc = ShiftOp.getOpcode() == ISD::SRL ? NVPTXISD::SRL_CLAMP
6772 : NVPTXISD::SHL_CLAMP;
6773 return DCI.DAG.getNode(Opcode: ClampOpc, DL: SDLoc(N), VT: ShiftOp.getValueType(),
6774 N1: ShiftOp.getOperand(i: 0), N2: ShiftOp.getOperand(i: 1));
6775}
6776
6777static SDValue PerformVSELECTCombine(SDNode *N,
6778 TargetLowering::DAGCombinerInfo &DCI) {
6779 SDValue VA = N->getOperand(Num: 1);
6780 EVT VectorVT = VA.getValueType();
6781 if (VectorVT != MVT::v4i8)
6782 return SDValue();
6783
6784 // We need to split vselect into individual per-element operations Because we
6785 // use BFE/BFI instruction for byte extraction/insertion, we do end up with
6786 // 32-bit values, so we may as well do comparison as i32 to avoid conversions
6787 // to/from i16 normally used for i8 values.
6788 SmallVector<SDValue, 4> E;
6789 SDLoc DL(N);
6790 SDValue VCond = N->getOperand(Num: 0);
6791 SDValue VB = N->getOperand(Num: 2);
6792 for (int I = 0; I < 4; ++I) {
6793 SDValue C = DCI.DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i1, N1: VCond,
6794 N2: DCI.DAG.getConstant(Val: I, DL, VT: MVT::i32));
6795 SDValue EA = DCI.DAG.getAnyExtOrTrunc(
6796 Op: DCI.DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i8, N1: VA,
6797 N2: DCI.DAG.getConstant(Val: I, DL, VT: MVT::i32)),
6798 DL, VT: MVT::i32);
6799 SDValue EB = DCI.DAG.getAnyExtOrTrunc(
6800 Op: DCI.DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i8, N1: VB,
6801 N2: DCI.DAG.getConstant(Val: I, DL, VT: MVT::i32)),
6802 DL, VT: MVT::i32);
6803 E.push_back(Elt: DCI.DAG.getAnyExtOrTrunc(
6804 Op: DCI.DAG.getNode(Opcode: ISD::SELECT, DL, VT: MVT::i32, N1: C, N2: EA, N3: EB), DL, VT: MVT::i8));
6805 }
6806 return DCI.DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: MVT::v4i8, Ops: E);
6807}
6808
6809static SDValue
6810PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
6811 auto VT = N->getValueType(ResNo: 0);
6812 if (!DCI.isAfterLegalizeDAG() ||
6813 // only process v2*16 types
6814 !(NVPTX::isPackedVectorTy(VT) && VT.is32BitVector() &&
6815 VT.getVectorNumElements() == 2))
6816 return SDValue();
6817
6818 auto Op0 = N->getOperand(Num: 0);
6819 auto Op1 = N->getOperand(Num: 1);
6820
6821 // Start out by assuming we want to take the lower 2 bytes of each i32
6822 // operand.
6823 uint64_t Op0Bytes = 0x10;
6824 uint64_t Op1Bytes = 0x54;
6825
6826 std::pair<SDValue *, uint64_t *> OpData[2] = {{&Op0, &Op0Bytes},
6827 {&Op1, &Op1Bytes}};
6828
6829 // Check that each operand is an i16, truncated from an i32 operand. We'll
6830 // select individual bytes from those original operands. Optionally, fold in a
6831 // shift right of that original operand.
6832 for (auto &[Op, OpBytes] : OpData) {
6833 // Eat up any bitcast
6834 if (Op->getOpcode() == ISD::BITCAST)
6835 *Op = Op->getOperand(i: 0);
6836
6837 if (!(Op->getValueType() == MVT::i16 && Op->getOpcode() == ISD::TRUNCATE &&
6838 Op->getOperand(i: 0).getValueType() == MVT::i32))
6839 return SDValue();
6840
6841 // If the truncate has multiple uses, this optimization can increase
6842 // register pressure
6843 if (!Op->hasOneUse())
6844 return SDValue();
6845
6846 *Op = Op->getOperand(i: 0);
6847
6848 // Optionally, fold in a shift-right of the original operand and let permute
6849 // pick the two higher bytes of the original value directly.
6850 if (Op->getOpcode() == ISD::SRL && isa<ConstantSDNode>(Val: Op->getOperand(i: 1))) {
6851 if (cast<ConstantSDNode>(Val: Op->getOperand(i: 1))->getZExtValue() == 16) {
6852 // Shift the PRMT byte selector to pick upper bytes from each respective
6853 // value, instead of the lower ones: 0x10 -> 0x32, 0x54 -> 0x76
6854 assert((*OpBytes == 0x10 || *OpBytes == 0x54) &&
6855 "PRMT selector values out of range");
6856 *OpBytes += 0x22;
6857 *Op = Op->getOperand(i: 0);
6858 }
6859 }
6860 }
6861
6862 SDLoc DL(N);
6863 auto &DAG = DCI.DAG;
6864
6865 auto PRMT =
6866 getPRMT(A: DAG.getBitcast(VT: MVT::i32, V: Op0), B: DAG.getBitcast(VT: MVT::i32, V: Op1),
6867 Selector: (Op1Bytes << 8) | Op0Bytes, DL, DAG);
6868 return DAG.getBitcast(VT, V: PRMT);
6869}
6870
6871static SDValue combineADDRSPACECAST(SDNode *N,
6872 TargetLowering::DAGCombinerInfo &DCI) {
6873 auto *ASCN1 = cast<AddrSpaceCastSDNode>(Val: N);
6874
6875 if (auto *ASCN2 = dyn_cast<AddrSpaceCastSDNode>(Val: ASCN1->getOperand(Num: 0))) {
6876 assert(ASCN2->getDestAddressSpace() == ASCN1->getSrcAddressSpace());
6877
6878 // Fold asc[B -> A](asc[A -> B](x)) -> x
6879 if (ASCN1->getDestAddressSpace() == ASCN2->getSrcAddressSpace())
6880 return ASCN2->getOperand(Num: 0);
6881 }
6882
6883 return SDValue();
6884}
6885
6886// Given a constant selector value and a prmt mode, return the selector value
6887// normalized to the generic prmt mode. See the PTX ISA documentation for more
6888// details:
6889// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt
6890static APInt getPRMTSelector(const APInt &Selector, unsigned Mode) {
6891 assert(Selector.getBitWidth() == 32 && "PRMT must have i32 operands");
6892
6893 if (Mode == NVPTX::PTXPrmtMode::NONE)
6894 return Selector;
6895
6896 const unsigned V = Selector.trunc(width: 2).getZExtValue();
6897
6898 const auto GetSelector = [](unsigned S0, unsigned S1, unsigned S2,
6899 unsigned S3) {
6900 return APInt(32, S0 | (S1 << 4) | (S2 << 8) | (S3 << 12));
6901 };
6902
6903 switch (Mode) {
6904 case NVPTX::PTXPrmtMode::F4E:
6905 return GetSelector(V, V + 1, V + 2, V + 3);
6906 case NVPTX::PTXPrmtMode::B4E:
6907 return GetSelector(V, (V - 1) & 7, (V - 2) & 7, (V - 3) & 7);
6908 case NVPTX::PTXPrmtMode::RC8:
6909 return GetSelector(V, V, V, V);
6910 case NVPTX::PTXPrmtMode::ECL:
6911 return GetSelector(V, std::max(a: V, b: 1U), std::max(a: V, b: 2U), 3U);
6912 case NVPTX::PTXPrmtMode::ECR:
6913 return GetSelector(0, std::min(a: V, b: 1U), std::min(a: V, b: 2U), V);
6914 case NVPTX::PTXPrmtMode::RC16: {
6915 unsigned V1 = (V & 1) << 1;
6916 return GetSelector(V1, V1 + 1, V1, V1 + 1);
6917 }
6918 default:
6919 llvm_unreachable("Invalid PRMT mode");
6920 }
6921}
6922
6923static APInt computePRMT(APInt A, APInt B, APInt Selector, unsigned Mode) {
6924 assert(A.getBitWidth() == 32 && B.getBitWidth() == 32 &&
6925 Selector.getBitWidth() == 32 && "PRMT must have i32 operands");
6926 // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
6927 APInt BitField = B.concat(NewLSB: A);
6928 APInt SelectorVal = getPRMTSelector(Selector, Mode);
6929 APInt Result(32, 0);
6930 for (unsigned I : llvm::seq(Size: 4U)) {
6931 APInt Sel = SelectorVal.extractBits(numBits: 4, bitPosition: I * 4);
6932 unsigned Idx = Sel.getLoBits(numBits: 3).getZExtValue();
6933 unsigned Sign = Sel.getHiBits(numBits: 1).getZExtValue();
6934 APInt Byte = BitField.extractBits(numBits: 8, bitPosition: Idx * 8);
6935 if (Sign)
6936 Byte = Byte.ashr(ShiftAmt: 8);
6937 Result.insertBits(SubBits: Byte, bitPosition: I * 8);
6938 }
6939 return Result;
6940}
6941
6942static SDValue combinePRMT(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
6943 CodeGenOptLevel OptLevel) {
6944 if (OptLevel == CodeGenOptLevel::None)
6945 return SDValue();
6946
6947 // Constant fold PRMT
6948 if (isa<ConstantSDNode>(Val: N->getOperand(Num: 0)) &&
6949 isa<ConstantSDNode>(Val: N->getOperand(Num: 1)) &&
6950 isa<ConstantSDNode>(Val: N->getOperand(Num: 2)))
6951 return DCI.DAG.getConstant(Val: computePRMT(A: N->getConstantOperandAPInt(Num: 0),
6952 B: N->getConstantOperandAPInt(Num: 1),
6953 Selector: N->getConstantOperandAPInt(Num: 2),
6954 Mode: N->getConstantOperandVal(Num: 3)),
6955 DL: SDLoc(N), VT: N->getValueType(ResNo: 0));
6956 return SDValue();
6957}
6958
6959// During call lowering we wrap the return values in a ProxyReg node which
6960// depend on the chain value produced by the completed call. This ensures that
6961// the full call is emitted in cases where libcalls are used to legalize
6962// operations. To improve the functioning of other DAG combines we pull all
6963// operations we can through one of these nodes, ensuring that the ProxyReg
6964// directly wraps a load. That is:
6965//
6966// (ProxyReg (zext (load retval0))) => (zext (ProxyReg (load retval0)))
6967//
6968static SDValue sinkProxyReg(SDValue R, SDValue Chain,
6969 TargetLowering::DAGCombinerInfo &DCI) {
6970 switch (R.getOpcode()) {
6971 case ISD::TRUNCATE:
6972 case ISD::ANY_EXTEND:
6973 case ISD::SIGN_EXTEND:
6974 case ISD::ZERO_EXTEND:
6975 case ISD::BITCAST: {
6976 if (SDValue V = sinkProxyReg(R: R.getOperand(i: 0), Chain, DCI))
6977 return DCI.DAG.getNode(Opcode: R.getOpcode(), DL: SDLoc(R), VT: R.getValueType(), Operand: V);
6978 return SDValue();
6979 }
6980 case ISD::SHL:
6981 case ISD::SRL:
6982 case ISD::SRA:
6983 case ISD::OR: {
6984 if (SDValue A = sinkProxyReg(R: R.getOperand(i: 0), Chain, DCI))
6985 if (SDValue B = sinkProxyReg(R: R.getOperand(i: 1), Chain, DCI))
6986 return DCI.DAG.getNode(Opcode: R.getOpcode(), DL: SDLoc(R), VT: R.getValueType(), N1: A, N2: B);
6987 return SDValue();
6988 }
6989 case ISD::Constant:
6990 return R;
6991 case ISD::LOAD:
6992 case NVPTXISD::LoadV2:
6993 case NVPTXISD::LoadV4: {
6994 return DCI.DAG.getNode(Opcode: NVPTXISD::ProxyReg, DL: SDLoc(R), VT: R.getValueType(),
6995 Ops: {Chain, R});
6996 }
6997 case ISD::BUILD_VECTOR: {
6998 if (DCI.isBeforeLegalize())
6999 return SDValue();
7000
7001 SmallVector<SDValue, 16> Ops;
7002 for (auto &Op : R->ops()) {
7003 SDValue V = sinkProxyReg(R: Op, Chain, DCI);
7004 if (!V)
7005 return SDValue();
7006 Ops.push_back(Elt: V);
7007 }
7008 return DCI.DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL: SDLoc(R), VT: R.getValueType(), Ops);
7009 }
7010 case ISD::EXTRACT_VECTOR_ELT: {
7011 if (DCI.isBeforeLegalize())
7012 return SDValue();
7013
7014 if (SDValue V = sinkProxyReg(R: R.getOperand(i: 0), Chain, DCI))
7015 return DCI.DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: SDLoc(R),
7016 VT: R.getValueType(), N1: V, N2: R.getOperand(i: 1));
7017 return SDValue();
7018 }
7019 default:
7020 return SDValue();
7021 }
7022}
7023
7024static unsigned getF16SubOpc(Intrinsic::ID AddIntrinsicID) {
7025 switch (AddIntrinsicID) {
7026 default:
7027 break;
7028 case Intrinsic::nvvm_add_rn_sat_f16:
7029 case Intrinsic::nvvm_add_rn_sat_v2f16:
7030 return NVPTXISD::SUB_RN_SAT;
7031 case Intrinsic::nvvm_add_rn_ftz_sat_f16:
7032 case Intrinsic::nvvm_add_rn_ftz_sat_v2f16:
7033 return NVPTXISD::SUB_RN_FTZ_SAT;
7034 }
7035 llvm_unreachable("Invalid F16 add intrinsic");
7036}
7037
7038static SDValue combineF16AddWithNeg(SDNode *N, SelectionDAG &DAG,
7039 Intrinsic::ID AddIntrinsicID) {
7040 SDValue Op1 = N->getOperand(Num: 1);
7041 SDValue Op2 = N->getOperand(Num: 2);
7042
7043 SDValue SubOp1, SubOp2;
7044
7045 if (Op1.getOpcode() == ISD::FNEG) {
7046 SubOp1 = Op2;
7047 SubOp2 = Op1.getOperand(i: 0);
7048 } else if (Op2.getOpcode() == ISD::FNEG) {
7049 SubOp1 = Op1;
7050 SubOp2 = Op2.getOperand(i: 0);
7051 } else {
7052 return SDValue();
7053 }
7054
7055 SDLoc DL(N);
7056 return DAG.getNode(Opcode: getF16SubOpc(AddIntrinsicID), DL, VT: N->getValueType(ResNo: 0),
7057 N1: SubOp1, N2: SubOp2);
7058}
7059
7060static SDValue combineIntrinsicWOChain(SDNode *N,
7061 TargetLowering::DAGCombinerInfo &DCI,
7062 const NVPTXSubtarget &STI) {
7063 unsigned IID = N->getConstantOperandVal(Num: 0);
7064
7065 switch (IID) {
7066 default:
7067 break;
7068 case Intrinsic::nvvm_add_rn_sat_f16:
7069 case Intrinsic::nvvm_add_rn_ftz_sat_f16:
7070 case Intrinsic::nvvm_add_rn_sat_v2f16:
7071 case Intrinsic::nvvm_add_rn_ftz_sat_v2f16:
7072 return combineF16AddWithNeg(N, DAG&: DCI.DAG, AddIntrinsicID: IID);
7073 }
7074 return SDValue();
7075}
7076
7077static SDValue combineProxyReg(SDNode *N,
7078 TargetLowering::DAGCombinerInfo &DCI) {
7079
7080 SDValue Chain = N->getOperand(Num: 0);
7081 SDValue Reg = N->getOperand(Num: 1);
7082
7083 // If the ProxyReg is not wrapping a load, try to pull the operations through
7084 // the ProxyReg.
7085 if (Reg.getOpcode() != ISD::LOAD) {
7086 if (SDValue V = sinkProxyReg(R: Reg, Chain, DCI))
7087 return V;
7088 }
7089
7090 return SDValue();
7091}
7092
7093SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
7094 DAGCombinerInfo &DCI) const {
7095 CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
7096 switch (N->getOpcode()) {
7097 default:
7098 break;
7099 case ISD::ADD:
7100 return PerformADDCombine(N, DCI, OptLevel);
7101 case ISD::ADDRSPACECAST:
7102 return combineADDRSPACECAST(N, DCI);
7103 case ISD::SIGN_EXTEND:
7104 case ISD::ZERO_EXTEND:
7105 return combineMulWide(N, DCI, OptLevel);
7106 case ISD::BUILD_VECTOR:
7107 return PerformBUILD_VECTORCombine(N, DCI);
7108 case ISD::EXTRACT_VECTOR_ELT:
7109 return PerformEXTRACTCombine(N, DCI);
7110 case ISD::FADD:
7111 return PerformFADDCombine(N, DCI, OptLevel);
7112 case ISD::FMA:
7113 case ISD::FMUL:
7114 case ISD::FSUB:
7115 return PerformScalarizeV2F32Op(N, DCI);
7116 case ISD::FMAXNUM:
7117 case ISD::FMINNUM:
7118 case ISD::FMAXIMUM:
7119 case ISD::FMINIMUM:
7120 case ISD::FMAXIMUMNUM:
7121 case ISD::FMINIMUMNUM:
7122 return PerformFMinMaxCombine(N, DCI, PTXVersion: STI.getPTXVersion(),
7123 SmVersion: STI.getSmVersion());
7124 case ISD::LOAD:
7125 case NVPTXISD::LoadV2:
7126 case NVPTXISD::LoadV4:
7127 return combineLOAD(N, DCI, STI);
7128 case ISD::MUL:
7129 return PerformMULCombine(N, DCI, OptLevel);
7130 case NVPTXISD::PRMT:
7131 return combinePRMT(N, DCI, OptLevel);
7132 case NVPTXISD::ProxyReg:
7133 return combineProxyReg(N, DCI);
7134 case ISD::SETCC:
7135 return PerformSETCCCombine(N, DCI, SmVersion: STI.getSmVersion());
7136 case ISD::SHL:
7137 return PerformSHLCombine(N, DCI, OptLevel);
7138 case ISD::SREM:
7139 case ISD::UREM:
7140 return PerformREMCombine(N, DCI, OptLevel);
7141 case ISD::STORE:
7142 case NVPTXISD::StoreV2:
7143 case NVPTXISD::StoreV4:
7144 return combineSTORE(N, DCI, STI);
7145 case ISD::SELECT:
7146 return PerformSELECTShiftCombine(N, DCI);
7147 case ISD::VSELECT:
7148 return PerformVSELECTCombine(N, DCI);
7149 case ISD::INTRINSIC_WO_CHAIN:
7150 return combineIntrinsicWOChain(N, DCI, STI);
7151 }
7152 return SDValue();
7153}
7154
7155static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG,
7156 SmallVectorImpl<SDValue> &Results) {
7157 // Handle bitcasting to v2i8 without hitting the default promotion
7158 // strategy which goes through stack memory.
7159 SDValue Op(Node, 0);
7160 EVT ToVT = Op->getValueType(ResNo: 0);
7161 if (ToVT != MVT::v2i8) {
7162 return;
7163 }
7164
7165 // Bitcast to i16 and unpack elements into a vector
7166 SDLoc DL(Node);
7167 SDValue AsInt = DAG.getBitcast(VT: MVT::i16, V: Op->getOperand(Num: 0));
7168 SDValue Vec0 = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MVT::i8, Operand: AsInt);
7169 SDValue Const8 = DAG.getConstant(Val: 8, DL, VT: MVT::i16);
7170 SDValue Vec1 =
7171 DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MVT::i8,
7172 Operand: DAG.getNode(Opcode: ISD::SRL, DL, VT: MVT::i16, Ops: {AsInt, Const8}));
7173 Results.push_back(
7174 Elt: DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: MVT::v2i8, Ops: {Vec0, Vec1}));
7175}
7176
7177static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
7178 SmallVectorImpl<SDValue> &Results) {
7179 SDValue Chain = N->getOperand(Num: 0);
7180 SDValue Intrin = N->getOperand(Num: 1);
7181 SDLoc DL(N);
7182
7183 // Get the intrinsic ID
7184 unsigned IntrinNo = Intrin.getNode()->getAsZExtVal();
7185 switch (IntrinNo) {
7186 default:
7187 return;
7188 case Intrinsic::nvvm_ldu_global_i:
7189 case Intrinsic::nvvm_ldu_global_f:
7190 case Intrinsic::nvvm_ldu_global_p: {
7191 EVT ResVT = N->getValueType(ResNo: 0);
7192
7193 if (ResVT.isVector()) {
7194 // Vector LDG/LDU
7195
7196 unsigned NumElts = ResVT.getVectorNumElements();
7197 EVT EltVT = ResVT.getVectorElementType();
7198
7199 // Since LDU/LDG are target nodes, we cannot rely on DAG type
7200 // legalization.
7201 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
7202 // loaded type to i16 and propagate the "real" type as the memory type.
7203 bool NeedTrunc = false;
7204 if (EltVT.getSizeInBits() < 16) {
7205 EltVT = MVT::i16;
7206 NeedTrunc = true;
7207 }
7208
7209 unsigned Opcode = 0;
7210 SDVTList LdResVTs;
7211
7212 switch (NumElts) {
7213 default:
7214 return;
7215 case 2:
7216 Opcode = NVPTXISD::LDUV2;
7217 LdResVTs = DAG.getVTList(VT1: EltVT, VT2: EltVT, VT3: MVT::Other);
7218 break;
7219 case 4: {
7220 Opcode = NVPTXISD::LDUV4;
7221 EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
7222 LdResVTs = DAG.getVTList(VTs: ListVTs);
7223 break;
7224 }
7225 }
7226
7227 SmallVector<SDValue, 8> OtherOps;
7228
7229 // Copy regular operands
7230
7231 OtherOps.push_back(Elt: Chain); // Chain
7232 // Skip operand 1 (intrinsic ID)
7233 // Others
7234 OtherOps.append(in_start: N->op_begin() + 2, in_end: N->op_end());
7235
7236 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
7237
7238 SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, dl: DL, VTList: LdResVTs, Ops: OtherOps,
7239 MemVT: MemSD->getMemoryVT(),
7240 MMO: MemSD->getMemOperand());
7241
7242 SmallVector<SDValue, 4> ScalarRes;
7243
7244 for (unsigned i = 0; i < NumElts; ++i) {
7245 SDValue Res = NewLD.getValue(R: i);
7246 if (NeedTrunc)
7247 Res =
7248 DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: ResVT.getVectorElementType(), Operand: Res);
7249 ScalarRes.push_back(Elt: Res);
7250 }
7251
7252 SDValue LoadChain = NewLD.getValue(R: NumElts);
7253
7254 SDValue BuildVec =
7255 DAG.getBuildVector(VT: ResVT, DL, Ops: ScalarRes);
7256
7257 Results.push_back(Elt: BuildVec);
7258 Results.push_back(Elt: LoadChain);
7259 } else {
7260 // i8 LDG/LDU
7261 assert(ResVT.isSimple() && ResVT.getSimpleVT().SimpleTy == MVT::i8 &&
7262 "Custom handling of non-i8 ldu/ldg?");
7263
7264 // Just copy all operands as-is
7265 SmallVector<SDValue, 4> Ops(N->ops());
7266
7267 // Force output to i16
7268 SDVTList LdResVTs = DAG.getVTList(VT1: MVT::i16, VT2: MVT::Other);
7269
7270 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
7271
7272 // We make sure the memory type is i8, which will be used during isel
7273 // to select the proper instruction.
7274 SDValue NewLD =
7275 DAG.getMemIntrinsicNode(Opcode: ISD::INTRINSIC_W_CHAIN, dl: DL, VTList: LdResVTs, Ops,
7276 MemVT: MVT::i8, MMO: MemSD->getMemOperand());
7277
7278 Results.push_back(Elt: DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MVT::i8,
7279 Operand: NewLD.getValue(R: 0)));
7280 Results.push_back(Elt: NewLD.getValue(R: 1));
7281 }
7282 return;
7283 }
7284
7285 case Intrinsic::nvvm_tcgen05_ld_16x64b_x4:
7286 case Intrinsic::nvvm_tcgen05_ld_16x64b_x8:
7287 case Intrinsic::nvvm_tcgen05_ld_16x64b_x16:
7288 case Intrinsic::nvvm_tcgen05_ld_16x64b_x32:
7289 case Intrinsic::nvvm_tcgen05_ld_16x64b_x64:
7290 case Intrinsic::nvvm_tcgen05_ld_16x64b_x128:
7291 case Intrinsic::nvvm_tcgen05_ld_32x32b_x4:
7292 case Intrinsic::nvvm_tcgen05_ld_32x32b_x8:
7293 case Intrinsic::nvvm_tcgen05_ld_32x32b_x16:
7294 case Intrinsic::nvvm_tcgen05_ld_32x32b_x32:
7295 case Intrinsic::nvvm_tcgen05_ld_32x32b_x64:
7296 case Intrinsic::nvvm_tcgen05_ld_32x32b_x128:
7297 case Intrinsic::nvvm_tcgen05_ld_16x128b_x2:
7298 case Intrinsic::nvvm_tcgen05_ld_16x128b_x4:
7299 case Intrinsic::nvvm_tcgen05_ld_16x128b_x8:
7300 case Intrinsic::nvvm_tcgen05_ld_16x128b_x16:
7301 case Intrinsic::nvvm_tcgen05_ld_16x128b_x32:
7302 case Intrinsic::nvvm_tcgen05_ld_16x128b_x64:
7303 case Intrinsic::nvvm_tcgen05_ld_16x256b_x1:
7304 case Intrinsic::nvvm_tcgen05_ld_16x256b_x2:
7305 case Intrinsic::nvvm_tcgen05_ld_16x256b_x4:
7306 case Intrinsic::nvvm_tcgen05_ld_16x256b_x8:
7307 case Intrinsic::nvvm_tcgen05_ld_16x256b_x16:
7308 case Intrinsic::nvvm_tcgen05_ld_16x256b_x32:
7309 if (auto Res = lowerTcgen05Ld(N, DAG)) {
7310 Results.push_back(Elt: Res->first);
7311 Results.push_back(Elt: Res->second);
7312 }
7313 return;
7314
7315 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x4:
7316 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x8:
7317 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x16:
7318 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x32:
7319 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x64:
7320 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x128:
7321 if (auto Res = lowerTcgen05Ld(N, DAG, /*HasOffset=*/true)) {
7322 Results.push_back(Elt: Res->first);
7323 Results.push_back(Elt: Res->second);
7324 }
7325 return;
7326
7327 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x8_i32:
7328 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x8_f32:
7329 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x64_i32:
7330 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x64_f32:
7331 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x4_i32:
7332 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x4_f32:
7333 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x32_i32:
7334 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x32_f32:
7335 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x16_i32:
7336 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x16_f32:
7337 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x128_i32:
7338 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x128_f32:
7339 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x8_i32:
7340 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x8_f32:
7341 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x64_i32:
7342 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x64_f32:
7343 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x4_i32:
7344 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x4_f32:
7345 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x32_i32:
7346 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x32_f32:
7347 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x16_i32:
7348 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x16_f32:
7349 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x128_i32:
7350 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x128_f32:
7351 if (auto Res = lowerTcgen05LdRed(N, DAG)) {
7352 Results.push_back(Elt: std::get<0>(t&: *Res));
7353 Results.push_back(Elt: std::get<1>(t&: *Res));
7354 Results.push_back(Elt: std::get<2>(t&: *Res));
7355 }
7356 return;
7357 }
7358}
7359
7360static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG,
7361 SmallVectorImpl<SDValue> &Results) {
7362 // Change the CopyFromReg to output 2 64-bit results instead of a 128-bit
7363 // result so that it can pass the legalization
7364 SDLoc DL(N);
7365 SDValue Chain = N->getOperand(Num: 0);
7366 SDValue Reg = N->getOperand(Num: 1);
7367 SDValue Glue = N->getOperand(Num: 2);
7368
7369 assert(Reg.getValueType() == MVT::i128 &&
7370 "Custom lowering for CopyFromReg with 128-bit reg only");
7371 SmallVector<EVT, 4> ResultsType = {MVT::i64, MVT::i64, N->getValueType(ResNo: 1),
7372 N->getValueType(ResNo: 2)};
7373 SmallVector<SDValue, 3> NewOps = {Chain, Reg, Glue};
7374
7375 SDValue NewValue = DAG.getNode(Opcode: ISD::CopyFromReg, DL, ResultTys: ResultsType, Ops: NewOps);
7376 SDValue Pair = DAG.getNode(Opcode: ISD::BUILD_PAIR, DL, VT: MVT::i128,
7377 Ops: {NewValue.getValue(R: 0), NewValue.getValue(R: 1)});
7378
7379 Results.push_back(Elt: Pair);
7380 Results.push_back(Elt: NewValue.getValue(R: 2));
7381 Results.push_back(Elt: NewValue.getValue(R: 3));
7382}
7383
7384static void replaceProxyReg(SDNode *N, SelectionDAG &DAG,
7385 const TargetLowering &TLI,
7386 SmallVectorImpl<SDValue> &Results) {
7387 SDValue Chain = N->getOperand(Num: 0);
7388 SDValue Reg = N->getOperand(Num: 1);
7389
7390 MVT VT = TLI.getRegisterType(Context&: *DAG.getContext(), VT: Reg.getValueType());
7391
7392 SDValue NewReg = DAG.getAnyExtOrTrunc(Op: Reg, DL: SDLoc(N), VT);
7393 SDValue NewProxy =
7394 DAG.getNode(Opcode: NVPTXISD::ProxyReg, DL: SDLoc(N), VT, Ops: {Chain, NewReg});
7395 SDValue Res = DAG.getAnyExtOrTrunc(Op: NewProxy, DL: SDLoc(N), VT: N->getValueType(ResNo: 0));
7396
7397 Results.push_back(Elt: Res);
7398}
7399
7400static void replaceAtomicSwap128(SDNode *N, SelectionDAG &DAG,
7401 const NVPTXSubtarget &STI,
7402 SmallVectorImpl<SDValue> &Results) {
7403 assert(N->getValueType(0) == MVT::i128 &&
7404 "Custom lowering for atomic128 only supports i128");
7405
7406 AtomicSDNode *AN = cast<AtomicSDNode>(Val: N);
7407 SDLoc dl(N);
7408
7409 if (!STI.hasAtomSwap128()) {
7410 DAG.getContext()->diagnose(DI: DiagnosticInfoUnsupported(
7411 DAG.getMachineFunction().getFunction(),
7412 "Support for b128 atomics introduced in PTX ISA version 8.3 and "
7413 "requires target sm_90.",
7414 dl.getDebugLoc()));
7415
7416 Results.push_back(Elt: DAG.getUNDEF(VT: MVT::i128));
7417 Results.push_back(Elt: AN->getOperand(Num: 0)); // Chain
7418 return;
7419 }
7420
7421 SmallVector<SDValue, 6> Ops;
7422 Ops.push_back(Elt: AN->getOperand(Num: 0)); // Chain
7423 Ops.push_back(Elt: AN->getOperand(Num: 1)); // Ptr
7424 for (const auto &Op : AN->ops().drop_front(N: 2)) {
7425 // Low part
7426 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::EXTRACT_ELEMENT, DL: dl, VT: MVT::i64, N1: Op,
7427 N2: DAG.getIntPtrConstant(Val: 0, DL: dl)));
7428 // High part
7429 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::EXTRACT_ELEMENT, DL: dl, VT: MVT::i64, N1: Op,
7430 N2: DAG.getIntPtrConstant(Val: 1, DL: dl)));
7431 }
7432 unsigned Opcode = N->getOpcode() == ISD::ATOMIC_SWAP
7433 ? NVPTXISD::ATOMIC_SWAP_B128
7434 : NVPTXISD::ATOMIC_CMP_SWAP_B128;
7435 SDVTList Tys = DAG.getVTList(VT1: MVT::i64, VT2: MVT::i64, VT3: MVT::Other);
7436 SDValue Result = DAG.getMemIntrinsicNode(Opcode, dl, VTList: Tys, Ops, MemVT: MVT::i128,
7437 MMO: AN->getMemOperand());
7438 Results.push_back(Elt: DAG.getNode(Opcode: ISD::BUILD_PAIR, DL: dl, VT: MVT::i128,
7439 Ops: {Result.getValue(R: 0), Result.getValue(R: 1)}));
7440 Results.push_back(Elt: Result.getValue(R: 2));
7441}
7442
7443void NVPTXTargetLowering::ReplaceNodeResults(
7444 SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
7445 switch (N->getOpcode()) {
7446 default:
7447 report_fatal_error(reason: "Unhandled custom legalization");
7448 case ISD::BITCAST:
7449 ReplaceBITCAST(Node: N, DAG, Results);
7450 return;
7451 case ISD::LOAD:
7452 case ISD::MLOAD:
7453 replaceLoadVector(N, DAG, Results, STI);
7454 return;
7455 case ISD::INTRINSIC_W_CHAIN:
7456 ReplaceINTRINSIC_W_CHAIN(N, DAG, Results);
7457 return;
7458 case ISD::CopyFromReg:
7459 ReplaceCopyFromReg_128(N, DAG, Results);
7460 return;
7461 case NVPTXISD::ProxyReg:
7462 replaceProxyReg(N, DAG, TLI: *this, Results);
7463 return;
7464 case ISD::ATOMIC_CMP_SWAP:
7465 case ISD::ATOMIC_SWAP:
7466 replaceAtomicSwap128(N, DAG, STI, Results);
7467 return;
7468 }
7469}
7470
7471NVPTXTargetLowering::AtomicExpansionKind
7472NVPTXTargetLowering::shouldExpandAtomicRMWInIR(const AtomicRMWInst *AI) const {
7473 Type *Ty = AI->getValOperand()->getType();
7474
7475 if (AI->isFloatingPointOperation()) {
7476 if (AI->getOperation() == AtomicRMWInst::BinOp::FAdd) {
7477 if (Ty->isHalfTy() && STI.getSmVersion() >= 70 &&
7478 STI.getPTXVersion() >= 63)
7479 return AtomicExpansionKind::None;
7480 if (Ty->isBFloatTy() && STI.getSmVersion() >= 90 &&
7481 STI.getPTXVersion() >= 78)
7482 return AtomicExpansionKind::None;
7483 if (Ty->isFloatTy())
7484 return AtomicExpansionKind::None;
7485 if (Ty->isDoubleTy() && STI.hasAtomAddF64())
7486 return AtomicExpansionKind::None;
7487 }
7488 return AtomicExpansionKind::CmpXChg;
7489 }
7490
7491 assert(Ty->isIntegerTy() && "Ty should be integer at this point");
7492 const unsigned BitWidth = cast<IntegerType>(Val: Ty)->getBitWidth();
7493
7494 switch (AI->getOperation()) {
7495 default:
7496 return AtomicExpansionKind::CmpXChg;
7497 case AtomicRMWInst::BinOp::Xchg:
7498 if (BitWidth == 128)
7499 return AtomicExpansionKind::None;
7500 [[fallthrough]];
7501 case AtomicRMWInst::BinOp::And:
7502 case AtomicRMWInst::BinOp::Or:
7503 case AtomicRMWInst::BinOp::Xor:
7504 switch (BitWidth) {
7505 case 8:
7506 case 16:
7507 return AtomicExpansionKind::CmpXChg;
7508 case 32:
7509 return AtomicExpansionKind::None;
7510 case 64:
7511 if (STI.hasAtomBitwise64())
7512 return AtomicExpansionKind::None;
7513 return AtomicExpansionKind::CmpXChg;
7514 case 128:
7515 return AtomicExpansionKind::CmpXChg;
7516 default:
7517 llvm_unreachable("unsupported width encountered");
7518 }
7519 case AtomicRMWInst::BinOp::Add:
7520 case AtomicRMWInst::BinOp::Sub:
7521 case AtomicRMWInst::BinOp::Max:
7522 case AtomicRMWInst::BinOp::Min:
7523 case AtomicRMWInst::BinOp::UMax:
7524 case AtomicRMWInst::BinOp::UMin:
7525 switch (BitWidth) {
7526 case 8:
7527 case 16:
7528 return AtomicExpansionKind::CmpXChg;
7529 case 32:
7530 return AtomicExpansionKind::None;
7531 case 64:
7532 if (STI.hasAtomMinMax64())
7533 return AtomicExpansionKind::None;
7534 return AtomicExpansionKind::CmpXChg;
7535 case 128:
7536 return AtomicExpansionKind::CmpXChg;
7537 default:
7538 llvm_unreachable("unsupported width encountered");
7539 }
7540 case AtomicRMWInst::BinOp::UIncWrap:
7541 case AtomicRMWInst::BinOp::UDecWrap:
7542 switch (BitWidth) {
7543 case 32:
7544 return AtomicExpansionKind::None;
7545 case 8:
7546 case 16:
7547 case 64:
7548 case 128:
7549 return AtomicExpansionKind::CmpXChg;
7550 default:
7551 llvm_unreachable("unsupported width encountered");
7552 }
7553 }
7554
7555 return AtomicExpansionKind::CmpXChg;
7556}
7557
7558bool NVPTXTargetLowering::shouldInsertFencesForAtomic(
7559 const Instruction *I) const {
7560 // This function returns true iff the operation is emulated using a CAS-loop,
7561 // or if it has the memory order seq_cst (which is not natively supported in
7562 // the PTX `atom` instruction).
7563 //
7564 // atomicrmw and cmpxchg instructions not efficiently supported by PTX
7565 // are lowered to CAS emulation loops that preserve their memory order,
7566 // syncscope, and volatile semantics. For PTX, it is more efficient to use
7567 // atom.cas.relaxed.sco instructions within the loop, and fences before and
7568 // after the loop to restore order.
7569 //
7570 // Atomic instructions efficiently supported by PTX are lowered to
7571 // `atom.<op>.<sem>.<scope` instruction with their corresponding memory order
7572 // and scope. Since PTX does not support seq_cst, we emulate it by lowering to
7573 // a fence.sc followed by an atom according to the PTX atomics ABI
7574 // https://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/atomic-abi.html
7575 if (auto *CI = dyn_cast<AtomicCmpXchgInst>(Val: I))
7576 return (cast<IntegerType>(Val: CI->getCompareOperand()->getType())
7577 ->getBitWidth() < STI.getMinCmpXchgSizeInBits()) ||
7578 CI->getMergedOrdering() == AtomicOrdering::SequentiallyConsistent;
7579 if (auto *RI = dyn_cast<AtomicRMWInst>(Val: I))
7580 return shouldExpandAtomicRMWInIR(AI: RI) == AtomicExpansionKind::CmpXChg ||
7581 RI->getOrdering() == AtomicOrdering::SequentiallyConsistent;
7582 return false;
7583}
7584
7585AtomicOrdering NVPTXTargetLowering::atomicOperationOrderAfterFenceSplit(
7586 const Instruction *I) const {
7587 // If the operation is emulated by a CAS-loop, we lower the instruction to
7588 // atom.<op>.relaxed, since AtomicExpandPass will insert fences for enforcing
7589 // the correct memory ordering around the CAS loop.
7590 //
7591 // When the operation is not emulated, but the memory order is seq_cst,
7592 // we must lower to "fence.sc.<scope>; atom.<op>.acquire.<scope>;" to conform
7593 // to the PTX atomics ABI.
7594 // https://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/atomic-abi.html
7595 // For such cases, emitLeadingFence() will separately insert the leading
7596 // "fence.sc.<scope>;". Here, we only set the memory order to acquire.
7597 //
7598 // Otherwise, the operation is not emulated, and the memory order is not
7599 // seq_cst. In this case, the LLVM memory order is natively supported by the
7600 // PTX `atom` instruction, and we just lower to the corresponding
7601 // `atom.<op>.relaxed|acquire|release|acq_rel". For such cases, this function
7602 // will NOT be called.
7603 // prerequisite: shouldInsertFencesForAtomic() should have returned `true` for
7604 // I before its memory order was modified.
7605 if (auto *CI = dyn_cast<AtomicCmpXchgInst>(Val: I);
7606 CI && CI->getMergedOrdering() == AtomicOrdering::SequentiallyConsistent &&
7607 cast<IntegerType>(Val: CI->getCompareOperand()->getType())->getBitWidth() >=
7608 STI.getMinCmpXchgSizeInBits())
7609 return AtomicOrdering::Acquire;
7610 else if (auto *RI = dyn_cast<AtomicRMWInst>(Val: I);
7611 RI && RI->getOrdering() == AtomicOrdering::SequentiallyConsistent &&
7612 shouldExpandAtomicRMWInIR(AI: RI) == AtomicExpansionKind::None)
7613 return AtomicOrdering::Acquire;
7614
7615 return AtomicOrdering::Monotonic;
7616}
7617
7618Instruction *NVPTXTargetLowering::emitLeadingFence(IRBuilderBase &Builder,
7619 Instruction *Inst,
7620 AtomicOrdering Ord) const {
7621 // prerequisite: shouldInsertFencesForAtomic() should have returned `true` for
7622 // `Inst` before its memory order was modified. We cannot enforce this with an
7623 // assert, because AtomicExpandPass will have modified the memory order
7624 // between the initial call to shouldInsertFencesForAtomic() and the call to
7625 // this function.
7626 if (!isa<AtomicCmpXchgInst>(Val: Inst) && !isa<AtomicRMWInst>(Val: Inst))
7627 return TargetLoweringBase::emitLeadingFence(Builder, Inst, Ord);
7628
7629 // Specialize for cmpxchg and atomicrmw
7630 auto SSID = getAtomicSyncScopeID(I: Inst);
7631 assert(SSID.has_value() && "Expected an atomic operation");
7632
7633 if (isReleaseOrStronger(AO: Ord))
7634 return Builder.CreateFence(Ordering: Ord == AtomicOrdering::SequentiallyConsistent
7635 ? AtomicOrdering::SequentiallyConsistent
7636 : AtomicOrdering::Release,
7637 SSID: SSID.value());
7638
7639 return nullptr;
7640}
7641
7642Instruction *NVPTXTargetLowering::emitTrailingFence(IRBuilderBase &Builder,
7643 Instruction *Inst,
7644 AtomicOrdering Ord) const {
7645 // prerequisite: shouldInsertFencesForAtomic() should have returned `true` for
7646 // `Inst` before its memory order was modified. See `emitLeadingFence` for why
7647 // this cannot be enforced with an assert. Specialize for cmpxchg and
7648 // atomicrmw
7649 auto *CI = dyn_cast<AtomicCmpXchgInst>(Val: Inst);
7650 auto *RI = dyn_cast<AtomicRMWInst>(Val: Inst);
7651 if (!CI && !RI)
7652 return TargetLoweringBase::emitTrailingFence(Builder, Inst, Ord);
7653
7654 auto SSID = getAtomicSyncScopeID(I: Inst);
7655 assert(SSID.has_value() && "Expected an atomic operation");
7656
7657 bool IsEmulated =
7658 CI ? cast<IntegerType>(Val: CI->getCompareOperand()->getType())
7659 ->getBitWidth() < STI.getMinCmpXchgSizeInBits()
7660 : shouldExpandAtomicRMWInIR(AI: RI) == AtomicExpansionKind::CmpXChg;
7661
7662 if (isAcquireOrStronger(AO: Ord) && IsEmulated)
7663 return Builder.CreateFence(Ordering: AtomicOrdering::Acquire, SSID: SSID.value());
7664
7665 return nullptr;
7666}
7667
7668// Rather than default to SINT when both UINT and SINT are custom, we only
7669// change the opcode when UINT is not legal and SINT is. UINT is preferred when
7670// both are custom since unsigned CVT instructions can lead to slightly better
7671// SASS code with fewer instructions.
7672unsigned NVPTXTargetLowering::getPreferredFPToIntOpcode(unsigned Op, EVT FromVT,
7673 EVT ToVT) const {
7674 if (isOperationLegal(Op, VT: ToVT))
7675 return Op;
7676 switch (Op) {
7677 case ISD::FP_TO_UINT:
7678 if (isOperationLegal(Op: ISD::FP_TO_SINT, VT: ToVT))
7679 return ISD::FP_TO_SINT;
7680 break;
7681 case ISD::STRICT_FP_TO_UINT:
7682 if (isOperationLegal(Op: ISD::STRICT_FP_TO_SINT, VT: ToVT))
7683 return ISD::STRICT_FP_TO_SINT;
7684 break;
7685 case ISD::VP_FP_TO_UINT:
7686 if (isOperationLegal(Op: ISD::VP_FP_TO_SINT, VT: ToVT))
7687 return ISD::VP_FP_TO_SINT;
7688 break;
7689 default:
7690 break;
7691 }
7692 return Op;
7693}
7694
7695// Pin NVPTXTargetObjectFile's vtables to this file.
7696NVPTXTargetObjectFile::~NVPTXTargetObjectFile() = default;
7697
7698MCSection *NVPTXTargetObjectFile::SelectSectionForGlobal(
7699 const GlobalObject *GO, SectionKind Kind, const TargetMachine &TM) const {
7700 return getDataSection();
7701}
7702
7703static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
7704 const SelectionDAG &DAG, unsigned Depth) {
7705 SDValue A = Op.getOperand(i: 0);
7706 SDValue B = Op.getOperand(i: 1);
7707 ConstantSDNode *Selector = dyn_cast<ConstantSDNode>(Val: Op.getOperand(i: 2));
7708 unsigned Mode = Op.getConstantOperandVal(i: 3);
7709
7710 if (!Selector)
7711 return;
7712
7713 KnownBits AKnown = DAG.computeKnownBits(Op: A, Depth);
7714 KnownBits BKnown = DAG.computeKnownBits(Op: B, Depth);
7715
7716 // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
7717 assert(AKnown.getBitWidth() == 32 && BKnown.getBitWidth() == 32 &&
7718 "PRMT must have i32 operands");
7719 assert(Known.getBitWidth() == 32 && "PRMT must have i32 result");
7720 KnownBits BitField = BKnown.concat(Lo: AKnown);
7721
7722 APInt SelectorVal = getPRMTSelector(Selector: Selector->getAPIntValue(), Mode);
7723 for (unsigned I : llvm::seq(Size: 4)) {
7724 APInt Sel = SelectorVal.extractBits(numBits: 4, bitPosition: I * 4);
7725 unsigned Idx = Sel.getLoBits(numBits: 3).getZExtValue();
7726 unsigned Sign = Sel.getHiBits(numBits: 1).getZExtValue();
7727 KnownBits Byte = BitField.extractBits(NumBits: 8, BitPosition: Idx * 8);
7728 if (Sign)
7729 Byte = KnownBits::ashr(LHS: Byte, RHS: 8);
7730 Known.insertBits(SubBits: Byte, BitPosition: I * 8);
7731 }
7732}
7733
7734static void computeKnownBitsForLoadV(const SDValue Op, KnownBits &Known) {
7735 MemSDNode *LD = cast<MemSDNode>(Val: Op);
7736
7737 // We can't do anything without knowing the sign bit.
7738 auto ExtType = LD->getConstantOperandVal(Num: LD->getNumOperands() - 1);
7739 if (ExtType == ISD::SEXTLOAD)
7740 return;
7741
7742 // ExtLoading to vector types is weird and may not work well with known bits.
7743 auto DestVT = LD->getValueType(ResNo: 0);
7744 if (DestVT.isVector())
7745 return;
7746
7747 assert(Known.getBitWidth() == DestVT.getSizeInBits());
7748 auto ElementBitWidth = NVPTXDAGToDAGISel::getFromTypeWidthForLoad(Mem: LD);
7749 Known.Zero.setHighBits(Known.getBitWidth() - ElementBitWidth);
7750}
7751
7752void NVPTXTargetLowering::computeKnownBitsForTargetNode(
7753 const SDValue Op, KnownBits &Known, const APInt &DemandedElts,
7754 const SelectionDAG &DAG, unsigned Depth) const {
7755 Known.resetAll();
7756
7757 switch (Op.getOpcode()) {
7758 case NVPTXISD::PRMT:
7759 computeKnownBitsForPRMT(Op, Known, DAG, Depth);
7760 break;
7761 case NVPTXISD::LoadV2:
7762 case NVPTXISD::LoadV4:
7763 case NVPTXISD::LoadV8:
7764 computeKnownBitsForLoadV(Op, Known);
7765 break;
7766 default:
7767 break;
7768 }
7769}
7770
7771static std::pair<APInt, APInt> getPRMTDemandedBits(const APInt &SelectorVal,
7772 const APInt &DemandedBits) {
7773 APInt DemandedLHS = APInt(32, 0);
7774 APInt DemandedRHS = APInt(32, 0);
7775
7776 for (unsigned I : llvm::seq(Size: 4)) {
7777 if (DemandedBits.extractBits(numBits: 8, bitPosition: I * 8).isZero())
7778 continue;
7779
7780 APInt Sel = SelectorVal.extractBits(numBits: 4, bitPosition: I * 4);
7781 unsigned Idx = Sel.getLoBits(numBits: 3).getZExtValue();
7782 unsigned Sign = Sel.getHiBits(numBits: 1).getZExtValue();
7783
7784 APInt &Src = Idx < 4 ? DemandedLHS : DemandedRHS;
7785 unsigned ByteStart = (Idx % 4) * 8;
7786 if (Sign)
7787 Src.setBit(ByteStart + 7);
7788 else
7789 Src.setBits(loBit: ByteStart, hiBit: ByteStart + 8);
7790 }
7791
7792 return {DemandedLHS, DemandedRHS};
7793}
7794
7795// Replace undef with 0 as this is easier for other optimizations such as
7796// known bits.
7797static SDValue canonicalizePRMTInput(SDValue Op, SelectionDAG &DAG) {
7798 if (!Op)
7799 return SDValue();
7800 if (Op.isUndef())
7801 return DAG.getConstant(Val: 0, DL: SDLoc(), VT: MVT::i32);
7802 return Op;
7803}
7804
7805static SDValue simplifyDemandedBitsForPRMT(SDValue PRMT,
7806 const APInt &DemandedBits,
7807 SelectionDAG &DAG,
7808 const TargetLowering &TLI,
7809 unsigned Depth) {
7810 assert(PRMT.getOpcode() == NVPTXISD::PRMT);
7811 SDValue Op0 = PRMT.getOperand(i: 0);
7812 SDValue Op1 = PRMT.getOperand(i: 1);
7813 auto *SelectorConst = dyn_cast<ConstantSDNode>(Val: PRMT.getOperand(i: 2));
7814 if (!SelectorConst)
7815 return SDValue();
7816
7817 unsigned Mode = PRMT.getConstantOperandVal(i: 3);
7818 const APInt Selector = getPRMTSelector(Selector: SelectorConst->getAPIntValue(), Mode);
7819
7820 // Try to simplify the PRMT to one of the inputs if the used bytes are all
7821 // from the same input in the correct order.
7822 const unsigned LeadingBytes = DemandedBits.countLeadingZeros() / 8;
7823 const unsigned SelBits = (4 - LeadingBytes) * 4;
7824 if (Selector.getLoBits(numBits: SelBits) == APInt(32, 0x3210).getLoBits(numBits: SelBits))
7825 return Op0;
7826 if (Selector.getLoBits(numBits: SelBits) == APInt(32, 0x7654).getLoBits(numBits: SelBits))
7827 return Op1;
7828
7829 auto [DemandedLHS, DemandedRHS] = getPRMTDemandedBits(SelectorVal: Selector, DemandedBits);
7830
7831 // Attempt to avoid multi-use ops if we don't need anything from them.
7832 SDValue DemandedOp0 =
7833 TLI.SimplifyMultipleUseDemandedBits(Op: Op0, DemandedBits: DemandedLHS, DAG, Depth: Depth + 1);
7834 SDValue DemandedOp1 =
7835 TLI.SimplifyMultipleUseDemandedBits(Op: Op1, DemandedBits: DemandedRHS, DAG, Depth: Depth + 1);
7836
7837 DemandedOp0 = canonicalizePRMTInput(Op: DemandedOp0, DAG);
7838 DemandedOp1 = canonicalizePRMTInput(Op: DemandedOp1, DAG);
7839 if ((DemandedOp0 && DemandedOp0 != Op0) ||
7840 (DemandedOp1 && DemandedOp1 != Op1)) {
7841 Op0 = DemandedOp0 ? DemandedOp0 : Op0;
7842 Op1 = DemandedOp1 ? DemandedOp1 : Op1;
7843 return getPRMT(A: Op0, B: Op1, Selector: Selector.getZExtValue(), DL: SDLoc(PRMT), DAG);
7844 }
7845
7846 return SDValue();
7847}
7848
7849bool NVPTXTargetLowering::SimplifyDemandedBitsForTargetNode(
7850 SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
7851 KnownBits &Known, TargetLoweringOpt &TLO, unsigned Depth) const {
7852 Known.resetAll();
7853
7854 switch (Op.getOpcode()) {
7855 case NVPTXISD::PRMT:
7856 if (SDValue Result = simplifyDemandedBitsForPRMT(PRMT: Op, DemandedBits, DAG&: TLO.DAG,
7857 TLI: *this, Depth)) {
7858 TLO.CombineTo(O: Op, N: Result);
7859 return true;
7860 }
7861 break;
7862 default:
7863 break;
7864 }
7865
7866 computeKnownBitsForTargetNode(Op, Known, DemandedElts, DAG: TLO.DAG, Depth);
7867 return false;
7868}
7869