1//===-- NVPTXISelLowering.cpp - NVPTX DAG Lowering Implementation ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the interfaces that NVPTX uses to lower LLVM code into a
10// selection DAG.
11//
12//===----------------------------------------------------------------------===//
13
14#include "NVPTXISelLowering.h"
15#include "MCTargetDesc/NVPTXBaseInfo.h"
16#include "NVPTX.h"
17#include "NVPTXISelDAGToDAG.h"
18#include "NVPTXSelectionDAGInfo.h"
19#include "NVPTXSubtarget.h"
20#include "NVPTXTargetMachine.h"
21#include "NVPTXTargetObjectFile.h"
22#include "NVPTXUtilities.h"
23#include "NVVMProperties.h"
24#include "llvm/ADT/APFloat.h"
25#include "llvm/ADT/APInt.h"
26#include "llvm/ADT/STLExtras.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/ADT/StringRef.h"
29#include "llvm/CodeGen/Analysis.h"
30#include "llvm/CodeGen/ISDOpcodes.h"
31#include "llvm/CodeGen/MachineFunction.h"
32#include "llvm/CodeGen/MachineJumpTableInfo.h"
33#include "llvm/CodeGen/MachineMemOperand.h"
34#include "llvm/CodeGen/SDPatternMatch.h"
35#include "llvm/CodeGen/SelectionDAG.h"
36#include "llvm/CodeGen/SelectionDAGNodes.h"
37#include "llvm/CodeGen/TargetCallingConv.h"
38#include "llvm/CodeGen/TargetLowering.h"
39#include "llvm/CodeGen/ValueTypes.h"
40#include "llvm/CodeGenTypes/MachineValueType.h"
41#include "llvm/IR/Argument.h"
42#include "llvm/IR/Attributes.h"
43#include "llvm/IR/Constants.h"
44#include "llvm/IR/DataLayout.h"
45#include "llvm/IR/DerivedTypes.h"
46#include "llvm/IR/DiagnosticInfo.h"
47#include "llvm/IR/FPEnv.h"
48#include "llvm/IR/Function.h"
49#include "llvm/IR/GlobalValue.h"
50#include "llvm/IR/IRBuilder.h"
51#include "llvm/IR/Instruction.h"
52#include "llvm/IR/Instructions.h"
53#include "llvm/IR/IntrinsicsNVPTX.h"
54#include "llvm/IR/Module.h"
55#include "llvm/IR/Type.h"
56#include "llvm/IR/Value.h"
57#include "llvm/Support/Alignment.h"
58#include "llvm/Support/AtomicOrdering.h"
59#include "llvm/Support/Casting.h"
60#include "llvm/Support/CodeGen.h"
61#include "llvm/Support/CommandLine.h"
62#include "llvm/Support/ErrorHandling.h"
63#include "llvm/Support/KnownBits.h"
64#include "llvm/Support/NVPTXAddrSpace.h"
65#include "llvm/Support/raw_ostream.h"
66#include "llvm/Target/TargetMachine.h"
67#include "llvm/Target/TargetOptions.h"
68#include <algorithm>
69#include <cassert>
70#include <cmath>
71#include <cstdint>
72#include <iterator>
73#include <optional>
74#include <string>
75#include <tuple>
76#include <utility>
77#include <vector>
78
79#define DEBUG_TYPE "nvptx-lower"
80
81using namespace llvm;
82
83static cl::opt<bool> sched4reg(
84 "nvptx-sched4reg",
85 cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(Val: false));
86
87static cl::opt<unsigned> FMAContractLevelOpt(
88 "nvptx-fma-level", cl::Hidden,
89 cl::desc("NVPTX Specific: FMA contraction (0: don't do it"
90 " 1: do it 2: do it aggressively"),
91 cl::init(Val: 2));
92
93static cl::opt<NVPTX::DivPrecisionLevel> UsePrecDivF32(
94 "nvptx-prec-divf32", cl::Hidden,
95 cl::desc(
96 "NVPTX Specific: Override the precision of the lowering for f32 fdiv"),
97 cl::values(
98 clEnumValN(NVPTX::DivPrecisionLevel::Approx, "0", "Use div.approx"),
99 clEnumValN(NVPTX::DivPrecisionLevel::Full, "1", "Use div.full"),
100 clEnumValN(NVPTX::DivPrecisionLevel::IEEE754, "2",
101 "Use IEEE Compliant F32 div.rnd if available (default)"),
102 clEnumValN(NVPTX::DivPrecisionLevel::IEEE754_NoFTZ, "3",
103 "Use IEEE Compliant F32 div.rnd if available, no FTZ")),
104 cl::init(Val: NVPTX::DivPrecisionLevel::IEEE754));
105
106static cl::opt<bool> UsePrecSqrtF32(
107 "nvptx-prec-sqrtf32", cl::Hidden,
108 cl::desc("NVPTX Specific: 0 use sqrt.approx, 1 use sqrt.rn."),
109 cl::init(Val: true));
110
111// PTX atom.add.f32 has fixed FTZ behavior that may not match the function's
112// (see shouldExpandAtomicRMWInIR), so by default we fall back to a CAS loop
113// when they disagree. This flag is an escape hatch to use atom.add anyway,
114// trading correct denormal handling for the speed of the native instruction.
115static cl::opt<bool> AllowFTZAtomics(
116 "nvptx-allow-ftz-atomics", cl::Hidden,
117 cl::desc("NVPTX Specific: Lower atomicrmw fadd to atom.add even when its "
118 "FTZ behavior does not match the function's denormal mode."),
119 cl::init(Val: false));
120
121/// Whereas CUDA's implementation (see libdevice) uses ex2.approx for exp2(), it
122/// does NOT use lg2.approx for log2, so this is disabled by default.
123static cl::opt<bool> UseApproxLog2F32(
124 "nvptx-approx-log2f32",
125 cl::desc("NVPTX Specific: whether to use lg2.approx for log2"),
126 cl::init(Val: false));
127
128NVPTX::DivPrecisionLevel
129NVPTXTargetLowering::getDivF32Level(const MachineFunction &MF,
130 const SDNode &N) const {
131 // If nvptx-prec-div32=N is used on the command-line, always honor it
132 if (UsePrecDivF32.getNumOccurrences() > 0)
133 return UsePrecDivF32;
134
135 const SDNodeFlags Flags = N.getFlags();
136 if (Flags.hasApproximateFuncs())
137 return NVPTX::DivPrecisionLevel::Approx;
138
139 return NVPTX::DivPrecisionLevel::IEEE754;
140}
141
142bool NVPTXTargetLowering::usePrecSqrtF32(const SDNode *N) const {
143 // If nvptx-prec-sqrtf32 is used on the command-line, always honor it
144 if (UsePrecSqrtF32.getNumOccurrences() > 0)
145 return UsePrecSqrtF32;
146
147 if (N) {
148 const SDNodeFlags Flags = N->getFlags();
149 if (Flags.hasApproximateFuncs())
150 return false;
151 }
152
153 return true;
154}
155
156bool NVPTXTargetLowering::useF32FTZ(const MachineFunction &MF) const {
157 return MF.getDenormalMode(FPType: APFloat::IEEEsingle()).Output ==
158 DenormalMode::PreserveSign;
159}
160
161static bool IsPTXVectorType(MVT VT) {
162 switch (VT.SimpleTy) {
163 default:
164 return false;
165 case MVT::v2i1:
166 case MVT::v4i1:
167 case MVT::v2i8:
168 case MVT::v4i8:
169 case MVT::v8i8: // <2 x i8x4>
170 case MVT::v16i8: // <4 x i8x4>
171 case MVT::v2i16:
172 case MVT::v4i16:
173 case MVT::v8i16: // <4 x i16x2>
174 case MVT::v2i32:
175 case MVT::v4i32:
176 case MVT::v2i64:
177 case MVT::v2f16:
178 case MVT::v4f16:
179 case MVT::v8f16: // <4 x f16x2>
180 case MVT::v2bf16:
181 case MVT::v4bf16:
182 case MVT::v8bf16: // <4 x bf16x2>
183 case MVT::v2f32:
184 case MVT::v4f32:
185 case MVT::v2f64:
186 case MVT::v4i64:
187 case MVT::v4f64:
188 case MVT::v8i32:
189 case MVT::v8f32:
190 case MVT::v16f16: // <8 x f16x2>
191 case MVT::v16bf16: // <8 x bf16x2>
192 case MVT::v16i16: // <8 x i16x2>
193 case MVT::v32i8: // <8 x i8x4>
194 return true;
195 }
196}
197
198// When legalizing vector loads/stores, this function is called, which does two
199// things:
200// 1. Determines Whether the vector is something we want to custom lower,
201// std::nullopt is returned if we do not want to custom lower it.
202// 2. If we do want to handle it, returns two parameters:
203// - unsigned int NumElts - The number of elements in the final vector
204// - EVT EltVT - The type of the elements in the final vector
205static std::optional<std::pair<unsigned int, MVT>>
206getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI,
207 unsigned AddressSpace) {
208 const bool CanLowerTo256Bit = STI.has256BitVectorLoadStore(AS: AddressSpace);
209
210 if (CanLowerTo256Bit && VectorEVT.isScalarInteger() &&
211 VectorEVT.getSizeInBits() == 256)
212 return {{4, MVT::i64}};
213
214 if (!VectorEVT.isSimple())
215 return std::nullopt;
216 const MVT VectorVT = VectorEVT.getSimpleVT();
217
218 if (!VectorVT.isVector()) {
219 if (VectorVT == MVT::i128 || VectorVT == MVT::f128)
220 return {{2, MVT::i64}};
221 return std::nullopt;
222 }
223
224 const MVT EltVT = VectorVT.getVectorElementType();
225 const unsigned NumElts = VectorVT.getVectorNumElements();
226
227 // The size of the PTX virtual register that holds a packed type.
228 unsigned PackRegSize;
229
230 // We only handle "native" vector sizes for now, e.g. <4 x double> is not
231 // legal. We can (and should) split that into 2 stores of <2 x double> here
232 // but I'm leaving that as a TODO for now.
233 switch (VectorVT.SimpleTy) {
234 default:
235 return std::nullopt;
236
237 case MVT::v4i64:
238 case MVT::v4f64:
239 // This is a "native" vector type iff the address space is global and the
240 // target supports 256-bit loads/stores
241 if (!CanLowerTo256Bit)
242 return std::nullopt;
243 [[fallthrough]];
244 case MVT::v2i8:
245 case MVT::v2i64:
246 case MVT::v2f64:
247 // This is a "native" vector type
248 return std::pair(NumElts, EltVT);
249
250 case MVT::v16f16: // <8 x f16x2>
251 case MVT::v16bf16: // <8 x bf16x2>
252 case MVT::v16i16: // <8 x i16x2>
253 case MVT::v32i8: // <8 x i8x4>
254 // This can be upsized into a "native" vector type iff the address space is
255 // global and the target supports 256-bit loads/stores.
256 if (!CanLowerTo256Bit)
257 return std::nullopt;
258 [[fallthrough]];
259 case MVT::v2i16: // <1 x i16x2>
260 case MVT::v2f16: // <1 x f16x2>
261 case MVT::v2bf16: // <1 x bf16x2>
262 case MVT::v4i8: // <1 x i8x4>
263 case MVT::v4i16: // <2 x i16x2>
264 case MVT::v4f16: // <2 x f16x2>
265 case MVT::v4bf16: // <2 x bf16x2>
266 case MVT::v8i8: // <2 x i8x4>
267 case MVT::v8f16: // <4 x f16x2>
268 case MVT::v8bf16: // <4 x bf16x2>
269 case MVT::v8i16: // <4 x i16x2>
270 case MVT::v16i8: // <4 x i8x4>
271 PackRegSize = 32;
272 break;
273
274 case MVT::v8f32: // <4 x f32x2>
275 case MVT::v8i32: // <4 x i32x2>
276 // This is a "native" vector type iff the address space is global and the
277 // target supports 256-bit loads/stores
278 if (!CanLowerTo256Bit)
279 return std::nullopt;
280 [[fallthrough]];
281 case MVT::v2f32: // <1 x f32x2>
282 case MVT::v4f32: // <2 x f32x2>
283 case MVT::v2i32: // <1 x i32x2>
284 case MVT::v4i32: // <2 x i32x2>
285 if (!STI.hasF32x2Instructions())
286 return std::pair(NumElts, EltVT);
287 PackRegSize = 64;
288 break;
289 }
290
291 // If we reach here, then we can pack 2 or more elements into a single 32-bit
292 // or 64-bit PTX register and treat the vector as a new vector containing
293 // packed elements.
294
295 // Number of elements to pack in one word.
296 const unsigned NPerReg = PackRegSize / EltVT.getSizeInBits();
297
298 return std::pair(NumElts / NPerReg, MVT::getVectorVT(VT: EltVT, NumElements: NPerReg));
299}
300
301/// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
302/// legal-ish MVTs that compose it. Unlike ComputeValueVTs, this will legalize
303/// the types as required by the calling convention (with special handling for
304/// i8s).
305/// NOTE: This is a band-aid for code that expects ComputeValueVTs to return the
306/// same number of types as the Ins/Outs arrays in LowerFormalArguments,
307/// LowerCall, and LowerReturn.
308static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
309 LLVMContext &Ctx, CallingConv::ID CallConv,
310 Type *Ty, SmallVectorImpl<EVT> &ValueVTs,
311 SmallVectorImpl<uint64_t> &Offsets,
312 uint64_t StartingOffset = 0) {
313 SmallVector<EVT, 16> TempVTs;
314 SmallVector<uint64_t, 16> TempOffsets;
315 ComputeValueVTs(TLI, DL, Ty, ValueVTs&: TempVTs, /*MemVTs=*/nullptr, FixedOffsets: &TempOffsets,
316 StartingOffset);
317
318 for (const auto [VT, Off] : zip(t&: TempVTs, u&: TempOffsets)) {
319 MVT RegisterVT = TLI.getRegisterTypeForCallingConv(Context&: Ctx, CC: CallConv, VT);
320 unsigned NumRegs = TLI.getNumRegistersForCallingConv(Context&: Ctx, CC: CallConv, VT);
321
322 // Since we actually can load/store b8, we need to ensure that we'll use
323 // the original sized type for any i8s or i8 vectors.
324 if (VT.getScalarType() == MVT::i8) {
325 if (RegisterVT == MVT::i16)
326 RegisterVT = MVT::i8;
327 else if (RegisterVT == MVT::v2i16)
328 RegisterVT = MVT::v2i8;
329 else
330 assert(RegisterVT == MVT::v4i8 &&
331 "Expected v4i8, v2i16, or i16 for i8 RegisterVT");
332 }
333
334 // TODO: This is horribly incorrect for cases where the vector elements are
335 // not a multiple of bytes (ex i1) and legal or i8. However, this problem
336 // has existed for as long as NVPTX has and no one has complained, so we'll
337 // leave it for now.
338 for (unsigned I : seq(Size: NumRegs)) {
339 ValueVTs.push_back(Elt: RegisterVT);
340 Offsets.push_back(Elt: Off + I * RegisterVT.getStoreSize());
341 }
342 }
343}
344
345// We return an EVT that can hold N VTs
346// If the VT is a vector, the resulting EVT is a flat vector with the same
347// element type as VT's element type.
348static EVT getVectorizedVT(EVT VT, unsigned N, LLVMContext &C) {
349 if (N == 1)
350 return VT;
351
352 return VT.isVector() ? EVT::getVectorVT(Context&: C, VT: VT.getScalarType(),
353 NumElements: VT.getVectorNumElements() * N)
354 : EVT::getVectorVT(Context&: C, VT, NumElements: N);
355}
356
357static SDValue getExtractVectorizedValue(SDValue V, unsigned I, EVT VT,
358 const SDLoc &dl, SelectionDAG &DAG) {
359 if (V.getValueType() == VT) {
360 assert(I == 0 && "Index must be 0 for scalar value");
361 return V;
362 }
363
364 if (!VT.isVector())
365 return DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: dl, VT, N1: V,
366 N2: DAG.getVectorIdxConstant(Val: I, DL: dl));
367
368 return DAG.getNode(
369 Opcode: ISD::EXTRACT_SUBVECTOR, DL: dl, VT, N1: V,
370 N2: DAG.getVectorIdxConstant(Val: I * VT.getVectorNumElements(), DL: dl));
371}
372
373template <typename T>
374static inline SDValue getBuildVectorizedValue(unsigned N, const SDLoc &dl,
375 SelectionDAG &DAG, T GetElement) {
376 if (N == 1)
377 return GetElement(0);
378
379 SmallVector<SDValue, 8> Values;
380 for (const unsigned I : llvm::seq(Size: N)) {
381 SDValue Val = GetElement(I);
382 if (Val.getValueType().isVector())
383 DAG.ExtractVectorElements(Op: Val, Args&: Values);
384 else
385 Values.push_back(Elt: Val);
386 }
387
388 EVT VT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: Values[0].getValueType(),
389 NumElements: Values.size());
390 return DAG.getBuildVector(VT, DL: dl, Ops: Values);
391}
392
393/// PromoteScalarIntegerPTX
394/// Used to make sure the arguments/returns are suitable for passing
395/// and promote them to a larger size if they're not.
396///
397/// The promoted type is placed in \p PromoteVT if the function returns true.
398static EVT promoteScalarIntegerPTX(const EVT VT) {
399 if (VT.isScalarInteger()) {
400 switch (PowerOf2Ceil(A: VT.getFixedSizeInBits())) {
401 default:
402 llvm_unreachable(
403 "Promotion is not suitable for scalars of size larger than 64-bits");
404 case 1:
405 return MVT::i1;
406 case 2:
407 case 4:
408 case 8:
409 return MVT::i8;
410 case 16:
411 return MVT::i16;
412 case 32:
413 return MVT::i32;
414 case 64:
415 return MVT::i64;
416 }
417 }
418 return VT;
419}
420
421// Check whether we can merge loads/stores of some of the pieces of a
422// flattened function parameter or return value into a single vector
423// load/store.
424//
425// The flattened parameter is represented as a list of EVTs and
426// offsets, and the whole structure is aligned to ParamAlignment. This
427// function determines whether we can load/store pieces of the
428// parameter starting at index Idx using a single vectorized op of
429// size AccessSize. If so, it returns the number of param pieces
430// covered by the vector op. Otherwise, it returns 1.
431template <typename T>
432static unsigned canMergeParamLoadStoresStartingAt(
433 unsigned Idx, uint32_t AccessSize, const SmallVectorImpl<EVT> &ValueVTs,
434 const SmallVectorImpl<T> &Offsets, Align ParamAlignment) {
435
436 // Can't vectorize if param alignment is not sufficient.
437 if (ParamAlignment < AccessSize)
438 return 1;
439 // Can't vectorize if offset is not aligned.
440 if (Offsets[Idx] & (AccessSize - 1))
441 return 1;
442
443 EVT EltVT = ValueVTs[Idx];
444 unsigned EltSize = EltVT.getStoreSize();
445
446 // Element is too large to vectorize.
447 if (EltSize >= AccessSize)
448 return 1;
449
450 unsigned NumElts = AccessSize / EltSize;
451 // Can't vectorize if AccessBytes if not a multiple of EltSize.
452 if (AccessSize != EltSize * NumElts)
453 return 1;
454
455 // We don't have enough elements to vectorize.
456 if (Idx + NumElts > ValueVTs.size())
457 return 1;
458
459 // PTX ISA can only deal with 2- and 4-element vector ops.
460 if (NumElts != 4 && NumElts != 2)
461 return 1;
462
463 for (unsigned j = Idx + 1; j < Idx + NumElts; ++j) {
464 // Types do not match.
465 if (ValueVTs[j] != EltVT)
466 return 1;
467
468 // Elements are not contiguous.
469 if (Offsets[j] - Offsets[j - 1] != EltSize)
470 return 1;
471 }
472 // OK. We can vectorize ValueVTs[i..i+NumElts)
473 return NumElts;
474}
475
476// Computes whether and how we can vectorize the loads/stores of a
477// flattened function parameter or return value.
478//
479// The flattened parameter is represented as the list of ValueVTs and
480// Offsets, and is aligned to ParamAlignment bytes. We return a vector
481// of the same size as ValueVTs indicating how each piece should be
482// loaded/stored (i.e. as a scalar, or as part of a vector
483// load/store).
484template <typename T>
485static SmallVector<unsigned, 16>
486VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
487 const SmallVectorImpl<T> &Offsets, Align ParamAlignment,
488 bool IsVAArg = false) {
489 // Set vector size to match ValueVTs and mark all elements as
490 // scalars by default.
491
492 if (IsVAArg)
493 return SmallVector<unsigned>(ValueVTs.size(), 1);
494
495 SmallVector<unsigned, 16> VectorInfo;
496
497 const auto GetNumElts = [&](unsigned I) -> unsigned {
498 for (const unsigned AccessSize : {16, 8, 4, 2}) {
499 const unsigned NumElts = canMergeParamLoadStoresStartingAt(
500 I, AccessSize, ValueVTs, Offsets, ParamAlignment);
501 assert((NumElts == 1 || NumElts == 2 || NumElts == 4) &&
502 "Unexpected vectorization size");
503 if (NumElts != 1)
504 return NumElts;
505 }
506 return 1;
507 };
508
509 // Check what we can vectorize using 128/64/32-bit accesses.
510 for (unsigned I = 0, E = ValueVTs.size(); I != E;) {
511 const unsigned NumElts = GetNumElts(I);
512 VectorInfo.push_back(Elt: NumElts);
513 I += NumElts;
514 }
515 assert(std::accumulate(VectorInfo.begin(), VectorInfo.end(), 0u) ==
516 ValueVTs.size());
517 return VectorInfo;
518}
519
520// NVPTXTargetLowering Constructor.
521NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
522 const NVPTXSubtarget &STI)
523 : TargetLowering(TM, STI), nvTM(&TM), STI(STI), GlobalUniqueCallSite(0) {
524 // always lower memset, memcpy, and memmove intrinsics to load/store
525 // instructions, rather
526 // then generating calls to memset, mempcy or memmove.
527 MaxStoresPerMemset = MaxStoresPerMemsetOptSize = (unsigned)0xFFFFFFFF;
528 MaxStoresPerMemcpy = MaxStoresPerMemcpyOptSize = (unsigned) 0xFFFFFFFF;
529 MaxStoresPerMemmove = MaxStoresPerMemmoveOptSize = (unsigned) 0xFFFFFFFF;
530
531 setBooleanContents(ZeroOrNegativeOneBooleanContent);
532 setBooleanVectorContents(ZeroOrNegativeOneBooleanContent);
533
534 // Jump is Expensive. Don't create extra control flow for 'and', 'or'
535 // condition branches.
536 setJumpIsExpensive(true);
537
538 // Wide divides are _very_ slow. Try to reduce the width of the divide if
539 // possible.
540 addBypassSlowDiv(SlowBitWidth: 64, FastBitWidth: 32);
541
542 // By default, use the Source scheduling
543 if (sched4reg)
544 setSchedulingPreference(Sched::RegPressure);
545 else
546 setSchedulingPreference(Sched::Source);
547
548 auto setFP16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
549 LegalizeAction NoF16Action) {
550 bool IsOpSupported = STI.allowFP16Math();
551 switch (Op) {
552 // Several FP16 instructions are available on sm_80 only.
553 case ISD::FMINNUM:
554 case ISD::FMAXNUM:
555 case ISD::FMAXNUM_IEEE:
556 case ISD::FMINNUM_IEEE:
557 case ISD::FMAXIMUM:
558 case ISD::FMINIMUM:
559 case ISD::FMAXIMUMNUM:
560 case ISD::FMINIMUMNUM:
561 IsOpSupported &= STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
562 break;
563 case ISD::FEXP2:
564 IsOpSupported &= STI.getSmVersion() >= 75 && STI.getPTXVersion() >= 70;
565 break;
566 }
567 setOperationAction(Op, VT, Action: IsOpSupported ? Action : NoF16Action);
568 };
569
570 auto setBF16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
571 LegalizeAction NoBF16Action) {
572 bool IsOpSupported = STI.hasNativeBF16Support(Opcode: Op);
573 setOperationAction(
574 Op, VT, Action: IsOpSupported ? Action : NoBF16Action);
575 };
576
577 auto setI16x2OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
578 LegalizeAction NoI16x2Action) {
579 bool IsOpSupported = false;
580 // instructions are available on sm_90 only
581 switch (Op) {
582 case ISD::ADD:
583 case ISD::SMAX:
584 case ISD::SMIN:
585 case ISD::UMIN:
586 case ISD::UMAX:
587 IsOpSupported = STI.getSmVersion() >= 90 && STI.getPTXVersion() >= 80;
588 break;
589 }
590 setOperationAction(Op, VT, Action: IsOpSupported ? Action : NoI16x2Action);
591 };
592
593 addRegisterClass(VT: MVT::i1, RC: &NVPTX::B1RegClass);
594 addRegisterClass(VT: MVT::i16, RC: &NVPTX::B16RegClass);
595 addRegisterClass(VT: MVT::v2i16, RC: &NVPTX::B32RegClass);
596 addRegisterClass(VT: MVT::v4i8, RC: &NVPTX::B32RegClass);
597 addRegisterClass(VT: MVT::i32, RC: &NVPTX::B32RegClass);
598 addRegisterClass(VT: MVT::i64, RC: &NVPTX::B64RegClass);
599 addRegisterClass(VT: MVT::f32, RC: &NVPTX::B32RegClass);
600 addRegisterClass(VT: MVT::f64, RC: &NVPTX::B64RegClass);
601 addRegisterClass(VT: MVT::f16, RC: &NVPTX::B16RegClass);
602 addRegisterClass(VT: MVT::v2f16, RC: &NVPTX::B32RegClass);
603 addRegisterClass(VT: MVT::bf16, RC: &NVPTX::B16RegClass);
604 addRegisterClass(VT: MVT::v2bf16, RC: &NVPTX::B32RegClass);
605
606 if (STI.hasF32x2Instructions()) {
607 addRegisterClass(VT: MVT::v2f32, RC: &NVPTX::B64RegClass);
608 addRegisterClass(VT: MVT::v2i32, RC: &NVPTX::B64RegClass);
609 }
610
611 // Conversion to/from FP16/FP16x2 is always legal.
612 setOperationAction(Op: ISD::BUILD_VECTOR, VT: MVT::v2f16, Action: Custom);
613 setOperationAction(Op: ISD::EXTRACT_VECTOR_ELT, VT: MVT::v2f16, Action: Custom);
614 setOperationAction(Op: ISD::INSERT_VECTOR_ELT, VT: MVT::v2f16, Action: Expand);
615 setOperationAction(Op: ISD::VECTOR_SHUFFLE, VT: MVT::v2f16, Action: Expand);
616
617 setOperationAction(Op: ISD::READCYCLECOUNTER, VT: MVT::i64, Action: Legal);
618 if (STI.getSmVersion() >= 30 && STI.getPTXVersion() > 31)
619 setOperationAction(Op: ISD::READSTEADYCOUNTER, VT: MVT::i64, Action: Legal);
620
621 setFP16OperationAction(ISD::SETCC, MVT::f16, Legal, Promote);
622 setFP16OperationAction(ISD::SETCC, MVT::v2f16, Legal, Expand);
623
624 // Conversion to/from BFP16/BFP16x2 is always legal.
625 setOperationAction(Op: ISD::BUILD_VECTOR, VT: MVT::v2bf16, Action: Custom);
626 setOperationAction(Op: ISD::EXTRACT_VECTOR_ELT, VT: MVT::v2bf16, Action: Custom);
627 setOperationAction(Op: ISD::INSERT_VECTOR_ELT, VT: MVT::v2bf16, Action: Expand);
628 setOperationAction(Op: ISD::VECTOR_SHUFFLE, VT: MVT::v2bf16, Action: Expand);
629
630 setBF16OperationAction(ISD::SETCC, MVT::v2bf16, Legal, Expand);
631 setBF16OperationAction(ISD::SETCC, MVT::bf16, Legal, Promote);
632 if (getOperationAction(Op: ISD::SETCC, VT: MVT::bf16) == Promote)
633 AddPromotedToType(Opc: ISD::SETCC, OrigVT: MVT::bf16, DestVT: MVT::f32);
634
635 // Conversion to/from i16/i16x2 is always legal.
636 setOperationAction(Op: ISD::BUILD_VECTOR, VT: MVT::v2i16, Action: Custom);
637 setOperationAction(Op: ISD::EXTRACT_VECTOR_ELT, VT: MVT::v2i16, Action: Custom);
638 setOperationAction(Op: ISD::INSERT_VECTOR_ELT, VT: MVT::v2i16, Action: Expand);
639 setOperationAction(Op: ISD::VECTOR_SHUFFLE, VT: MVT::v2i16, Action: Expand);
640
641 setOperationAction(Op: ISD::BUILD_VECTOR, VT: MVT::v4i8, Action: Custom);
642 setOperationAction(Op: ISD::EXTRACT_VECTOR_ELT, VT: MVT::v4i8, Action: Custom);
643 setOperationAction(Op: ISD::INSERT_VECTOR_ELT, VT: MVT::v4i8, Action: Custom);
644 setOperationAction(Op: ISD::VECTOR_SHUFFLE, VT: MVT::v4i8, Action: Custom);
645
646 // No support for these operations with v2f32/v2i32
647 setOperationAction(Ops: ISD::INSERT_VECTOR_ELT, VTs: {MVT::v2f32, MVT::v2i32}, Action: Expand);
648 setOperationAction(Ops: ISD::VECTOR_SHUFFLE, VTs: {MVT::v2f32, MVT::v2i32}, Action: Expand);
649
650 setOperationAction(Op: ISD::TRUNCATE, VT: MVT::v2i16, Action: Expand);
651 setOperationAction(Ops: {ISD::ANY_EXTEND, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND},
652 VT: MVT::v2i32, Action: Expand);
653
654 // Need custom lowering in case the index is dynamic.
655 if (STI.hasF32x2Instructions())
656 setOperationAction(Ops: ISD::EXTRACT_VECTOR_ELT, VTs: {MVT::v2f32, MVT::v2i32},
657 Action: Custom);
658
659 // Custom conversions to/from v2i8.
660 setOperationAction(Op: ISD::BITCAST, VT: MVT::v2i8, Action: Custom);
661
662 // Only logical ops can be done on v4i8/v2i32 directly, others must be done
663 // elementwise.
664 setOperationAction(
665 Ops: {ISD::ABS, ISD::ADD, ISD::ADDC, ISD::ADDE,
666 ISD::BITREVERSE, ISD::CTLZ, ISD::CTPOP, ISD::CTTZ,
667 ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::FSHL, ISD::FSHR,
668 ISD::MUL, ISD::MULHS, ISD::MULHU, ISD::PARITY,
669 ISD::ROTL, ISD::ROTR, ISD::SADDO, ISD::SADDO_CARRY,
670 ISD::SADDSAT, ISD::SDIV, ISD::SDIVREM, ISD::SELECT_CC,
671 ISD::SETCC, ISD::SHL, ISD::SINT_TO_FP, ISD::SMAX,
672 ISD::SMIN, ISD::SMULO, ISD::SMUL_LOHI, ISD::SRA,
673 ISD::SREM, ISD::SRL, ISD::SSHLSAT, ISD::SSUBO,
674 ISD::SSUBO_CARRY, ISD::SSUBSAT, ISD::SUB, ISD::SUBC,
675 ISD::SUBE, ISD::UADDO, ISD::UADDO_CARRY, ISD::UADDSAT,
676 ISD::UDIV, ISD::UDIVREM, ISD::UINT_TO_FP, ISD::UMAX,
677 ISD::UMIN, ISD::UMULO, ISD::UMUL_LOHI, ISD::UREM,
678 ISD::USHLSAT, ISD::USUBO, ISD::USUBO_CARRY, ISD::VSELECT,
679 ISD::USUBSAT},
680 VTs: {MVT::v4i8, MVT::v2i32}, Action: Expand);
681
682 // Operations not directly supported by NVPTX.
683 for (MVT VT : {MVT::bf16, MVT::f16, MVT::v2bf16, MVT::v2f16, MVT::f32,
684 MVT::v2f32, MVT::f64, MVT::i1, MVT::i8, MVT::i16, MVT::v2i16,
685 MVT::v4i8, MVT::i32, MVT::v2i32, MVT::i64}) {
686 setOperationAction(Op: ISD::SELECT_CC, VT, Action: Expand);
687 setOperationAction(Op: ISD::BR_CC, VT, Action: Expand);
688 }
689
690 // We don't want ops like FMINIMUM or UMAX to be lowered to SETCC+VSELECT.
691 setOperationAction(Ops: ISD::VSELECT, VTs: {MVT::v2f32, MVT::v2i32}, Action: Expand);
692
693 // Some SIGN_EXTEND_INREG can be done using cvt instruction.
694 // For others we will expand to a SHL/SRA pair.
695 setOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: MVT::i64, Action: Legal);
696 setOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: MVT::i32, Action: Legal);
697 setOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: MVT::i16, Action: Legal);
698 setOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: MVT::i8 , Action: Legal);
699 setOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: MVT::i1, Action: Expand);
700 setOperationAction(Ops: ISD::SIGN_EXTEND_INREG, VTs: {MVT::v2i16, MVT::v2i32}, Action: Expand);
701
702 setOperationAction(Op: ISD::SHL_PARTS, VT: MVT::i32 , Action: Custom);
703 setOperationAction(Op: ISD::SRA_PARTS, VT: MVT::i32 , Action: Custom);
704 setOperationAction(Op: ISD::SRL_PARTS, VT: MVT::i32 , Action: Custom);
705 setOperationAction(Op: ISD::SHL_PARTS, VT: MVT::i64 , Action: Custom);
706 setOperationAction(Op: ISD::SRA_PARTS, VT: MVT::i64 , Action: Custom);
707 setOperationAction(Op: ISD::SRL_PARTS, VT: MVT::i64 , Action: Custom);
708
709 setOperationAction(Op: ISD::BITREVERSE, VT: MVT::i32, Action: Legal);
710 setOperationAction(Op: ISD::BITREVERSE, VT: MVT::i64, Action: Legal);
711
712 setOperationAction(Ops: {ISD::ROTL, ISD::ROTR},
713 VTs: {MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64},
714 Action: Expand);
715
716 if (STI.hasHWROT32()) {
717 setOperationAction(Ops: {ISD::FSHL, ISD::FSHR}, VT: MVT::i32, Action: Legal);
718 setOperationAction(Ops: {ISD::ROTL, ISD::ROTR, ISD::FSHL, ISD::FSHR}, VT: MVT::i64,
719 Action: Custom);
720 }
721
722 setOperationAction(Op: ISD::BR_JT, VT: MVT::Other, Action: STI.hasBrx() ? Legal : Expand);
723 setOperationAction(Op: ISD::BRIND, VT: MVT::Other, Action: Expand);
724
725 // We want to legalize constant related memmove and memcopy
726 // intrinsics.
727 setOperationAction(Op: ISD::INTRINSIC_W_CHAIN, VT: MVT::Other, Action: Custom);
728
729 // FP extload/truncstore is not legal in PTX. We need to expand all these.
730 for (auto FloatVTs :
731 {MVT::fp_valuetypes(), MVT::fp_fixedlen_vector_valuetypes()}) {
732 for (MVT ValVT : FloatVTs) {
733 for (MVT MemVT : FloatVTs) {
734 setLoadExtAction(ExtType: ISD::EXTLOAD, ValVT, MemVT, Action: Expand);
735 setTruncStoreAction(ValVT, MemVT, Action: Expand);
736 }
737 }
738 }
739
740 // To improve CodeGen we'll legalize any-extend loads to zext loads. This is
741 // how they'll be lowered in ISel anyway, and by doing this a little earlier
742 // we allow for more DAG combine opportunities.
743 for (auto IntVTs :
744 {MVT::integer_valuetypes(), MVT::integer_fixedlen_vector_valuetypes()})
745 for (MVT ValVT : IntVTs)
746 for (MVT MemVT : IntVTs)
747 if (isTypeLegal(VT: ValVT))
748 setLoadExtAction(ExtType: ISD::EXTLOAD, ValVT, MemVT, Action: Custom);
749
750 // PTX does not support load / store predicate registers
751 setOperationAction(Ops: {ISD::LOAD, ISD::STORE}, VT: MVT::i1, Action: Custom);
752 for (MVT VT : MVT::integer_valuetypes()) {
753 setLoadExtAction(ExtTypes: {ISD::SEXTLOAD, ISD::ZEXTLOAD, ISD::EXTLOAD}, ValVT: VT, MemVT: MVT::i1,
754 Action: Promote);
755 setTruncStoreAction(ValVT: VT, MemVT: MVT::i1, Action: Expand);
756 }
757
758 // Disable generations of extload/truncstore for v2i32/v2i16/v2i8. The generic
759 // expansion for these nodes when they are unaligned is incorrect if the
760 // type is a vector.
761 //
762 // TODO: Fix the generic expansion for these nodes found in
763 // TargetLowering::expandUnalignedLoad/Store.
764 setLoadExtAction(ExtTypes: {ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, ValVT: MVT::v2i16,
765 MemVT: MVT::v2i8, Action: Expand);
766 setLoadExtAction(ExtTypes: {ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, ValVT: MVT::v2i32,
767 MemVTs: {MVT::v2i8, MVT::v2i16}, Action: Expand);
768 setTruncStoreAction(ValVT: MVT::v2i16, MemVT: MVT::v2i8, Action: Expand);
769 setTruncStoreAction(ValVT: MVT::v2i32, MemVT: MVT::v2i16, Action: Expand);
770 setTruncStoreAction(ValVT: MVT::v2i32, MemVT: MVT::v2i8, Action: Expand);
771
772 // Register custom handling for illegal type loads/stores. We'll try to custom
773 // lower almost all illegal types and logic in the lowering will discard cases
774 // we can't handle.
775 setOperationAction(Ops: {ISD::LOAD, ISD::STORE}, VTs: {MVT::i128, MVT::i256, MVT::f128},
776 Action: Custom);
777 for (MVT VT : MVT::fixedlen_vector_valuetypes())
778 if (!isTypeLegal(VT) && VT.getStoreSizeInBits() <= 256)
779 setOperationAction(Ops: {ISD::STORE, ISD::LOAD, ISD::MSTORE, ISD::MLOAD}, VT,
780 Action: Custom);
781
782 // Custom legalization for LDU intrinsics.
783 // TODO: The logic to lower these is not very robust and we should rewrite it.
784 // Perhaps LDU should not be represented as an intrinsic at all.
785 setOperationAction(Op: ISD::INTRINSIC_W_CHAIN, VT: MVT::i8, Action: Custom);
786 for (MVT VT : MVT::fixedlen_vector_valuetypes())
787 if (IsPTXVectorType(VT))
788 setOperationAction(Op: ISD::INTRINSIC_W_CHAIN, VT, Action: Custom);
789
790 setCondCodeAction(CCs: {ISD::SETNE, ISD::SETEQ, ISD::SETUGE, ISD::SETULE,
791 ISD::SETUGT, ISD::SETULT, ISD::SETGT, ISD::SETLT,
792 ISD::SETGE, ISD::SETLE},
793 VT: MVT::i1, Action: Expand);
794
795 // This is legal in NVPTX
796 setOperationAction(Op: ISD::ConstantFP, VT: MVT::f64, Action: Legal);
797 setOperationAction(Op: ISD::ConstantFP, VT: MVT::f32, Action: Legal);
798 setOperationAction(Op: ISD::ConstantFP, VT: MVT::f16, Action: Legal);
799 setOperationAction(Op: ISD::ConstantFP, VT: MVT::bf16, Action: Legal);
800
801 setOperationAction(Ops: ISD::DYNAMIC_STACKALLOC, VTs: {MVT::i32, MVT::i64}, Action: Custom);
802 setOperationAction(Ops: {ISD::STACKRESTORE, ISD::STACKSAVE}, VT: MVT::Other, Action: Custom);
803
804 // TRAP can be lowered to PTX trap
805 setOperationAction(Op: ISD::TRAP, VT: MVT::Other, Action: Legal);
806 // DEBUGTRAP can be lowered to PTX brkpt
807 setOperationAction(Op: ISD::DEBUGTRAP, VT: MVT::Other, Action: Legal);
808
809 // Support varargs.
810 setOperationAction(Op: ISD::VASTART, VT: MVT::Other, Action: Custom);
811 setOperationAction(Op: ISD::VAARG, VT: MVT::Other, Action: Custom);
812 setOperationAction(Op: ISD::VACOPY, VT: MVT::Other, Action: Expand);
813 setOperationAction(Op: ISD::VAEND, VT: MVT::Other, Action: Expand);
814
815 setOperationAction(Ops: {ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX},
816 VTs: {MVT::i16, MVT::i32, MVT::i64}, Action: Legal);
817 // PTX abs.s is undefined for INT_MIN, so ISD::ABS (which requires
818 // abs(INT_MIN) == INT_MIN) must be expanded. ABS_MIN_POISON matches
819 // PTX abs semantics since INT_MIN input is poison/undefined.
820 setOperationAction(Ops: ISD::ABS, VTs: {MVT::i16, MVT::i32, MVT::i64}, Action: Expand);
821 setOperationAction(Ops: ISD::ABS_MIN_POISON, VTs: {MVT::i16, MVT::i32, MVT::i64},
822 Action: Legal);
823
824 setOperationAction(Ops: {ISD::CTPOP, ISD::CTLZ, ISD::CTLZ_ZERO_POISON}, VT: MVT::i16,
825 Action: Promote);
826 setOperationAction(Ops: {ISD::CTPOP, ISD::CTLZ}, VT: MVT::i32, Action: Legal);
827 setOperationAction(Ops: {ISD::CTPOP, ISD::CTLZ}, VT: MVT::i64, Action: Custom);
828
829 setI16x2OperationAction(ISD::ABS_MIN_POISON, MVT::v2i16, Legal, Custom);
830 setI16x2OperationAction(ISD::SMIN, MVT::v2i16, Legal, Custom);
831 setI16x2OperationAction(ISD::SMAX, MVT::v2i16, Legal, Custom);
832 setI16x2OperationAction(ISD::UMIN, MVT::v2i16, Legal, Custom);
833 setI16x2OperationAction(ISD::UMAX, MVT::v2i16, Legal, Custom);
834 setI16x2OperationAction(ISD::CTPOP, MVT::v2i16, Legal, Expand);
835 setI16x2OperationAction(ISD::CTLZ, MVT::v2i16, Legal, Expand);
836
837 setI16x2OperationAction(ISD::ADD, MVT::v2i16, Legal, Custom);
838 setI16x2OperationAction(ISD::SUB, MVT::v2i16, Legal, Custom);
839 setI16x2OperationAction(ISD::MUL, MVT::v2i16, Legal, Custom);
840 setI16x2OperationAction(ISD::SHL, MVT::v2i16, Legal, Custom);
841 setI16x2OperationAction(ISD::SREM, MVT::v2i16, Legal, Custom);
842 setI16x2OperationAction(ISD::UREM, MVT::v2i16, Legal, Custom);
843
844 // Other arithmetic and logic ops are unsupported.
845 setOperationAction(Ops: {ISD::SDIV, ISD::UDIV, ISD::SRA, ISD::SRL, ISD::MULHS,
846 ISD::MULHU, ISD::FP_TO_SINT, ISD::FP_TO_UINT,
847 ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::SETCC},
848 VTs: {MVT::v2i16, MVT::v2i32}, Action: Expand);
849
850 // v2i32 is not supported for any arithmetic operations
851 setOperationAction(Ops: {ISD::ABS, ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX,
852 ISD::CTPOP, ISD::CTLZ, ISD::ADD, ISD::SUB, ISD::MUL,
853 ISD::SHL, ISD::SRA, ISD::SRL, ISD::OR, ISD::AND, ISD::XOR,
854 ISD::SREM, ISD::UREM},
855 VT: MVT::v2i32, Action: Expand);
856
857 setOperationAction(Op: ISD::ADDC, VT: MVT::i32, Action: Legal);
858 setOperationAction(Op: ISD::ADDE, VT: MVT::i32, Action: Legal);
859 setOperationAction(Op: ISD::SUBC, VT: MVT::i32, Action: Legal);
860 setOperationAction(Op: ISD::SUBE, VT: MVT::i32, Action: Legal);
861 if (STI.getPTXVersion() >= 43) {
862 setOperationAction(Op: ISD::ADDC, VT: MVT::i64, Action: Legal);
863 setOperationAction(Op: ISD::ADDE, VT: MVT::i64, Action: Legal);
864 setOperationAction(Op: ISD::SUBC, VT: MVT::i64, Action: Legal);
865 setOperationAction(Op: ISD::SUBE, VT: MVT::i64, Action: Legal);
866 }
867
868 setOperationAction(Op: ISD::CTTZ, VT: MVT::i16, Action: Expand);
869 setOperationAction(Ops: ISD::CTTZ, VTs: {MVT::v2i16, MVT::v2i32}, Action: Expand);
870 setOperationAction(Op: ISD::CTTZ, VT: MVT::i32, Action: Expand);
871 setOperationAction(Op: ISD::CTTZ, VT: MVT::i64, Action: Expand);
872
873 // PTX does not directly support SELP of i1, so promote to i32 first
874 setOperationAction(Op: ISD::SELECT, VT: MVT::i1, Action: Custom);
875
876 // PTX cannot multiply two i64s in a single instruction.
877 setOperationAction(Op: ISD::SMUL_LOHI, VT: MVT::i64, Action: Expand);
878 setOperationAction(Op: ISD::UMUL_LOHI, VT: MVT::i64, Action: Expand);
879
880 // We have some custom DAG combine patterns for these nodes
881 setTargetDAGCombine({ISD::ADD,
882 ISD::AND,
883 ISD::EXTRACT_VECTOR_ELT,
884 ISD::FADD,
885 ISD::FMAXNUM,
886 ISD::FMINNUM,
887 ISD::FMAXIMUM,
888 ISD::FMINIMUM,
889 ISD::FMAXIMUMNUM,
890 ISD::FMINIMUMNUM,
891 ISD::MUL,
892 ISD::SELECT,
893 ISD::SHL,
894 ISD::SREM,
895 ISD::UREM,
896 ISD::VSELECT,
897 ISD::BUILD_VECTOR,
898 ISD::ADDRSPACECAST,
899 ISD::LOAD,
900 ISD::STORE,
901 ISD::ZERO_EXTEND,
902 ISD::SIGN_EXTEND,
903 ISD::INTRINSIC_WO_CHAIN});
904
905 // If the vector operands require register coalescing, scalarize instead
906 if (STI.hasF32x2Instructions())
907 setTargetDAGCombine({ISD::FMA, ISD::FMUL, ISD::FSUB});
908
909 // setcc for f16x2 and bf16x2 needs special handling to prevent
910 // legalizer's attempt to scalarize it due to v2i1 not being legal.
911 if (STI.allowFP16Math() || STI.hasBF16Math())
912 setTargetDAGCombine(ISD::SETCC);
913
914 // Vector reduction operations. These may be turned into shuffle or tree
915 // reductions depending on what instructions are available for each type.
916 for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
917 MVT EltVT = VT.getVectorElementType();
918 if (EltVT == MVT::f32 || EltVT == MVT::f64) {
919 setOperationAction(Ops: {ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
920 ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
921 VT, Action: Custom);
922 }
923 }
924
925 // Promote fp16 arithmetic if fp16 hardware isn't available or the
926 // user passed --nvptx-no-fp16-math. The flag is useful because,
927 // although sm_53+ GPUs have some sort of FP16 support in
928 // hardware, only sm_53 and sm_60 have full implementation. Others
929 // only have token amount of hardware and are likely to run faster
930 // by using fp32 units instead.
931 for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) {
932 setFP16OperationAction(Op, MVT::f16, Legal, Promote);
933 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
934 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
935 // bf16 must be promoted to f32.
936 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
937 if (getOperationAction(Op, VT: MVT::bf16) == Promote)
938 AddPromotedToType(Opc: Op, OrigVT: MVT::bf16, DestVT: MVT::f32);
939 setOperationAction(Op, VT: MVT::v2f32,
940 Action: STI.hasF32x2Instructions() ? Legal : Expand);
941 }
942
943 // On SM80, we select add/mul/sub as fma to avoid promotion to float
944 for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB}) {
945 for (const auto &VT : {MVT::bf16, MVT::v2bf16}) {
946 if (!STI.hasNativeBF16Support(Opcode: Op) && STI.hasNativeBF16Support(Opcode: ISD::FMA)) {
947 setOperationAction(Op, VT, Action: Custom);
948 }
949 }
950 }
951
952 // f16/f16x2 neg was introduced in PTX 60, SM_53.
953 const bool IsFP16FP16x2NegAvailable = STI.getSmVersion() >= 53 &&
954 STI.getPTXVersion() >= 60 &&
955 STI.allowFP16Math();
956 for (const auto &VT : {MVT::f16, MVT::v2f16})
957 setOperationAction(Op: ISD::FNEG, VT,
958 Action: IsFP16FP16x2NegAvailable ? Legal : Expand);
959
960 setBF16OperationAction(ISD::FNEG, MVT::bf16, Legal, Expand);
961 setBF16OperationAction(ISD::FNEG, MVT::v2bf16, Legal, Expand);
962 setOperationAction(Op: ISD::FNEG, VT: MVT::v2f32, Action: Expand);
963 // (would be) Library functions.
964
965 // These map to conversion instructions for scalar FP types.
966 for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
967 ISD::FROUNDEVEN, ISD::FTRUNC}) {
968 setOperationAction(Op, VT: MVT::f16, Action: Legal);
969 setOperationAction(Op, VT: MVT::f32, Action: Legal);
970 setOperationAction(Op, VT: MVT::f64, Action: Legal);
971 setOperationAction(Op, VT: MVT::v2f16, Action: Expand);
972 setOperationAction(Op, VT: MVT::v2bf16, Action: Expand);
973 setOperationAction(Op, VT: MVT::v2f32, Action: Expand);
974 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
975 if (getOperationAction(Op, VT: MVT::bf16) == Promote)
976 AddPromotedToType(Opc: Op, OrigVT: MVT::bf16, DestVT: MVT::f32);
977 }
978
979 if (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71) {
980 setOperationAction(Op: ISD::BF16_TO_FP, VT: MVT::f32, Action: Expand);
981 }
982 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
983 for (MVT VT : {MVT::bf16, MVT::f32, MVT::f64}) {
984 setOperationAction(Op: ISD::FP_EXTEND, VT, Action: Custom);
985 setOperationAction(Op: ISD::FP_ROUND, VT, Action: Custom);
986 }
987 }
988
989 // Expand v2f32 = fp_extend
990 setOperationAction(Op: ISD::FP_EXTEND, VT: MVT::v2f32, Action: Expand);
991 // Expand v2[b]f16 = fp_round v2f32
992 setOperationAction(Ops: ISD::FP_ROUND, VTs: {MVT::v2bf16, MVT::v2f16}, Action: Expand);
993
994 // sm_80 only has conversions between f32 and bf16. Custom lower all other
995 // bf16 conversions.
996 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
997 for (MVT VT : {MVT::i1, MVT::i16, MVT::i32, MVT::i64}) {
998 setOperationAction(
999 Ops: {ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT},
1000 VT, Action: Custom);
1001 }
1002 setOperationAction(
1003 Ops: {ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT},
1004 VT: MVT::bf16, Action: Custom);
1005 }
1006
1007 setOperationAction(Ops: {ISD::FP_TO_SINT, ISD::FP_TO_UINT}, VT: MVT::i1, Action: Custom);
1008 setOperationAction(Op: ISD::FROUND, VT: MVT::f16, Action: Promote);
1009 setOperationAction(Op: ISD::FROUND, VT: MVT::v2f16, Action: Expand);
1010 setOperationAction(Op: ISD::FROUND, VT: MVT::v2bf16, Action: Expand);
1011 setOperationAction(Op: ISD::FROUND, VT: MVT::f32, Action: Custom);
1012 setOperationAction(Op: ISD::FROUND, VT: MVT::f64, Action: Custom);
1013 setOperationAction(Op: ISD::FROUND, VT: MVT::bf16, Action: Promote);
1014 AddPromotedToType(Opc: ISD::FROUND, OrigVT: MVT::bf16, DestVT: MVT::f32);
1015
1016 setOperationAction(Ops: {ISD::LROUND, ISD::LLROUND}, VTs: {MVT::f32, MVT::f64}, Action: Expand);
1017
1018 // 'Expand' implements FCOPYSIGN without calling an external library.
1019 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::f16, Action: Expand);
1020 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::v2f16, Action: Expand);
1021 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::bf16, Action: Expand);
1022 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::v2bf16, Action: Expand);
1023 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::f32, Action: Custom);
1024 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::f64, Action: Custom);
1025
1026 // These map to corresponding instructions for f32/f64. f16 must be
1027 // promoted to f32. v2f16 is expanded to f16, which is then promoted
1028 // to f32.
1029 for (const auto &Op :
1030 {ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS, ISD::FTANH}) {
1031 setOperationAction(Op, VT: MVT::f16, Action: Promote);
1032 setOperationAction(Op, VT: MVT::f32, Action: Legal);
1033 // only div/rem/sqrt are legal for f64
1034 if (Op == ISD::FDIV || Op == ISD::FREM || Op == ISD::FSQRT) {
1035 setOperationAction(Op, VT: MVT::f64, Action: Legal);
1036 }
1037 setOperationAction(Ops: Op, VTs: {MVT::v2f16, MVT::v2bf16, MVT::v2f32}, Action: Expand);
1038 setOperationAction(Op, VT: MVT::bf16, Action: Promote);
1039 AddPromotedToType(Opc: Op, OrigVT: MVT::bf16, DestVT: MVT::f32);
1040 }
1041 setOperationAction(Ops: ISD::FREM, VTs: {MVT::f32, MVT::f64}, Action: Custom);
1042
1043 setOperationAction(Ops: ISD::FABS, VTs: {MVT::f32, MVT::f64}, Action: Legal);
1044 setOperationAction(Op: ISD::FABS, VT: MVT::v2f32, Action: Expand);
1045 if (STI.getPTXVersion() >= 65) {
1046 setFP16OperationAction(ISD::FABS, MVT::f16, Legal, Promote);
1047 setFP16OperationAction(ISD::FABS, MVT::v2f16, Legal, Expand);
1048 } else {
1049 setOperationAction(Op: ISD::FABS, VT: MVT::f16, Action: Promote);
1050 setOperationAction(Op: ISD::FABS, VT: MVT::v2f16, Action: Expand);
1051 }
1052 setBF16OperationAction(ISD::FABS, MVT::v2bf16, Legal, Expand);
1053 setBF16OperationAction(ISD::FABS, MVT::bf16, Legal, Promote);
1054 if (getOperationAction(Op: ISD::FABS, VT: MVT::bf16) == Promote)
1055 AddPromotedToType(Opc: ISD::FABS, OrigVT: MVT::bf16, DestVT: MVT::f32);
1056
1057 for (const auto &Op :
1058 {ISD::FMINNUM, ISD::FMAXNUM, ISD::FMINIMUMNUM, ISD::FMAXIMUMNUM}) {
1059 setOperationAction(Op, VT: MVT::f32, Action: Legal);
1060 setOperationAction(Op, VT: MVT::f64, Action: Legal);
1061 setFP16OperationAction(Op, MVT::f16, Legal, Promote);
1062 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
1063 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
1064 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
1065 if (getOperationAction(Op, VT: MVT::bf16) == Promote)
1066 AddPromotedToType(Opc: Op, OrigVT: MVT::bf16, DestVT: MVT::f32);
1067 setOperationAction(Op, VT: MVT::v2f32, Action: Expand);
1068 }
1069 bool SupportsF32MinMaxNaN =
1070 STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
1071 for (const auto &Op : {ISD::FMINIMUM, ISD::FMAXIMUM}) {
1072 setOperationAction(Op, VT: MVT::f32, Action: SupportsF32MinMaxNaN ? Legal : Expand);
1073 setFP16OperationAction(Op, MVT::f16, Legal, Expand);
1074 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
1075 setBF16OperationAction(Op, MVT::bf16, Legal, Expand);
1076 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
1077 setOperationAction(Op, VT: MVT::v2f32, Action: Expand);
1078 }
1079
1080 // Custom lowering for inline asm with 128-bit operands
1081 setOperationAction(Op: ISD::CopyToReg, VT: MVT::i128, Action: Custom);
1082 setOperationAction(Op: ISD::CopyFromReg, VT: MVT::i128, Action: Custom);
1083
1084 // FEXP2 support:
1085 // - f32
1086 // - f16/f16x2 (sm_70+, PTX 7.0+)
1087 // - bf16/bf16x2 (sm_90+, PTX 7.8+)
1088 // When f16/bf16 types aren't supported, they are promoted/expanded to f32.
1089 setOperationAction(Op: ISD::FEXP2, VT: MVT::f32, Action: Legal);
1090 setOperationAction(Op: ISD::FEXP2, VT: MVT::v2f32, Action: Expand);
1091 setFP16OperationAction(ISD::FEXP2, MVT::f16, Legal, Promote);
1092 setFP16OperationAction(ISD::FEXP2, MVT::v2f16, Legal, Expand);
1093 setBF16OperationAction(ISD::FEXP2, MVT::bf16, Legal, Promote);
1094 setBF16OperationAction(ISD::FEXP2, MVT::v2bf16, Legal, Expand);
1095
1096 // FLOG2 supports f32 only
1097 // f16/bf16 types aren't supported, but they are promoted/expanded to f32.
1098 if (UseApproxLog2F32) {
1099 setOperationAction(Op: ISD::FLOG2, VT: MVT::f32, Action: Legal);
1100 setOperationPromotedToType(Opc: ISD::FLOG2, OrigVT: MVT::f16, DestVT: MVT::f32);
1101 setOperationPromotedToType(Opc: ISD::FLOG2, OrigVT: MVT::bf16, DestVT: MVT::f32);
1102 setOperationAction(Ops: ISD::FLOG2, VTs: {MVT::v2f16, MVT::v2bf16, MVT::v2f32},
1103 Action: Expand);
1104 }
1105
1106 setOperationAction(Ops: ISD::ADDRSPACECAST, VTs: {MVT::i32, MVT::i64}, Action: Custom);
1107
1108 setOperationAction(Ops: ISD::ATOMIC_LOAD_SUB, VTs: {MVT::i32, MVT::i64}, Action: Expand);
1109
1110 // atom.b128 is legal in PTX but since we don't represent i128 as a legal
1111 // type, we need to custom lower it.
1112 setOperationAction(Ops: {ISD::ATOMIC_CMP_SWAP, ISD::ATOMIC_SWAP}, VT: MVT::i128,
1113 Action: Custom);
1114
1115 // Now deduce the information based on the above mentioned
1116 // actions
1117 computeRegisterProperties(TRI: STI.getRegisterInfo());
1118
1119 // PTX support for 16-bit CAS is emulated. Only use 32+
1120 setMinCmpXchgSizeInBits(STI.getMinCmpXchgSizeInBits());
1121 setMaxAtomicSizeInBitsSupported(STI.hasAtomSwap128() ? 128 : 64);
1122 setMaxDivRemBitWidthSupported(64);
1123
1124 // Custom lowering for tcgen05.ld vector operands
1125 setOperationAction(Ops: ISD::INTRINSIC_W_CHAIN,
1126 VTs: {MVT::v2i32, MVT::v4i32, MVT::v8i32, MVT::v16i32,
1127 MVT::v32i32, MVT::v64i32, MVT::v128i32, MVT::v2f32,
1128 MVT::v4f32, MVT::v8f32, MVT::v16f32, MVT::v32f32,
1129 MVT::v64f32, MVT::v128f32},
1130 Action: Custom);
1131
1132 // Custom lowering for tcgen05.st vector operands
1133 setOperationAction(Ops: ISD::INTRINSIC_VOID,
1134 VTs: {MVT::v2i32, MVT::v4i32, MVT::v8i32, MVT::v16i32,
1135 MVT::v32i32, MVT::v64i32, MVT::v128i32, MVT::Other},
1136 Action: Custom);
1137
1138 // Enable custom lowering for the following:
1139 // * MVT::i128 - clusterlaunchcontrol
1140 // * MVT::i32 - prmt
1141 // * MVT::v4f32 - cvt_rs fp{4/6/8}x4 intrinsics
1142 // * MVT::Other - internal.addrspace.wrap
1143 setOperationAction(Ops: ISD::INTRINSIC_WO_CHAIN,
1144 VTs: {MVT::i32, MVT::i128, MVT::v4f32, MVT::Other}, Action: Custom);
1145
1146 // Custom lowering for bswap
1147 setOperationAction(Ops: ISD::BSWAP, VTs: {MVT::i16, MVT::i32, MVT::i64, MVT::v2i16},
1148 Action: Custom);
1149}
1150
1151TargetLoweringBase::LegalizeTypeAction
1152NVPTXTargetLowering::getPreferredVectorAction(MVT VT) const {
1153 if (!VT.isScalableVector() && VT.getVectorNumElements() != 1 &&
1154 VT.getScalarType() == MVT::i1)
1155 return TypeSplitVector;
1156 return TargetLoweringBase::getPreferredVectorAction(VT);
1157}
1158
1159SDValue NVPTXTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG,
1160 int Enabled, int &ExtraSteps,
1161 bool &UseOneConst,
1162 bool Reciprocal) const {
1163 if (!(Enabled == ReciprocalEstimate::Enabled ||
1164 (Enabled == ReciprocalEstimate::Unspecified && !usePrecSqrtF32())))
1165 return SDValue();
1166
1167 if (ExtraSteps == ReciprocalEstimate::Unspecified)
1168 ExtraSteps = 0;
1169
1170 SDLoc DL(Operand);
1171 EVT VT = Operand.getValueType();
1172 bool Ftz = useF32FTZ(MF: DAG.getMachineFunction());
1173
1174 auto MakeIntrinsicCall = [&](Intrinsic::ID IID) {
1175 return DAG.getNode(Opcode: ISD::INTRINSIC_WO_CHAIN, DL, VT,
1176 N1: DAG.getConstant(Val: IID, DL, VT: MVT::i32), N2: Operand);
1177 };
1178
1179 // The sqrt and rsqrt refinement processes assume we always start out with an
1180 // approximation of the rsqrt. Therefore, if we're going to do any refinement
1181 // (i.e. ExtraSteps > 0), we must return an rsqrt. But if we're *not* doing
1182 // any refinement, we must return a regular sqrt.
1183 if (Reciprocal || ExtraSteps > 0) {
1184 if (VT == MVT::f32)
1185 return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_rsqrt_approx_ftz_f
1186 : Intrinsic::nvvm_rsqrt_approx_f);
1187 else if (VT == MVT::f64)
1188 return MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d);
1189 else
1190 return SDValue();
1191 } else {
1192 if (VT == MVT::f32)
1193 return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_sqrt_approx_ftz_f
1194 : Intrinsic::nvvm_sqrt_approx_f);
1195 else {
1196 // There's no sqrt.approx.f64 instruction, so we emit
1197 // reciprocal(rsqrt(x)). This is faster than
1198 // select(x == 0, 0, x * rsqrt(x)). (In fact, it's faster than plain
1199 // x * rsqrt(x).)
1200 return DAG.getNode(
1201 Opcode: ISD::INTRINSIC_WO_CHAIN, DL, VT,
1202 N1: DAG.getConstant(Val: Intrinsic::nvvm_rcp_approx_ftz_d, DL, VT: MVT::i32),
1203 N2: MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d));
1204 }
1205 }
1206}
1207
1208std::string NVPTXTargetLowering::getPrototype(
1209 const DataLayout &DL, Type *RetTy, const ArgListTy &Args,
1210 const SmallVectorImpl<ISD::OutputArg> &Outs,
1211 std::optional<unsigned> FirstVAArg, const CallBase &CB,
1212 unsigned UniqueCallSite) const {
1213 auto PtrVT = getPointerTy(DL);
1214
1215 std::string Prototype;
1216 raw_string_ostream O(Prototype);
1217 O << "prototype_" << UniqueCallSite << " : .callprototype ";
1218
1219 if (RetTy->isVoidTy()) {
1220 O << "()";
1221 } else {
1222 O << "(";
1223 if (shouldPassAsArray(Ty: RetTy)) {
1224 const Align RetAlign =
1225 getPTXParamAlign(CB: &CB, Ty: RetTy, AttrIdx: AttributeList::ReturnIndex, DL);
1226 O << ".param .align " << RetAlign.value() << " .b8 _["
1227 << DL.getTypeAllocSize(Ty: RetTy) << "]";
1228 } else if (RetTy->isFloatingPointTy() || RetTy->isIntegerTy()) {
1229 unsigned size = 0;
1230 if (auto *ITy = dyn_cast<IntegerType>(Val: RetTy)) {
1231 size = ITy->getBitWidth();
1232 } else {
1233 assert(RetTy->isFloatingPointTy() &&
1234 "Floating point type expected here");
1235 size = RetTy->getPrimitiveSizeInBits();
1236 }
1237 // PTX ABI requires all scalar return values to be at least 32
1238 // bits in size. fp16 normally uses .b16 as its storage type in
1239 // PTX, so its size must be adjusted here, too.
1240 size = promoteScalarArgumentSize(size);
1241
1242 O << ".param .b" << size << " _";
1243 } else if (isa<PointerType>(Val: RetTy)) {
1244 O << ".param .b" << PtrVT.getSizeInBits() << " _";
1245 } else {
1246 llvm_unreachable("Unknown return type");
1247 }
1248 O << ") ";
1249 }
1250 O << "_ (";
1251
1252 bool first = true;
1253
1254 const unsigned NumArgs = FirstVAArg.value_or(u: Args.size());
1255 auto AllOuts = ArrayRef(Outs);
1256 for (const unsigned I : llvm::seq(Size: NumArgs)) {
1257 const auto ArgOuts =
1258 AllOuts.take_while(Pred: [I](auto O) { return O.OrigArgIndex == I; });
1259 AllOuts = AllOuts.drop_front(N: ArgOuts.size());
1260
1261 Type *Ty = Args[I].Ty;
1262 if (!first) {
1263 O << ", ";
1264 }
1265 first = false;
1266
1267 if (ArgOuts[0].Flags.isByVal()) {
1268 // Indirect calls need strict ABI alignment so we disable optimizations by
1269 // not providing a function to optimize.
1270 Type *ETy = Args[I].IndirectType;
1271 Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
1272 Align ParamByValAlign =
1273 getDeviceByValParamAlign(/*F=*/nullptr, ArgTy: ETy, InitialAlign, DL);
1274
1275 O << ".param .align " << ParamByValAlign.value() << " .b8 _["
1276 << ArgOuts[0].Flags.getByValSize() << "]";
1277 } else {
1278 if (shouldPassAsArray(Ty)) {
1279 Align ParamAlign =
1280 getPTXParamAlign(CB: &CB, Ty, AttrIdx: I + AttributeList::FirstArgIndex, DL);
1281 O << ".param .align " << ParamAlign.value() << " .b8 _["
1282 << DL.getTypeAllocSize(Ty) << "]";
1283 continue;
1284 }
1285 // i8 types in IR will be i16 types in SDAG
1286 assert((getValueType(DL, Ty) == ArgOuts[0].VT ||
1287 (getValueType(DL, Ty) == MVT::i8 && ArgOuts[0].VT == MVT::i16)) &&
1288 "type mismatch between callee prototype and arguments");
1289 // scalar type
1290 unsigned sz = 0;
1291 if (auto *ITy = dyn_cast<IntegerType>(Val: Ty)) {
1292 sz = promoteScalarArgumentSize(size: ITy->getBitWidth());
1293 } else if (isa<PointerType>(Val: Ty)) {
1294 sz = PtrVT.getSizeInBits();
1295 } else {
1296 sz = Ty->getPrimitiveSizeInBits();
1297 }
1298 O << ".param .b" << sz << " _";
1299 }
1300 }
1301
1302 if (FirstVAArg)
1303 O << (first ? "" : ",") << " .param .align "
1304 << STI.getMaxRequiredAlignment() << " .b8 _[]";
1305 O << ")";
1306 if (shouldEmitPTXNoReturn(V: &CB, TM: *nvTM))
1307 O << " .noreturn";
1308 O << ";";
1309
1310 return Prototype;
1311}
1312
1313static MachinePointerInfo refinePtrAS(SDValue &Ptr, SelectionDAG &DAG,
1314 const DataLayout &DL,
1315 const TargetLowering &TL) {
1316 if (Ptr->getOpcode() == ISD::FrameIndex) {
1317 auto Ty = TL.getPointerTy(DL, AS: ADDRESS_SPACE_LOCAL);
1318 Ptr = DAG.getAddrSpaceCast(dl: SDLoc(), VT: Ty, Ptr, SrcAS: ADDRESS_SPACE_GENERIC,
1319 DestAS: ADDRESS_SPACE_LOCAL);
1320
1321 return MachinePointerInfo(ADDRESS_SPACE_LOCAL);
1322 }
1323
1324 // Peel of an addrspacecast to generic and load directly from the specific
1325 // address space.
1326 if (Ptr->getOpcode() == ISD::ADDRSPACECAST) {
1327 const auto *ASC = cast<AddrSpaceCastSDNode>(Val&: Ptr);
1328 if (ASC->getDestAddressSpace() == ADDRESS_SPACE_GENERIC) {
1329 Ptr = ASC->getOperand(Num: 0);
1330 return MachinePointerInfo(ASC->getSrcAddressSpace());
1331 }
1332 }
1333
1334 return MachinePointerInfo();
1335}
1336
1337static ISD::NodeType getExtOpcode(const ISD::ArgFlagsTy &Flags) {
1338 if (Flags.isSExt())
1339 return ISD::SIGN_EXTEND;
1340 if (Flags.isZExt())
1341 return ISD::ZERO_EXTEND;
1342 return ISD::ANY_EXTEND;
1343}
1344
1345static SDValue correctParamType(SDValue V, EVT ExpectedVT,
1346 ISD::ArgFlagsTy Flags, SelectionDAG &DAG,
1347 SDLoc dl) {
1348 const EVT ActualVT = V.getValueType();
1349 assert((ActualVT == ExpectedVT ||
1350 (ExpectedVT.isInteger() && ActualVT.isInteger())) &&
1351 "Non-integer argument type size mismatch");
1352 if (ExpectedVT.bitsGT(VT: ActualVT))
1353 return DAG.getNode(Opcode: getExtOpcode(Flags), DL: dl, VT: ExpectedVT, Operand: V);
1354 if (ExpectedVT.bitsLT(VT: ActualVT))
1355 return DAG.getNode(Opcode: ISD::TRUNCATE, DL: dl, VT: ExpectedVT, Operand: V);
1356
1357 return V;
1358}
1359
1360SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
1361 SmallVectorImpl<SDValue> &InVals) const {
1362
1363 if (CLI.IsVarArg && (STI.getPTXVersion() < 60 || STI.getSmVersion() < 30))
1364 report_fatal_error(
1365 reason: "Support for variadic functions (unsized array parameter) introduced "
1366 "in PTX ISA version 6.0 and requires target sm_30.");
1367
1368 SelectionDAG &DAG = CLI.DAG;
1369 SDLoc dl = CLI.DL;
1370 const SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
1371 SDValue Callee = CLI.Callee;
1372 ArgListTy &Args = CLI.getArgs();
1373 Type *RetTy = CLI.RetTy;
1374 const CallBase *CB = CLI.CB;
1375 const DataLayout &DL = DAG.getDataLayout();
1376 LLVMContext &Ctx = *DAG.getContext();
1377
1378 const auto GetI32 = [&](const unsigned I) {
1379 return DAG.getConstant(Val: I, DL: dl, VT: MVT::i32);
1380 };
1381
1382 const unsigned UniqueCallSite = GlobalUniqueCallSite++;
1383 const SDValue CallChain = CLI.Chain;
1384 const SDValue StartChain =
1385 DAG.getCALLSEQ_START(Chain: CallChain, InSize: UniqueCallSite, OutSize: 0, DL: dl);
1386 SDValue DeclareGlue = StartChain.getValue(R: 1);
1387
1388 SmallVector<SDValue, 16> CallPrereqs{StartChain};
1389
1390 const auto MakeDeclareScalarParam = [&](SDValue Symbol, unsigned Size) {
1391 // PTX ABI requires integral types to be at least 32 bits in size. FP16 is
1392 // loaded/stored using i16, so it's handled here as well.
1393 const unsigned SizeBits = promoteScalarArgumentSize(size: Size * 8);
1394 SDValue Declare =
1395 DAG.getNode(Opcode: NVPTXISD::DeclareScalarParam, DL: dl, ResultTys: {MVT::Other, MVT::Glue},
1396 Ops: {StartChain, Symbol, GetI32(SizeBits), DeclareGlue});
1397 CallPrereqs.push_back(Elt: Declare);
1398 DeclareGlue = Declare.getValue(R: 1);
1399 return Declare;
1400 };
1401
1402 const auto MakeDeclareArrayParam = [&](SDValue Symbol, Align Align,
1403 unsigned Size) {
1404 SDValue Declare = DAG.getNode(
1405 Opcode: NVPTXISD::DeclareArrayParam, DL: dl, ResultTys: {MVT::Other, MVT::Glue},
1406 Ops: {StartChain, Symbol, GetI32(Align.value()), GetI32(Size), DeclareGlue});
1407 CallPrereqs.push_back(Elt: Declare);
1408 DeclareGlue = Declare.getValue(R: 1);
1409 return Declare;
1410 };
1411
1412 // Variadic arguments.
1413 //
1414 // Normally, for each argument, we declare a param scalar or a param
1415 // byte array in the .param space, and store the argument value to that
1416 // param scalar or array starting at offset 0.
1417 //
1418 // In the case of the first variadic argument, we declare a vararg byte array
1419 // with size 0. The exact size of this array isn't known at this point, so
1420 // it'll be patched later. All the variadic arguments will be stored to this
1421 // array at a certain offset (which gets tracked by 'VAOffset'). The offset is
1422 // initially set to 0, so it can be used for non-variadic arguments (which use
1423 // 0 offset) to simplify the code.
1424 //
1425 // After all vararg is processed, 'VAOffset' holds the size of the
1426 // vararg byte array.
1427 assert((CLI.IsVarArg || CLI.Args.size() == CLI.NumFixedArgs) &&
1428 "Non-VarArg function with extra arguments");
1429
1430 const unsigned FirstVAArg = CLI.NumFixedArgs; // position of first variadic
1431 unsigned VAOffset = 0; // current offset in the param array
1432
1433 const SDValue VADeclareParam =
1434 CLI.Args.size() > FirstVAArg
1435 ? MakeDeclareArrayParam(getCallParamSymbol(DAG, I: FirstVAArg, T: MVT::i32),
1436 Align(STI.getMaxRequiredAlignment()), 0)
1437 : SDValue();
1438
1439 // Args.size() and Outs.size() need not match.
1440 // Outs.size() will be larger
1441 // * if there is an aggregate argument with multiple fields (each field
1442 // showing up separately in Outs)
1443 // * if there is a vector argument with more than typical vector-length
1444 // elements (generally if more than 4) where each vector element is
1445 // individually present in Outs.
1446 // So a different index should be used for indexing into Outs/OutVals.
1447 // See similar issue in LowerFormalArguments.
1448 auto AllOuts = ArrayRef(CLI.Outs);
1449 auto AllOutVals = ArrayRef(CLI.OutVals);
1450 assert(AllOuts.size() == AllOutVals.size() &&
1451 "Outs and OutVals must be the same size");
1452 // Declare the .params or .reg need to pass values
1453 // to the function
1454 for (const auto E : llvm::enumerate(First&: Args)) {
1455 const auto ArgI = E.index();
1456 const auto Arg = E.value();
1457 const auto ArgOuts =
1458 AllOuts.take_while(Pred: [&](auto O) { return O.OrigArgIndex == ArgI; });
1459 const auto ArgOutVals = AllOutVals.take_front(N: ArgOuts.size());
1460 AllOuts = AllOuts.drop_front(N: ArgOuts.size());
1461 AllOutVals = AllOutVals.drop_front(N: ArgOuts.size());
1462
1463 const bool IsVAArg = (ArgI >= FirstVAArg);
1464 const bool IsByVal = Arg.IsByVal;
1465
1466 const SDValue ParamSymbol =
1467 getCallParamSymbol(DAG, I: IsVAArg ? FirstVAArg : ArgI, T: MVT::i32);
1468
1469 assert((!IsByVal || Arg.IndirectType) &&
1470 "byval arg must have indirect type");
1471 Type *ETy = (IsByVal ? Arg.IndirectType : Arg.Ty);
1472
1473 const Align ArgAlign = [&]() {
1474 if (IsByVal) {
1475 // The ByValAlign in the Outs[OIdx].Flags is always set at this point,
1476 // so we don't need to worry whether it's naturally aligned or not.
1477 // See TargetLowering::LowerCallTo().
1478 const Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
1479 return getDeviceByValParamAlign(F: CB->getCalledFunction(), ArgTy: ETy,
1480 InitialAlign, DL);
1481 }
1482 return getPTXParamAlign(CB, Ty: Arg.Ty, AttrIdx: ArgI + AttributeList::FirstArgIndex,
1483 DL);
1484 }();
1485
1486 const unsigned TySize = DL.getTypeAllocSize(Ty: ETy);
1487 assert((!IsByVal || TySize == ArgOuts[0].Flags.getByValSize()) &&
1488 "type size mismatch");
1489
1490 const SDValue ArgDeclare = [&]() {
1491 if (IsVAArg)
1492 return VADeclareParam;
1493
1494 if (IsByVal || shouldPassAsArray(Ty: Arg.Ty))
1495 return MakeDeclareArrayParam(ParamSymbol, ArgAlign, TySize);
1496
1497 assert(ArgOuts.size() == 1 && "We must pass only one value as non-array");
1498 assert((ArgOuts[0].VT.isInteger() || ArgOuts[0].VT.isFloatingPoint()) &&
1499 "Only int and float types are supported as non-array arguments");
1500
1501 return MakeDeclareScalarParam(ParamSymbol, TySize);
1502 }();
1503
1504 if (IsByVal) {
1505 assert(ArgOutVals.size() == 1 && "We must pass only one value as byval");
1506 SDValue SrcPtr = ArgOutVals[0];
1507 const auto PointerInfo = refinePtrAS(Ptr&: SrcPtr, DAG, DL, TL: *this);
1508 const Align BaseSrcAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
1509
1510 if (IsVAArg)
1511 VAOffset = alignTo(Size: VAOffset, A: ArgAlign);
1512
1513 SmallVector<EVT, 4> ValueVTs, MemVTs;
1514 SmallVector<TypeSize, 4> Offsets;
1515 ComputeValueVTs(TLI: *this, DL, Ty: ETy, ValueVTs, MemVTs: &MemVTs, Offsets: &Offsets);
1516
1517 unsigned J = 0;
1518 const auto VI = VectorizePTXValueVTs(ValueVTs: MemVTs, Offsets, ParamAlignment: ArgAlign, IsVAArg);
1519 for (const unsigned NumElts : VI) {
1520 EVT LoadVT = getVectorizedVT(VT: MemVTs[J], N: NumElts, C&: Ctx);
1521 Align SrcAlign = commonAlignment(A: BaseSrcAlign, Offset: Offsets[J]);
1522 SDValue SrcAddr = DAG.getObjectPtrOffset(SL: dl, Ptr: SrcPtr, Offset: Offsets[J]);
1523 SDValue SrcLoad =
1524 DAG.getLoad(VT: LoadVT, dl, Chain: CallChain, Ptr: SrcAddr, PtrInfo: PointerInfo, Alignment: SrcAlign);
1525
1526 TypeSize ParamOffset = Offsets[J].getWithIncrement(RHS: VAOffset);
1527 Align ParamAlign = commonAlignment(A: ArgAlign, Offset: ParamOffset);
1528 SDValue ParamAddr =
1529 DAG.getObjectPtrOffset(SL: dl, Ptr: ParamSymbol, Offset: ParamOffset);
1530 SDValue StoreParam = DAG.getStore(
1531 Chain: ArgDeclare, dl, Val: SrcLoad, Ptr: ParamAddr,
1532 PtrInfo: MachinePointerInfo(NVPTX::AddressSpace::DeviceParam), Alignment: ParamAlign);
1533 CallPrereqs.push_back(Elt: StoreParam);
1534
1535 J += NumElts;
1536 }
1537 if (IsVAArg)
1538 VAOffset += TySize;
1539 } else {
1540 SmallVector<EVT, 16> VTs;
1541 SmallVector<uint64_t, 16> Offsets;
1542 ComputePTXValueVTs(TLI: *this, DL, Ctx, CallConv: CLI.CallConv, Ty: Arg.Ty, ValueVTs&: VTs, Offsets,
1543 StartingOffset: VAOffset);
1544 assert(VTs.size() == Offsets.size() && "Size mismatch");
1545 assert(VTs.size() == ArgOuts.size() && "Size mismatch");
1546
1547 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
1548 // than 32-bits are sign extended or zero extended, depending on
1549 // whether they are signed or unsigned types. This case applies
1550 // only to scalar parameters and not to aggregate values.
1551 const bool ExtendIntegerParam =
1552 Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty: Arg.Ty) < 32;
1553
1554 const auto GetStoredValue = [&](const unsigned I) {
1555 SDValue StVal = ArgOutVals[I];
1556 assert(promoteScalarIntegerPTX(StVal.getValueType()) ==
1557 StVal.getValueType() &&
1558 "OutVal type should always be legal");
1559
1560 const EVT VTI = promoteScalarIntegerPTX(VT: VTs[I]);
1561 const EVT StoreVT =
1562 ExtendIntegerParam ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
1563
1564 return correctParamType(V: StVal, ExpectedVT: StoreVT, Flags: ArgOuts[I].Flags, DAG, dl);
1565 };
1566
1567 unsigned J = 0;
1568 const auto VI = VectorizePTXValueVTs(ValueVTs: VTs, Offsets, ParamAlignment: ArgAlign, IsVAArg);
1569 for (const unsigned NumElts : VI) {
1570 const EVT EltVT = promoteScalarIntegerPTX(VT: VTs[J]);
1571
1572 unsigned Offset;
1573 if (IsVAArg) {
1574 // TODO: We may need to support vector types that can be passed
1575 // as scalars in variadic arguments.
1576 assert(NumElts == 1 &&
1577 "Vectorization should be disabled for vaargs.");
1578
1579 // Align each part of the variadic argument to their type.
1580 VAOffset = alignTo(Size: VAOffset, A: DAG.getEVTAlign(MemoryVT: EltVT));
1581 Offset = VAOffset;
1582
1583 const EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
1584 VAOffset += DL.getTypeAllocSize(Ty: TheStoreType.getTypeForEVT(Context&: Ctx));
1585 } else {
1586 assert(VAOffset == 0 && "VAOffset must be 0 for non-VA args");
1587 Offset = Offsets[J];
1588 }
1589
1590 SDValue Ptr =
1591 DAG.getObjectPtrOffset(SL: dl, Ptr: ParamSymbol, Offset: TypeSize::getFixed(ExactSize: Offset));
1592
1593 const MaybeAlign CurrentAlign = ExtendIntegerParam
1594 ? MaybeAlign(std::nullopt)
1595 : commonAlignment(A: ArgAlign, Offset);
1596
1597 SDValue Val =
1598 getBuildVectorizedValue(N: NumElts, dl, DAG, GetElement: [&](unsigned K) {
1599 return GetStoredValue(J + K);
1600 });
1601
1602 SDValue StoreParam = DAG.getStore(
1603 Chain: ArgDeclare, dl, Val, Ptr,
1604 PtrInfo: MachinePointerInfo(NVPTX::AddressSpace::DeviceParam), Alignment: CurrentAlign);
1605 CallPrereqs.push_back(Elt: StoreParam);
1606
1607 J += NumElts;
1608 }
1609 }
1610 }
1611
1612 // Handle Result
1613 if (!Ins.empty()) {
1614 const SDValue RetSymbol = DAG.getExternalSymbol(Sym: "retval0", VT: MVT::i32);
1615 const unsigned ResultSize = DL.getTypeAllocSize(Ty: RetTy);
1616 if (shouldPassAsArray(Ty: RetTy)) {
1617 const Align RetAlign =
1618 getPTXParamAlign(CB, Ty: RetTy, AttrIdx: AttributeList::ReturnIndex, DL);
1619 MakeDeclareArrayParam(RetSymbol, RetAlign, ResultSize);
1620 } else {
1621 MakeDeclareScalarParam(RetSymbol, ResultSize);
1622 }
1623 }
1624
1625 // Set the size of the vararg param byte array if the callee is a variadic
1626 // function and the variadic part is not empty.
1627 if (VADeclareParam) {
1628 SDValue DeclareParamOps[] = {VADeclareParam.getOperand(i: 0),
1629 VADeclareParam.getOperand(i: 1),
1630 VADeclareParam.getOperand(i: 2), GetI32(VAOffset),
1631 VADeclareParam.getOperand(i: 4)};
1632 DAG.MorphNodeTo(N: VADeclareParam.getNode(), Opc: VADeclareParam.getOpcode(),
1633 VTs: VADeclareParam->getVTList(), Ops: DeclareParamOps);
1634 }
1635
1636 const auto *Func = dyn_cast<GlobalAddressSDNode>(Val: Callee.getNode());
1637 const auto *CalleeF = Func ? dyn_cast<Function>(Val: Func->getGlobal()) : nullptr;
1638
1639 // If the type of the callsite does not match that of the function, convert
1640 // the callsite to an indirect call.
1641 const bool ConvertToIndirectCall =
1642 CalleeF && CB->getFunctionType() != CalleeF->getFunctionType();
1643
1644 // Both indirect calls and libcalls have nullptr Func. In order to distinguish
1645 // between them we must rely on the call site value which is valid for
1646 // indirect calls but is always null for libcalls.
1647 const bool IsIndirectCall = (!Func && CB) || ConvertToIndirectCall;
1648
1649 if (isa<ExternalSymbolSDNode>(Val: Callee)) {
1650 Function* CalleeFunc = nullptr;
1651
1652 // Try to find the callee in the current module.
1653 Callee = DAG.getSymbolFunctionGlobalAddress(Op: Callee, TargetFunction: &CalleeFunc);
1654 assert(CalleeFunc != nullptr && "Libcall callee must be set.");
1655
1656 // Set the "libcall callee" attribute to indicate that the function
1657 // must always have a declaration.
1658 CalleeFunc->addFnAttr(Kind: "nvptx-libcall-callee", Val: "true");
1659 }
1660
1661 if (IsIndirectCall) {
1662 // This is indirect function call case : PTX requires a prototype of the
1663 // form
1664 // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
1665 // to be emitted, and the label has to used as the last arg of call
1666 // instruction.
1667 // The prototype is embedded in a string and put as the operand for a
1668 // CallPrototype SDNode which will print out to the value of the string.
1669 const bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs);
1670 std::string Proto =
1671 getPrototype(DL, RetTy, Args, Outs: CLI.Outs,
1672 FirstVAArg: HasVAArgs ? std::optional(FirstVAArg) : std::nullopt, CB: *CB,
1673 UniqueCallSite);
1674 const char *ProtoStr = nvTM->getStrPool().save(S: Proto).data();
1675 const SDValue PrototypeDeclare = DAG.getNode(
1676 Opcode: NVPTXISD::CallPrototype, DL: dl, VT: MVT::Other,
1677 Ops: {StartChain, DAG.getTargetExternalSymbol(Sym: ProtoStr, VT: MVT::i32)});
1678 CallPrereqs.push_back(Elt: PrototypeDeclare);
1679 }
1680
1681 const bool IsUnknownIntrinsic =
1682 CalleeF && CalleeF->isIntrinsic() &&
1683 CalleeF->getIntrinsicID() == Intrinsic::not_intrinsic;
1684 if (IsUnknownIntrinsic) {
1685 DAG.getContext()->diagnose(DI: DiagnosticInfoUnsupported(
1686 DAG.getMachineFunction().getFunction(),
1687 "call to unknown intrinsic '" + CalleeF->getName() +
1688 "' cannot be lowered by the NVPTX backend",
1689 dl.getDebugLoc()));
1690 }
1691
1692 const unsigned Proto = IsIndirectCall ? UniqueCallSite : 0;
1693 const unsigned NumArgs =
1694 std::min<unsigned>(a: CLI.NumFixedArgs + 1, b: Args.size());
1695 /// CALL(Chain, IsConvergent, IsIndirectCall/IsUniform, NumReturns,
1696 /// NumParams, Callee, Proto)
1697 const SDValue CallToken = DAG.getTokenFactor(DL: dl, Vals&: CallPrereqs);
1698 const SDValue Call = DAG.getNode(
1699 Opcode: NVPTXISD::CALL, DL: dl, VT: MVT::Other,
1700 Ops: {CallToken, GetI32(CLI.IsConvergent), GetI32(IsIndirectCall),
1701 GetI32(Ins.empty() ? 0 : 1), GetI32(NumArgs), Callee, GetI32(Proto)});
1702
1703 SmallVector<SDValue, 16> LoadChains{Call};
1704 SmallVector<SDValue, 16> ProxyRegOps;
1705 if (!Ins.empty()) {
1706 SmallVector<EVT, 16> VTs;
1707 SmallVector<uint64_t, 16> Offsets;
1708 ComputePTXValueVTs(TLI: *this, DL, Ctx, CallConv: CLI.CallConv, Ty: RetTy, ValueVTs&: VTs, Offsets);
1709 assert(VTs.size() == Ins.size() && "Bad value decomposition");
1710
1711 const Align RetAlign =
1712 getPTXParamAlign(CB, Ty: RetTy, AttrIdx: AttributeList::ReturnIndex, DL);
1713 const SDValue RetSymbol = DAG.getExternalSymbol(Sym: "retval0", VT: MVT::i32);
1714
1715 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
1716 // 32-bits are sign extended or zero extended, depending on whether
1717 // they are signed or unsigned types.
1718 const bool ExtendIntegerRetVal =
1719 RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty: RetTy) < 32;
1720
1721 unsigned I = 0;
1722 const auto VI = VectorizePTXValueVTs(ValueVTs: VTs, Offsets, ParamAlignment: RetAlign);
1723 for (const unsigned NumElts : VI) {
1724 const MaybeAlign CurrentAlign =
1725 ExtendIntegerRetVal ? MaybeAlign(std::nullopt)
1726 : commonAlignment(A: RetAlign, Offset: Offsets[I]);
1727
1728 const EVT VTI = promoteScalarIntegerPTX(VT: VTs[I]);
1729 const EVT LoadVT =
1730 ExtendIntegerRetVal ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
1731 const EVT VecVT = getVectorizedVT(VT: LoadVT, N: NumElts, C&: Ctx);
1732 SDValue Ptr =
1733 DAG.getObjectPtrOffset(SL: dl, Ptr: RetSymbol, Offset: TypeSize::getFixed(ExactSize: Offsets[I]));
1734
1735 SDValue R = DAG.getLoad(
1736 VT: VecVT, dl, Chain: Call, Ptr,
1737 PtrInfo: MachinePointerInfo(NVPTX::AddressSpace::DeviceParam), Alignment: CurrentAlign);
1738
1739 LoadChains.push_back(Elt: R.getValue(R: 1));
1740 for (const unsigned J : llvm::seq(Size: NumElts))
1741 ProxyRegOps.push_back(Elt: getExtractVectorizedValue(V: R, I: J, VT: LoadVT, dl, DAG));
1742 I += NumElts;
1743 }
1744 }
1745
1746 const SDValue EndToken = DAG.getTokenFactor(DL: dl, Vals&: LoadChains);
1747 const SDValue CallEnd = DAG.getCALLSEQ_END(Chain: EndToken, Size1: UniqueCallSite,
1748 Size2: UniqueCallSite + 1, Glue: SDValue(), DL: dl);
1749
1750 // Append ProxyReg instructions to the chain to make sure that `callseq_end`
1751 // will not get lost. Otherwise, during libcalls expansion, the nodes can become
1752 // dangling.
1753 for (const auto [I, Reg] : llvm::enumerate(First&: ProxyRegOps)) {
1754 SDValue Proxy =
1755 DAG.getNode(Opcode: NVPTXISD::ProxyReg, DL: dl, VT: Reg.getValueType(), Ops: {CallEnd, Reg});
1756 SDValue Ret = correctParamType(V: Proxy, ExpectedVT: Ins[I].VT, Flags: Ins[I].Flags, DAG, dl);
1757 InVals.push_back(Elt: Ret);
1758 }
1759
1760 // set IsTailCall to false for now, until we figure out how to express
1761 // tail call optimization in PTX
1762 CLI.IsTailCall = false;
1763 return CallEnd;
1764}
1765
1766SDValue NVPTXTargetLowering::LowerDYNAMIC_STACKALLOC(SDValue Op,
1767 SelectionDAG &DAG) const {
1768
1769 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
1770 const Function &Fn = DAG.getMachineFunction().getFunction();
1771
1772 DAG.getContext()->diagnose(DI: DiagnosticInfoUnsupported(
1773 Fn,
1774 "Support for dynamic alloca introduced in PTX ISA version 7.3 and "
1775 "requires target sm_52.",
1776 SDLoc(Op).getDebugLoc()));
1777 auto Ops = {DAG.getConstant(Val: 0, DL: SDLoc(), VT: Op.getValueType()),
1778 Op.getOperand(i: 0)};
1779 return DAG.getMergeValues(Ops, dl: SDLoc());
1780 }
1781
1782 SDLoc DL(Op.getNode());
1783 SDValue Chain = Op.getOperand(i: 0);
1784 SDValue Size = Op.getOperand(i: 1);
1785 uint64_t Align = Op.getConstantOperandVal(i: 2);
1786
1787 // The alignment on a ISD::DYNAMIC_STACKALLOC node may be 0 to indicate that
1788 // the default stack alignment should be used.
1789 if (Align == 0)
1790 Align = DAG.getSubtarget().getFrameLowering()->getStackAlign().value();
1791
1792 // The size for ptx alloca instruction is 64-bit for m64 and 32-bit for m32.
1793 const MVT LocalVT = getPointerTy(DL: DAG.getDataLayout(), AS: ADDRESS_SPACE_LOCAL);
1794
1795 SDValue Alloc =
1796 DAG.getNode(Opcode: NVPTXISD::DYNAMIC_STACKALLOC, DL, ResultTys: {LocalVT, MVT::Other},
1797 Ops: {Chain, DAG.getZExtOrTrunc(Op: Size, DL, VT: LocalVT),
1798 DAG.getTargetConstant(Val: Align, DL, VT: MVT::i32)});
1799
1800 SDValue ASC = DAG.getAddrSpaceCast(
1801 dl: DL, VT: Op.getValueType(), Ptr: Alloc, SrcAS: ADDRESS_SPACE_LOCAL, DestAS: ADDRESS_SPACE_GENERIC);
1802
1803 return DAG.getMergeValues(Ops: {ASC, SDValue(Alloc.getNode(), 1)}, dl: DL);
1804}
1805
1806SDValue NVPTXTargetLowering::LowerSTACKRESTORE(SDValue Op,
1807 SelectionDAG &DAG) const {
1808 SDLoc DL(Op.getNode());
1809 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
1810 const Function &Fn = DAG.getMachineFunction().getFunction();
1811
1812 DAG.getContext()->diagnose(DI: DiagnosticInfoUnsupported(
1813 Fn,
1814 "Support for stackrestore requires PTX ISA version >= 7.3 and target "
1815 ">= sm_52.",
1816 DL.getDebugLoc()));
1817 return Op.getOperand(i: 0);
1818 }
1819
1820 const MVT LocalVT = getPointerTy(DL: DAG.getDataLayout(), AS: ADDRESS_SPACE_LOCAL);
1821 SDValue Chain = Op.getOperand(i: 0);
1822 SDValue Ptr = Op.getOperand(i: 1);
1823 SDValue ASC = DAG.getAddrSpaceCast(dl: DL, VT: LocalVT, Ptr, SrcAS: ADDRESS_SPACE_GENERIC,
1824 DestAS: ADDRESS_SPACE_LOCAL);
1825 return DAG.getNode(Opcode: NVPTXISD::STACKRESTORE, DL, VT: MVT::Other, Ops: {Chain, ASC});
1826}
1827
1828SDValue NVPTXTargetLowering::LowerSTACKSAVE(SDValue Op,
1829 SelectionDAG &DAG) const {
1830 SDLoc DL(Op.getNode());
1831 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
1832 const Function &Fn = DAG.getMachineFunction().getFunction();
1833
1834 DAG.getContext()->diagnose(DI: DiagnosticInfoUnsupported(
1835 Fn,
1836 "Support for stacksave requires PTX ISA version >= 7.3 and target >= "
1837 "sm_52.",
1838 DL.getDebugLoc()));
1839 auto Ops = {DAG.getConstant(Val: 0, DL, VT: Op.getValueType()), Op.getOperand(i: 0)};
1840 return DAG.getMergeValues(Ops, dl: DL);
1841 }
1842
1843 const MVT LocalVT = getPointerTy(DL: DAG.getDataLayout(), AS: ADDRESS_SPACE_LOCAL);
1844 SDValue Chain = Op.getOperand(i: 0);
1845 SDValue SS =
1846 DAG.getNode(Opcode: NVPTXISD::STACKSAVE, DL, ResultTys: {LocalVT, MVT::Other}, Ops: Chain);
1847 SDValue ASC = DAG.getAddrSpaceCast(
1848 dl: DL, VT: Op.getValueType(), Ptr: SS, SrcAS: ADDRESS_SPACE_LOCAL, DestAS: ADDRESS_SPACE_GENERIC);
1849 return DAG.getMergeValues(Ops: {ASC, SDValue(SS.getNode(), 1)}, dl: DL);
1850}
1851
1852// By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
1853// (see LegalizeDAG.cpp). This is slow and uses local memory.
1854// We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
1855SDValue
1856NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
1857 SDNode *Node = Op.getNode();
1858 SDLoc dl(Node);
1859 SmallVector<SDValue, 8> Ops;
1860 unsigned NumOperands = Node->getNumOperands();
1861 for (unsigned i = 0; i < NumOperands; ++i) {
1862 SDValue SubOp = Node->getOperand(Num: i);
1863 EVT VVT = SubOp.getNode()->getValueType(ResNo: 0);
1864 EVT EltVT = VVT.getVectorElementType();
1865 unsigned NumSubElem = VVT.getVectorNumElements();
1866 for (unsigned j = 0; j < NumSubElem; ++j) {
1867 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: dl, VT: EltVT, N1: SubOp,
1868 N2: DAG.getIntPtrConstant(Val: j, DL: dl)));
1869 }
1870 }
1871 return DAG.getBuildVector(VT: Node->getValueType(ResNo: 0), DL: dl, Ops);
1872}
1873
1874static SDValue getPRMT(SDValue A, SDValue B, SDValue Selector, SDLoc DL,
1875 SelectionDAG &DAG,
1876 unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
1877 assert(A.getValueType() == MVT::i32 && B.getValueType() == MVT::i32 &&
1878 Selector.getValueType() == MVT::i32 && "PRMT must have i32 operands");
1879 return DAG.getNode(Opcode: NVPTXISD::PRMT, DL, VT: MVT::i32,
1880 Ops: {A, B, Selector, DAG.getConstant(Val: Mode, DL, VT: MVT::i32)});
1881}
1882
1883static SDValue getPRMT(SDValue A, SDValue B, uint64_t Selector, SDLoc DL,
1884 SelectionDAG &DAG,
1885 unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
1886 return getPRMT(A, B, Selector: DAG.getConstant(Val: Selector, DL, VT: MVT::i32), DL, DAG, Mode);
1887}
1888
1889/// Reduces the elements using the scalar operations provided. The operations
1890/// are sorted descending in number of inputs they take. The flags on the
1891/// original reduction operation will be propagated to each scalar operation.
1892/// Nearby elements are grouped in tree reduction, unlike the shuffle reduction
1893/// used in ExpandReductions and SelectionDAG.
1894static SDValue buildTreeReduction(
1895 const SmallVector<SDValue> &Elements, EVT EltTy,
1896 ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
1897 const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
1898 // Build the reduction tree at each level, starting with all the elements.
1899 SmallVector<SDValue> Level = Elements;
1900
1901 unsigned OpIdx = 0;
1902 while (Level.size() > 1) {
1903 // Try to reduce this level using the current operator.
1904 const auto [Op, NumInputs] = Ops[OpIdx];
1905
1906 // Build the next level by partially reducing all elements.
1907 SmallVector<SDValue> ReducedLevel;
1908 unsigned I = 0, E = Level.size();
1909 for (; I + NumInputs <= E; I += NumInputs) {
1910 // Reduce elements in groups of [NumInputs], as much as possible.
1911 ReducedLevel.push_back(Elt: DAG.getNode(
1912 Opcode: Op, DL, VT: EltTy, Ops: ArrayRef<SDValue>(Level).slice(N: I, M: NumInputs), Flags));
1913 }
1914
1915 if (I < E) {
1916 // Handle leftover elements.
1917
1918 if (ReducedLevel.empty()) {
1919 // We didn't reduce anything at this level. We need to pick a smaller
1920 // operator.
1921 ++OpIdx;
1922 assert(OpIdx < Ops.size() && "no smaller operators for reduction");
1923 continue;
1924 }
1925
1926 // We reduced some things but there's still more left, meaning the
1927 // operator's number of inputs doesn't evenly divide this level size. Move
1928 // these elements to the next level.
1929 for (; I < E; ++I)
1930 ReducedLevel.push_back(Elt: Level[I]);
1931 }
1932
1933 // Process the next level.
1934 Level = ReducedLevel;
1935 }
1936
1937 return *Level.begin();
1938}
1939
1940// Get scalar reduction opcode
1941static ISD::NodeType getScalarOpcodeForReduction(unsigned ReductionOpcode) {
1942 switch (ReductionOpcode) {
1943 case ISD::VECREDUCE_FMAX:
1944 return ISD::FMAXNUM;
1945 case ISD::VECREDUCE_FMIN:
1946 return ISD::FMINNUM;
1947 case ISD::VECREDUCE_FMAXIMUM:
1948 return ISD::FMAXIMUM;
1949 case ISD::VECREDUCE_FMINIMUM:
1950 return ISD::FMINIMUM;
1951 default:
1952 llvm_unreachable("unhandled reduction opcode");
1953 }
1954}
1955
1956/// Get 3-input scalar reduction opcode
1957static std::optional<unsigned>
1958getScalar3OpcodeForReduction(unsigned ReductionOpcode) {
1959 switch (ReductionOpcode) {
1960 case ISD::VECREDUCE_FMAX:
1961 return NVPTXISD::FMAXNUM3;
1962 case ISD::VECREDUCE_FMIN:
1963 return NVPTXISD::FMINNUM3;
1964 case ISD::VECREDUCE_FMAXIMUM:
1965 return NVPTXISD::FMAXIMUM3;
1966 case ISD::VECREDUCE_FMINIMUM:
1967 return NVPTXISD::FMINIMUM3;
1968 default:
1969 return std::nullopt;
1970 }
1971}
1972
1973/// Lower reductions to either a sequence of operations or a tree if
1974/// reassociations are allowed. This method will use larger operations like
1975/// max3/min3 when the target supports them.
1976SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
1977 SelectionDAG &DAG) const {
1978 SDLoc DL(Op);
1979 const SDNodeFlags Flags = Op->getFlags();
1980 SDValue Vector = Op.getOperand(i: 0);
1981
1982 const unsigned Opcode = Op->getOpcode();
1983 const EVT EltTy = Vector.getValueType().getVectorElementType();
1984
1985 // Whether we can use 3-input min/max when expanding the reduction.
1986 const bool CanUseMinMax3 =
1987 EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
1988 STI.getPTXVersion() >= 88 &&
1989 (Opcode == ISD::VECREDUCE_FMAX || Opcode == ISD::VECREDUCE_FMIN ||
1990 Opcode == ISD::VECREDUCE_FMAXIMUM || Opcode == ISD::VECREDUCE_FMINIMUM);
1991
1992 // A list of SDNode opcodes with equivalent semantics, sorted descending by
1993 // number of inputs they take.
1994 SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> ScalarOps;
1995
1996 if (auto Opcode3Elem = getScalar3OpcodeForReduction(ReductionOpcode: Opcode);
1997 CanUseMinMax3 && Opcode3Elem)
1998 ScalarOps.push_back(Elt: {*Opcode3Elem, 3});
1999 ScalarOps.push_back(Elt: {getScalarOpcodeForReduction(ReductionOpcode: Opcode), 2});
2000
2001 SmallVector<SDValue> Elements;
2002 DAG.ExtractVectorElements(Op: Vector, Args&: Elements);
2003
2004 return buildTreeReduction(Elements, EltTy, Ops: ScalarOps, DL, Flags, DAG);
2005}
2006
2007SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
2008 // Handle bitcasting from v2i8 without hitting the default promotion
2009 // strategy which goes through stack memory.
2010 EVT FromVT = Op->getOperand(Num: 0)->getValueType(ResNo: 0);
2011 if (FromVT != MVT::v2i8) {
2012 return Op;
2013 }
2014
2015 // Pack vector elements into i16 and bitcast to final type
2016 SDLoc DL(Op);
2017 SDValue Vec0 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i8,
2018 N1: Op->getOperand(Num: 0), N2: DAG.getIntPtrConstant(Val: 0, DL));
2019 SDValue Vec1 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i8,
2020 N1: Op->getOperand(Num: 0), N2: DAG.getIntPtrConstant(Val: 1, DL));
2021 SDValue Extend0 = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: MVT::i16, Operand: Vec0);
2022 SDValue Extend1 = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: MVT::i16, Operand: Vec1);
2023 SDValue Const8 = DAG.getConstant(Val: 8, DL, VT: MVT::i16);
2024 SDValue AsInt = DAG.getNode(
2025 Opcode: ISD::OR, DL, VT: MVT::i16,
2026 Ops: {Extend0, DAG.getNode(Opcode: ISD::SHL, DL, VT: MVT::i16, Ops: {Extend1, Const8})});
2027 EVT ToVT = Op->getValueType(ResNo: 0);
2028 return DAG.getBitcast(VT: ToVT, V: AsInt);
2029}
2030
2031// We can init constant f16x2/v2i16/v4i8 with a single .b32 move. Normally it
2032// would get lowered as two constant loads and vector-packing move.
2033// Instead we want just a constant move:
2034// mov.b32 %r2, 0x40003C00
2035SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
2036 SelectionDAG &DAG) const {
2037 EVT VT = Op->getValueType(ResNo: 0);
2038 if (!(NVPTX::isPackedVectorTy(VT) && VT.is32BitVector()))
2039 return Op;
2040 SDLoc DL(Op);
2041
2042 if (!llvm::all_of(Range: Op->ops(), P: [](SDValue Operand) {
2043 return Operand->isUndef() || isa<ConstantSDNode>(Val: Operand) ||
2044 isa<ConstantFPSDNode>(Val: Operand);
2045 })) {
2046 if (VT != MVT::v4i8)
2047 return Op;
2048 // Lower non-const v4i8 vector as byte-wise constructed i32, which allows us
2049 // to optimize calculation of constant parts.
2050 auto GetPRMT = [&](const SDValue Left, const SDValue Right, bool Cast,
2051 uint64_t SelectionValue) -> SDValue {
2052 SDValue L = Left;
2053 SDValue R = Right;
2054 if (Cast) {
2055 L = DAG.getAnyExtOrTrunc(Op: L, DL, VT: MVT::i32);
2056 R = DAG.getAnyExtOrTrunc(Op: R, DL, VT: MVT::i32);
2057 }
2058 return getPRMT(A: L, B: R, Selector: SelectionValue, DL, DAG);
2059 };
2060 auto PRMT__10 = GetPRMT(Op->getOperand(Num: 0), Op->getOperand(Num: 1), true, 0x3340);
2061 auto PRMT__32 = GetPRMT(Op->getOperand(Num: 2), Op->getOperand(Num: 3), true, 0x3340);
2062 auto PRMT3210 = GetPRMT(PRMT__10, PRMT__32, false, 0x5410);
2063 return DAG.getBitcast(VT, V: PRMT3210);
2064 }
2065
2066 // Get value or the Nth operand as an APInt(32). Undef values treated as 0.
2067 auto GetOperand = [](SDValue Op, int N) -> APInt {
2068 const SDValue &Operand = Op->getOperand(Num: N);
2069 EVT VT = Op->getValueType(ResNo: 0);
2070 if (Operand->isUndef())
2071 return APInt(32, 0);
2072 APInt Value;
2073 if (VT == MVT::v2f16 || VT == MVT::v2bf16)
2074 Value = cast<ConstantFPSDNode>(Val: Operand)->getValueAPF().bitcastToAPInt();
2075 else if (VT == MVT::v2i16 || VT == MVT::v4i8)
2076 Value = Operand->getAsAPIntVal();
2077 else
2078 llvm_unreachable("Unsupported type");
2079 // i8 values are carried around as i16, so we need to zero out upper bits,
2080 // so they do not get in the way of combining individual byte values
2081 if (VT == MVT::v4i8)
2082 Value = Value.trunc(width: 8);
2083 return Value.zext(width: 32);
2084 };
2085
2086 // Construct a 32-bit constant by shifting into place smaller values
2087 // (elements of the vector type VT).
2088 // For example, if VT has 2 elements, then N == 2:
2089 // ShiftAmount = 32 / N = 16
2090 // Value |= Op0 (b16) << 0
2091 // Value |= Op1 (b16) << 16
2092 // If N == 4:
2093 // ShiftAmount = 32 / N = 8
2094 // Value |= Op0 (b8) << 0
2095 // Value |= Op1 (b8) << 8
2096 // Value |= Op2 (b8) << 16
2097 // Value |= Op3 (b8) << 24
2098 // ...etc
2099 APInt Value(32, 0);
2100 const unsigned NumElements = VT.getVectorNumElements();
2101 assert(32 % NumElements == 0 && "must evenly divide bit length");
2102 const unsigned ShiftAmount = 32 / NumElements;
2103 for (unsigned ElementNo : seq(Size: NumElements))
2104 Value |= GetOperand(Op, ElementNo).shl(shiftAmt: ElementNo * ShiftAmount);
2105 SDValue Const = DAG.getConstant(Val: Value, DL, VT: MVT::i32);
2106 return DAG.getNode(Opcode: ISD::BITCAST, DL, VT: Op->getValueType(ResNo: 0), Operand: Const);
2107}
2108
2109SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
2110 SelectionDAG &DAG) const {
2111 SDValue Index = Op->getOperand(Num: 1);
2112 SDValue Vector = Op->getOperand(Num: 0);
2113 SDLoc DL(Op);
2114 EVT VectorVT = Vector.getValueType();
2115
2116 if (VectorVT == MVT::v4i8) {
2117 SDValue Selector = DAG.getNode(Opcode: ISD::OR, DL, VT: MVT::i32,
2118 N1: DAG.getZExtOrTrunc(Op: Index, DL, VT: MVT::i32),
2119 N2: DAG.getConstant(Val: 0x7770, DL, VT: MVT::i32));
2120 SDValue PRMT = getPRMT(A: DAG.getBitcast(VT: MVT::i32, V: Vector),
2121 B: DAG.getConstant(Val: 0, DL, VT: MVT::i32), Selector, DL, DAG);
2122 SDValue Ext = DAG.getAnyExtOrTrunc(Op: PRMT, DL, VT: Op->getValueType(ResNo: 0));
2123 SDNodeFlags Flags;
2124 Flags.setNoSignedWrap(Ext.getScalarValueSizeInBits() > 8);
2125 Flags.setNoUnsignedWrap(Ext.getScalarValueSizeInBits() >= 8);
2126 Ext->setFlags(Flags);
2127 return Ext;
2128 }
2129
2130 // Constant index will be matched by tablegen.
2131 if (isa<ConstantSDNode>(Val: Index.getNode()))
2132 return Op;
2133
2134 // Extract individual elements and select one of them.
2135 assert(NVPTX::isPackedVectorTy(VectorVT) &&
2136 VectorVT.getVectorNumElements() == 2 && "Unexpected vector type.");
2137 EVT EltVT = VectorVT.getVectorElementType();
2138
2139 SDLoc dl(Op.getNode());
2140 SDValue E0 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: dl, VT: EltVT, N1: Vector,
2141 N2: DAG.getIntPtrConstant(Val: 0, DL: dl));
2142 SDValue E1 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: dl, VT: EltVT, N1: Vector,
2143 N2: DAG.getIntPtrConstant(Val: 1, DL: dl));
2144 return DAG.getSelectCC(DL: dl, LHS: Index, RHS: DAG.getIntPtrConstant(Val: 0, DL: dl), True: E0, False: E1,
2145 Cond: ISD::CondCode::SETEQ);
2146}
2147
2148SDValue NVPTXTargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op,
2149 SelectionDAG &DAG) const {
2150 SDValue Vector = Op->getOperand(Num: 0);
2151 EVT VectorVT = Vector.getValueType();
2152
2153 if (VectorVT != MVT::v4i8)
2154 return Op;
2155 SDLoc DL(Op);
2156 SDValue Value = Op->getOperand(Num: 1);
2157 if (Value->isUndef())
2158 return Vector;
2159
2160 SDValue Index = Op->getOperand(Num: 2);
2161
2162 SDValue BFI =
2163 DAG.getNode(Opcode: NVPTXISD::BFI, DL, VT: MVT::i32,
2164 Ops: {DAG.getZExtOrTrunc(Op: Value, DL, VT: MVT::i32), Vector,
2165 DAG.getNode(Opcode: ISD::MUL, DL, VT: MVT::i32,
2166 N1: DAG.getZExtOrTrunc(Op: Index, DL, VT: MVT::i32),
2167 N2: DAG.getConstant(Val: 8, DL, VT: MVT::i32)),
2168 DAG.getConstant(Val: 8, DL, VT: MVT::i32)});
2169 return DAG.getNode(Opcode: ISD::BITCAST, DL, VT: Op->getValueType(ResNo: 0), Operand: BFI);
2170}
2171
2172SDValue NVPTXTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
2173 SelectionDAG &DAG) const {
2174 SDValue V1 = Op.getOperand(i: 0);
2175 EVT VectorVT = V1.getValueType();
2176 if (VectorVT != MVT::v4i8 || Op.getValueType() != MVT::v4i8)
2177 return Op;
2178
2179 // Lower shuffle to PRMT instruction.
2180 const ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Val: Op.getNode());
2181 SDValue V2 = Op.getOperand(i: 1);
2182 uint32_t Selector = 0;
2183 for (auto I : llvm::enumerate(First: SVN->getMask())) {
2184 if (I.value() != -1) // -1 is a placeholder for undef.
2185 Selector |= (I.value() << (I.index() * 4));
2186 }
2187
2188 SDLoc DL(Op);
2189 SDValue PRMT = getPRMT(A: DAG.getBitcast(VT: MVT::i32, V: V1),
2190 B: DAG.getBitcast(VT: MVT::i32, V: V2), Selector, DL, DAG);
2191 return DAG.getBitcast(VT: Op.getValueType(), V: PRMT);
2192}
2193/// LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which
2194/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
2195/// amount, or
2196/// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift
2197/// amount.
2198SDValue NVPTXTargetLowering::LowerShiftRightParts(SDValue Op,
2199 SelectionDAG &DAG) const {
2200 assert(Op.getNumOperands() == 3 && "Not a double-shift!");
2201 assert(Op.getOpcode() == ISD::SRA_PARTS || Op.getOpcode() == ISD::SRL_PARTS);
2202
2203 EVT VT = Op.getValueType();
2204 unsigned VTBits = VT.getSizeInBits();
2205 SDLoc dl(Op);
2206 SDValue ShOpLo = Op.getOperand(i: 0);
2207 SDValue ShOpHi = Op.getOperand(i: 1);
2208 SDValue ShAmt = Op.getOperand(i: 2);
2209 unsigned Opc = (Op.getOpcode() == ISD::SRA_PARTS) ? ISD::SRA : ISD::SRL;
2210
2211 if (VTBits == 32 && STI.getSmVersion() >= 35) {
2212 // For 32bit and sm35, we can use the funnel shift 'shf' instruction.
2213 // {dHi, dLo} = {aHi, aLo} >> Amt
2214 // dHi = aHi >> Amt
2215 // dLo = shf.r.clamp aLo, aHi, Amt
2216
2217 SDValue Hi = DAG.getNode(Opcode: Opc, DL: dl, VT, N1: ShOpHi, N2: ShAmt);
2218 SDValue Lo =
2219 DAG.getNode(Opcode: NVPTXISD::FSHR_CLAMP, DL: dl, VT, N1: ShOpHi, N2: ShOpLo, N3: ShAmt);
2220
2221 SDValue Ops[2] = { Lo, Hi };
2222 return DAG.getMergeValues(Ops, dl);
2223 }
2224 else {
2225 // {dHi, dLo} = {aHi, aLo} >> Amt
2226 // - if (Amt>=size) then
2227 // dLo = aHi >> (Amt-size)
2228 // dHi = aHi >> Amt (this is either all 0 or all 1)
2229 // else
2230 // dLo = (aLo >>logic Amt) | (aHi << (size-Amt))
2231 // dHi = aHi >> Amt
2232
2233 SDValue RevShAmt = DAG.getNode(Opcode: ISD::SUB, DL: dl, VT: MVT::i32,
2234 N1: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32),
2235 N2: ShAmt);
2236 SDValue Tmp1 = DAG.getNode(Opcode: ISD::SRL, DL: dl, VT, N1: ShOpLo, N2: ShAmt);
2237 SDValue ExtraShAmt = DAG.getNode(Opcode: ISD::SUB, DL: dl, VT: MVT::i32, N1: ShAmt,
2238 N2: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32));
2239 SDValue Tmp2 = DAG.getNode(Opcode: ISD::SHL, DL: dl, VT, N1: ShOpHi, N2: RevShAmt);
2240 SDValue FalseVal = DAG.getNode(Opcode: ISD::OR, DL: dl, VT, N1: Tmp1, N2: Tmp2);
2241 SDValue TrueVal = DAG.getNode(Opcode: Opc, DL: dl, VT, N1: ShOpHi, N2: ExtraShAmt);
2242
2243 SDValue Cmp = DAG.getSetCC(DL: dl, VT: MVT::i1, LHS: ShAmt,
2244 RHS: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32),
2245 Cond: ISD::SETGE);
2246 SDValue Hi = DAG.getNode(Opcode: Opc, DL: dl, VT, N1: ShOpHi, N2: ShAmt);
2247 SDValue Lo = DAG.getNode(Opcode: ISD::SELECT, DL: dl, VT, N1: Cmp, N2: TrueVal, N3: FalseVal);
2248
2249 SDValue Ops[2] = { Lo, Hi };
2250 return DAG.getMergeValues(Ops, dl);
2251 }
2252}
2253
2254/// LowerShiftLeftParts - Lower SHL_PARTS, which
2255/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
2256/// amount, or
2257/// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift
2258/// amount.
2259SDValue NVPTXTargetLowering::LowerShiftLeftParts(SDValue Op,
2260 SelectionDAG &DAG) const {
2261 assert(Op.getNumOperands() == 3 && "Not a double-shift!");
2262 assert(Op.getOpcode() == ISD::SHL_PARTS);
2263
2264 EVT VT = Op.getValueType();
2265 unsigned VTBits = VT.getSizeInBits();
2266 SDLoc dl(Op);
2267 SDValue ShOpLo = Op.getOperand(i: 0);
2268 SDValue ShOpHi = Op.getOperand(i: 1);
2269 SDValue ShAmt = Op.getOperand(i: 2);
2270
2271 if (VTBits == 32 && STI.getSmVersion() >= 35) {
2272 // For 32bit and sm35, we can use the funnel shift 'shf' instruction.
2273 // {dHi, dLo} = {aHi, aLo} << Amt
2274 // dHi = shf.l.clamp aLo, aHi, Amt
2275 // dLo = aLo << Amt
2276
2277 SDValue Hi =
2278 DAG.getNode(Opcode: NVPTXISD::FSHL_CLAMP, DL: dl, VT, N1: ShOpHi, N2: ShOpLo, N3: ShAmt);
2279 SDValue Lo = DAG.getNode(Opcode: ISD::SHL, DL: dl, VT, N1: ShOpLo, N2: ShAmt);
2280
2281 SDValue Ops[2] = { Lo, Hi };
2282 return DAG.getMergeValues(Ops, dl);
2283 }
2284 else {
2285 // {dHi, dLo} = {aHi, aLo} << Amt
2286 // - if (Amt>=size) then
2287 // dLo = aLo << Amt (all 0)
2288 // dLo = aLo << (Amt-size)
2289 // else
2290 // dLo = aLo << Amt
2291 // dHi = (aHi << Amt) | (aLo >> (size-Amt))
2292
2293 SDValue RevShAmt = DAG.getNode(Opcode: ISD::SUB, DL: dl, VT: MVT::i32,
2294 N1: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32),
2295 N2: ShAmt);
2296 SDValue Tmp1 = DAG.getNode(Opcode: ISD::SHL, DL: dl, VT, N1: ShOpHi, N2: ShAmt);
2297 SDValue ExtraShAmt = DAG.getNode(Opcode: ISD::SUB, DL: dl, VT: MVT::i32, N1: ShAmt,
2298 N2: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32));
2299 SDValue Tmp2 = DAG.getNode(Opcode: ISD::SRL, DL: dl, VT, N1: ShOpLo, N2: RevShAmt);
2300 SDValue FalseVal = DAG.getNode(Opcode: ISD::OR, DL: dl, VT, N1: Tmp1, N2: Tmp2);
2301 SDValue TrueVal = DAG.getNode(Opcode: ISD::SHL, DL: dl, VT, N1: ShOpLo, N2: ExtraShAmt);
2302
2303 SDValue Cmp = DAG.getSetCC(DL: dl, VT: MVT::i1, LHS: ShAmt,
2304 RHS: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32),
2305 Cond: ISD::SETGE);
2306 SDValue Lo = DAG.getNode(Opcode: ISD::SHL, DL: dl, VT, N1: ShOpLo, N2: ShAmt);
2307 SDValue Hi = DAG.getNode(Opcode: ISD::SELECT, DL: dl, VT, N1: Cmp, N2: TrueVal, N3: FalseVal);
2308
2309 SDValue Ops[2] = { Lo, Hi };
2310 return DAG.getMergeValues(Ops, dl);
2311 }
2312}
2313
2314/// If the types match, convert the generic copysign to the NVPTXISD version,
2315/// otherwise bail ensuring that mismatched cases are properly expaned.
2316SDValue NVPTXTargetLowering::LowerFCOPYSIGN(SDValue Op,
2317 SelectionDAG &DAG) const {
2318 EVT VT = Op.getValueType();
2319 SDLoc DL(Op);
2320
2321 SDValue In1 = Op.getOperand(i: 0);
2322 SDValue In2 = Op.getOperand(i: 1);
2323 EVT SrcVT = In2.getValueType();
2324
2325 if (!SrcVT.bitsEq(VT))
2326 return SDValue();
2327
2328 return DAG.getNode(Opcode: NVPTXISD::FCOPYSIGN, DL, VT, N1: In1, N2: In2);
2329}
2330
2331SDValue NVPTXTargetLowering::LowerFROUND(SDValue Op, SelectionDAG &DAG) const {
2332 EVT VT = Op.getValueType();
2333
2334 if (VT == MVT::f32)
2335 return LowerFROUND32(Op, DAG);
2336
2337 if (VT == MVT::f64)
2338 return LowerFROUND64(Op, DAG);
2339
2340 llvm_unreachable("unhandled type");
2341}
2342
2343// This is the the rounding method used in CUDA libdevice in C like code:
2344// float roundf(float A)
2345// {
2346// float RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f));
2347// RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
2348// return abs(A) < 0.5 ? (float)(int)A : RoundedA;
2349// }
2350SDValue NVPTXTargetLowering::LowerFROUND32(SDValue Op,
2351 SelectionDAG &DAG) const {
2352 SDLoc SL(Op);
2353 SDValue A = Op.getOperand(i: 0);
2354 EVT VT = Op.getValueType();
2355
2356 SDValue AbsA = DAG.getNode(Opcode: ISD::FABS, DL: SL, VT, Operand: A);
2357
2358 // RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f))
2359 SDValue Bitcast = DAG.getNode(Opcode: ISD::BITCAST, DL: SL, VT: MVT::i32, Operand: A);
2360 const unsigned SignBitMask = 0x80000000;
2361 SDValue Sign = DAG.getNode(Opcode: ISD::AND, DL: SL, VT: MVT::i32, N1: Bitcast,
2362 N2: DAG.getConstant(Val: SignBitMask, DL: SL, VT: MVT::i32));
2363 const unsigned PointFiveInBits = 0x3F000000;
2364 SDValue PointFiveWithSignRaw =
2365 DAG.getNode(Opcode: ISD::OR, DL: SL, VT: MVT::i32, N1: Sign,
2366 N2: DAG.getConstant(Val: PointFiveInBits, DL: SL, VT: MVT::i32));
2367 SDValue PointFiveWithSign =
2368 DAG.getNode(Opcode: ISD::BITCAST, DL: SL, VT, Operand: PointFiveWithSignRaw);
2369 SDValue AdjustedA = DAG.getNode(Opcode: ISD::FADD, DL: SL, VT, N1: A, N2: PointFiveWithSign);
2370 SDValue RoundedA = DAG.getNode(Opcode: ISD::FTRUNC, DL: SL, VT, Operand: AdjustedA);
2371
2372 // RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
2373 EVT SetCCVT = getSetCCResultType(DL: DAG.getDataLayout(), Ctx&: *DAG.getContext(), VT);
2374 SDValue IsLarge =
2375 DAG.getSetCC(DL: SL, VT: SetCCVT, LHS: AbsA, RHS: DAG.getConstantFP(Val: pow(x: 2.0, y: 23.0), DL: SL, VT),
2376 Cond: ISD::SETOGT);
2377 RoundedA = DAG.getNode(Opcode: ISD::SELECT, DL: SL, VT, N1: IsLarge, N2: A, N3: RoundedA);
2378
2379 // return abs(A) < 0.5 ? (float)(int)A : RoundedA;
2380 SDValue IsSmall =DAG.getSetCC(DL: SL, VT: SetCCVT, LHS: AbsA,
2381 RHS: DAG.getConstantFP(Val: 0.5, DL: SL, VT), Cond: ISD::SETOLT);
2382 SDValue RoundedAForSmallA = DAG.getNode(Opcode: ISD::FTRUNC, DL: SL, VT, Operand: A);
2383 return DAG.getNode(Opcode: ISD::SELECT, DL: SL, VT, N1: IsSmall, N2: RoundedAForSmallA, N3: RoundedA);
2384}
2385
2386// The implementation of round(double) is similar to that of round(float) in
2387// that they both separate the value range into three regions and use a method
2388// specific to the region to round the values. However, round(double) first
2389// calculates the round of the absolute value and then adds the sign back while
2390// round(float) directly rounds the value with sign.
2391SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
2392 SelectionDAG &DAG) const {
2393 SDLoc SL(Op);
2394 SDValue A = Op.getOperand(i: 0);
2395 EVT VT = Op.getValueType();
2396
2397 SDValue AbsA = DAG.getNode(Opcode: ISD::FABS, DL: SL, VT, Operand: A);
2398
2399 // double RoundedA = (double) (int) (abs(A) + 0.5f);
2400 SDValue AdjustedA = DAG.getNode(Opcode: ISD::FADD, DL: SL, VT, N1: AbsA,
2401 N2: DAG.getConstantFP(Val: 0.5, DL: SL, VT));
2402 SDValue RoundedA = DAG.getNode(Opcode: ISD::FTRUNC, DL: SL, VT, Operand: AdjustedA);
2403
2404 // RoundedA = abs(A) < 0.5 ? (double)0 : RoundedA;
2405 EVT SetCCVT = getSetCCResultType(DL: DAG.getDataLayout(), Ctx&: *DAG.getContext(), VT);
2406 SDValue IsSmall =DAG.getSetCC(DL: SL, VT: SetCCVT, LHS: AbsA,
2407 RHS: DAG.getConstantFP(Val: 0.5, DL: SL, VT), Cond: ISD::SETOLT);
2408 RoundedA = DAG.getNode(Opcode: ISD::SELECT, DL: SL, VT, N1: IsSmall,
2409 N2: DAG.getConstantFP(Val: 0, DL: SL, VT),
2410 N3: RoundedA);
2411
2412 // Add sign to rounded_A
2413 RoundedA = DAG.getNode(Opcode: ISD::FCOPYSIGN, DL: SL, VT, N1: RoundedA, N2: A);
2414 DAG.getNode(Opcode: ISD::FTRUNC, DL: SL, VT, Operand: A);
2415
2416 // RoundedA = abs(A) > 0x1.0p52 ? A : RoundedA;
2417 SDValue IsLarge =
2418 DAG.getSetCC(DL: SL, VT: SetCCVT, LHS: AbsA, RHS: DAG.getConstantFP(Val: pow(x: 2.0, y: 52.0), DL: SL, VT),
2419 Cond: ISD::SETOGT);
2420 return DAG.getNode(Opcode: ISD::SELECT, DL: SL, VT, N1: IsLarge, N2: A, N3: RoundedA);
2421}
2422
2423static SDValue PromoteBinOpToF32(SDNode *N, SelectionDAG &DAG) {
2424 EVT VT = N->getValueType(ResNo: 0);
2425 EVT NVT = MVT::f32;
2426 if (VT.isVector()) {
2427 NVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: NVT, EC: VT.getVectorElementCount());
2428 }
2429 SDLoc DL(N);
2430 SDValue Tmp0 = DAG.getFPExtendOrRound(Op: N->getOperand(Num: 0), DL, VT: NVT);
2431 SDValue Tmp1 = DAG.getFPExtendOrRound(Op: N->getOperand(Num: 1), DL, VT: NVT);
2432 SDValue Res = DAG.getNode(Opcode: N->getOpcode(), DL, VT: NVT, N1: Tmp0, N2: Tmp1, Flags: N->getFlags());
2433 return DAG.getFPExtendOrRound(Op: Res, DL, VT);
2434}
2435
2436SDValue NVPTXTargetLowering::PromoteBinOpIfF32FTZ(SDValue Op,
2437 SelectionDAG &DAG) const {
2438 if (useF32FTZ(MF: DAG.getMachineFunction())) {
2439 return PromoteBinOpToF32(N: Op.getNode(), DAG);
2440 }
2441 return Op;
2442}
2443
2444SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
2445 SelectionDAG &DAG) const {
2446 assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
2447
2448 if (Op.getValueType() == MVT::bf16) {
2449 SDLoc Loc(Op);
2450 return DAG.getNode(
2451 Opcode: ISD::FP_ROUND, DL: Loc, VT: MVT::bf16,
2452 N1: DAG.getNode(Opcode: Op.getOpcode(), DL: Loc, VT: MVT::f32, Operand: Op.getOperand(i: 0)),
2453 N2: DAG.getIntPtrConstant(Val: 0, DL: Loc, /*isTarget=*/true));
2454 }
2455
2456 // Everything else is considered legal.
2457 return Op;
2458}
2459
2460SDValue NVPTXTargetLowering::LowerFP_TO_INT(SDValue Op,
2461 SelectionDAG &DAG) const {
2462 assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
2463
2464 if (Op.getOperand(i: 0).getValueType() == MVT::bf16) {
2465 SDLoc Loc(Op);
2466 return DAG.getNode(
2467 Opcode: Op.getOpcode(), DL: Loc, VT: Op.getValueType(),
2468 Operand: DAG.getNode(Opcode: ISD::FP_EXTEND, DL: Loc, VT: MVT::f32, Operand: Op.getOperand(i: 0)));
2469 }
2470
2471 // Everything else is considered legal.
2472 return Op;
2473}
2474
2475SDValue NVPTXTargetLowering::LowerFP_ROUND(SDValue Op,
2476 SelectionDAG &DAG) const {
2477 EVT NarrowVT = Op.getValueType();
2478 SDValue Wide = Op.getOperand(i: 0);
2479 EVT WideVT = Wide.getValueType();
2480 if (NarrowVT.getScalarType() == MVT::bf16) {
2481 const TargetLowering *TLI = STI.getTargetLowering();
2482 if (STI.getSmVersion() < 80 || STI.getPTXVersion() < 70) {
2483 return TLI->expandFP_ROUND(Node: Op.getNode(), DAG);
2484 }
2485 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
2486 // This combination was the first to support f32 -> bf16.
2487 if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70) {
2488 if (WideVT.getScalarType() == MVT::f32) {
2489 return Op;
2490 }
2491 if (WideVT.getScalarType() == MVT::f64) {
2492 SDLoc Loc(Op);
2493 // Round-inexact-to-odd f64 to f32, then do the final rounding using
2494 // the hardware f32 -> bf16 instruction.
2495 SDValue rod = TLI->expandRoundInexactToOdd(
2496 ResultVT: WideVT.changeElementType(Context&: *DAG.getContext(), EltVT: MVT::f32), Op: Wide, DL: Loc,
2497 DAG);
2498 return DAG.getFPExtendOrRound(Op: rod, DL: Loc, VT: NarrowVT);
2499 }
2500 }
2501 return TLI->expandFP_ROUND(Node: Op.getNode(), DAG);
2502 }
2503 }
2504
2505 // Everything else is considered legal.
2506 return Op;
2507}
2508
2509SDValue NVPTXTargetLowering::LowerFP_EXTEND(SDValue Op,
2510 SelectionDAG &DAG) const {
2511 SDValue Narrow = Op.getOperand(i: 0);
2512 EVT NarrowVT = Narrow.getValueType();
2513 EVT WideVT = Op.getValueType();
2514 if (NarrowVT.getScalarType() == MVT::bf16) {
2515 if (WideVT.getScalarType() == MVT::f32 &&
2516 (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71)) {
2517 SDLoc Loc(Op);
2518 return DAG.getNode(Opcode: ISD::BF16_TO_FP, DL: Loc, VT: WideVT, Operand: Narrow);
2519 }
2520 if (WideVT.getScalarType() == MVT::f64 &&
2521 (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78)) {
2522 EVT F32 = NarrowVT.changeElementType(Context&: *DAG.getContext(), EltVT: MVT::f32);
2523 SDLoc Loc(Op);
2524 if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 71) {
2525 Op = DAG.getNode(Opcode: ISD::FP_EXTEND, DL: Loc, VT: F32, Operand: Narrow);
2526 } else {
2527 Op = DAG.getNode(Opcode: ISD::BF16_TO_FP, DL: Loc, VT: F32, Operand: Narrow);
2528 }
2529 return DAG.getNode(Opcode: ISD::FP_EXTEND, DL: Loc, VT: WideVT, Operand: Op);
2530 }
2531 }
2532
2533 // Everything else is considered legal.
2534 return Op;
2535}
2536
2537static SDValue LowerVectorArith(SDValue Op, SelectionDAG &DAG) {
2538 SDLoc DL(Op);
2539 if (Op.getValueType() != MVT::v2i16)
2540 return Op;
2541 EVT EltVT = Op.getValueType().getVectorElementType();
2542 SmallVector<SDValue> VecElements;
2543 for (int I = 0, E = Op.getValueType().getVectorNumElements(); I < E; I++) {
2544 SmallVector<SDValue> ScalarArgs;
2545 llvm::transform(Range: Op->ops(), d_first: std::back_inserter(x&: ScalarArgs),
2546 F: [&](const SDUse &O) {
2547 return DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT,
2548 N1: O.get(), N2: DAG.getIntPtrConstant(Val: I, DL));
2549 });
2550 VecElements.push_back(Elt: DAG.getNode(Opcode: Op.getOpcode(), DL, VT: EltVT, Ops: ScalarArgs));
2551 }
2552 SDValue V =
2553 DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: Op.getValueType(), Ops: VecElements);
2554 return V;
2555}
2556
2557static SDValue lowerTcgen05St(SDValue Op, SelectionDAG &DAG,
2558 bool hasOffset = false) {
2559 // skip lowering if the vector operand is already legalized
2560 if (!Op->getOperand(Num: hasOffset ? 4 : 3).getValueType().isVector())
2561 return Op;
2562
2563 SDNode *N = Op.getNode();
2564 SDLoc DL(N);
2565 SmallVector<SDValue, 32> Ops;
2566
2567 // split the vector argument
2568 for (size_t I = 0; I < N->getNumOperands(); I++) {
2569 SDValue Val = N->getOperand(Num: I);
2570 EVT ValVT = Val.getValueType();
2571 if (ValVT.isVector()) {
2572 EVT EltVT = ValVT.getVectorElementType();
2573 for (unsigned J = 0, NElts = ValVT.getVectorNumElements(); J < NElts; J++)
2574 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: Val,
2575 N2: DAG.getIntPtrConstant(Val: J, DL)));
2576 } else
2577 Ops.push_back(Elt: Val);
2578 }
2579
2580 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
2581 SDValue Tcgen05StNode =
2582 DAG.getMemIntrinsicNode(Opcode: ISD::INTRINSIC_VOID, dl: DL, VTList: N->getVTList(), Ops,
2583 MemVT: MemSD->getMemoryVT(), MMO: MemSD->getMemOperand());
2584
2585 return Tcgen05StNode;
2586}
2587
2588static SDValue lowerBSWAP(SDValue Op, SelectionDAG &DAG) {
2589 SDLoc DL(Op);
2590 SDValue Src = Op.getOperand(i: 0);
2591 EVT VT = Op.getValueType();
2592
2593 switch (VT.getSimpleVT().SimpleTy) {
2594 case MVT::i16: {
2595 SDValue Extended = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT: MVT::i32, Operand: Src);
2596 SDValue Swapped =
2597 getPRMT(A: Extended, B: DAG.getConstant(Val: 0, DL, VT: MVT::i32), Selector: 0x7701, DL, DAG);
2598 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MVT::i16, Operand: Swapped);
2599 }
2600 case MVT::i32: {
2601 return getPRMT(A: Src, B: DAG.getConstant(Val: 0, DL, VT: MVT::i32), Selector: 0x0123, DL, DAG);
2602 }
2603 case MVT::v2i16: {
2604 SDValue Converted = DAG.getBitcast(VT: MVT::i32, V: Src);
2605 SDValue Swapped =
2606 getPRMT(A: Converted, B: DAG.getConstant(Val: 0, DL, VT: MVT::i32), Selector: 0x2301, DL, DAG);
2607 return DAG.getNode(Opcode: ISD::BITCAST, DL, VT: MVT::v2i16, Operand: Swapped);
2608 }
2609 case MVT::i64: {
2610 SDValue UnpackSrc =
2611 DAG.getNode(Opcode: NVPTXISD::UNPACK_VECTOR, DL, ResultTys: {MVT::i32, MVT::i32}, Ops: Src);
2612 SDValue SwappedLow =
2613 getPRMT(A: UnpackSrc.getValue(R: 0), B: DAG.getConstant(Val: 0, DL, VT: MVT::i32), Selector: 0x0123,
2614 DL, DAG);
2615 SDValue SwappedHigh =
2616 getPRMT(A: UnpackSrc.getValue(R: 1), B: DAG.getConstant(Val: 0, DL, VT: MVT::i32), Selector: 0x0123,
2617 DL, DAG);
2618 return DAG.getNode(Opcode: NVPTXISD::BUILD_VECTOR, DL, VT: MVT::i64,
2619 Ops: {SwappedHigh, SwappedLow});
2620 }
2621 default:
2622 llvm_unreachable("unsupported type for bswap");
2623 }
2624}
2625
2626static unsigned getTcgen05MMADisableOutputLane(unsigned IID) {
2627 switch (IID) {
2628 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1:
2629 return NVPTXISD::TCGEN05_MMA_SHARED_DISABLE_OUTPUT_LANE_CG1;
2630 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2:
2631 return NVPTXISD::TCGEN05_MMA_SHARED_DISABLE_OUTPUT_LANE_CG2;
2632 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1:
2633 return NVPTXISD::TCGEN05_MMA_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG1;
2634 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2:
2635 return NVPTXISD::TCGEN05_MMA_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG2;
2636 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1:
2637 return NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG1;
2638 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2:
2639 return NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG2;
2640 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1:
2641 return NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1;
2642 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2:
2643 return NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2;
2644 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift:
2645 return NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG1_ASHIFT;
2646 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift:
2647 return NVPTXISD::TCGEN05_MMA_TENSOR_DISABLE_OUTPUT_LANE_CG2_ASHIFT;
2648 case Intrinsic::
2649 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift:
2650 return NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1_ASHIFT;
2651 case Intrinsic::
2652 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift:
2653 return NVPTXISD::TCGEN05_MMA_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2_ASHIFT;
2654 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1:
2655 return NVPTXISD::TCGEN05_MMA_SP_SHARED_DISABLE_OUTPUT_LANE_CG1;
2656 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2:
2657 return NVPTXISD::TCGEN05_MMA_SP_SHARED_DISABLE_OUTPUT_LANE_CG2;
2658 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1:
2659 return NVPTXISD::TCGEN05_MMA_SP_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG1;
2660 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2:
2661 return NVPTXISD::TCGEN05_MMA_SP_SHARED_SCALE_D_DISABLE_OUTPUT_LANE_CG2;
2662 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1:
2663 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG1;
2664 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2:
2665 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG2;
2666 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift:
2667 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG1_ASHIFT;
2668 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift:
2669 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_DISABLE_OUTPUT_LANE_CG2_ASHIFT;
2670 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1:
2671 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1;
2672 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2:
2673 return NVPTXISD::TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2;
2674 case Intrinsic::
2675 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift:
2676 return NVPTXISD::
2677 TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1_ASHIFT;
2678 case Intrinsic::
2679 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift:
2680 return NVPTXISD::
2681 TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2_ASHIFT;
2682 };
2683 llvm_unreachable("unhandled tcgen05.mma.disable_output_lane intrinsic");
2684}
2685
2686static SDValue LowerTcgen05MMADisableOutputLane(SDValue Op, SelectionDAG &DAG) {
2687 SDNode *N = Op.getNode();
2688 SDLoc DL(N);
2689 unsigned IID = cast<ConstantSDNode>(Val: N->getOperand(Num: 1))->getZExtValue();
2690
2691 SmallVector<SDValue, 16> Ops;
2692 // split the vector argument
2693 for (size_t I = 0; I < N->getNumOperands(); I++) {
2694 if (I == 1)
2695 continue; // skip IID
2696 SDValue Val = N->getOperand(Num: I);
2697 EVT ValVT = Val.getValueType();
2698 if (ValVT.isVector()) {
2699 EVT EltVT = ValVT.getVectorElementType();
2700 for (unsigned J = 0, NElts = ValVT.getVectorNumElements(); J < NElts; J++)
2701 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: Val,
2702 N2: DAG.getIntPtrConstant(Val: J, DL)));
2703 } else
2704 Ops.push_back(Elt: Val);
2705 }
2706
2707 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
2708 SDValue Tcgen05MMANode = DAG.getMemIntrinsicNode(
2709 Opcode: getTcgen05MMADisableOutputLane(IID), dl: DL, VTList: N->getVTList(), Ops,
2710 MemVT: MemSD->getMemoryVT(), MMO: MemSD->getMemOperand());
2711
2712 return Tcgen05MMANode;
2713}
2714
2715// Lower vector return type of tcgen05.ld intrinsics
2716static std::optional<std::pair<SDValue, SDValue>>
2717lowerTcgen05Ld(SDNode *N, SelectionDAG &DAG, bool HasOffset = false) {
2718 SDLoc DL(N);
2719 EVT ResVT = N->getValueType(ResNo: 0);
2720 if (!ResVT.isVector())
2721 return {}; // already legalized.
2722
2723 const unsigned NumElts = ResVT.getVectorNumElements();
2724
2725 // Create the return type of the instructions
2726 SmallVector<EVT, 5> ListVTs;
2727 for (unsigned i = 0; i < NumElts; ++i)
2728 ListVTs.push_back(Elt: MVT::i32);
2729
2730 ListVTs.push_back(Elt: N->getValueType(ResNo: 1)); // Chain
2731
2732 SDVTList ResVTs = DAG.getVTList(VTs: ListVTs);
2733
2734 SmallVector<SDValue, 8> Ops{N->getOperand(Num: 0), N->getOperand(Num: 1),
2735 N->getOperand(Num: 2)};
2736
2737 if (HasOffset) {
2738 Ops.push_back(Elt: N->getOperand(Num: 3)); // offset
2739 Ops.push_back(Elt: N->getOperand(Num: 4)); // Pack flag
2740 } else
2741 Ops.push_back(Elt: N->getOperand(Num: 3)); // Pack flag
2742
2743 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
2744 SDValue NewNode =
2745 DAG.getMemIntrinsicNode(Opcode: ISD::INTRINSIC_W_CHAIN, dl: DL, VTList: ResVTs, Ops,
2746 MemVT: MemSD->getMemoryVT(), MMO: MemSD->getMemOperand());
2747
2748 // split the vector result
2749 SmallVector<SDValue, 4> ScalarRes;
2750 for (unsigned i = 0; i < NumElts; ++i) {
2751 SDValue Res = NewNode.getValue(R: i);
2752 ScalarRes.push_back(Elt: Res);
2753 }
2754
2755 SDValue Chain = NewNode.getValue(R: NumElts);
2756 SDValue BuildVector = DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: ResVT, Ops: ScalarRes);
2757 return {{BuildVector, Chain}};
2758}
2759
2760static SDValue reportInvalidTensormapReplaceUsage(SDValue Op, SelectionDAG &DAG,
2761 unsigned Val) {
2762 SDNode *N = Op.getNode();
2763 SDLoc DL(N);
2764
2765 const Function &Fn = DAG.getMachineFunction().getFunction();
2766
2767 unsigned AS = 0;
2768 if (auto *MemN = dyn_cast<MemIntrinsicSDNode>(Val: N))
2769 AS = MemN->getAddressSpace();
2770 Type *PtrTy = PointerType::get(C&: *DAG.getContext(), AddressSpace: AS);
2771 Module *M = DAG.getMachineFunction().getFunction().getParent();
2772
2773 DAG.getContext()->diagnose(DI: DiagnosticInfoUnsupported(
2774 Fn,
2775 "Intrinsic " +
2776 Intrinsic::getName(Id: N->getConstantOperandVal(Num: 1), OverloadTys: {PtrTy}, M) +
2777 " with value " + Twine(Val) +
2778 " is not supported on the given target.",
2779 DL.getDebugLoc()));
2780 return Op.getOperand(i: 0);
2781}
2782
2783static SDValue lowerTensormapReplaceElemtype(SDValue Op, SelectionDAG &DAG) {
2784 SDNode *N = Op.getNode();
2785 SDLoc DL(N);
2786
2787 // immediate argument representing elemtype
2788 unsigned Val = N->getConstantOperandVal(Num: 3);
2789
2790 if (!DAG.getSubtarget<NVPTXSubtarget>().hasTensormapReplaceElemtypeSupport(
2791 value: Val))
2792 return reportInvalidTensormapReplaceUsage(Op, DAG, Val);
2793
2794 return Op;
2795}
2796
2797static SDValue lowerTensormapReplaceSwizzleMode(SDValue Op, SelectionDAG &DAG) {
2798 SDNode *N = Op.getNode();
2799 SDLoc DL(N);
2800
2801 // immediate argument representing swizzle mode
2802 unsigned Val = N->getConstantOperandVal(Num: 3);
2803
2804 if (!DAG.getSubtarget<NVPTXSubtarget>().hasTensormapReplaceSwizzleModeSupport(
2805 value: Val))
2806 return reportInvalidTensormapReplaceUsage(Op, DAG, Val);
2807
2808 return Op;
2809}
2810
2811static SDValue lowerIntrinsicVoid(SDValue Op, SelectionDAG &DAG) {
2812 SDNode *N = Op.getNode();
2813 SDValue Intrin = N->getOperand(Num: 1);
2814
2815 // Get the intrinsic ID
2816 unsigned IntrinNo = cast<ConstantSDNode>(Val: Intrin.getNode())->getZExtValue();
2817 switch (IntrinNo) {
2818 default:
2819 break;
2820 case Intrinsic::nvvm_tcgen05_st_16x64b_x2:
2821 case Intrinsic::nvvm_tcgen05_st_16x64b_x4:
2822 case Intrinsic::nvvm_tcgen05_st_16x64b_x8:
2823 case Intrinsic::nvvm_tcgen05_st_16x64b_x16:
2824 case Intrinsic::nvvm_tcgen05_st_16x64b_x32:
2825 case Intrinsic::nvvm_tcgen05_st_16x64b_x128:
2826 case Intrinsic::nvvm_tcgen05_st_16x128b_x1:
2827 case Intrinsic::nvvm_tcgen05_st_16x128b_x2:
2828 case Intrinsic::nvvm_tcgen05_st_16x128b_x4:
2829 case Intrinsic::nvvm_tcgen05_st_16x128b_x8:
2830 case Intrinsic::nvvm_tcgen05_st_16x128b_x16:
2831 case Intrinsic::nvvm_tcgen05_st_16x128b_x32:
2832 case Intrinsic::nvvm_tcgen05_st_16x128b_x64:
2833 case Intrinsic::nvvm_tcgen05_st_16x256b_x1:
2834 case Intrinsic::nvvm_tcgen05_st_16x256b_x2:
2835 case Intrinsic::nvvm_tcgen05_st_16x256b_x4:
2836 case Intrinsic::nvvm_tcgen05_st_16x256b_x8:
2837 case Intrinsic::nvvm_tcgen05_st_16x256b_x16:
2838 case Intrinsic::nvvm_tcgen05_st_16x256b_x32:
2839 case Intrinsic::nvvm_tcgen05_st_32x32b_x2:
2840 case Intrinsic::nvvm_tcgen05_st_32x32b_x4:
2841 case Intrinsic::nvvm_tcgen05_st_32x32b_x8:
2842 case Intrinsic::nvvm_tcgen05_st_32x32b_x16:
2843 case Intrinsic::nvvm_tcgen05_st_32x32b_x32:
2844 case Intrinsic::nvvm_tcgen05_st_16x64b_x64:
2845 case Intrinsic::nvvm_tcgen05_st_32x32b_x64:
2846 case Intrinsic::nvvm_tcgen05_st_32x32b_x128:
2847 return lowerTcgen05St(Op, DAG);
2848 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x2:
2849 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x4:
2850 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x8:
2851 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x16:
2852 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x32:
2853 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x64:
2854 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x128:
2855 return lowerTcgen05St(Op, DAG, /* hasOffset */ true);
2856 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1:
2857 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2:
2858 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1:
2859 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2:
2860 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1:
2861 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2:
2862 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1:
2863 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2:
2864 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1:
2865 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2:
2866 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1:
2867 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2:
2868 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1:
2869 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2:
2870 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1:
2871 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2:
2872 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift:
2873 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift:
2874 case Intrinsic::
2875 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift:
2876 case Intrinsic::
2877 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift:
2878 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift:
2879 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift:
2880 case Intrinsic::
2881 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift:
2882 case Intrinsic::
2883 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift:
2884 return LowerTcgen05MMADisableOutputLane(Op, DAG);
2885 case Intrinsic::nvvm_tensormap_replace_elemtype:
2886 return lowerTensormapReplaceElemtype(Op, DAG);
2887 case Intrinsic::nvvm_tensormap_replace_swizzle_mode:
2888 return lowerTensormapReplaceSwizzleMode(Op, DAG);
2889 }
2890 return Op;
2891}
2892
2893static SDValue LowerClusterLaunchControlQueryCancel(SDValue Op,
2894 SelectionDAG &DAG) {
2895
2896 SDNode *N = Op.getNode();
2897 if (N->getOperand(Num: 1).getValueType() != MVT::i128) {
2898 // return, if the operand is already lowered
2899 return SDValue();
2900 }
2901
2902 unsigned IID =
2903 cast<ConstantSDNode>(Val: N->getOperand(Num: 0).getNode())->getZExtValue();
2904 auto Opcode = [&]() {
2905 switch (IID) {
2906 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled:
2907 return NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_IS_CANCELED;
2908 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x:
2909 return NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_X;
2910 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y:
2911 return NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Y;
2912 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z:
2913 return NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Z;
2914 default:
2915 llvm_unreachable("unsupported/unhandled intrinsic");
2916 }
2917 }();
2918
2919 SDLoc DL(N);
2920 SDValue TryCancelResponse = N->getOperand(Num: 1);
2921 SDValue Cast = DAG.getNode(Opcode: ISD::BITCAST, DL, VT: MVT::v2i64, Operand: TryCancelResponse);
2922 SDValue TryCancelResponse0 =
2923 DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i64, N1: Cast,
2924 N2: DAG.getIntPtrConstant(Val: 0, DL));
2925 SDValue TryCancelResponse1 =
2926 DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i64, N1: Cast,
2927 N2: DAG.getIntPtrConstant(Val: 1, DL));
2928
2929 return DAG.getNode(Opcode, DL, VTList: N->getVTList(),
2930 Ops: {TryCancelResponse0, TryCancelResponse1});
2931}
2932
2933static SDValue lowerCvtRSIntrinsics(SDValue Op, SelectionDAG &DAG) {
2934 SDNode *N = Op.getNode();
2935 SDLoc DL(N);
2936 SDValue F32Vec = N->getOperand(Num: 1);
2937 SDValue RBits = N->getOperand(Num: 2);
2938
2939 unsigned IntrinsicID = N->getConstantOperandVal(Num: 0);
2940
2941 // Extract the 4 float elements from the vector
2942 SmallVector<SDValue, 6> Ops;
2943 for (unsigned i = 0; i < 4; ++i)
2944 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::f32, N1: F32Vec,
2945 N2: DAG.getIntPtrConstant(Val: i, DL)));
2946
2947 using NVPTX::PTXCvtMode::CvtMode;
2948
2949 auto [OpCode, RetTy, CvtModeFlag] =
2950 [&]() -> std::tuple<unsigned, MVT::SimpleValueType, uint32_t> {
2951 switch (IntrinsicID) {
2952 case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite:
2953 return {NVPTXISD::CVT_E4M3X4_F32X4_RS_SF, MVT::v4i8,
2954 CvtMode::RS | CvtMode::RELU_FLAG};
2955 case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite:
2956 return {NVPTXISD::CVT_E4M3X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
2957 case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite:
2958 return {NVPTXISD::CVT_E5M2X4_F32X4_RS_SF, MVT::v4i8,
2959 CvtMode::RS | CvtMode::RELU_FLAG};
2960 case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite:
2961 return {NVPTXISD::CVT_E5M2X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
2962 case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite:
2963 return {NVPTXISD::CVT_E2M3X4_F32X4_RS_SF, MVT::v4i8,
2964 CvtMode::RS | CvtMode::RELU_FLAG};
2965 case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite:
2966 return {NVPTXISD::CVT_E2M3X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
2967 case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite:
2968 return {NVPTXISD::CVT_E3M2X4_F32X4_RS_SF, MVT::v4i8,
2969 CvtMode::RS | CvtMode::RELU_FLAG};
2970 case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite:
2971 return {NVPTXISD::CVT_E3M2X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
2972 case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite:
2973 return {NVPTXISD::CVT_E2M1X4_F32X4_RS_SF, MVT::i16,
2974 CvtMode::RS | CvtMode::RELU_FLAG};
2975 case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite:
2976 return {NVPTXISD::CVT_E2M1X4_F32X4_RS_SF, MVT::i16, CvtMode::RS};
2977 default:
2978 llvm_unreachable("unsupported/unhandled intrinsic");
2979 }
2980 }();
2981
2982 Ops.push_back(Elt: RBits);
2983 Ops.push_back(Elt: DAG.getConstant(Val: CvtModeFlag, DL, VT: MVT::i32));
2984
2985 return DAG.getNode(Opcode: OpCode, DL, VT: RetTy, Ops);
2986}
2987
2988static SDValue lowerPrmtIntrinsic(SDValue Op, SelectionDAG &DAG) {
2989 const unsigned Mode = [&]() {
2990 switch (Op->getConstantOperandVal(Num: 0)) {
2991 case Intrinsic::nvvm_prmt:
2992 return NVPTX::PTXPrmtMode::NONE;
2993 case Intrinsic::nvvm_prmt_b4e:
2994 return NVPTX::PTXPrmtMode::B4E;
2995 case Intrinsic::nvvm_prmt_ecl:
2996 return NVPTX::PTXPrmtMode::ECL;
2997 case Intrinsic::nvvm_prmt_ecr:
2998 return NVPTX::PTXPrmtMode::ECR;
2999 case Intrinsic::nvvm_prmt_f4e:
3000 return NVPTX::PTXPrmtMode::F4E;
3001 case Intrinsic::nvvm_prmt_rc16:
3002 return NVPTX::PTXPrmtMode::RC16;
3003 case Intrinsic::nvvm_prmt_rc8:
3004 return NVPTX::PTXPrmtMode::RC8;
3005 default:
3006 llvm_unreachable("unsupported/unhandled intrinsic");
3007 }
3008 }();
3009 SDLoc DL(Op);
3010 SDValue A = Op->getOperand(Num: 1);
3011 SDValue B = Op.getNumOperands() == 4 ? Op.getOperand(i: 2)
3012 : DAG.getConstant(Val: 0, DL, VT: MVT::i32);
3013 SDValue Selector = (Op->op_end() - 1)->get();
3014 return getPRMT(A, B, Selector, DL, DAG, Mode);
3015}
3016
3017#define TCGEN05_LD_RED_INTR(SHAPE, NUM, TYPE) \
3018 Intrinsic::nvvm_tcgen05_ld_red_##SHAPE##_x##NUM##_##TYPE
3019
3020#define TCGEN05_LD_RED_INST(SHAPE, NUM, TYPE) \
3021 NVPTXISD::TCGEN05_LD_RED_##SHAPE##_X##NUM##_##TYPE
3022
3023static unsigned getTcgen05LdRedID(Intrinsic::ID IID) {
3024 switch (IID) {
3025 case TCGEN05_LD_RED_INTR(32x32b, 2, f32):
3026 return TCGEN05_LD_RED_INST(32x32b, 2, F32);
3027 case TCGEN05_LD_RED_INTR(32x32b, 4, f32):
3028 return TCGEN05_LD_RED_INST(32x32b, 4, F32);
3029 case TCGEN05_LD_RED_INTR(32x32b, 8, f32):
3030 return TCGEN05_LD_RED_INST(32x32b, 8, F32);
3031 case TCGEN05_LD_RED_INTR(32x32b, 16, f32):
3032 return TCGEN05_LD_RED_INST(32x32b, 16, F32);
3033 case TCGEN05_LD_RED_INTR(32x32b, 32, f32):
3034 return TCGEN05_LD_RED_INST(32x32b, 32, F32);
3035 case TCGEN05_LD_RED_INTR(32x32b, 64, f32):
3036 return TCGEN05_LD_RED_INST(32x32b, 64, F32);
3037 case TCGEN05_LD_RED_INTR(32x32b, 128, f32):
3038 return TCGEN05_LD_RED_INST(32x32b, 128, F32);
3039 case TCGEN05_LD_RED_INTR(16x32bx2, 2, f32):
3040 return TCGEN05_LD_RED_INST(16x32bx2, 2, F32);
3041 case TCGEN05_LD_RED_INTR(16x32bx2, 4, f32):
3042 return TCGEN05_LD_RED_INST(16x32bx2, 4, F32);
3043 case TCGEN05_LD_RED_INTR(16x32bx2, 8, f32):
3044 return TCGEN05_LD_RED_INST(16x32bx2, 8, F32);
3045 case TCGEN05_LD_RED_INTR(16x32bx2, 16, f32):
3046 return TCGEN05_LD_RED_INST(16x32bx2, 16, F32);
3047 case TCGEN05_LD_RED_INTR(16x32bx2, 32, f32):
3048 return TCGEN05_LD_RED_INST(16x32bx2, 32, F32);
3049 case TCGEN05_LD_RED_INTR(16x32bx2, 64, f32):
3050 return TCGEN05_LD_RED_INST(16x32bx2, 64, F32);
3051 case TCGEN05_LD_RED_INTR(16x32bx2, 128, f32):
3052 return TCGEN05_LD_RED_INST(16x32bx2, 128, F32);
3053 case TCGEN05_LD_RED_INTR(32x32b, 2, i32):
3054 return TCGEN05_LD_RED_INST(32x32b, 2, I32);
3055 case TCGEN05_LD_RED_INTR(32x32b, 4, i32):
3056 return TCGEN05_LD_RED_INST(32x32b, 4, I32);
3057 case TCGEN05_LD_RED_INTR(32x32b, 8, i32):
3058 return TCGEN05_LD_RED_INST(32x32b, 8, I32);
3059 case TCGEN05_LD_RED_INTR(32x32b, 16, i32):
3060 return TCGEN05_LD_RED_INST(32x32b, 16, I32);
3061 case TCGEN05_LD_RED_INTR(32x32b, 32, i32):
3062 return TCGEN05_LD_RED_INST(32x32b, 32, I32);
3063 case TCGEN05_LD_RED_INTR(32x32b, 64, i32):
3064 return TCGEN05_LD_RED_INST(32x32b, 64, I32);
3065 case TCGEN05_LD_RED_INTR(32x32b, 128, i32):
3066 return TCGEN05_LD_RED_INST(32x32b, 128, I32);
3067 case TCGEN05_LD_RED_INTR(16x32bx2, 2, i32):
3068 return TCGEN05_LD_RED_INST(16x32bx2, 2, I32);
3069 case TCGEN05_LD_RED_INTR(16x32bx2, 4, i32):
3070 return TCGEN05_LD_RED_INST(16x32bx2, 4, I32);
3071 case TCGEN05_LD_RED_INTR(16x32bx2, 8, i32):
3072 return TCGEN05_LD_RED_INST(16x32bx2, 8, I32);
3073 case TCGEN05_LD_RED_INTR(16x32bx2, 16, i32):
3074 return TCGEN05_LD_RED_INST(16x32bx2, 16, I32);
3075 case TCGEN05_LD_RED_INTR(16x32bx2, 32, i32):
3076 return TCGEN05_LD_RED_INST(16x32bx2, 32, I32);
3077 case TCGEN05_LD_RED_INTR(16x32bx2, 64, i32):
3078 return TCGEN05_LD_RED_INST(16x32bx2, 64, I32);
3079 case TCGEN05_LD_RED_INTR(16x32bx2, 128, i32):
3080 return TCGEN05_LD_RED_INST(16x32bx2, 128, I32);
3081 default:
3082 llvm_unreachable("Invalid tcgen05.ld.red intrinsic ID");
3083 }
3084}
3085
3086// Lower vector return type of tcgen05.ld intrinsics
3087static std::optional<std::tuple<SDValue, SDValue, SDValue>>
3088lowerTcgen05LdRed(SDNode *N, SelectionDAG &DAG) {
3089 SDLoc DL(N);
3090 EVT ResVT = N->getValueType(ResNo: 0);
3091 if (!ResVT.isVector())
3092 return {}; // already legalized.
3093
3094 const unsigned NumElts = ResVT.getVectorNumElements();
3095
3096 // Create the return type of the instructions
3097 // +1 represents the reduction value
3098 SmallVector<EVT, 132> ListVTs{
3099 NumElts + 1,
3100 ResVT.getVectorElementType().isFloatingPoint() ? MVT::f32 : MVT::i32};
3101
3102 ListVTs.push_back(Elt: MVT::Other); // Chain
3103
3104 SDVTList ResVTs = DAG.getVTList(VTs: ListVTs);
3105
3106 // Prepare the Operands
3107 SmallVector<SDValue, 8> Ops{N->getOperand(Num: 0)}; // Chain
3108
3109 // skip IID at index 1
3110 for (unsigned i = 2; i < N->getNumOperands(); i++)
3111 Ops.push_back(Elt: N->getOperand(Num: i));
3112
3113 unsigned IID = cast<ConstantSDNode>(Val: N->getOperand(Num: 1))->getZExtValue();
3114 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
3115 SDValue NewNode =
3116 DAG.getMemIntrinsicNode(Opcode: getTcgen05LdRedID(IID), dl: DL, VTList: ResVTs, Ops,
3117 MemVT: MemSD->getMemoryVT(), MMO: MemSD->getMemOperand());
3118
3119 // Split vector result
3120 SmallVector<SDValue, 132> ScalarRes;
3121 for (unsigned i = 0; i < NumElts; ++i) {
3122 SDValue Res = NewNode.getValue(R: i);
3123 ScalarRes.push_back(Elt: Res);
3124 }
3125
3126 SDValue BuildVector = DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: ResVT, Ops: ScalarRes);
3127 SDValue RedResult = NewNode.getValue(R: NumElts);
3128 SDValue Chain = NewNode.getValue(R: NumElts + 1);
3129 return {{BuildVector, RedResult, Chain}};
3130}
3131
3132static SDValue lowerIntrinsicWChain(SDValue Op, SelectionDAG &DAG) {
3133 switch (Op->getConstantOperandVal(Num: 1)) {
3134 default:
3135 return Op;
3136
3137 // These tcgen05 intrinsics return a v2i32, which is legal, so we have to
3138 // lower them through LowerOperation() instead of ReplaceNodeResults().
3139 case Intrinsic::nvvm_tcgen05_ld_16x64b_x2:
3140 case Intrinsic::nvvm_tcgen05_ld_16x128b_x1:
3141 case Intrinsic::nvvm_tcgen05_ld_32x32b_x2:
3142 if (auto Res = lowerTcgen05Ld(N: Op.getNode(), DAG))
3143 return DAG.getMergeValues(Ops: {Res->first, Res->second}, dl: SDLoc(Op));
3144 return SDValue();
3145
3146 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x2:
3147 if (auto Res = lowerTcgen05Ld(N: Op.getNode(), DAG, /*HasOffset=*/true))
3148 return DAG.getMergeValues(Ops: {Res->first, Res->second}, dl: SDLoc(Op));
3149 return SDValue();
3150
3151 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x2_f32:
3152 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x2_i32:
3153 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x2_f32:
3154 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x2_i32:
3155 if (auto Res = lowerTcgen05LdRed(N: Op.getNode(), DAG))
3156 return DAG.getMergeValues(
3157 Ops: {std::get<0>(t&: *Res), std::get<1>(t&: *Res), std::get<2>(t&: *Res)}, dl: SDLoc(Op));
3158 return SDValue();
3159 }
3160}
3161
3162static SDValue lowerIntrinsicWOChain(SDValue Op, SelectionDAG &DAG) {
3163 switch (Op->getConstantOperandVal(Num: 0)) {
3164 default:
3165 return Op;
3166 case Intrinsic::nvvm_prmt:
3167 case Intrinsic::nvvm_prmt_b4e:
3168 case Intrinsic::nvvm_prmt_ecl:
3169 case Intrinsic::nvvm_prmt_ecr:
3170 case Intrinsic::nvvm_prmt_f4e:
3171 case Intrinsic::nvvm_prmt_rc16:
3172 case Intrinsic::nvvm_prmt_rc8:
3173 return lowerPrmtIntrinsic(Op, DAG);
3174 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled:
3175 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x:
3176 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y:
3177 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z:
3178 return LowerClusterLaunchControlQueryCancel(Op, DAG);
3179 case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite:
3180 case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite:
3181 case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite:
3182 case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite:
3183 case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite:
3184 case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite:
3185 case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite:
3186 case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite:
3187 case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite:
3188 case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite:
3189 return lowerCvtRSIntrinsics(Op, DAG);
3190 }
3191}
3192
3193// In PTX 64-bit CTLZ and CTPOP are supported, but they return a 32-bit value.
3194// Lower these into a node returning the correct type which is zero-extended
3195// back to the correct size.
3196static SDValue lowerCTLZCTPOP(SDValue Op, SelectionDAG &DAG) {
3197 SDValue V = Op->getOperand(Num: 0);
3198 assert(V.getValueType() == MVT::i64 &&
3199 "Unexpected CTLZ/CTPOP type to legalize");
3200
3201 SDLoc DL(Op);
3202 SDValue CT = DAG.getNode(Opcode: Op->getOpcode(), DL, VT: MVT::i32, Operand: V);
3203 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: MVT::i64, Operand: CT, Flags: SDNodeFlags::NonNeg);
3204}
3205
3206static SDValue expandFSH64(SDValue A, SDValue B, SDValue ShiftAmount, SDLoc DL,
3207 unsigned Opcode, SelectionDAG &DAG) {
3208 assert(A.getValueType() == MVT::i64 && B.getValueType() == MVT::i64);
3209
3210 const auto *AmtConst = dyn_cast<ConstantSDNode>(Val&: ShiftAmount);
3211 if (!AmtConst)
3212 return SDValue();
3213 const auto Amt = AmtConst->getZExtValue() & 63;
3214
3215 SDValue UnpackA =
3216 DAG.getNode(Opcode: NVPTXISD::UNPACK_VECTOR, DL, ResultTys: {MVT::i32, MVT::i32}, Ops: A);
3217 SDValue UnpackB =
3218 DAG.getNode(Opcode: NVPTXISD::UNPACK_VECTOR, DL, ResultTys: {MVT::i32, MVT::i32}, Ops: B);
3219
3220 // Arch is Little endiain: 0 = low bits, 1 = high bits
3221 SDValue ALo = UnpackA.getValue(R: 0);
3222 SDValue AHi = UnpackA.getValue(R: 1);
3223 SDValue BLo = UnpackB.getValue(R: 0);
3224 SDValue BHi = UnpackB.getValue(R: 1);
3225
3226 // The bitfeild consists of { AHi : ALo : BHi : BLo }
3227 //
3228 // * FSHL, Amt < 32 - The window will contain { AHi : ALo : BHi }
3229 // * FSHL, Amt >= 32 - The window will contain { ALo : BHi : BLo }
3230 // * FSHR, Amt < 32 - The window will contain { ALo : BHi : BLo }
3231 // * FSHR, Amt >= 32 - The window will contain { AHi : ALo : BHi }
3232 //
3233 // Note that Amt = 0 and Amt = 32 are special cases where 32-bit funnel shifts
3234 // are not needed at all. Amt = 0 is a no-op producing either A or B depending
3235 // on the direction. Amt = 32 can be implemented by a packing and unpacking
3236 // move to select and arrange the 32bit values. For simplicity, these cases
3237 // are not handled here explicitly and instead we rely on DAGCombiner to
3238 // remove the no-op funnel shifts we insert.
3239 auto [High, Mid, Low] = ((Opcode == ISD::FSHL) == (Amt < 32))
3240 ? std::make_tuple(args&: AHi, args&: ALo, args&: BHi)
3241 : std::make_tuple(args&: ALo, args&: BHi, args&: BLo);
3242
3243 SDValue NewAmt = DAG.getConstant(Val: Amt & 31, DL, VT: MVT::i32);
3244 SDValue RHi = DAG.getNode(Opcode, DL, VT: MVT::i32, Ops: {High, Mid, NewAmt});
3245 SDValue RLo = DAG.getNode(Opcode, DL, VT: MVT::i32, Ops: {Mid, Low, NewAmt});
3246
3247 return DAG.getNode(Opcode: NVPTXISD::BUILD_VECTOR, DL, VT: MVT::i64, Ops: {RLo, RHi});
3248}
3249
3250static SDValue lowerFSH(SDValue Op, SelectionDAG &DAG) {
3251 return expandFSH64(A: Op->getOperand(Num: 0), B: Op->getOperand(Num: 1), ShiftAmount: Op->getOperand(Num: 2),
3252 DL: SDLoc(Op), Opcode: Op->getOpcode(), DAG);
3253}
3254
3255static SDValue lowerROT(SDValue Op, SelectionDAG &DAG) {
3256 unsigned Opcode = Op->getOpcode() == ISD::ROTL ? ISD::FSHL : ISD::FSHR;
3257 return expandFSH64(A: Op->getOperand(Num: 0), B: Op->getOperand(Num: 0), ShiftAmount: Op->getOperand(Num: 1),
3258 DL: SDLoc(Op), Opcode, DAG);
3259}
3260
3261static SDValue lowerFREM(SDValue Op, SelectionDAG &DAG) {
3262 // Lower (frem x, y) into (sub x, (mul (ftrunc (div x, y)) y)),
3263 // i.e. "poor man's fmod()". When y is infinite, x is returned. This matches
3264 // the semantics of LLVM's frem.
3265 SDLoc DL(Op);
3266 SDValue X = Op->getOperand(Num: 0);
3267 SDValue Y = Op->getOperand(Num: 1);
3268 EVT Ty = Op.getValueType();
3269 SDNodeFlags Flags = Op->getFlags();
3270
3271 SDValue Div = DAG.getNode(Opcode: ISD::FDIV, DL, VT: Ty, N1: X, N2: Y, Flags);
3272 SDValue Trunc = DAG.getNode(Opcode: ISD::FTRUNC, DL, VT: Ty, Operand: Div, Flags);
3273 SDValue Mul = DAG.getNode(Opcode: ISD::FMUL, DL, VT: Ty, N1: Trunc, N2: Y,
3274 Flags: Flags | SDNodeFlags::AllowContract);
3275 SDValue Sub = DAG.getNode(Opcode: ISD::FSUB, DL, VT: Ty, N1: X, N2: Mul,
3276 Flags: Flags | SDNodeFlags::AllowContract);
3277
3278 if (Flags.hasNoInfs())
3279 return Sub;
3280
3281 // If Y is infinite, return X
3282 SDValue AbsY = DAG.getNode(Opcode: ISD::FABS, DL, VT: Ty, Operand: Y);
3283 SDValue Inf =
3284 DAG.getConstantFP(Val: APFloat::getInf(Sem: Ty.getFltSemantics()), DL, VT: Ty);
3285 SDValue IsInf = DAG.getSetCC(DL, VT: MVT::i1, LHS: AbsY, RHS: Inf, Cond: ISD::SETEQ);
3286 return DAG.getSelect(DL, VT: Ty, Cond: IsInf, LHS: X, RHS: Sub);
3287}
3288
3289static SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG) {
3290 assert(Op.getValueType() == MVT::i1 && "Custom lowering enabled only for i1");
3291
3292 SDValue Cond = Op->getOperand(Num: 0);
3293 SDValue TrueVal = Op->getOperand(Num: 1);
3294 SDValue FalseVal = Op->getOperand(Num: 2);
3295 SDLoc DL(Op);
3296
3297 // If both operands are truncated, we push the select through the truncates.
3298 if (TrueVal.getOpcode() == ISD::TRUNCATE &&
3299 FalseVal.getOpcode() == ISD::TRUNCATE) {
3300 TrueVal = TrueVal.getOperand(i: 0);
3301 FalseVal = FalseVal.getOperand(i: 0);
3302
3303 EVT VT = TrueVal.getSimpleValueType().bitsLE(VT: FalseVal.getSimpleValueType())
3304 ? TrueVal.getValueType()
3305 : FalseVal.getValueType();
3306 TrueVal = DAG.getAnyExtOrTrunc(Op: TrueVal, DL, VT);
3307 FalseVal = DAG.getAnyExtOrTrunc(Op: FalseVal, DL, VT);
3308 SDValue Select = DAG.getSelect(DL, VT, Cond, LHS: TrueVal, RHS: FalseVal);
3309 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MVT::i1, Operand: Select);
3310 }
3311
3312 // Otherwise, expand the select into a series of logical operations. These
3313 // often can be folded into other operations either by us or ptxas.
3314 TrueVal = DAG.getFreeze(V: TrueVal);
3315 FalseVal = DAG.getFreeze(V: FalseVal);
3316 SDValue And1 = DAG.getNode(Opcode: ISD::AND, DL, VT: MVT::i1, N1: Cond, N2: TrueVal);
3317 SDValue NotCond = DAG.getNOT(DL, Val: Cond, VT: MVT::i1);
3318 SDValue And2 = DAG.getNode(Opcode: ISD::AND, DL, VT: MVT::i1, N1: NotCond, N2: FalseVal);
3319 SDValue Or = DAG.getNode(Opcode: ISD::OR, DL, VT: MVT::i1, N1: And1, N2: And2);
3320 return Or;
3321}
3322
3323static SDValue lowerMSTORE(SDValue Op, SelectionDAG &DAG) {
3324 SDNode *N = Op.getNode();
3325
3326 SDValue Chain = N->getOperand(Num: 0);
3327 SDValue Val = N->getOperand(Num: 1);
3328 SDValue BasePtr = N->getOperand(Num: 2);
3329 SDValue Offset = N->getOperand(Num: 3);
3330 SDValue Mask = N->getOperand(Num: 4);
3331
3332 SDLoc DL(N);
3333 EVT ValVT = Val.getValueType();
3334 MemSDNode *MemSD = cast<MemSDNode>(Val: N);
3335 assert(ValVT.isVector() && "Masked vector store must have vector type");
3336 assert(MemSD->getAlign() >= DAG.getEVTAlign(ValVT) &&
3337 "Unexpected alignment for masked store");
3338
3339 unsigned Opcode = 0;
3340 switch (ValVT.getSimpleVT().SimpleTy) {
3341 default:
3342 llvm_unreachable("Unexpected masked vector store type");
3343 case MVT::v4i64:
3344 case MVT::v4f64: {
3345 Opcode = NVPTXISD::StoreV4;
3346 break;
3347 }
3348 case MVT::v8i32:
3349 case MVT::v8f32: {
3350 Opcode = NVPTXISD::StoreV8;
3351 break;
3352 }
3353 }
3354
3355 SmallVector<SDValue, 8> Ops;
3356
3357 // Construct the new SDNode. First operand is the chain.
3358 Ops.push_back(Elt: Chain);
3359
3360 // The next N operands are the values to store. Encode the mask into the
3361 // values using the sentinel register 0 to represent a masked-off element.
3362 assert(Mask.getValueType().isVector() &&
3363 Mask.getValueType().getVectorElementType() == MVT::i1 &&
3364 "Mask must be a vector of i1");
3365 assert(Mask.getOpcode() == ISD::BUILD_VECTOR &&
3366 "Mask expected to be a BUILD_VECTOR");
3367 assert(Mask.getValueType().getVectorNumElements() ==
3368 ValVT.getVectorNumElements() &&
3369 "Mask size must be the same as the vector size");
3370 for (auto [I, Op] : enumerate(First: Mask->ops())) {
3371 // Mask elements must be constants.
3372 if (Op.getNode()->getAsZExtVal() == 0) {
3373 // Append a sentinel register 0 to the Ops vector to represent a masked
3374 // off element, this will be handled in tablegen
3375 Ops.push_back(Elt: DAG.getRegister(Reg: MCRegister::NoRegister,
3376 VT: ValVT.getVectorElementType()));
3377 } else {
3378 // Extract the element from the vector to store
3379 SDValue ExtVal =
3380 DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: ValVT.getVectorElementType(),
3381 N1: Val, N2: DAG.getIntPtrConstant(Val: I, DL));
3382 Ops.push_back(Elt: ExtVal);
3383 }
3384 }
3385
3386 // Next, the pointer operand.
3387 Ops.push_back(Elt: BasePtr);
3388
3389 // Finally, the offset operand. We expect this to always be undef, and it will
3390 // be ignored in lowering, but to mirror the handling of the other vector
3391 // store instructions we include it in the new SDNode.
3392 assert(Offset.getOpcode() == ISD::UNDEF &&
3393 "Offset operand expected to be undef");
3394 Ops.push_back(Elt: Offset);
3395
3396 SDValue NewSt =
3397 DAG.getMemIntrinsicNode(Opcode, dl: DL, VTList: DAG.getVTList(VT: MVT::Other), Ops,
3398 MemVT: MemSD->getMemoryVT(), MMO: MemSD->getMemOperand());
3399
3400 return NewSt;
3401}
3402
3403SDValue
3404NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
3405 switch (Op.getOpcode()) {
3406 case ISD::RETURNADDR:
3407 return SDValue();
3408 case ISD::FRAMEADDR:
3409 return SDValue();
3410 case ISD::ADDRSPACECAST:
3411 return LowerADDRSPACECAST(Op, DAG);
3412 case ISD::INTRINSIC_W_CHAIN:
3413 return lowerIntrinsicWChain(Op, DAG);
3414 case ISD::INTRINSIC_WO_CHAIN:
3415 return lowerIntrinsicWOChain(Op, DAG);
3416 case ISD::INTRINSIC_VOID:
3417 return lowerIntrinsicVoid(Op, DAG);
3418 case ISD::BUILD_VECTOR:
3419 return LowerBUILD_VECTOR(Op, DAG);
3420 case ISD::BITCAST:
3421 return LowerBITCAST(Op, DAG);
3422 case ISD::EXTRACT_SUBVECTOR:
3423 return Op;
3424 case ISD::EXTRACT_VECTOR_ELT:
3425 return LowerEXTRACT_VECTOR_ELT(Op, DAG);
3426 case ISD::INSERT_VECTOR_ELT:
3427 return LowerINSERT_VECTOR_ELT(Op, DAG);
3428 case ISD::VECTOR_SHUFFLE:
3429 return LowerVECTOR_SHUFFLE(Op, DAG);
3430 case ISD::CONCAT_VECTORS:
3431 return LowerCONCAT_VECTORS(Op, DAG);
3432 case ISD::VECREDUCE_FMAX:
3433 case ISD::VECREDUCE_FMIN:
3434 case ISD::VECREDUCE_FMAXIMUM:
3435 case ISD::VECREDUCE_FMINIMUM:
3436 return LowerVECREDUCE(Op, DAG);
3437 case ISD::STORE:
3438 return LowerSTORE(Op, DAG);
3439 case ISD::MSTORE: {
3440 assert(STI.has256BitVectorLoadStore(
3441 cast<MemSDNode>(Op.getNode())->getAddressSpace()) &&
3442 "Masked store vector not supported on subtarget.");
3443 return lowerMSTORE(Op, DAG);
3444 }
3445 case ISD::LOAD:
3446 return LowerLOAD(Op, DAG);
3447 case ISD::MLOAD:
3448 return LowerMLOAD(Op, DAG);
3449 case ISD::SHL_PARTS:
3450 return LowerShiftLeftParts(Op, DAG);
3451 case ISD::SRA_PARTS:
3452 case ISD::SRL_PARTS:
3453 return LowerShiftRightParts(Op, DAG);
3454 case ISD::SELECT:
3455 return lowerSELECT(Op, DAG);
3456 case ISD::FROUND:
3457 return LowerFROUND(Op, DAG);
3458 case ISD::FCOPYSIGN:
3459 return LowerFCOPYSIGN(Op, DAG);
3460 case ISD::SINT_TO_FP:
3461 case ISD::UINT_TO_FP:
3462 return LowerINT_TO_FP(Op, DAG);
3463 case ISD::FP_TO_SINT:
3464 case ISD::FP_TO_UINT:
3465 // fptosi/fptoui to i1 truncate toward zero, so the only defined results
3466 // are {0,-1} (signed) and {0,1} (unsigned); every other input results in
3467 // poison. Thus we can simply lower to `x <= -1.0` or `x >= 1.0`.
3468 if (Op.getValueType() == MVT::i1) {
3469 SDLoc DL(Op);
3470 SDValue X = Op.getOperand(i: 0);
3471 bool IsSigned = Op.getOpcode() == ISD::FP_TO_SINT;
3472 return DAG.getSetCC(
3473 DL, VT: MVT::i1, LHS: X,
3474 RHS: DAG.getConstantFP(Val: IsSigned ? -1.0 : 1.0, DL, VT: X.getValueType()),
3475 Cond: IsSigned ? ISD::SETOLE : ISD::SETOGE);
3476 }
3477 return LowerFP_TO_INT(Op, DAG);
3478 case ISD::FP_ROUND:
3479 return LowerFP_ROUND(Op, DAG);
3480 case ISD::FP_EXTEND:
3481 return LowerFP_EXTEND(Op, DAG);
3482 case ISD::VAARG:
3483 return LowerVAARG(Op, DAG);
3484 case ISD::VASTART:
3485 return LowerVASTART(Op, DAG);
3486 case ISD::FSHL:
3487 case ISD::FSHR:
3488 return lowerFSH(Op, DAG);
3489 case ISD::ROTL:
3490 case ISD::ROTR:
3491 return lowerROT(Op, DAG);
3492 case ISD::ABS:
3493 case ISD::ABS_MIN_POISON:
3494 case ISD::SMIN:
3495 case ISD::SMAX:
3496 case ISD::UMIN:
3497 case ISD::UMAX:
3498 case ISD::ADD:
3499 case ISD::SUB:
3500 case ISD::MUL:
3501 case ISD::SHL:
3502 case ISD::SREM:
3503 case ISD::UREM:
3504 return LowerVectorArith(Op, DAG);
3505 case ISD::DYNAMIC_STACKALLOC:
3506 return LowerDYNAMIC_STACKALLOC(Op, DAG);
3507 case ISD::STACKRESTORE:
3508 return LowerSTACKRESTORE(Op, DAG);
3509 case ISD::STACKSAVE:
3510 return LowerSTACKSAVE(Op, DAG);
3511 case ISD::CopyToReg:
3512 return LowerCopyToReg_128(Op, DAG);
3513 case ISD::FADD:
3514 case ISD::FSUB:
3515 case ISD::FMUL:
3516 // Used only for bf16 on SM80, where we select fma for non-ftz operation
3517 return PromoteBinOpIfF32FTZ(Op, DAG);
3518 case ISD::CTPOP:
3519 case ISD::CTLZ:
3520 return lowerCTLZCTPOP(Op, DAG);
3521 case ISD::FREM:
3522 return lowerFREM(Op, DAG);
3523 case ISD::BSWAP:
3524 return lowerBSWAP(Op, DAG);
3525 default:
3526 llvm_unreachable("Custom lowering not defined for operation");
3527 }
3528}
3529
3530// This will prevent AsmPrinter from trying to print the jump tables itself.
3531unsigned NVPTXTargetLowering::getJumpTableEncoding() const {
3532 return MachineJumpTableInfo::EK_Inline;
3533}
3534
3535SDValue NVPTXTargetLowering::LowerADDRSPACECAST(SDValue Op,
3536 SelectionDAG &DAG) const {
3537 AddrSpaceCastSDNode *N = cast<AddrSpaceCastSDNode>(Val: Op.getNode());
3538 unsigned SrcAS = N->getSrcAddressSpace();
3539 unsigned DestAS = N->getDestAddressSpace();
3540 if (SrcAS != llvm::ADDRESS_SPACE_GENERIC &&
3541 DestAS != llvm::ADDRESS_SPACE_GENERIC) {
3542 // Shared and SharedCluster can be converted to each other through generic
3543 // space
3544 if ((SrcAS == llvm::ADDRESS_SPACE_SHARED &&
3545 DestAS == llvm::ADDRESS_SPACE_SHARED_CLUSTER) ||
3546 (SrcAS == llvm::ADDRESS_SPACE_SHARED_CLUSTER &&
3547 DestAS == llvm::ADDRESS_SPACE_SHARED)) {
3548 SDLoc DL(Op.getNode());
3549 const MVT GenerictVT =
3550 getPointerTy(DL: DAG.getDataLayout(), AS: ADDRESS_SPACE_GENERIC);
3551 SDValue GenericConversion = DAG.getAddrSpaceCast(
3552 dl: DL, VT: GenerictVT, Ptr: Op.getOperand(i: 0), SrcAS, DestAS: ADDRESS_SPACE_GENERIC);
3553 SDValue SharedClusterConversion =
3554 DAG.getAddrSpaceCast(dl: DL, VT: Op.getValueType(), Ptr: GenericConversion,
3555 SrcAS: ADDRESS_SPACE_GENERIC, DestAS);
3556 return SharedClusterConversion;
3557 }
3558
3559 return DAG.getUNDEF(VT: Op.getValueType());
3560 }
3561
3562 return Op;
3563}
3564
3565// This function is almost a copy of SelectionDAG::expandVAArg().
3566// The only diff is that this one produces loads from local address space.
3567SDValue NVPTXTargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {
3568 const TargetLowering *TLI = STI.getTargetLowering();
3569 SDLoc DL(Op);
3570
3571 SDNode *Node = Op.getNode();
3572 const Value *V = cast<SrcValueSDNode>(Val: Node->getOperand(Num: 2))->getValue();
3573 EVT VT = Node->getValueType(ResNo: 0);
3574 auto *Ty = VT.getTypeForEVT(Context&: *DAG.getContext());
3575 SDValue Tmp1 = Node->getOperand(Num: 0);
3576 SDValue Tmp2 = Node->getOperand(Num: 1);
3577 const MaybeAlign MA(Node->getConstantOperandVal(Num: 3));
3578
3579 SDValue VAListLoad = DAG.getLoad(VT: TLI->getPointerTy(DL: DAG.getDataLayout()), dl: DL,
3580 Chain: Tmp1, Ptr: Tmp2, PtrInfo: MachinePointerInfo(V));
3581 SDValue VAList = VAListLoad;
3582
3583 if (MA && *MA > TLI->getMinStackArgumentAlignment()) {
3584 VAList = DAG.getNode(
3585 Opcode: ISD::ADD, DL, VT: VAList.getValueType(), N1: VAList,
3586 N2: DAG.getConstant(Val: MA->value() - 1, DL, VT: VAList.getValueType()));
3587
3588 VAList = DAG.getNode(Opcode: ISD::AND, DL, VT: VAList.getValueType(), N1: VAList,
3589 N2: DAG.getSignedConstant(Val: -(int64_t)MA->value(), DL,
3590 VT: VAList.getValueType()));
3591 }
3592
3593 // Increment the pointer, VAList, to the next vaarg
3594 Tmp1 = DAG.getNode(Opcode: ISD::ADD, DL, VT: VAList.getValueType(), N1: VAList,
3595 N2: DAG.getConstant(Val: DAG.getDataLayout().getTypeAllocSize(Ty),
3596 DL, VT: VAList.getValueType()));
3597
3598 // Store the incremented VAList to the legalized pointer
3599 Tmp1 = DAG.getStore(Chain: VAListLoad.getValue(R: 1), dl: DL, Val: Tmp1, Ptr: Tmp2,
3600 PtrInfo: MachinePointerInfo(V));
3601
3602 const Value *SrcV = Constant::getNullValue(
3603 Ty: PointerType::get(C&: *DAG.getContext(), AddressSpace: ADDRESS_SPACE_LOCAL));
3604
3605 // Load the actual argument out of the pointer VAList
3606 return DAG.getLoad(VT, dl: DL, Chain: Tmp1, Ptr: VAList, PtrInfo: MachinePointerInfo(SrcV));
3607}
3608
3609SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
3610 const TargetLowering *TLI = STI.getTargetLowering();
3611 SDLoc DL(Op);
3612 EVT PtrVT = TLI->getPointerTy(DL: DAG.getDataLayout());
3613
3614 // Store the address of unsized array <function>_vararg[] in the ap object.
3615 SDValue VAReg = getParamSymbol(DAG, /* vararg */ I: -1, T: PtrVT);
3616
3617 const Value *SV = cast<SrcValueSDNode>(Val: Op.getOperand(i: 2))->getValue();
3618 return DAG.getStore(Chain: Op.getOperand(i: 0), dl: DL, Val: VAReg, Ptr: Op.getOperand(i: 1),
3619 PtrInfo: MachinePointerInfo(SV));
3620}
3621
3622static std::pair<MemSDNode *, uint32_t>
3623convertMLOADToLoadWithUsedBytesMask(MemSDNode *N, SelectionDAG &DAG,
3624 const NVPTXSubtarget &STI) {
3625 SDValue Chain = N->getOperand(Num: 0);
3626 SDValue BasePtr = N->getOperand(Num: 1);
3627 SDValue Mask = N->getOperand(Num: 3);
3628 [[maybe_unused]] SDValue Passthru = N->getOperand(Num: 4);
3629
3630 SDLoc DL(N);
3631 EVT ResVT = N->getValueType(ResNo: 0);
3632 assert(ResVT.isVector() && "Masked vector load must have vector type");
3633 // While we only expect poison passthru vectors as an input to the backend,
3634 // when the legalization framework splits a poison vector in half, it creates
3635 // two undef vectors, so we can technically expect those too.
3636 assert((Passthru.getOpcode() == ISD::POISON ||
3637 Passthru.getOpcode() == ISD::UNDEF) &&
3638 "Passthru operand expected to be poison or undef");
3639
3640 // Extract the mask and convert it to a uint32_t representing the used bytes
3641 // of the entire vector load
3642 uint32_t UsedBytesMask = 0;
3643 uint32_t ElementSizeInBits = ResVT.getVectorElementType().getSizeInBits();
3644 assert(ElementSizeInBits % 8 == 0 && "Unexpected element size");
3645 uint32_t ElementSizeInBytes = ElementSizeInBits / 8;
3646 uint32_t ElementMask = (1u << ElementSizeInBytes) - 1u;
3647
3648 for (SDValue Op : reverse(C: Mask->ops())) {
3649 // We technically only want to do this shift for every
3650 // iteration *but* the first, but in the first iteration UsedBytesMask is 0,
3651 // so this shift is a no-op.
3652 UsedBytesMask <<= ElementSizeInBytes;
3653
3654 // Mask elements must be constants.
3655 if (Op->getAsZExtVal() != 0)
3656 UsedBytesMask |= ElementMask;
3657 }
3658
3659 assert(UsedBytesMask != 0 && UsedBytesMask != UINT32_MAX &&
3660 "Unexpected masked load with elements masked all on or all off");
3661
3662 // Create a new load sd node to be handled normally by ReplaceLoadVector.
3663 MemSDNode *NewLD = cast<MemSDNode>(
3664 Val: DAG.getLoad(VT: ResVT, dl: DL, Chain, Ptr: BasePtr, MMO: N->getMemOperand()).getNode());
3665
3666 // If our subtarget does not support the used bytes mask pragma, "drop" the
3667 // mask by setting it to UINT32_MAX
3668 if (!STI.hasUsedBytesMaskPragma())
3669 UsedBytesMask = UINT32_MAX;
3670
3671 return {NewLD, UsedBytesMask};
3672}
3673
3674/// replaceLoadVector - Convert vector loads into multi-output scalar loads.
3675static std::optional<std::pair<SDValue, SDValue>>
3676replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
3677 MemSDNode *LD = cast<MemSDNode>(Val: N);
3678 const EVT ResVT = LD->getValueType(ResNo: 0);
3679 const EVT MemVT = LD->getMemoryVT();
3680
3681 // If we're doing sign/zero extension as part of the load, avoid lowering to
3682 // a LoadV node. TODO: consider relaxing this restriction.
3683 if (ResVT != MemVT)
3684 return std::nullopt;
3685
3686 const auto NumEltsAndEltVT =
3687 getVectorLoweringShape(VectorEVT: ResVT, STI, AddressSpace: LD->getAddressSpace());
3688 if (!NumEltsAndEltVT)
3689 return std::nullopt;
3690 const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
3691
3692 Align Alignment = LD->getAlign();
3693 const auto &TD = DAG.getDataLayout();
3694 Align PrefAlign = TD.getPrefTypeAlign(Ty: MemVT.getTypeForEVT(Context&: *DAG.getContext()));
3695 if (Alignment < PrefAlign) {
3696 // This load is not sufficiently aligned, so bail out and let this vector
3697 // load be scalarized. Note that we may still be able to emit smaller
3698 // vector loads. For example, if we are loading a <4 x float> with an
3699 // alignment of 8, this check will fail but the legalizer will try again
3700 // with 2 x <2 x float>, which will succeed with an alignment of 8.
3701 return std::nullopt;
3702 }
3703
3704 // If we have a masked load, convert it to a normal load now
3705 std::optional<uint32_t> UsedBytesMask = std::nullopt;
3706 if (LD->getOpcode() == ISD::MLOAD)
3707 std::tie(args&: LD, args&: UsedBytesMask) =
3708 convertMLOADToLoadWithUsedBytesMask(N: LD, DAG, STI);
3709
3710 // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
3711 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
3712 // loaded type to i16 and propagate the "real" type as the memory type.
3713 const MVT LoadEltVT = (EltVT.getSizeInBits() < 16) ? MVT::i16 : EltVT;
3714
3715 unsigned Opcode;
3716 switch (NumElts) {
3717 default:
3718 return std::nullopt;
3719 case 2:
3720 Opcode = NVPTXISD::LoadV2;
3721 break;
3722 case 4:
3723 Opcode = NVPTXISD::LoadV4;
3724 break;
3725 case 8:
3726 Opcode = NVPTXISD::LoadV8;
3727 break;
3728 }
3729 auto ListVTs = SmallVector<EVT, 9>(NumElts, LoadEltVT);
3730 ListVTs.push_back(Elt: MVT::Other);
3731 SDVTList LdResVTs = DAG.getVTList(VTs: ListVTs);
3732
3733 SDLoc DL(LD);
3734
3735 // Copy regular operands
3736 SmallVector<SDValue, 8> OtherOps(LD->ops());
3737
3738 OtherOps.push_back(
3739 Elt: DAG.getConstant(Val: UsedBytesMask.value_or(UINT32_MAX), DL, VT: MVT::i32));
3740
3741 // The select routine does not have access to the LoadSDNode instance, so
3742 // pass along the extension information
3743 OtherOps.push_back(
3744 Elt: DAG.getIntPtrConstant(Val: cast<LoadSDNode>(Val: LD)->getExtensionType(), DL));
3745
3746 SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, dl: DL, VTList: LdResVTs, Ops: OtherOps, MemVT,
3747 MMO: LD->getMemOperand());
3748
3749 SmallVector<SDValue> ScalarRes;
3750 if (EltVT.isVector()) {
3751 assert(EVT(EltVT.getVectorElementType()) == ResVT.getVectorElementType());
3752 assert(NumElts * EltVT.getVectorNumElements() ==
3753 ResVT.getVectorNumElements());
3754 // Generate EXTRACT_VECTOR_ELTs to split v2[i,f,bf]16/v4i8 subvectors back
3755 // into individual elements.
3756 for (const unsigned I : llvm::seq(Size: NumElts)) {
3757 SDValue SubVector = NewLD.getValue(R: I);
3758 DAG.ExtractVectorElements(Op: SubVector, Args&: ScalarRes);
3759 }
3760 } else {
3761 for (const unsigned I : llvm::seq(Size: NumElts)) {
3762 SDValue Res = NewLD.getValue(R: I);
3763 if (LoadEltVT != EltVT)
3764 Res = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: EltVT, Operand: Res);
3765 ScalarRes.push_back(Elt: Res);
3766 }
3767 }
3768
3769 SDValue LoadChain = NewLD.getValue(R: NumElts);
3770
3771 const MVT BuildVecVT =
3772 MVT::getVectorVT(VT: EltVT.getScalarType(), NumElements: ScalarRes.size());
3773 SDValue BuildVec = DAG.getBuildVector(VT: BuildVecVT, DL, Ops: ScalarRes);
3774 SDValue LoadValue = DAG.getBitcast(VT: ResVT, V: BuildVec);
3775
3776 return {{LoadValue, LoadChain}};
3777}
3778
3779static void replaceLoadVector(SDNode *N, SelectionDAG &DAG,
3780 SmallVectorImpl<SDValue> &Results,
3781 const NVPTXSubtarget &STI) {
3782 if (auto Res = replaceLoadVector(N, DAG, STI))
3783 Results.append(IL: {Res->first, Res->second});
3784}
3785
3786static SDValue lowerLoadVector(SDNode *N, SelectionDAG &DAG,
3787 const NVPTXSubtarget &STI) {
3788 if (auto Res = replaceLoadVector(N, DAG, STI))
3789 return DAG.getMergeValues(Ops: {Res->first, Res->second}, dl: SDLoc(N));
3790 return SDValue();
3791}
3792
3793// v = ld i1* addr
3794// =>
3795// v1 = ld i8* addr (-> i16)
3796// v = trunc i16 to i1
3797static SDValue lowerLOADi1(LoadSDNode *LD, SelectionDAG &DAG) {
3798 SDLoc dl(LD);
3799 assert(LD->getExtensionType() == ISD::NON_EXTLOAD);
3800 assert(LD->getValueType(0) == MVT::i1 && "Custom lowering for i1 load only");
3801 SDValue newLD = DAG.getExtLoad(ExtType: ISD::ZEXTLOAD, dl, VT: MVT::i16, Chain: LD->getChain(),
3802 Ptr: LD->getBasePtr(), PtrInfo: LD->getPointerInfo(),
3803 MemVT: MVT::i8, Alignment: LD->getAlign(),
3804 MMOFlags: LD->getMemOperand()->getFlags());
3805 SDValue result = DAG.getNode(Opcode: ISD::TRUNCATE, DL: dl, VT: MVT::i1, Operand: newLD);
3806 // The legalizer (the caller) is expecting two values from the legalized
3807 // load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
3808 // in LegalizeDAG.cpp which also uses MergeValues.
3809 return DAG.getMergeValues(Ops: {result, LD->getChain()}, dl);
3810}
3811
3812SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
3813 LoadSDNode *LD = cast<LoadSDNode>(Val&: Op);
3814
3815 if (Op.getValueType() == MVT::i1)
3816 return lowerLOADi1(LD, DAG);
3817
3818 // To improve CodeGen we'll legalize any-extend loads to zext loads. This is
3819 // how they'll be lowered in ISel anyway, and by doing this a little earlier
3820 // we allow for more DAG combine opportunities.
3821 if (LD->getExtensionType() == ISD::EXTLOAD) {
3822 assert(LD->getValueType(0).isInteger() && LD->getMemoryVT().isInteger() &&
3823 "Unexpected fpext-load");
3824 return DAG.getExtLoad(ExtType: ISD::ZEXTLOAD, dl: SDLoc(Op), VT: Op.getValueType(),
3825 Chain: LD->getChain(), Ptr: LD->getBasePtr(), MemVT: LD->getMemoryVT(),
3826 MMO: LD->getMemOperand());
3827 }
3828
3829 llvm_unreachable("Unexpected custom lowering for load");
3830}
3831
3832SDValue NVPTXTargetLowering::LowerMLOAD(SDValue Op, SelectionDAG &DAG) const {
3833 // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
3834 // masked loads of these types and have to handle them here.
3835 // v2f32 also needs to be handled here if the subtarget has f32x2
3836 // instructions, making it legal.
3837 //
3838 // Note: misaligned masked loads should never reach this point
3839 // because the override of isLegalMaskedLoad in NVPTXTargetTransformInfo.cpp
3840 // will validate alignment. Therefore, we do not need to special case handle
3841 // them here.
3842 EVT VT = Op.getValueType();
3843 if (NVPTX::isPackedVectorTy(VT)) {
3844 auto Result = convertMLOADToLoadWithUsedBytesMask(
3845 N: cast<MemSDNode>(Val: Op.getNode()), DAG, STI);
3846 MemSDNode *LD = std::get<0>(in&: Result);
3847 uint32_t UsedBytesMask = std::get<1>(in&: Result);
3848
3849 SDLoc DL(LD);
3850
3851 // Copy regular operands
3852 SmallVector<SDValue, 8> OtherOps(LD->ops());
3853
3854 OtherOps.push_back(Elt: DAG.getConstant(Val: UsedBytesMask, DL, VT: MVT::i32));
3855
3856 // We currently are not lowering extending loads, but pass the extension
3857 // type anyway as later handling expects it.
3858 OtherOps.push_back(
3859 Elt: DAG.getIntPtrConstant(Val: cast<LoadSDNode>(Val: LD)->getExtensionType(), DL));
3860 SDValue NewLD =
3861 DAG.getMemIntrinsicNode(Opcode: NVPTXISD::MLoad, dl: DL, VTList: LD->getVTList(), Ops: OtherOps,
3862 MemVT: LD->getMemoryVT(), MMO: LD->getMemOperand());
3863 return NewLD;
3864 }
3865 return SDValue();
3866}
3867
3868static SDValue lowerSTOREVector(SDValue Op, SelectionDAG &DAG,
3869 const NVPTXSubtarget &STI) {
3870 MemSDNode *N = cast<MemSDNode>(Val: Op.getNode());
3871 SDValue Val = N->getOperand(Num: 1);
3872 SDLoc DL(N);
3873 const EVT ValVT = Val.getValueType();
3874 const EVT MemVT = N->getMemoryVT();
3875
3876 // If we're truncating as part of the store, avoid lowering to a StoreV node.
3877 // TODO: consider relaxing this restriction.
3878 if (ValVT != MemVT)
3879 return SDValue();
3880
3881 const auto NumEltsAndEltVT =
3882 getVectorLoweringShape(VectorEVT: ValVT, STI, AddressSpace: N->getAddressSpace());
3883 if (!NumEltsAndEltVT)
3884 return SDValue();
3885 const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
3886
3887 const DataLayout &TD = DAG.getDataLayout();
3888
3889 Align Alignment = N->getAlign();
3890 Align PrefAlign = TD.getPrefTypeAlign(Ty: ValVT.getTypeForEVT(Context&: *DAG.getContext()));
3891 if (Alignment < PrefAlign) {
3892 // This store is not sufficiently aligned, so bail out and let this vector
3893 // store be scalarized. Note that we may still be able to emit smaller
3894 // vector stores. For example, if we are storing a <4 x float> with an
3895 // alignment of 8, this check will fail but the legalizer will try again
3896 // with 2 x <2 x float>, which will succeed with an alignment of 8.
3897 return SDValue();
3898 }
3899
3900 unsigned Opcode;
3901 switch (NumElts) {
3902 default:
3903 return SDValue();
3904 case 2:
3905 Opcode = NVPTXISD::StoreV2;
3906 break;
3907 case 4:
3908 Opcode = NVPTXISD::StoreV4;
3909 break;
3910 case 8:
3911 Opcode = NVPTXISD::StoreV8;
3912 break;
3913 }
3914
3915 SmallVector<SDValue, 8> Ops;
3916
3917 // First is the chain
3918 Ops.push_back(Elt: N->getOperand(Num: 0));
3919
3920 // Then the split values
3921 if (EltVT.isVector()) {
3922 assert(EVT(EltVT.getVectorElementType()) == ValVT.getVectorElementType());
3923 assert(NumElts * EltVT.getVectorNumElements() ==
3924 ValVT.getVectorNumElements());
3925 // Combine individual elements into v2[i,f,bf]16/v4i8 subvectors to be
3926 // stored as b32s
3927 const unsigned NumEltsPerSubVector = EltVT.getVectorNumElements();
3928 for (const unsigned I : llvm::seq(Size: NumElts)) {
3929 SmallVector<SDValue, 4> SubVectorElts;
3930 DAG.ExtractVectorElements(Op: Val, Args&: SubVectorElts, Start: I * NumEltsPerSubVector,
3931 Count: NumEltsPerSubVector);
3932 Ops.push_back(Elt: DAG.getBuildVector(VT: EltVT, DL, Ops: SubVectorElts));
3933 }
3934 } else {
3935 SDValue V = DAG.getBitcast(VT: MVT::getVectorVT(VT: EltVT, NumElements: NumElts), V: Val);
3936 for (const unsigned I : llvm::seq(Size: NumElts)) {
3937 SDValue ExtVal = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: V,
3938 N2: DAG.getIntPtrConstant(Val: I, DL));
3939
3940 // Since StoreV2 is a target node, we cannot rely on DAG type
3941 // legalization. Therefore, we must ensure the type is legal. For i1 and
3942 // i8, we set the stored type to i16 and propagate the "real" type as the
3943 // memory type.
3944 if (EltVT.getSizeInBits() < 16)
3945 ExtVal = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT: MVT::i16, Operand: ExtVal);
3946 Ops.push_back(Elt: ExtVal);
3947 }
3948 }
3949
3950 // Then any remaining arguments
3951 Ops.append(in_start: N->op_begin() + 2, in_end: N->op_end());
3952
3953 SDValue NewSt =
3954 DAG.getMemIntrinsicNode(Opcode, dl: DL, VTList: DAG.getVTList(VT: MVT::Other), Ops,
3955 MemVT: N->getMemoryVT(), MMO: N->getMemOperand());
3956
3957 // return DCI.CombineTo(N, NewSt, true);
3958 return NewSt;
3959}
3960
3961SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
3962 StoreSDNode *Store = cast<StoreSDNode>(Val&: Op);
3963 EVT VT = Store->getMemoryVT();
3964
3965 if (VT == MVT::i1)
3966 return LowerSTOREi1(Op, DAG);
3967
3968 // Lower store of any other vector type, including v2f32 as we want to break
3969 // it apart since this is not a widely-supported type.
3970 return lowerSTOREVector(Op, DAG, STI);
3971}
3972
3973// st i1 v, addr
3974// =>
3975// v1 = zxt v to i16
3976// st.u8 i16, addr
3977SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
3978 SDNode *Node = Op.getNode();
3979 SDLoc dl(Node);
3980 StoreSDNode *ST = cast<StoreSDNode>(Val: Node);
3981 SDValue Tmp1 = ST->getChain();
3982 SDValue Tmp2 = ST->getBasePtr();
3983 SDValue Tmp3 = ST->getValue();
3984 assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only");
3985 Tmp3 = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL: dl, VT: MVT::i16, Operand: Tmp3);
3986 SDValue Result =
3987 DAG.getTruncStore(Chain: Tmp1, dl, Val: Tmp3, Ptr: Tmp2, PtrInfo: ST->getPointerInfo(), SVT: MVT::i8,
3988 Alignment: ST->getAlign(), MMOFlags: ST->getMemOperand()->getFlags());
3989 return Result;
3990}
3991
3992SDValue NVPTXTargetLowering::LowerCopyToReg_128(SDValue Op,
3993 SelectionDAG &DAG) const {
3994 // Change the CopyToReg to take in two 64-bit operands instead of a 128-bit
3995 // operand so that it can pass the legalization.
3996
3997 assert(Op.getOperand(1).getValueType() == MVT::i128 &&
3998 "Custom lowering for 128-bit CopyToReg only");
3999
4000 SDNode *Node = Op.getNode();
4001 SDLoc DL(Node);
4002
4003 SDValue Cast = DAG.getBitcast(VT: MVT::v2i64, V: Op->getOperand(Num: 2));
4004 SDValue Lo = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i64, N1: Cast,
4005 N2: DAG.getIntPtrConstant(Val: 0, DL));
4006 SDValue Hi = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i64, N1: Cast,
4007 N2: DAG.getIntPtrConstant(Val: 1, DL));
4008
4009 SmallVector<SDValue, 5> NewOps(Op->getNumOperands() + 1);
4010 SmallVector<EVT, 3> ResultsType(Node->values());
4011
4012 NewOps[0] = Op->getOperand(Num: 0); // Chain
4013 NewOps[1] = Op->getOperand(Num: 1); // Dst Reg
4014 NewOps[2] = Lo; // Lower 64-bit
4015 NewOps[3] = Hi; // Higher 64-bit
4016 if (Op.getNumOperands() == 4)
4017 NewOps[4] = Op->getOperand(Num: 3); // Glue if exists
4018
4019 return DAG.getNode(Opcode: ISD::CopyToReg, DL, ResultTys: ResultsType, Ops: NewOps);
4020}
4021
4022unsigned NVPTXTargetLowering::getNumRegisters(
4023 LLVMContext &Context, EVT VT,
4024 std::optional<MVT> RegisterVT = std::nullopt) const {
4025 if (VT == MVT::i128 && RegisterVT == MVT::i128)
4026 return 1;
4027 return TargetLoweringBase::getNumRegisters(Context, VT, RegisterVT);
4028}
4029
4030bool NVPTXTargetLowering::splitValueIntoRegisterParts(
4031 SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
4032 unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
4033 if (Val.getValueType() == MVT::i128 && NumParts == 1) {
4034 Parts[0] = Val;
4035 return true;
4036 }
4037 return false;
4038}
4039
4040// This creates target external symbol for a function parameter.
4041// Name of the symbol is composed from its index and the function name.
4042// Negative index corresponds to special parameter (unsized array) used for
4043// passing variable arguments.
4044SDValue NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int I,
4045 EVT T) const {
4046 StringRef SavedStr = nvTM->getStrPool().save(
4047 S: getParamName(F: &DAG.getMachineFunction().getFunction(), Idx: I));
4048 return DAG.getExternalSymbol(Sym: SavedStr.data(), VT: T);
4049}
4050
4051SDValue NVPTXTargetLowering::getCallParamSymbol(SelectionDAG &DAG, int I,
4052 EVT T) const {
4053 const StringRef SavedStr = nvTM->getStrPool().save(S: "param" + Twine(I));
4054 return DAG.getExternalSymbol(Sym: SavedStr.data(), VT: T);
4055}
4056
4057SDValue NVPTXTargetLowering::LowerFormalArguments(
4058 SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
4059 const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &dl,
4060 SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
4061 const DataLayout &DL = DAG.getDataLayout();
4062 LLVMContext &Ctx = *DAG.getContext();
4063 auto PtrVT = getPointerTy(DL: DAG.getDataLayout());
4064
4065 const Function &F = DAG.getMachineFunction().getFunction();
4066 const bool IsKernel = isKernelFunction(F);
4067
4068 SDValue Root = DAG.getRoot();
4069 SmallVector<SDValue, 16> OutChains;
4070
4071 // argTypes.size() (or theArgs.size()) and Ins.size() need not match.
4072 // Ins.size() will be larger
4073 // * if there is an aggregate argument with multiple fields (each field
4074 // showing up separately in Ins)
4075 // * if there is a vector argument with more than typical vector-length
4076 // elements (generally if more than 4) where each vector element is
4077 // individually present in Ins.
4078 // So a different index should be used for indexing into Ins.
4079 // See similar issue in LowerCall.
4080
4081 auto AllIns = ArrayRef(Ins);
4082 for (const auto &Arg : F.args()) {
4083 const auto ArgIns = AllIns.take_while(
4084 Pred: [&](auto I) { return I.OrigArgIndex == Arg.getArgNo(); });
4085 AllIns = AllIns.drop_front(N: ArgIns.size());
4086
4087 Type *Ty = Arg.getType();
4088
4089 if (ArgIns.empty())
4090 report_fatal_error(reason: "Empty parameter types are not supported");
4091
4092 if (Arg.use_empty()) {
4093 // argument is dead
4094 for (const auto &In : ArgIns) {
4095 assert(!In.Used && "Arg.use_empty() is true but Arg is used?");
4096 InVals.push_back(Elt: DAG.getUNDEF(VT: In.VT));
4097 }
4098 continue;
4099 }
4100
4101 SDValue ArgSymbol = getParamSymbol(DAG, I: Arg.getArgNo(), T: PtrVT);
4102
4103 // In the following cases, assign a node order of "i+1"
4104 // to newly created nodes. The SDNodes for params have to
4105 // appear in the same order as their order of appearance
4106 // in the original function. "i+1" holds that order.
4107 if (Arg.hasByValAttr()) {
4108 // Param has ByVal attribute
4109 // Return MoveParam(param symbol).
4110 // Ideally, the param symbol can be returned directly,
4111 // but when SDNode builder decides to use it in a CopyToReg(),
4112 // machine instruction fails because TargetExternalSymbol
4113 // (not lowered) is target dependent, and CopyToReg assumes
4114 // the source is lowered.
4115 assert(ArgIns.size() == 1 && "ByVal argument must be a pointer");
4116 const auto &ByvalIn = ArgIns[0];
4117 assert(getValueType(DL, Ty) == ByvalIn.VT &&
4118 "Ins type did not match function type");
4119 assert(ByvalIn.VT == PtrVT && "ByVal argument must be a pointer");
4120
4121 SDValue P;
4122 if (IsKernel) {
4123 assert(Arg.getType()->getPointerAddressSpace() ==
4124 ADDRESS_SPACE_ENTRY_PARAM &&
4125 "Kernel ByVal argument must be lowered to the param address "
4126 "space by NVPTXLowerArgs");
4127 P = ArgSymbol;
4128 P.getNode()->setIROrder(Arg.getArgNo() + 1);
4129 } else {
4130 P = DAG.getNode(Opcode: NVPTXISD::MoveParam, DL: dl, VT: ByvalIn.VT, Operand: ArgSymbol);
4131 P.getNode()->setIROrder(Arg.getArgNo() + 1);
4132 P = DAG.getAddrSpaceCast(dl, VT: ByvalIn.VT, Ptr: P, SrcAS: ADDRESS_SPACE_LOCAL,
4133 DestAS: ADDRESS_SPACE_GENERIC);
4134 }
4135 InVals.push_back(Elt: P);
4136 } else {
4137 SmallVector<EVT, 16> VTs;
4138 SmallVector<uint64_t, 16> Offsets;
4139 ComputePTXValueVTs(TLI: *this, DL, Ctx, CallConv, Ty, ValueVTs&: VTs, Offsets);
4140 assert(VTs.size() == ArgIns.size() && "Size mismatch");
4141 assert(VTs.size() == Offsets.size() && "Size mismatch");
4142
4143 const Align ArgAlign = getPTXParamAlign(
4144 F: &F, Ty, AttrIdx: Arg.getArgNo() + AttributeList::FirstArgIndex, DL);
4145
4146 unsigned I = 0;
4147 const auto VI = VectorizePTXValueVTs(ValueVTs: VTs, Offsets, ParamAlignment: ArgAlign);
4148 for (const unsigned NumElts : VI) {
4149 // i1 is loaded/stored as i8
4150 const EVT LoadVT = VTs[I] == MVT::i1 ? MVT::i8 : VTs[I];
4151 const EVT VecVT = getVectorizedVT(VT: LoadVT, N: NumElts, C&: Ctx);
4152
4153 SDValue VecAddr = DAG.getObjectPtrOffset(
4154 SL: dl, Ptr: ArgSymbol, Offset: TypeSize::getFixed(ExactSize: Offsets[I]));
4155
4156 const Align PartAlign = commonAlignment(A: ArgAlign, Offset: Offsets[I]);
4157 const unsigned AS = IsKernel ? NVPTX::AddressSpace::EntryParam
4158 : NVPTX::AddressSpace::DeviceParam;
4159 SDValue P = DAG.getLoad(VT: VecVT, dl, Chain: Root, Ptr: VecAddr,
4160 PtrInfo: MachinePointerInfo(AS), Alignment: PartAlign,
4161 MMOFlags: MachineMemOperand::MODereferenceable |
4162 MachineMemOperand::MOInvariant);
4163 P.getNode()->setIROrder(Arg.getArgNo() + 1);
4164 for (const unsigned J : llvm::seq(Size: NumElts)) {
4165 SDValue Elt = getExtractVectorizedValue(V: P, I: J, VT: LoadVT, dl, DAG);
4166
4167 Elt = correctParamType(V: Elt, ExpectedVT: ArgIns[I + J].VT, Flags: ArgIns[I + J].Flags,
4168 DAG, dl);
4169 InVals.push_back(Elt);
4170 }
4171 I += NumElts;
4172 }
4173 }
4174 }
4175
4176 if (!OutChains.empty())
4177 DAG.setRoot(DAG.getTokenFactor(DL: dl, Vals&: OutChains));
4178
4179 return Chain;
4180}
4181
4182SDValue
4183NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
4184 bool isVarArg,
4185 const SmallVectorImpl<ISD::OutputArg> &Outs,
4186 const SmallVectorImpl<SDValue> &OutVals,
4187 const SDLoc &dl, SelectionDAG &DAG) const {
4188 const Function &F = DAG.getMachineFunction().getFunction();
4189 Type *RetTy = F.getReturnType();
4190
4191 if (RetTy->isVoidTy()) {
4192 assert(OutVals.empty() && Outs.empty() && "Return value expected for void");
4193 return DAG.getNode(Opcode: NVPTXISD::RET_GLUE, DL: dl, VT: MVT::Other, Operand: Chain);
4194 }
4195
4196 const DataLayout &DL = DAG.getDataLayout();
4197 LLVMContext &Ctx = *DAG.getContext();
4198
4199 const SDValue RetSymbol = DAG.getExternalSymbol(Sym: "func_retval0", VT: MVT::i32);
4200 const auto RetAlign =
4201 getPTXParamAlign(F: &F, Ty: RetTy, AttrIdx: AttributeList::ReturnIndex, DL);
4202
4203 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
4204 // 32-bits are sign extended or zero extended, depending on whether
4205 // they are signed or unsigned types.
4206 const bool ExtendIntegerRetVal =
4207 RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty: RetTy) < 32;
4208
4209 SmallVector<EVT, 16> VTs;
4210 SmallVector<uint64_t, 16> Offsets;
4211 ComputePTXValueVTs(TLI: *this, DL, Ctx, CallConv, Ty: RetTy, ValueVTs&: VTs, Offsets);
4212 assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
4213
4214 const auto GetRetVal = [&](unsigned I) -> SDValue {
4215 SDValue RetVal = OutVals[I];
4216 assert(promoteScalarIntegerPTX(RetVal.getValueType()) ==
4217 RetVal.getValueType() &&
4218 "OutVal type should always be legal");
4219
4220 const EVT VTI = promoteScalarIntegerPTX(VT: VTs[I]);
4221 const EVT StoreVT =
4222 ExtendIntegerRetVal ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
4223 return correctParamType(V: RetVal, ExpectedVT: StoreVT, Flags: Outs[I].Flags, DAG, dl);
4224 };
4225
4226 unsigned I = 0;
4227 const auto VI = VectorizePTXValueVTs(ValueVTs: VTs, Offsets, ParamAlignment: RetAlign);
4228 for (const unsigned NumElts : VI) {
4229 const MaybeAlign CurrentAlign = ExtendIntegerRetVal
4230 ? MaybeAlign(std::nullopt)
4231 : commonAlignment(A: RetAlign, Offset: Offsets[I]);
4232
4233 SDValue Val = getBuildVectorizedValue(
4234 N: NumElts, dl, DAG, GetElement: [&](unsigned K) { return GetRetVal(I + K); });
4235
4236 SDValue Ptr =
4237 DAG.getObjectPtrOffset(SL: dl, Ptr: RetSymbol, Offset: TypeSize::getFixed(ExactSize: Offsets[I]));
4238
4239 Chain = DAG.getStore(Chain, dl, Val, Ptr,
4240 PtrInfo: MachinePointerInfo(NVPTX::AddressSpace::DeviceParam),
4241 Alignment: CurrentAlign);
4242
4243 I += NumElts;
4244 }
4245
4246 return DAG.getNode(Opcode: NVPTXISD::RET_GLUE, DL: dl, VT: MVT::Other, Operand: Chain);
4247}
4248
4249void NVPTXTargetLowering::LowerAsmOperandForConstraint(
4250 SDValue Op, StringRef Constraint, std::vector<SDValue> &Ops,
4251 SelectionDAG &DAG) const {
4252 if (Constraint.size() > 1)
4253 return;
4254 TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
4255}
4256
4257// llvm.ptx.memcpy.const and llvm.ptx.memmove.const need to be modeled as
4258// TgtMemIntrinsic
4259// because we need the information that is only available in the "Value" type
4260// of destination
4261// pointer. In particular, the address space information.
4262void NVPTXTargetLowering::getTgtMemIntrinsic(
4263 SmallVectorImpl<IntrinsicInfo> &Infos, const CallBase &I,
4264 MachineFunction &MF, unsigned Intrinsic) const {
4265 IntrinsicInfo Info;
4266 switch (Intrinsic) {
4267 default:
4268 return;
4269 case Intrinsic::nvvm_match_all_sync_i32p:
4270 case Intrinsic::nvvm_match_all_sync_i64p:
4271 Info.opc = ISD::INTRINSIC_W_CHAIN;
4272 // memVT is bogus. These intrinsics have IntrInaccessibleMemOnly attribute
4273 // in order to model data exchange with other threads, but perform no real
4274 // memory accesses.
4275 Info.memVT = MVT::i1;
4276
4277 // Our result depends on both our and other thread's arguments.
4278 Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOStore;
4279 Infos.push_back(Elt: Info);
4280 return;
4281 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col:
4282 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row:
4283 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col_stride:
4284 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row_stride:
4285 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col:
4286 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row:
4287 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col_stride:
4288 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride:
4289 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col:
4290 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row:
4291 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col_stride:
4292 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row_stride:
4293 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col:
4294 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row:
4295 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col_stride:
4296 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row_stride:
4297 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col:
4298 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row:
4299 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col_stride:
4300 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row_stride:
4301 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col:
4302 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row:
4303 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col_stride:
4304 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row_stride: {
4305 Info.opc = ISD::INTRINSIC_W_CHAIN;
4306 Info.memVT = MVT::v8f16;
4307 Info.ptrVal = I.getArgOperand(i: 0);
4308 Info.offset = 0;
4309 Info.flags = MachineMemOperand::MOLoad;
4310 Info.align = Align(16);
4311 Infos.push_back(Elt: Info);
4312 return;
4313 }
4314 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col:
4315 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col_stride:
4316 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col_stride:
4317 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col:
4318 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row:
4319 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row_stride:
4320 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row_stride:
4321 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row:
4322 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col:
4323 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col_stride:
4324 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row:
4325 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row_stride:
4326 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col:
4327 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col_stride:
4328 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col_stride:
4329 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col:
4330 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row:
4331 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row_stride:
4332 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row_stride:
4333 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row:
4334 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col:
4335 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col_stride:
4336 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row:
4337 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row_stride: {
4338 Info.opc = ISD::INTRINSIC_W_CHAIN;
4339 Info.memVT = MVT::v2i32;
4340 Info.ptrVal = I.getArgOperand(i: 0);
4341 Info.offset = 0;
4342 Info.flags = MachineMemOperand::MOLoad;
4343 Info.align = Align(8);
4344 Infos.push_back(Elt: Info);
4345 return;
4346 }
4347
4348 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col:
4349 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col_stride:
4350 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col_stride:
4351 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col:
4352 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row:
4353 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row_stride:
4354 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row_stride:
4355 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row:
4356 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col:
4357 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col_stride:
4358 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row:
4359 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row_stride:
4360 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col:
4361 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col_stride:
4362 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row:
4363 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row_stride:
4364
4365 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col:
4366 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col_stride:
4367 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col_stride:
4368 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col:
4369 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row:
4370 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row_stride:
4371 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row_stride:
4372 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row:
4373 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col:
4374 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col_stride:
4375 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row:
4376 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row_stride:
4377 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col:
4378 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col_stride:
4379 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row:
4380 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride:
4381 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16:
4382 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16:
4383 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8:
4384 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64:
4385 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32:
4386 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64:
4387 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32: {
4388 Info.opc = ISD::INTRINSIC_W_CHAIN;
4389 Info.memVT = MVT::v4i32;
4390 Info.ptrVal = I.getArgOperand(i: 0);
4391 Info.offset = 0;
4392 Info.flags = MachineMemOperand::MOLoad;
4393 Info.align = Align(16);
4394 Infos.push_back(Elt: Info);
4395 return;
4396 }
4397
4398 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col:
4399 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col_stride:
4400 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col_stride:
4401 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col:
4402 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row:
4403 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row_stride:
4404 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row_stride:
4405 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row:
4406
4407 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col:
4408 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col_stride:
4409 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col_stride:
4410 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col:
4411 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row:
4412 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row_stride:
4413 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row_stride:
4414 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row:
4415 case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row:
4416 case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row_stride:
4417 case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col:
4418 case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col_stride:
4419 case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row:
4420 case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row_stride:
4421 case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row_stride:
4422 case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row:
4423 case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col:
4424 case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col_stride:
4425 case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col_stride:
4426 case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col:
4427 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16:
4428 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16:
4429 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64:
4430 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32: {
4431 Info.opc = ISD::INTRINSIC_W_CHAIN;
4432 Info.memVT = MVT::i32;
4433 Info.ptrVal = I.getArgOperand(i: 0);
4434 Info.offset = 0;
4435 Info.flags = MachineMemOperand::MOLoad;
4436 Info.align = Align(4);
4437 Infos.push_back(Elt: Info);
4438 return;
4439 }
4440
4441 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col:
4442 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row:
4443 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col_stride:
4444 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride:
4445 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col:
4446 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row:
4447 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col_stride:
4448 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row_stride:
4449 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col:
4450 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row:
4451 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col_stride:
4452 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row_stride: {
4453 Info.opc = ISD::INTRINSIC_W_CHAIN;
4454 Info.memVT = MVT::v4f16;
4455 Info.ptrVal = I.getArgOperand(i: 0);
4456 Info.offset = 0;
4457 Info.flags = MachineMemOperand::MOLoad;
4458 Info.align = Align(16);
4459 Infos.push_back(Elt: Info);
4460 return;
4461 }
4462
4463 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col:
4464 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row:
4465 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col_stride:
4466 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride:
4467 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col:
4468 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row:
4469 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col_stride:
4470 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row_stride:
4471 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col:
4472 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row:
4473 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col_stride:
4474 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride:
4475 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col:
4476 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row:
4477 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col_stride:
4478 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row_stride: {
4479 Info.opc = ISD::INTRINSIC_W_CHAIN;
4480 Info.memVT = MVT::v8f32;
4481 Info.ptrVal = I.getArgOperand(i: 0);
4482 Info.offset = 0;
4483 Info.flags = MachineMemOperand::MOLoad;
4484 Info.align = Align(16);
4485 Infos.push_back(Elt: Info);
4486 return;
4487 }
4488
4489 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col:
4490 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col_stride:
4491 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row:
4492 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row_stride:
4493
4494 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col:
4495 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col_stride:
4496 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row:
4497 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row_stride:
4498
4499 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col:
4500 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col_stride:
4501 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row:
4502 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row_stride:
4503 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col:
4504 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col_stride:
4505 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row:
4506 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row_stride:
4507 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col:
4508 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col_stride:
4509 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row:
4510 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row_stride: {
4511 Info.opc = ISD::INTRINSIC_W_CHAIN;
4512 Info.memVT = MVT::v8i32;
4513 Info.ptrVal = I.getArgOperand(i: 0);
4514 Info.offset = 0;
4515 Info.flags = MachineMemOperand::MOLoad;
4516 Info.align = Align(16);
4517 Infos.push_back(Elt: Info);
4518 return;
4519 }
4520
4521 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col:
4522 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col_stride:
4523 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row:
4524 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row_stride:
4525 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col:
4526 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col_stride:
4527 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row:
4528 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride:
4529 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16:
4530 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16:
4531 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8:
4532 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64:
4533 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32:
4534 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64:
4535 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32: {
4536 Info.opc = ISD::INTRINSIC_W_CHAIN;
4537 Info.memVT = MVT::v2i32;
4538 Info.ptrVal = I.getArgOperand(i: 0);
4539 Info.offset = 0;
4540 Info.flags = MachineMemOperand::MOLoad;
4541 Info.align = Align(8);
4542 Infos.push_back(Elt: Info);
4543 return;
4544 }
4545
4546 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col:
4547 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col_stride:
4548 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row:
4549 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row_stride:
4550
4551 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col:
4552 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col_stride:
4553 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row:
4554 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row_stride: {
4555 Info.opc = ISD::INTRINSIC_W_CHAIN;
4556 Info.memVT = MVT::f64;
4557 Info.ptrVal = I.getArgOperand(i: 0);
4558 Info.offset = 0;
4559 Info.flags = MachineMemOperand::MOLoad;
4560 Info.align = Align(8);
4561 Infos.push_back(Elt: Info);
4562 return;
4563 }
4564
4565 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col:
4566 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col_stride:
4567 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row:
4568 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row_stride: {
4569 Info.opc = ISD::INTRINSIC_W_CHAIN;
4570 Info.memVT = MVT::v2f64;
4571 Info.ptrVal = I.getArgOperand(i: 0);
4572 Info.offset = 0;
4573 Info.flags = MachineMemOperand::MOLoad;
4574 Info.align = Align(16);
4575 Infos.push_back(Elt: Info);
4576 return;
4577 }
4578
4579 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col:
4580 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row:
4581 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride:
4582 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride:
4583 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col:
4584 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row:
4585 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col_stride:
4586 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row_stride:
4587 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col:
4588 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row:
4589 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col_stride:
4590 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row_stride: {
4591 Info.opc = ISD::INTRINSIC_VOID;
4592 Info.memVT = MVT::v4f16;
4593 Info.ptrVal = I.getArgOperand(i: 0);
4594 Info.offset = 0;
4595 Info.flags = MachineMemOperand::MOStore;
4596 Info.align = Align(16);
4597 Infos.push_back(Elt: Info);
4598 return;
4599 }
4600
4601 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col:
4602 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row:
4603 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col_stride:
4604 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride:
4605 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col:
4606 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row:
4607 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col_stride:
4608 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row_stride:
4609 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col:
4610 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row:
4611 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col_stride:
4612 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride:
4613 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col:
4614 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row:
4615 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col_stride:
4616 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row_stride: {
4617 Info.opc = ISD::INTRINSIC_VOID;
4618 Info.memVT = MVT::v8f32;
4619 Info.ptrVal = I.getArgOperand(i: 0);
4620 Info.offset = 0;
4621 Info.flags = MachineMemOperand::MOStore;
4622 Info.align = Align(16);
4623 Infos.push_back(Elt: Info);
4624 return;
4625 }
4626
4627 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col:
4628 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col_stride:
4629 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row:
4630 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row_stride:
4631 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col:
4632 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col_stride:
4633 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row:
4634 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row_stride:
4635 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col:
4636 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col_stride:
4637 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row:
4638 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row_stride: {
4639 Info.opc = ISD::INTRINSIC_VOID;
4640 Info.memVT = MVT::v8i32;
4641 Info.ptrVal = I.getArgOperand(i: 0);
4642 Info.offset = 0;
4643 Info.flags = MachineMemOperand::MOStore;
4644 Info.align = Align(16);
4645 Infos.push_back(Elt: Info);
4646 return;
4647 }
4648
4649 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col:
4650 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col_stride:
4651 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row:
4652 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row_stride:
4653 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col:
4654 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride:
4655 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row:
4656 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride:
4657 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16:
4658 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16:
4659 case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8: {
4660 Info.opc = ISD::INTRINSIC_VOID;
4661 Info.memVT = MVT::v2i32;
4662 Info.ptrVal = I.getArgOperand(i: 0);
4663 Info.offset = 0;
4664 Info.flags = MachineMemOperand::MOStore;
4665 Info.align = Align(8);
4666 Infos.push_back(Elt: Info);
4667 return;
4668 }
4669
4670 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col:
4671 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col_stride:
4672 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row:
4673 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row_stride: {
4674 Info.opc = ISD::INTRINSIC_VOID;
4675 Info.memVT = MVT::v2f64;
4676 Info.ptrVal = I.getArgOperand(i: 0);
4677 Info.offset = 0;
4678 Info.flags = MachineMemOperand::MOStore;
4679 Info.align = Align(16);
4680 Infos.push_back(Elt: Info);
4681 return;
4682 }
4683
4684 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16:
4685 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16:
4686 case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8: {
4687 Info.opc = ISD::INTRINSIC_VOID;
4688 Info.memVT = MVT::i32;
4689 Info.ptrVal = I.getArgOperand(i: 0);
4690 Info.offset = 0;
4691 Info.flags = MachineMemOperand::MOStore;
4692 Info.align = Align(4);
4693 Infos.push_back(Elt: Info);
4694 return;
4695 }
4696
4697 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16:
4698 case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16:
4699 case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8: {
4700 Info.opc = ISD::INTRINSIC_VOID;
4701 Info.memVT = MVT::v4i32;
4702 Info.ptrVal = I.getArgOperand(i: 0);
4703 Info.offset = 0;
4704 Info.flags = MachineMemOperand::MOStore;
4705 Info.align = Align(16);
4706 Infos.push_back(Elt: Info);
4707 return;
4708 }
4709
4710 case Intrinsic::nvvm_prefetch_tensormap: {
4711 auto &DL = I.getDataLayout();
4712 Info.opc = ISD::INTRINSIC_VOID;
4713 Info.memVT = getPointerTy(DL);
4714 Info.ptrVal = I.getArgOperand(i: 0);
4715 Info.offset = 0;
4716 Info.flags =
4717 MachineMemOperand::MOLoad | MachineMemOperand::MODereferenceable;
4718 Info.align.reset();
4719 Infos.push_back(Elt: Info);
4720 return;
4721 }
4722
4723 case Intrinsic::nvvm_tensormap_replace_global_address:
4724 case Intrinsic::nvvm_tensormap_replace_global_stride: {
4725 Info.opc = ISD::INTRINSIC_VOID;
4726 Info.memVT = MVT::i64;
4727 Info.ptrVal = I.getArgOperand(i: 0);
4728 Info.offset = 0;
4729 Info.flags = MachineMemOperand::MOStore;
4730 Info.align.reset();
4731 Infos.push_back(Elt: Info);
4732 return;
4733 }
4734
4735 case Intrinsic::nvvm_tensormap_replace_rank:
4736 case Intrinsic::nvvm_tensormap_replace_box_dim:
4737 case Intrinsic::nvvm_tensormap_replace_global_dim:
4738 case Intrinsic::nvvm_tensormap_replace_element_stride:
4739 case Intrinsic::nvvm_tensormap_replace_elemtype:
4740 case Intrinsic::nvvm_tensormap_replace_interleave_layout:
4741 case Intrinsic::nvvm_tensormap_replace_swizzle_mode:
4742 case Intrinsic::nvvm_tensormap_replace_swizzle_atomicity:
4743 case Intrinsic::nvvm_tensormap_replace_fill_mode: {
4744 Info.opc = ISD::INTRINSIC_VOID;
4745 Info.memVT = MVT::i32;
4746 Info.ptrVal = I.getArgOperand(i: 0);
4747 Info.offset = 0;
4748 Info.flags = MachineMemOperand::MOStore;
4749 Info.align.reset();
4750 Infos.push_back(Elt: Info);
4751 return;
4752 }
4753
4754 case Intrinsic::nvvm_ldu_global_i:
4755 case Intrinsic::nvvm_ldu_global_f:
4756 case Intrinsic::nvvm_ldu_global_p: {
4757 Info.opc = ISD::INTRINSIC_W_CHAIN;
4758 Info.memVT = getValueType(DL: I.getDataLayout(), Ty: I.getType());
4759 Info.ptrVal = I.getArgOperand(i: 0);
4760 Info.offset = 0;
4761 Info.flags = MachineMemOperand::MOLoad;
4762 Info.align = cast<ConstantInt>(Val: I.getArgOperand(i: 1))->getMaybeAlignValue();
4763
4764 Infos.push_back(Elt: Info);
4765 return;
4766 }
4767 case Intrinsic::nvvm_tex_1d_v4f32_s32:
4768 case Intrinsic::nvvm_tex_1d_v4f32_f32:
4769 case Intrinsic::nvvm_tex_1d_level_v4f32_f32:
4770 case Intrinsic::nvvm_tex_1d_grad_v4f32_f32:
4771 case Intrinsic::nvvm_tex_1d_array_v4f32_s32:
4772 case Intrinsic::nvvm_tex_1d_array_v4f32_f32:
4773 case Intrinsic::nvvm_tex_1d_array_level_v4f32_f32:
4774 case Intrinsic::nvvm_tex_1d_array_grad_v4f32_f32:
4775 case Intrinsic::nvvm_tex_2d_v4f32_s32:
4776 case Intrinsic::nvvm_tex_2d_v4f32_f32:
4777 case Intrinsic::nvvm_tex_2d_level_v4f32_f32:
4778 case Intrinsic::nvvm_tex_2d_grad_v4f32_f32:
4779 case Intrinsic::nvvm_tex_2d_array_v4f32_s32:
4780 case Intrinsic::nvvm_tex_2d_array_v4f32_f32:
4781 case Intrinsic::nvvm_tex_2d_array_level_v4f32_f32:
4782 case Intrinsic::nvvm_tex_2d_array_grad_v4f32_f32:
4783 case Intrinsic::nvvm_tex_3d_v4f32_s32:
4784 case Intrinsic::nvvm_tex_3d_v4f32_f32:
4785 case Intrinsic::nvvm_tex_3d_level_v4f32_f32:
4786 case Intrinsic::nvvm_tex_3d_grad_v4f32_f32:
4787 case Intrinsic::nvvm_tex_cube_v4f32_f32:
4788 case Intrinsic::nvvm_tex_cube_level_v4f32_f32:
4789 case Intrinsic::nvvm_tex_cube_array_v4f32_f32:
4790 case Intrinsic::nvvm_tex_cube_array_level_v4f32_f32:
4791 case Intrinsic::nvvm_tld4_r_2d_v4f32_f32:
4792 case Intrinsic::nvvm_tld4_g_2d_v4f32_f32:
4793 case Intrinsic::nvvm_tld4_b_2d_v4f32_f32:
4794 case Intrinsic::nvvm_tld4_a_2d_v4f32_f32:
4795 case Intrinsic::nvvm_tex_unified_1d_v4f32_s32:
4796 case Intrinsic::nvvm_tex_unified_1d_v4f32_f32:
4797 case Intrinsic::nvvm_tex_unified_1d_level_v4f32_f32:
4798 case Intrinsic::nvvm_tex_unified_1d_grad_v4f32_f32:
4799 case Intrinsic::nvvm_tex_unified_1d_array_v4f32_s32:
4800 case Intrinsic::nvvm_tex_unified_1d_array_v4f32_f32:
4801 case Intrinsic::nvvm_tex_unified_1d_array_level_v4f32_f32:
4802 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4f32_f32:
4803 case Intrinsic::nvvm_tex_unified_2d_v4f32_s32:
4804 case Intrinsic::nvvm_tex_unified_2d_v4f32_f32:
4805 case Intrinsic::nvvm_tex_unified_2d_level_v4f32_f32:
4806 case Intrinsic::nvvm_tex_unified_2d_grad_v4f32_f32:
4807 case Intrinsic::nvvm_tex_unified_2d_array_v4f32_s32:
4808 case Intrinsic::nvvm_tex_unified_2d_array_v4f32_f32:
4809 case Intrinsic::nvvm_tex_unified_2d_array_level_v4f32_f32:
4810 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4f32_f32:
4811 case Intrinsic::nvvm_tex_unified_3d_v4f32_s32:
4812 case Intrinsic::nvvm_tex_unified_3d_v4f32_f32:
4813 case Intrinsic::nvvm_tex_unified_3d_level_v4f32_f32:
4814 case Intrinsic::nvvm_tex_unified_3d_grad_v4f32_f32:
4815 case Intrinsic::nvvm_tex_unified_cube_v4f32_f32:
4816 case Intrinsic::nvvm_tex_unified_cube_level_v4f32_f32:
4817 case Intrinsic::nvvm_tex_unified_cube_array_v4f32_f32:
4818 case Intrinsic::nvvm_tex_unified_cube_array_level_v4f32_f32:
4819 case Intrinsic::nvvm_tex_unified_cube_grad_v4f32_f32:
4820 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4f32_f32:
4821 case Intrinsic::nvvm_tld4_unified_r_2d_v4f32_f32:
4822 case Intrinsic::nvvm_tld4_unified_g_2d_v4f32_f32:
4823 case Intrinsic::nvvm_tld4_unified_b_2d_v4f32_f32:
4824 case Intrinsic::nvvm_tld4_unified_a_2d_v4f32_f32:
4825 Info.opc = ISD::INTRINSIC_W_CHAIN;
4826 Info.memVT = MVT::v4f32;
4827 Info.ptrVal = nullptr;
4828 Info.offset = 0;
4829 Info.flags = MachineMemOperand::MOLoad;
4830 Info.align = Align(16);
4831 Infos.push_back(Elt: Info);
4832 return;
4833
4834 case Intrinsic::nvvm_tex_1d_v4s32_s32:
4835 case Intrinsic::nvvm_tex_1d_v4s32_f32:
4836 case Intrinsic::nvvm_tex_1d_level_v4s32_f32:
4837 case Intrinsic::nvvm_tex_1d_grad_v4s32_f32:
4838 case Intrinsic::nvvm_tex_1d_array_v4s32_s32:
4839 case Intrinsic::nvvm_tex_1d_array_v4s32_f32:
4840 case Intrinsic::nvvm_tex_1d_array_level_v4s32_f32:
4841 case Intrinsic::nvvm_tex_1d_array_grad_v4s32_f32:
4842 case Intrinsic::nvvm_tex_2d_v4s32_s32:
4843 case Intrinsic::nvvm_tex_2d_v4s32_f32:
4844 case Intrinsic::nvvm_tex_2d_level_v4s32_f32:
4845 case Intrinsic::nvvm_tex_2d_grad_v4s32_f32:
4846 case Intrinsic::nvvm_tex_2d_array_v4s32_s32:
4847 case Intrinsic::nvvm_tex_2d_array_v4s32_f32:
4848 case Intrinsic::nvvm_tex_2d_array_level_v4s32_f32:
4849 case Intrinsic::nvvm_tex_2d_array_grad_v4s32_f32:
4850 case Intrinsic::nvvm_tex_3d_v4s32_s32:
4851 case Intrinsic::nvvm_tex_3d_v4s32_f32:
4852 case Intrinsic::nvvm_tex_3d_level_v4s32_f32:
4853 case Intrinsic::nvvm_tex_3d_grad_v4s32_f32:
4854 case Intrinsic::nvvm_tex_cube_v4s32_f32:
4855 case Intrinsic::nvvm_tex_cube_level_v4s32_f32:
4856 case Intrinsic::nvvm_tex_cube_array_v4s32_f32:
4857 case Intrinsic::nvvm_tex_cube_array_level_v4s32_f32:
4858 case Intrinsic::nvvm_tex_cube_v4u32_f32:
4859 case Intrinsic::nvvm_tex_cube_level_v4u32_f32:
4860 case Intrinsic::nvvm_tex_cube_array_v4u32_f32:
4861 case Intrinsic::nvvm_tex_cube_array_level_v4u32_f32:
4862 case Intrinsic::nvvm_tex_1d_v4u32_s32:
4863 case Intrinsic::nvvm_tex_1d_v4u32_f32:
4864 case Intrinsic::nvvm_tex_1d_level_v4u32_f32:
4865 case Intrinsic::nvvm_tex_1d_grad_v4u32_f32:
4866 case Intrinsic::nvvm_tex_1d_array_v4u32_s32:
4867 case Intrinsic::nvvm_tex_1d_array_v4u32_f32:
4868 case Intrinsic::nvvm_tex_1d_array_level_v4u32_f32:
4869 case Intrinsic::nvvm_tex_1d_array_grad_v4u32_f32:
4870 case Intrinsic::nvvm_tex_2d_v4u32_s32:
4871 case Intrinsic::nvvm_tex_2d_v4u32_f32:
4872 case Intrinsic::nvvm_tex_2d_level_v4u32_f32:
4873 case Intrinsic::nvvm_tex_2d_grad_v4u32_f32:
4874 case Intrinsic::nvvm_tex_2d_array_v4u32_s32:
4875 case Intrinsic::nvvm_tex_2d_array_v4u32_f32:
4876 case Intrinsic::nvvm_tex_2d_array_level_v4u32_f32:
4877 case Intrinsic::nvvm_tex_2d_array_grad_v4u32_f32:
4878 case Intrinsic::nvvm_tex_3d_v4u32_s32:
4879 case Intrinsic::nvvm_tex_3d_v4u32_f32:
4880 case Intrinsic::nvvm_tex_3d_level_v4u32_f32:
4881 case Intrinsic::nvvm_tex_3d_grad_v4u32_f32:
4882 case Intrinsic::nvvm_tld4_r_2d_v4s32_f32:
4883 case Intrinsic::nvvm_tld4_g_2d_v4s32_f32:
4884 case Intrinsic::nvvm_tld4_b_2d_v4s32_f32:
4885 case Intrinsic::nvvm_tld4_a_2d_v4s32_f32:
4886 case Intrinsic::nvvm_tld4_r_2d_v4u32_f32:
4887 case Intrinsic::nvvm_tld4_g_2d_v4u32_f32:
4888 case Intrinsic::nvvm_tld4_b_2d_v4u32_f32:
4889 case Intrinsic::nvvm_tld4_a_2d_v4u32_f32:
4890 case Intrinsic::nvvm_tex_unified_1d_v4s32_s32:
4891 case Intrinsic::nvvm_tex_unified_1d_v4s32_f32:
4892 case Intrinsic::nvvm_tex_unified_1d_level_v4s32_f32:
4893 case Intrinsic::nvvm_tex_unified_1d_grad_v4s32_f32:
4894 case Intrinsic::nvvm_tex_unified_1d_array_v4s32_s32:
4895 case Intrinsic::nvvm_tex_unified_1d_array_v4s32_f32:
4896 case Intrinsic::nvvm_tex_unified_1d_array_level_v4s32_f32:
4897 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4s32_f32:
4898 case Intrinsic::nvvm_tex_unified_2d_v4s32_s32:
4899 case Intrinsic::nvvm_tex_unified_2d_v4s32_f32:
4900 case Intrinsic::nvvm_tex_unified_2d_level_v4s32_f32:
4901 case Intrinsic::nvvm_tex_unified_2d_grad_v4s32_f32:
4902 case Intrinsic::nvvm_tex_unified_2d_array_v4s32_s32:
4903 case Intrinsic::nvvm_tex_unified_2d_array_v4s32_f32:
4904 case Intrinsic::nvvm_tex_unified_2d_array_level_v4s32_f32:
4905 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4s32_f32:
4906 case Intrinsic::nvvm_tex_unified_3d_v4s32_s32:
4907 case Intrinsic::nvvm_tex_unified_3d_v4s32_f32:
4908 case Intrinsic::nvvm_tex_unified_3d_level_v4s32_f32:
4909 case Intrinsic::nvvm_tex_unified_3d_grad_v4s32_f32:
4910 case Intrinsic::nvvm_tex_unified_1d_v4u32_s32:
4911 case Intrinsic::nvvm_tex_unified_1d_v4u32_f32:
4912 case Intrinsic::nvvm_tex_unified_1d_level_v4u32_f32:
4913 case Intrinsic::nvvm_tex_unified_1d_grad_v4u32_f32:
4914 case Intrinsic::nvvm_tex_unified_1d_array_v4u32_s32:
4915 case Intrinsic::nvvm_tex_unified_1d_array_v4u32_f32:
4916 case Intrinsic::nvvm_tex_unified_1d_array_level_v4u32_f32:
4917 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4u32_f32:
4918 case Intrinsic::nvvm_tex_unified_2d_v4u32_s32:
4919 case Intrinsic::nvvm_tex_unified_2d_v4u32_f32:
4920 case Intrinsic::nvvm_tex_unified_2d_level_v4u32_f32:
4921 case Intrinsic::nvvm_tex_unified_2d_grad_v4u32_f32:
4922 case Intrinsic::nvvm_tex_unified_2d_array_v4u32_s32:
4923 case Intrinsic::nvvm_tex_unified_2d_array_v4u32_f32:
4924 case Intrinsic::nvvm_tex_unified_2d_array_level_v4u32_f32:
4925 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4u32_f32:
4926 case Intrinsic::nvvm_tex_unified_3d_v4u32_s32:
4927 case Intrinsic::nvvm_tex_unified_3d_v4u32_f32:
4928 case Intrinsic::nvvm_tex_unified_3d_level_v4u32_f32:
4929 case Intrinsic::nvvm_tex_unified_3d_grad_v4u32_f32:
4930 case Intrinsic::nvvm_tex_unified_cube_v4s32_f32:
4931 case Intrinsic::nvvm_tex_unified_cube_level_v4s32_f32:
4932 case Intrinsic::nvvm_tex_unified_cube_array_v4s32_f32:
4933 case Intrinsic::nvvm_tex_unified_cube_array_level_v4s32_f32:
4934 case Intrinsic::nvvm_tex_unified_cube_v4u32_f32:
4935 case Intrinsic::nvvm_tex_unified_cube_level_v4u32_f32:
4936 case Intrinsic::nvvm_tex_unified_cube_array_v4u32_f32:
4937 case Intrinsic::nvvm_tex_unified_cube_array_level_v4u32_f32:
4938 case Intrinsic::nvvm_tex_unified_cube_grad_v4s32_f32:
4939 case Intrinsic::nvvm_tex_unified_cube_grad_v4u32_f32:
4940 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4s32_f32:
4941 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4u32_f32:
4942 case Intrinsic::nvvm_tld4_unified_r_2d_v4s32_f32:
4943 case Intrinsic::nvvm_tld4_unified_g_2d_v4s32_f32:
4944 case Intrinsic::nvvm_tld4_unified_b_2d_v4s32_f32:
4945 case Intrinsic::nvvm_tld4_unified_a_2d_v4s32_f32:
4946 case Intrinsic::nvvm_tld4_unified_r_2d_v4u32_f32:
4947 case Intrinsic::nvvm_tld4_unified_g_2d_v4u32_f32:
4948 case Intrinsic::nvvm_tld4_unified_b_2d_v4u32_f32:
4949 case Intrinsic::nvvm_tld4_unified_a_2d_v4u32_f32:
4950 Info.opc = ISD::INTRINSIC_W_CHAIN;
4951 Info.memVT = MVT::v4i32;
4952 Info.ptrVal = nullptr;
4953 Info.offset = 0;
4954 Info.flags = MachineMemOperand::MOLoad;
4955 Info.align = Align(16);
4956 Infos.push_back(Elt: Info);
4957 return;
4958
4959 case Intrinsic::nvvm_suld_1d_i8_clamp:
4960 case Intrinsic::nvvm_suld_1d_v2i8_clamp:
4961 case Intrinsic::nvvm_suld_1d_v4i8_clamp:
4962 case Intrinsic::nvvm_suld_1d_array_i8_clamp:
4963 case Intrinsic::nvvm_suld_1d_array_v2i8_clamp:
4964 case Intrinsic::nvvm_suld_1d_array_v4i8_clamp:
4965 case Intrinsic::nvvm_suld_2d_i8_clamp:
4966 case Intrinsic::nvvm_suld_2d_v2i8_clamp:
4967 case Intrinsic::nvvm_suld_2d_v4i8_clamp:
4968 case Intrinsic::nvvm_suld_2d_array_i8_clamp:
4969 case Intrinsic::nvvm_suld_2d_array_v2i8_clamp:
4970 case Intrinsic::nvvm_suld_2d_array_v4i8_clamp:
4971 case Intrinsic::nvvm_suld_3d_i8_clamp:
4972 case Intrinsic::nvvm_suld_3d_v2i8_clamp:
4973 case Intrinsic::nvvm_suld_3d_v4i8_clamp:
4974 case Intrinsic::nvvm_suld_1d_i8_trap:
4975 case Intrinsic::nvvm_suld_1d_v2i8_trap:
4976 case Intrinsic::nvvm_suld_1d_v4i8_trap:
4977 case Intrinsic::nvvm_suld_1d_array_i8_trap:
4978 case Intrinsic::nvvm_suld_1d_array_v2i8_trap:
4979 case Intrinsic::nvvm_suld_1d_array_v4i8_trap:
4980 case Intrinsic::nvvm_suld_2d_i8_trap:
4981 case Intrinsic::nvvm_suld_2d_v2i8_trap:
4982 case Intrinsic::nvvm_suld_2d_v4i8_trap:
4983 case Intrinsic::nvvm_suld_2d_array_i8_trap:
4984 case Intrinsic::nvvm_suld_2d_array_v2i8_trap:
4985 case Intrinsic::nvvm_suld_2d_array_v4i8_trap:
4986 case Intrinsic::nvvm_suld_3d_i8_trap:
4987 case Intrinsic::nvvm_suld_3d_v2i8_trap:
4988 case Intrinsic::nvvm_suld_3d_v4i8_trap:
4989 case Intrinsic::nvvm_suld_1d_i8_zero:
4990 case Intrinsic::nvvm_suld_1d_v2i8_zero:
4991 case Intrinsic::nvvm_suld_1d_v4i8_zero:
4992 case Intrinsic::nvvm_suld_1d_array_i8_zero:
4993 case Intrinsic::nvvm_suld_1d_array_v2i8_zero:
4994 case Intrinsic::nvvm_suld_1d_array_v4i8_zero:
4995 case Intrinsic::nvvm_suld_2d_i8_zero:
4996 case Intrinsic::nvvm_suld_2d_v2i8_zero:
4997 case Intrinsic::nvvm_suld_2d_v4i8_zero:
4998 case Intrinsic::nvvm_suld_2d_array_i8_zero:
4999 case Intrinsic::nvvm_suld_2d_array_v2i8_zero:
5000 case Intrinsic::nvvm_suld_2d_array_v4i8_zero:
5001 case Intrinsic::nvvm_suld_3d_i8_zero:
5002 case Intrinsic::nvvm_suld_3d_v2i8_zero:
5003 case Intrinsic::nvvm_suld_3d_v4i8_zero:
5004 Info.opc = ISD::INTRINSIC_W_CHAIN;
5005 Info.memVT = MVT::i8;
5006 Info.ptrVal = nullptr;
5007 Info.offset = 0;
5008 Info.flags = MachineMemOperand::MOLoad;
5009 Info.align = Align(16);
5010 Infos.push_back(Elt: Info);
5011 return;
5012
5013 case Intrinsic::nvvm_suld_1d_i16_clamp:
5014 case Intrinsic::nvvm_suld_1d_v2i16_clamp:
5015 case Intrinsic::nvvm_suld_1d_v4i16_clamp:
5016 case Intrinsic::nvvm_suld_1d_array_i16_clamp:
5017 case Intrinsic::nvvm_suld_1d_array_v2i16_clamp:
5018 case Intrinsic::nvvm_suld_1d_array_v4i16_clamp:
5019 case Intrinsic::nvvm_suld_2d_i16_clamp:
5020 case Intrinsic::nvvm_suld_2d_v2i16_clamp:
5021 case Intrinsic::nvvm_suld_2d_v4i16_clamp:
5022 case Intrinsic::nvvm_suld_2d_array_i16_clamp:
5023 case Intrinsic::nvvm_suld_2d_array_v2i16_clamp:
5024 case Intrinsic::nvvm_suld_2d_array_v4i16_clamp:
5025 case Intrinsic::nvvm_suld_3d_i16_clamp:
5026 case Intrinsic::nvvm_suld_3d_v2i16_clamp:
5027 case Intrinsic::nvvm_suld_3d_v4i16_clamp:
5028 case Intrinsic::nvvm_suld_1d_i16_trap:
5029 case Intrinsic::nvvm_suld_1d_v2i16_trap:
5030 case Intrinsic::nvvm_suld_1d_v4i16_trap:
5031 case Intrinsic::nvvm_suld_1d_array_i16_trap:
5032 case Intrinsic::nvvm_suld_1d_array_v2i16_trap:
5033 case Intrinsic::nvvm_suld_1d_array_v4i16_trap:
5034 case Intrinsic::nvvm_suld_2d_i16_trap:
5035 case Intrinsic::nvvm_suld_2d_v2i16_trap:
5036 case Intrinsic::nvvm_suld_2d_v4i16_trap:
5037 case Intrinsic::nvvm_suld_2d_array_i16_trap:
5038 case Intrinsic::nvvm_suld_2d_array_v2i16_trap:
5039 case Intrinsic::nvvm_suld_2d_array_v4i16_trap:
5040 case Intrinsic::nvvm_suld_3d_i16_trap:
5041 case Intrinsic::nvvm_suld_3d_v2i16_trap:
5042 case Intrinsic::nvvm_suld_3d_v4i16_trap:
5043 case Intrinsic::nvvm_suld_1d_i16_zero:
5044 case Intrinsic::nvvm_suld_1d_v2i16_zero:
5045 case Intrinsic::nvvm_suld_1d_v4i16_zero:
5046 case Intrinsic::nvvm_suld_1d_array_i16_zero:
5047 case Intrinsic::nvvm_suld_1d_array_v2i16_zero:
5048 case Intrinsic::nvvm_suld_1d_array_v4i16_zero:
5049 case Intrinsic::nvvm_suld_2d_i16_zero:
5050 case Intrinsic::nvvm_suld_2d_v2i16_zero:
5051 case Intrinsic::nvvm_suld_2d_v4i16_zero:
5052 case Intrinsic::nvvm_suld_2d_array_i16_zero:
5053 case Intrinsic::nvvm_suld_2d_array_v2i16_zero:
5054 case Intrinsic::nvvm_suld_2d_array_v4i16_zero:
5055 case Intrinsic::nvvm_suld_3d_i16_zero:
5056 case Intrinsic::nvvm_suld_3d_v2i16_zero:
5057 case Intrinsic::nvvm_suld_3d_v4i16_zero:
5058 Info.opc = ISD::INTRINSIC_W_CHAIN;
5059 Info.memVT = MVT::i16;
5060 Info.ptrVal = nullptr;
5061 Info.offset = 0;
5062 Info.flags = MachineMemOperand::MOLoad;
5063 Info.align = Align(16);
5064 Infos.push_back(Elt: Info);
5065 return;
5066
5067 case Intrinsic::nvvm_suld_1d_i32_clamp:
5068 case Intrinsic::nvvm_suld_1d_v2i32_clamp:
5069 case Intrinsic::nvvm_suld_1d_v4i32_clamp:
5070 case Intrinsic::nvvm_suld_1d_array_i32_clamp:
5071 case Intrinsic::nvvm_suld_1d_array_v2i32_clamp:
5072 case Intrinsic::nvvm_suld_1d_array_v4i32_clamp:
5073 case Intrinsic::nvvm_suld_2d_i32_clamp:
5074 case Intrinsic::nvvm_suld_2d_v2i32_clamp:
5075 case Intrinsic::nvvm_suld_2d_v4i32_clamp:
5076 case Intrinsic::nvvm_suld_2d_array_i32_clamp:
5077 case Intrinsic::nvvm_suld_2d_array_v2i32_clamp:
5078 case Intrinsic::nvvm_suld_2d_array_v4i32_clamp:
5079 case Intrinsic::nvvm_suld_3d_i32_clamp:
5080 case Intrinsic::nvvm_suld_3d_v2i32_clamp:
5081 case Intrinsic::nvvm_suld_3d_v4i32_clamp:
5082 case Intrinsic::nvvm_suld_1d_i32_trap:
5083 case Intrinsic::nvvm_suld_1d_v2i32_trap:
5084 case Intrinsic::nvvm_suld_1d_v4i32_trap:
5085 case Intrinsic::nvvm_suld_1d_array_i32_trap:
5086 case Intrinsic::nvvm_suld_1d_array_v2i32_trap:
5087 case Intrinsic::nvvm_suld_1d_array_v4i32_trap:
5088 case Intrinsic::nvvm_suld_2d_i32_trap:
5089 case Intrinsic::nvvm_suld_2d_v2i32_trap:
5090 case Intrinsic::nvvm_suld_2d_v4i32_trap:
5091 case Intrinsic::nvvm_suld_2d_array_i32_trap:
5092 case Intrinsic::nvvm_suld_2d_array_v2i32_trap:
5093 case Intrinsic::nvvm_suld_2d_array_v4i32_trap:
5094 case Intrinsic::nvvm_suld_3d_i32_trap:
5095 case Intrinsic::nvvm_suld_3d_v2i32_trap:
5096 case Intrinsic::nvvm_suld_3d_v4i32_trap:
5097 case Intrinsic::nvvm_suld_1d_i32_zero:
5098 case Intrinsic::nvvm_suld_1d_v2i32_zero:
5099 case Intrinsic::nvvm_suld_1d_v4i32_zero:
5100 case Intrinsic::nvvm_suld_1d_array_i32_zero:
5101 case Intrinsic::nvvm_suld_1d_array_v2i32_zero:
5102 case Intrinsic::nvvm_suld_1d_array_v4i32_zero:
5103 case Intrinsic::nvvm_suld_2d_i32_zero:
5104 case Intrinsic::nvvm_suld_2d_v2i32_zero:
5105 case Intrinsic::nvvm_suld_2d_v4i32_zero:
5106 case Intrinsic::nvvm_suld_2d_array_i32_zero:
5107 case Intrinsic::nvvm_suld_2d_array_v2i32_zero:
5108 case Intrinsic::nvvm_suld_2d_array_v4i32_zero:
5109 case Intrinsic::nvvm_suld_3d_i32_zero:
5110 case Intrinsic::nvvm_suld_3d_v2i32_zero:
5111 case Intrinsic::nvvm_suld_3d_v4i32_zero:
5112 Info.opc = ISD::INTRINSIC_W_CHAIN;
5113 Info.memVT = MVT::i32;
5114 Info.ptrVal = nullptr;
5115 Info.offset = 0;
5116 Info.flags = MachineMemOperand::MOLoad;
5117 Info.align = Align(16);
5118 Infos.push_back(Elt: Info);
5119 return;
5120
5121 case Intrinsic::nvvm_suld_1d_i64_clamp:
5122 case Intrinsic::nvvm_suld_1d_v2i64_clamp:
5123 case Intrinsic::nvvm_suld_1d_array_i64_clamp:
5124 case Intrinsic::nvvm_suld_1d_array_v2i64_clamp:
5125 case Intrinsic::nvvm_suld_2d_i64_clamp:
5126 case Intrinsic::nvvm_suld_2d_v2i64_clamp:
5127 case Intrinsic::nvvm_suld_2d_array_i64_clamp:
5128 case Intrinsic::nvvm_suld_2d_array_v2i64_clamp:
5129 case Intrinsic::nvvm_suld_3d_i64_clamp:
5130 case Intrinsic::nvvm_suld_3d_v2i64_clamp:
5131 case Intrinsic::nvvm_suld_1d_i64_trap:
5132 case Intrinsic::nvvm_suld_1d_v2i64_trap:
5133 case Intrinsic::nvvm_suld_1d_array_i64_trap:
5134 case Intrinsic::nvvm_suld_1d_array_v2i64_trap:
5135 case Intrinsic::nvvm_suld_2d_i64_trap:
5136 case Intrinsic::nvvm_suld_2d_v2i64_trap:
5137 case Intrinsic::nvvm_suld_2d_array_i64_trap:
5138 case Intrinsic::nvvm_suld_2d_array_v2i64_trap:
5139 case Intrinsic::nvvm_suld_3d_i64_trap:
5140 case Intrinsic::nvvm_suld_3d_v2i64_trap:
5141 case Intrinsic::nvvm_suld_1d_i64_zero:
5142 case Intrinsic::nvvm_suld_1d_v2i64_zero:
5143 case Intrinsic::nvvm_suld_1d_array_i64_zero:
5144 case Intrinsic::nvvm_suld_1d_array_v2i64_zero:
5145 case Intrinsic::nvvm_suld_2d_i64_zero:
5146 case Intrinsic::nvvm_suld_2d_v2i64_zero:
5147 case Intrinsic::nvvm_suld_2d_array_i64_zero:
5148 case Intrinsic::nvvm_suld_2d_array_v2i64_zero:
5149 case Intrinsic::nvvm_suld_3d_i64_zero:
5150 case Intrinsic::nvvm_suld_3d_v2i64_zero:
5151 Info.opc = ISD::INTRINSIC_W_CHAIN;
5152 Info.memVT = MVT::i64;
5153 Info.ptrVal = nullptr;
5154 Info.offset = 0;
5155 Info.flags = MachineMemOperand::MOLoad;
5156 Info.align = Align(16);
5157 Infos.push_back(Elt: Info);
5158 return;
5159
5160 case Intrinsic::nvvm_tcgen05_ld_16x64b_x1:
5161 case Intrinsic::nvvm_tcgen05_ld_32x32b_x1:
5162 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x1: {
5163 Info.opc = ISD::INTRINSIC_W_CHAIN;
5164 Info.memVT = MVT::v1i32;
5165 Info.ptrVal = I.getArgOperand(i: 0);
5166 Info.offset = 0;
5167 Info.flags = MachineMemOperand::MOLoad;
5168 Info.align.reset();
5169 Infos.push_back(Elt: Info);
5170 return;
5171 }
5172
5173 case Intrinsic::nvvm_tcgen05_ld_16x64b_x2:
5174 case Intrinsic::nvvm_tcgen05_ld_16x128b_x1:
5175 case Intrinsic::nvvm_tcgen05_ld_32x32b_x2:
5176 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x2:
5177 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x2_i32:
5178 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x2_i32: {
5179 Info.opc = ISD::INTRINSIC_W_CHAIN;
5180 Info.memVT = MVT::v2i32;
5181 Info.ptrVal = I.getArgOperand(i: 0);
5182 Info.offset = 0;
5183 Info.flags = MachineMemOperand::MOLoad;
5184 Info.align.reset();
5185 Infos.push_back(Elt: Info);
5186 return;
5187 }
5188
5189 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x2_f32:
5190 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x2_f32: {
5191 Info.opc = ISD::INTRINSIC_W_CHAIN;
5192 Info.memVT = MVT::v2f32;
5193 Info.ptrVal = I.getArgOperand(i: 0);
5194 Info.offset = 0;
5195 Info.flags = MachineMemOperand::MOLoad;
5196 Info.align.reset();
5197 Infos.push_back(Elt: Info);
5198 return;
5199 }
5200
5201 case Intrinsic::nvvm_tcgen05_ld_16x64b_x4:
5202 case Intrinsic::nvvm_tcgen05_ld_16x128b_x2:
5203 case Intrinsic::nvvm_tcgen05_ld_32x32b_x4:
5204 case Intrinsic::nvvm_tcgen05_ld_16x256b_x1:
5205 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x4:
5206 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x4_i32:
5207 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x4_i32: {
5208 Info.opc = ISD::INTRINSIC_W_CHAIN;
5209 Info.memVT = MVT::v4i32;
5210 Info.ptrVal = I.getArgOperand(i: 0);
5211 Info.offset = 0;
5212 Info.flags = MachineMemOperand::MOLoad;
5213 Info.align.reset();
5214 Infos.push_back(Elt: Info);
5215 return;
5216 }
5217
5218 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x4_f32:
5219 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x4_f32: {
5220 Info.opc = ISD::INTRINSIC_W_CHAIN;
5221 Info.memVT = MVT::v4f32;
5222 Info.ptrVal = I.getArgOperand(i: 0);
5223 Info.offset = 0;
5224 Info.flags = MachineMemOperand::MOLoad;
5225 Info.align.reset();
5226 Infos.push_back(Elt: Info);
5227 return;
5228 }
5229
5230 case Intrinsic::nvvm_tcgen05_ld_16x64b_x8:
5231 case Intrinsic::nvvm_tcgen05_ld_16x128b_x4:
5232 case Intrinsic::nvvm_tcgen05_ld_16x256b_x2:
5233 case Intrinsic::nvvm_tcgen05_ld_32x32b_x8:
5234 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x8:
5235 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x8_i32:
5236 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x8_i32: {
5237 Info.opc = ISD::INTRINSIC_W_CHAIN;
5238 Info.memVT = MVT::v8i32;
5239 Info.ptrVal = I.getArgOperand(i: 0);
5240 Info.offset = 0;
5241 Info.flags = MachineMemOperand::MOLoad;
5242 Info.align.reset();
5243 Infos.push_back(Elt: Info);
5244 return;
5245 }
5246
5247 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x8_f32:
5248 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x8_f32: {
5249 Info.opc = ISD::INTRINSIC_W_CHAIN;
5250 Info.memVT = MVT::v8f32;
5251 Info.ptrVal = I.getArgOperand(i: 0);
5252 Info.offset = 0;
5253 Info.flags = MachineMemOperand::MOLoad;
5254 Info.align.reset();
5255 Infos.push_back(Elt: Info);
5256 return;
5257 }
5258
5259 case Intrinsic::nvvm_tcgen05_ld_16x64b_x16:
5260 case Intrinsic::nvvm_tcgen05_ld_16x128b_x8:
5261 case Intrinsic::nvvm_tcgen05_ld_16x256b_x4:
5262 case Intrinsic::nvvm_tcgen05_ld_32x32b_x16:
5263 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x16:
5264 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x16_i32:
5265 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x16_i32: {
5266 Info.opc = ISD::INTRINSIC_W_CHAIN;
5267 Info.memVT = MVT::v16i32;
5268 Info.ptrVal = I.getArgOperand(i: 0);
5269 Info.offset = 0;
5270 Info.flags = MachineMemOperand::MOLoad;
5271 Info.align.reset();
5272 Infos.push_back(Elt: Info);
5273 return;
5274 }
5275
5276 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x16_f32:
5277 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x16_f32: {
5278 Info.opc = ISD::INTRINSIC_W_CHAIN;
5279 Info.memVT = MVT::v16f32;
5280 Info.ptrVal = I.getArgOperand(i: 0);
5281 Info.offset = 0;
5282 Info.flags = MachineMemOperand::MOLoad;
5283 Info.align.reset();
5284 Infos.push_back(Elt: Info);
5285 return;
5286 }
5287
5288 case Intrinsic::nvvm_tcgen05_ld_16x64b_x32:
5289 case Intrinsic::nvvm_tcgen05_ld_16x128b_x16:
5290 case Intrinsic::nvvm_tcgen05_ld_16x256b_x8:
5291 case Intrinsic::nvvm_tcgen05_ld_32x32b_x32:
5292 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x32:
5293 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x32_i32:
5294 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x32_i32: {
5295 Info.opc = ISD::INTRINSIC_W_CHAIN;
5296 Info.memVT = MVT::v32i32;
5297 Info.ptrVal = I.getArgOperand(i: 0);
5298 Info.offset = 0;
5299 Info.flags = MachineMemOperand::MOLoad;
5300 Info.align.reset();
5301 Infos.push_back(Elt: Info);
5302 return;
5303 }
5304
5305 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x32_f32:
5306 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x32_f32: {
5307 Info.opc = ISD::INTRINSIC_W_CHAIN;
5308 Info.memVT = MVT::v32f32;
5309 Info.ptrVal = I.getArgOperand(i: 0);
5310 Info.offset = 0;
5311 Info.flags = MachineMemOperand::MOLoad;
5312 Info.align.reset();
5313 Infos.push_back(Elt: Info);
5314 return;
5315 }
5316
5317 case Intrinsic::nvvm_tcgen05_ld_16x64b_x64:
5318 case Intrinsic::nvvm_tcgen05_ld_16x128b_x32:
5319 case Intrinsic::nvvm_tcgen05_ld_16x256b_x16:
5320 case Intrinsic::nvvm_tcgen05_ld_32x32b_x64:
5321 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x64:
5322 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x64_i32:
5323 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x64_i32: {
5324 Info.opc = ISD::INTRINSIC_W_CHAIN;
5325 Info.memVT = MVT::v64i32;
5326 Info.ptrVal = I.getArgOperand(i: 0);
5327 Info.offset = 0;
5328 Info.flags = MachineMemOperand::MOLoad;
5329 Info.align.reset();
5330 Infos.push_back(Elt: Info);
5331 return;
5332 }
5333
5334 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x64_f32:
5335 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x64_f32: {
5336 Info.opc = ISD::INTRINSIC_W_CHAIN;
5337 Info.memVT = MVT::v64f32;
5338 Info.ptrVal = I.getArgOperand(i: 0);
5339 Info.offset = 0;
5340 Info.flags = MachineMemOperand::MOLoad;
5341 Info.align.reset();
5342 Infos.push_back(Elt: Info);
5343 return;
5344 }
5345
5346 case Intrinsic::nvvm_tcgen05_ld_16x64b_x128:
5347 case Intrinsic::nvvm_tcgen05_ld_16x128b_x64:
5348 case Intrinsic::nvvm_tcgen05_ld_16x256b_x32:
5349 case Intrinsic::nvvm_tcgen05_ld_32x32b_x128:
5350 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x128:
5351 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x128_i32:
5352 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x128_i32: {
5353 Info.opc = ISD::INTRINSIC_W_CHAIN;
5354 Info.memVT = MVT::v128i32;
5355 Info.ptrVal = I.getArgOperand(i: 0);
5356 Info.offset = 0;
5357 Info.flags = MachineMemOperand::MOLoad;
5358 Info.align.reset();
5359 Infos.push_back(Elt: Info);
5360 return;
5361 }
5362
5363 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x128_f32:
5364 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x128_f32: {
5365 Info.opc = ISD::INTRINSIC_W_CHAIN;
5366 Info.memVT = MVT::v128f32;
5367 Info.ptrVal = I.getArgOperand(i: 0);
5368 Info.offset = 0;
5369 Info.flags = MachineMemOperand::MOLoad;
5370 Info.align.reset();
5371 Infos.push_back(Elt: Info);
5372 return;
5373 }
5374
5375 case Intrinsic::nvvm_tcgen05_st_16x64b_x1:
5376 case Intrinsic::nvvm_tcgen05_st_32x32b_x1:
5377 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x1: {
5378 Info.opc = ISD::INTRINSIC_VOID;
5379 Info.memVT = MVT::i32;
5380 Info.ptrVal = I.getArgOperand(i: 0);
5381 Info.offset = 0;
5382 Info.flags = MachineMemOperand::MOStore;
5383 Info.align.reset();
5384 Infos.push_back(Elt: Info);
5385 return;
5386 }
5387
5388 case Intrinsic::nvvm_tcgen05_st_16x64b_x2:
5389 case Intrinsic::nvvm_tcgen05_st_16x128b_x1:
5390 case Intrinsic::nvvm_tcgen05_st_32x32b_x2:
5391 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x2: {
5392 Info.opc = ISD::INTRINSIC_VOID;
5393 Info.memVT = MVT::v2i32;
5394 Info.ptrVal = I.getArgOperand(i: 0);
5395 Info.offset = 0;
5396 Info.flags = MachineMemOperand::MOStore;
5397 Info.align.reset();
5398 Infos.push_back(Elt: Info);
5399 return;
5400 }
5401
5402 case Intrinsic::nvvm_tcgen05_st_16x64b_x4:
5403 case Intrinsic::nvvm_tcgen05_st_16x128b_x2:
5404 case Intrinsic::nvvm_tcgen05_st_16x256b_x1:
5405 case Intrinsic::nvvm_tcgen05_st_32x32b_x4:
5406 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x4: {
5407 Info.opc = ISD::INTRINSIC_VOID;
5408 Info.memVT = MVT::v4i32;
5409 Info.ptrVal = I.getArgOperand(i: 0);
5410 Info.offset = 0;
5411 Info.flags = MachineMemOperand::MOStore;
5412 Info.align.reset();
5413 Infos.push_back(Elt: Info);
5414 return;
5415 }
5416
5417 case Intrinsic::nvvm_tcgen05_st_16x64b_x8:
5418 case Intrinsic::nvvm_tcgen05_st_16x128b_x4:
5419 case Intrinsic::nvvm_tcgen05_st_16x256b_x2:
5420 case Intrinsic::nvvm_tcgen05_st_32x32b_x8:
5421 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x8: {
5422 Info.opc = ISD::INTRINSIC_VOID;
5423 Info.memVT = MVT::v8i32;
5424 Info.ptrVal = I.getArgOperand(i: 0);
5425 Info.offset = 0;
5426 Info.flags = MachineMemOperand::MOStore;
5427 Info.align.reset();
5428 Infos.push_back(Elt: Info);
5429 return;
5430 }
5431
5432 case Intrinsic::nvvm_tcgen05_st_16x64b_x16:
5433 case Intrinsic::nvvm_tcgen05_st_16x128b_x8:
5434 case Intrinsic::nvvm_tcgen05_st_16x256b_x4:
5435 case Intrinsic::nvvm_tcgen05_st_32x32b_x16:
5436 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x16: {
5437 Info.opc = ISD::INTRINSIC_VOID;
5438 Info.memVT = MVT::v16i32;
5439 Info.ptrVal = I.getArgOperand(i: 0);
5440 Info.offset = 0;
5441 Info.flags = MachineMemOperand::MOStore;
5442 Info.align.reset();
5443 Infos.push_back(Elt: Info);
5444 return;
5445 }
5446
5447 case Intrinsic::nvvm_tcgen05_st_16x64b_x32:
5448 case Intrinsic::nvvm_tcgen05_st_16x128b_x16:
5449 case Intrinsic::nvvm_tcgen05_st_16x256b_x8:
5450 case Intrinsic::nvvm_tcgen05_st_32x32b_x32:
5451 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x32: {
5452 Info.opc = ISD::INTRINSIC_VOID;
5453 Info.memVT = MVT::v32i32;
5454 Info.ptrVal = I.getArgOperand(i: 0);
5455 Info.offset = 0;
5456 Info.flags = MachineMemOperand::MOStore;
5457 Info.align.reset();
5458 Infos.push_back(Elt: Info);
5459 return;
5460 }
5461
5462 case Intrinsic::nvvm_tcgen05_st_16x64b_x64:
5463 case Intrinsic::nvvm_tcgen05_st_16x128b_x32:
5464 case Intrinsic::nvvm_tcgen05_st_16x256b_x16:
5465 case Intrinsic::nvvm_tcgen05_st_32x32b_x64:
5466 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x64: {
5467 Info.opc = ISD::INTRINSIC_VOID;
5468 Info.memVT = MVT::v64i32;
5469 Info.ptrVal = I.getArgOperand(i: 0);
5470 Info.offset = 0;
5471 Info.flags = MachineMemOperand::MOStore;
5472 Info.align.reset();
5473 Infos.push_back(Elt: Info);
5474 return;
5475 }
5476
5477 case Intrinsic::nvvm_tcgen05_st_16x64b_x128:
5478 case Intrinsic::nvvm_tcgen05_st_16x128b_x64:
5479 case Intrinsic::nvvm_tcgen05_st_16x256b_x32:
5480 case Intrinsic::nvvm_tcgen05_st_32x32b_x128:
5481 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x128: {
5482 Info.opc = ISD::INTRINSIC_VOID;
5483 Info.memVT = MVT::v128i32;
5484 Info.ptrVal = I.getArgOperand(i: 0);
5485 Info.offset = 0;
5486 Info.flags = MachineMemOperand::MOStore;
5487 Info.align.reset();
5488 Infos.push_back(Elt: Info);
5489 return;
5490 }
5491 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1:
5492 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1:
5493 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1:
5494 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1:
5495 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1:
5496 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1:
5497 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift:
5498 case Intrinsic::
5499 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift:
5500 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1:
5501 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1:
5502 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift:
5503 case Intrinsic::
5504 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift: {
5505 // We are reading and writing back to TMem
5506 Info.opc = ISD::INTRINSIC_VOID;
5507 Info.memVT = MVT::v4i32;
5508 Info.ptrVal = I.getArgOperand(i: 0);
5509 Info.offset = 0;
5510 Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOStore;
5511 Info.align = Align(16);
5512 Infos.push_back(Elt: Info);
5513 return;
5514 }
5515
5516 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2:
5517 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2:
5518 case Intrinsic::nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2:
5519 case Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2:
5520 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2:
5521 case Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2:
5522 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2:
5523 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2:
5524 case Intrinsic::nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift:
5525 case Intrinsic::
5526 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift:
5527 case Intrinsic::nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift:
5528 case Intrinsic::
5529 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift: {
5530 // We are reading and writing back to TMem
5531 Info.opc = ISD::INTRINSIC_VOID;
5532 Info.memVT = MVT::v8i32;
5533 Info.ptrVal = I.getArgOperand(i: 0);
5534 Info.offset = 0;
5535 Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOStore;
5536 Info.align = Align(16);
5537 Infos.push_back(Elt: Info);
5538 return;
5539 }
5540 }
5541}
5542
5543// Helper for getting a function parameter name. Name is composed from
5544// its index and the function name. Negative index corresponds to special
5545// parameter (unsized array) used for passing variable arguments.
5546std::string NVPTXTargetLowering::getParamName(const Function *F,
5547 int Idx) const {
5548 std::string ParamName;
5549 raw_string_ostream ParamStr(ParamName);
5550
5551 ParamStr << getTargetMachine().getSymbol(GV: F)->getName();
5552 if (Idx < 0)
5553 ParamStr << "_vararg";
5554 else
5555 ParamStr << "_param_" << Idx;
5556
5557 return ParamName;
5558}
5559
5560/// isLegalAddressingMode - Return true if the addressing mode represented
5561/// by AM is legal for this target, for a load/store of the specified type.
5562/// Used to guide target specific optimizations, like loop strength reduction
5563/// (LoopStrengthReduce.cpp) and memory optimization for address mode
5564/// (CodeGenPrepare.cpp)
5565bool NVPTXTargetLowering::isLegalAddressingMode(const DataLayout &DL,
5566 const AddrMode &AM, Type *Ty,
5567 unsigned AS, Instruction *I) const {
5568 // AddrMode - This represents an addressing mode of:
5569 // BaseGV + BaseOffs + BaseReg + Scale*ScaleReg
5570 //
5571 // The legal address modes are
5572 // - [avar]
5573 // - [areg]
5574 // - [areg+immoff]
5575 // - [immAddr]
5576
5577 // immoff must fit in a signed 32-bit int
5578 if (!APInt(64, AM.BaseOffs).isSignedIntN(N: 32))
5579 return false;
5580
5581 if (AM.BaseGV)
5582 return !AM.BaseOffs && !AM.HasBaseReg && !AM.Scale;
5583
5584 switch (AM.Scale) {
5585 case 0: // "r", "r+i" or "i" is allowed
5586 break;
5587 case 1:
5588 if (AM.HasBaseReg) // "r+r+i" or "r+r" is not allowed.
5589 return false;
5590 // Otherwise we have r+i.
5591 break;
5592 default:
5593 // No scale > 1 is allowed
5594 return false;
5595 }
5596 return true;
5597}
5598
5599//===----------------------------------------------------------------------===//
5600// NVPTX Inline Assembly Support
5601//===----------------------------------------------------------------------===//
5602
5603/// getConstraintType - Given a constraint letter, return the type of
5604/// constraint it is for this target.
5605NVPTXTargetLowering::ConstraintType
5606NVPTXTargetLowering::getConstraintType(StringRef Constraint) const {
5607 if (Constraint.size() == 1) {
5608 switch (Constraint[0]) {
5609 default:
5610 break;
5611 case 'b':
5612 case 'r':
5613 case 'h':
5614 case 'c':
5615 case 'l':
5616 case 'f':
5617 case 'd':
5618 case 'q':
5619 case '0':
5620 case 'N':
5621 return C_RegisterClass;
5622 }
5623 }
5624 return TargetLowering::getConstraintType(Constraint);
5625}
5626
5627std::pair<unsigned, const TargetRegisterClass *>
5628NVPTXTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
5629 StringRef Constraint,
5630 MVT VT) const {
5631 if (Constraint.size() == 1) {
5632 switch (Constraint[0]) {
5633 case 'b':
5634 return std::make_pair(x: 0U, y: &NVPTX::B1RegClass);
5635 case 'c':
5636 case 'h':
5637 return std::make_pair(x: 0U, y: &NVPTX::B16RegClass);
5638 case 'r':
5639 case 'f':
5640 return std::make_pair(x: 0U, y: &NVPTX::B32RegClass);
5641 case 'l':
5642 case 'N':
5643 case 'd':
5644 return std::make_pair(x: 0U, y: &NVPTX::B64RegClass);
5645 case 'q': {
5646 if (STI.getSmVersion() < 70)
5647 report_fatal_error(reason: "Inline asm with 128 bit operands is only "
5648 "supported for sm_70 and higher!");
5649 return std::make_pair(x: 0U, y: &NVPTX::B128RegClass);
5650 }
5651 }
5652 }
5653 return TargetLowering::getRegForInlineAsmConstraint(TRI, Constraint, VT);
5654}
5655
5656//===----------------------------------------------------------------------===//
5657// NVPTX DAG Combining
5658//===----------------------------------------------------------------------===//
5659
5660bool NVPTXTargetLowering::allowFMA(MachineFunction &MF,
5661 CodeGenOptLevel OptLevel) const {
5662 // Always honor command-line argument
5663 if (FMAContractLevelOpt.getNumOccurrences() > 0)
5664 return FMAContractLevelOpt > 0;
5665
5666 // Do not contract if we're not optimizing the code.
5667 if (OptLevel == CodeGenOptLevel::None)
5668 return false;
5669
5670 // Honor TargetOptions flags that explicitly say fusion is okay.
5671 if (MF.getTarget().Options.AllowFPOpFusion == FPOpFusion::Fast)
5672 return true;
5673
5674 return false;
5675}
5676
5677static bool isConstZero(const SDValue &Operand) {
5678 const auto *Const = dyn_cast<ConstantSDNode>(Val: Operand);
5679 return Const && Const->getZExtValue() == 0;
5680}
5681
5682/// PerformADDCombineWithOperands - Try DAG combinations for an ADD with
5683/// operands N0 and N1. This is a helper for PerformADDCombine that is
5684/// called with the default operands, and if that fails, with commuted
5685/// operands.
5686static SDValue
5687PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
5688 TargetLowering::DAGCombinerInfo &DCI) {
5689 EVT VT = N0.getValueType();
5690
5691 // Since integer multiply-add costs the same as integer multiply
5692 // but is more costly than integer add, do the fusion only when
5693 // the mul is only used in the add.
5694 // TODO: this may not be true for later architectures, consider relaxing this
5695 if (!N0.getNode()->hasOneUse())
5696 return SDValue();
5697
5698 // fold (add (select cond, 0, (mul a, b)), c)
5699 // -> (select cond, c, (add (mul a, b), c))
5700 //
5701 if (N0.getOpcode() == ISD::SELECT) {
5702 unsigned ZeroOpNum;
5703 if (isConstZero(Operand: N0->getOperand(Num: 1)))
5704 ZeroOpNum = 1;
5705 else if (isConstZero(Operand: N0->getOperand(Num: 2)))
5706 ZeroOpNum = 2;
5707 else
5708 return SDValue();
5709
5710 SDValue M = N0->getOperand(Num: (ZeroOpNum == 1) ? 2 : 1);
5711 if (M->getOpcode() != ISD::MUL || !M.getNode()->hasOneUse())
5712 return SDValue();
5713
5714 SDLoc DL(N);
5715 SDValue Mul =
5716 DCI.DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: M->getOperand(Num: 0), N2: M->getOperand(Num: 1));
5717 SDValue MAD = DCI.DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Mul, N2: N1);
5718 return DCI.DAG.getSelect(DL: SDLoc(N), VT, Cond: N0->getOperand(Num: 0),
5719 LHS: ((ZeroOpNum == 1) ? N1 : MAD),
5720 RHS: ((ZeroOpNum == 1) ? MAD : N1));
5721 }
5722
5723 return SDValue();
5724}
5725
5726SDValue NVPTXTargetLowering::performFADDCombineWithOperands(
5727 SDNode *N, SDValue N0, SDValue N1, TargetLowering::DAGCombinerInfo &DCI,
5728 CodeGenOptLevel OptLevel) const {
5729 EVT VT = N0.getValueType();
5730 if (N0.getOpcode() == ISD::FMUL) {
5731 if (!(allowFMA(MF&: DCI.DAG.getMachineFunction(), OptLevel) ||
5732 (N->getFlags().hasAllowContract() &&
5733 N0->getFlags().hasAllowContract())))
5734 return SDValue();
5735
5736 // For floating point:
5737 // Do the fusion only when the mul has less than 5 uses and all
5738 // are add.
5739 // The heuristic is that if a use is not an add, then that use
5740 // cannot be fused into fma, therefore mul is still needed anyway.
5741 // If there are more than 4 uses, even if they are all add, fusing
5742 // them will increase register pressue.
5743 //
5744 int numUses = 0;
5745 int nonAddCount = 0;
5746 for (const SDNode *User : N0.getNode()->users()) {
5747 numUses++;
5748 if (User->getOpcode() != ISD::FADD)
5749 ++nonAddCount;
5750 if (numUses >= 5)
5751 return SDValue();
5752 }
5753 if (nonAddCount) {
5754 int orderNo = N->getIROrder();
5755 int orderNo2 = N0.getNode()->getIROrder();
5756 // simple heuristics here for considering potential register
5757 // pressure, the logics here is that the differnce are used
5758 // to measure the distance between def and use, the longer distance
5759 // more likely cause register pressure.
5760 if (orderNo - orderNo2 < 500)
5761 return SDValue();
5762
5763 // Now, check if at least one of the FMUL's operands is live beyond the
5764 // node N, which guarantees that the FMA will not increase register
5765 // pressure at node N.
5766 bool opIsLive = false;
5767 const SDNode *left = N0.getOperand(i: 0).getNode();
5768 const SDNode *right = N0.getOperand(i: 1).getNode();
5769
5770 if (isa<ConstantSDNode>(Val: left) || isa<ConstantSDNode>(Val: right))
5771 opIsLive = true;
5772
5773 if (!opIsLive)
5774 for (const SDNode *User : left->users()) {
5775 int orderNo3 = User->getIROrder();
5776 if (orderNo3 > orderNo) {
5777 opIsLive = true;
5778 break;
5779 }
5780 }
5781
5782 if (!opIsLive)
5783 for (const SDNode *User : right->users()) {
5784 int orderNo3 = User->getIROrder();
5785 if (orderNo3 > orderNo) {
5786 opIsLive = true;
5787 break;
5788 }
5789 }
5790
5791 if (!opIsLive)
5792 return SDValue();
5793 }
5794
5795 return DCI.DAG.getNode(Opcode: ISD::FMA, DL: SDLoc(N), VT, N1: N0.getOperand(i: 0),
5796 N2: N0.getOperand(i: 1), N3: N1);
5797 }
5798
5799 return SDValue();
5800}
5801
5802/// Fold unpacking movs into a load by increasing the number of return values.
5803///
5804/// ex:
5805/// L: v2f16,ch = load <p>
5806/// a: f16 = extractelt L:0, 0
5807/// b: f16 = extractelt L:0, 1
5808/// use(a, b)
5809///
5810/// ...is turned into...
5811///
5812/// L: f16,f16,ch = LoadV2 <p>
5813/// use(L:0, L:1)
5814static SDValue
5815combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
5816 // Don't run this optimization before the legalizer
5817 if (!DCI.isAfterLegalizeDAG())
5818 return SDValue();
5819
5820 EVT ElementVT = N->getValueType(ResNo: 0);
5821 // Avoid non-packed types and v4i8
5822 if (!NVPTX::isPackedVectorTy(VT: ElementVT) || ElementVT == MVT::v4i8)
5823 return SDValue();
5824
5825 // Check whether all outputs are either used by an extractelt or are
5826 // glue/chain nodes
5827 if (!all_of(Range: N->uses(), P: [&](SDUse &U) {
5828 // Skip glue, chain nodes
5829 if (U.getValueType() == MVT::Glue || U.getValueType() == MVT::Other)
5830 return true;
5831 if (U.getUser()->getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
5832 if (N->getOpcode() != ISD::LOAD)
5833 return true;
5834 // Since this is an ISD::LOAD, check all extractelts are used. If
5835 // any are not used, we don't want to defeat another optimization that
5836 // will narrow the load.
5837 //
5838 // For example:
5839 //
5840 // L: v2f16,ch = load <p>
5841 // e0: f16 = extractelt L:0, 0
5842 // e1: f16 = extractelt L:0, 1 <-- unused
5843 // store e0
5844 //
5845 // Can be optimized by DAGCombiner to:
5846 //
5847 // L: f16,ch = load <p>
5848 // store L:0
5849 return !U.getUser()->use_empty();
5850 }
5851
5852 // Otherwise, this use prevents us from splitting a value.
5853 return false;
5854 }))
5855 return SDValue();
5856
5857 auto *LD = cast<MemSDNode>(Val: N);
5858 SDLoc DL(LD);
5859
5860 // the new opcode after we double the number of operands
5861 unsigned Opcode;
5862 SmallVector<SDValue> Operands(LD->ops());
5863 unsigned OldNumOutputs; // non-glue, non-chain outputs
5864 switch (LD->getOpcode()) {
5865 case ISD::LOAD:
5866 OldNumOutputs = 1;
5867 // Any packed type is legal, so the legalizer will not have lowered
5868 // ISD::LOAD -> NVPTXISD::Load (unless it's under-aligned). We have to do it
5869 // here.
5870 Opcode = NVPTXISD::LoadV2;
5871 // append a "full" used bytes mask operand right before the extension type
5872 // operand, signifying that all bytes are used.
5873 Operands.push_back(Elt: DCI.DAG.getConstant(UINT32_MAX, DL, VT: MVT::i32));
5874 Operands.push_back(Elt: DCI.DAG.getIntPtrConstant(
5875 Val: cast<LoadSDNode>(Val: LD)->getExtensionType(), DL));
5876 break;
5877 case NVPTXISD::LoadV2:
5878 OldNumOutputs = 2;
5879 Opcode = NVPTXISD::LoadV4;
5880 break;
5881 case NVPTXISD::LoadV4:
5882 // V8 is only supported for f32/i32. Don't forget, we're not changing the
5883 // load size here. This is already a 256-bit load.
5884 if (ElementVT != MVT::v2f32 && ElementVT != MVT::v2i32)
5885 return SDValue();
5886 OldNumOutputs = 4;
5887 Opcode = NVPTXISD::LoadV8;
5888 break;
5889 case NVPTXISD::LoadV8:
5890 // PTX doesn't support the next doubling of outputs
5891 return SDValue();
5892 }
5893
5894 // the non-glue, non-chain outputs in the new load
5895 const unsigned NewNumOutputs = OldNumOutputs * 2;
5896 SmallVector<EVT> NewVTs(NewNumOutputs, ElementVT.getVectorElementType());
5897 // add remaining chain and glue values
5898 NewVTs.append(in_start: LD->value_begin() + OldNumOutputs, in_end: LD->value_end());
5899
5900 // Create the new load
5901 SDValue NewLoad = DCI.DAG.getMemIntrinsicNode(
5902 Opcode, dl: DL, VTList: DCI.DAG.getVTList(VTs: NewVTs), Ops: Operands, MemVT: LD->getMemoryVT(),
5903 MMO: LD->getMemOperand());
5904
5905 // Now we use a combination of BUILD_VECTORs and a MERGE_VALUES node to keep
5906 // the outputs the same. These nodes will be optimized away in later
5907 // DAGCombiner iterations.
5908 SmallVector<SDValue> Results;
5909 for (unsigned I : seq(Size: OldNumOutputs))
5910 Results.push_back(Elt: DCI.DAG.getBuildVector(
5911 VT: ElementVT, DL, Ops: {NewLoad.getValue(R: I * 2), NewLoad.getValue(R: I * 2 + 1)}));
5912 // Add remaining chain and glue nodes
5913 for (unsigned I : seq(Size: NewLoad->getNumValues() - NewNumOutputs))
5914 Results.push_back(Elt: NewLoad.getValue(R: NewNumOutputs + I));
5915
5916 return DCI.DAG.getMergeValues(Ops: Results, dl: DL);
5917}
5918
5919/// Fold packing movs into a store.
5920///
5921/// ex:
5922/// v1: v2f16 = BUILD_VECTOR a:f16, b:f16
5923/// v2: v2f16 = BUILD_VECTOR c:f16, d:f16
5924/// StoreV2 v1, v2
5925///
5926/// ...is turned into...
5927///
5928/// StoreV4 a, b, c, d
5929static SDValue combinePackingMovIntoStore(SDNode *N,
5930 TargetLowering::DAGCombinerInfo &DCI,
5931 unsigned Front, unsigned Back) {
5932 // We want to run this as late as possible since other optimizations may
5933 // eliminate the BUILD_VECTORs.
5934 if (!DCI.isAfterLegalizeDAG())
5935 return SDValue();
5936
5937 // Get the type of the operands being stored.
5938 EVT ElementVT = N->getOperand(Num: Front).getValueType();
5939
5940 // Avoid non-packed types and v4i8
5941 if (!NVPTX::isPackedVectorTy(VT: ElementVT) || ElementVT == MVT::v4i8)
5942 return SDValue();
5943
5944 auto *ST = cast<MemSDNode>(Val: N);
5945
5946 // The new opcode after we double the number of operands.
5947 unsigned Opcode;
5948 switch (N->getOpcode()) {
5949 case ISD::STORE:
5950 // Any packed type is legal, so the legalizer will not have lowered
5951 // ISD::STORE -> NVPTXISD::Store (unless it's under-aligned). We have to do
5952 // it here.
5953 Opcode = NVPTXISD::StoreV2;
5954 break;
5955 case NVPTXISD::StoreV2:
5956 Opcode = NVPTXISD::StoreV4;
5957 break;
5958 case NVPTXISD::StoreV4:
5959 // V8 is only supported for f32/i32. Don't forget, we're not changing the
5960 // store size here. This is already a 256-bit store.
5961 if (ElementVT != MVT::v2f32 && ElementVT != MVT::v2i32)
5962 return SDValue();
5963 Opcode = NVPTXISD::StoreV8;
5964 break;
5965 case NVPTXISD::StoreV8:
5966 // PTX doesn't support the next doubling of operands
5967 return SDValue();
5968 default:
5969 llvm_unreachable("Unhandled store opcode");
5970 }
5971
5972 // Scan the operands and if they're all BUILD_VECTORs, we'll have gathered
5973 // their elements.
5974 SmallVector<SDValue, 4> Operands(N->ops().take_front(N: Front));
5975 for (SDValue BV : N->ops().drop_front(N: Front).drop_back(N: Back)) {
5976 if (BV.getOpcode() != ISD::BUILD_VECTOR)
5977 return SDValue();
5978
5979 // If the operand has multiple uses, this optimization can increase register
5980 // pressure.
5981 if (!BV.hasOneUse())
5982 return SDValue();
5983
5984 // DAGCombiner visits nodes bottom-up. Check the BUILD_VECTOR operands for
5985 // any signs they may be folded by some other pattern or rule.
5986 for (SDValue Op : BV->ops()) {
5987 // Peek through bitcasts
5988 if (Op.getOpcode() == ISD::BITCAST)
5989 Op = Op.getOperand(i: 0);
5990
5991 // This may be folded into a PRMT.
5992 if (Op.getValueType() == MVT::i16 && Op.getOpcode() == ISD::TRUNCATE &&
5993 Op->getOperand(Num: 0).getValueType() == MVT::i32)
5994 return SDValue();
5995
5996 // This may be folded into cvt.bf16x2
5997 if (Op.getOpcode() == ISD::FP_ROUND)
5998 return SDValue();
5999 }
6000 Operands.append(IL: {BV.getOperand(i: 0), BV.getOperand(i: 1)});
6001 }
6002 Operands.append(in_start: N->op_end() - Back, in_end: N->op_end());
6003
6004 // Now we replace the store
6005 return DCI.DAG.getMemIntrinsicNode(Opcode, dl: SDLoc(N), VTList: N->getVTList(), Ops: Operands,
6006 MemVT: ST->getMemoryVT(), MMO: ST->getMemOperand());
6007}
6008
6009static SDValue combineSTORE(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
6010 const NVPTXSubtarget &STI) {
6011
6012 if (DCI.isBeforeLegalize() && N->getOpcode() == ISD::STORE) {
6013 // Here is our chance to custom lower a store with a non-simple type.
6014 // Unfortunately, we can't do this in the legalizer because there is no
6015 // way to setOperationAction for an non-simple type.
6016 StoreSDNode *ST = cast<StoreSDNode>(Val: N);
6017 if (!ST->getValue().getValueType().isSimple())
6018 return lowerSTOREVector(Op: SDValue(ST, 0), DAG&: DCI.DAG, STI);
6019 }
6020
6021 return combinePackingMovIntoStore(N, DCI, Front: 1, Back: 2);
6022}
6023
6024static SDValue combineLOAD(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
6025 const NVPTXSubtarget &STI) {
6026 if (DCI.isBeforeLegalize() && N->getOpcode() == ISD::LOAD) {
6027 // Here is our chance to custom lower a load with a non-simple type.
6028 // Unfortunately, we can't do this in the legalizer because there is no
6029 // way to setOperationAction for an non-simple type.
6030 if (!N->getValueType(ResNo: 0).isSimple())
6031 return lowerLoadVector(N, DAG&: DCI.DAG, STI);
6032 }
6033
6034 return combineUnpackingMovIntoLoad(N, DCI);
6035}
6036
6037/// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
6038///
6039static SDValue PerformADDCombine(SDNode *N,
6040 TargetLowering::DAGCombinerInfo &DCI,
6041 CodeGenOptLevel OptLevel) {
6042 if (OptLevel == CodeGenOptLevel::None)
6043 return SDValue();
6044
6045 SDValue N0 = N->getOperand(Num: 0);
6046 SDValue N1 = N->getOperand(Num: 1);
6047
6048 // Skip non-integer, non-scalar case
6049 EVT VT = N0.getValueType();
6050 if (VT.isVector() || VT != MVT::i32)
6051 return SDValue();
6052
6053 // First try with the default operand order.
6054 if (SDValue Result = PerformADDCombineWithOperands(N, N0, N1, DCI))
6055 return Result;
6056
6057 // If that didn't work, try again with the operands commuted.
6058 return PerformADDCombineWithOperands(N, N0: N1, N1: N0, DCI);
6059}
6060
6061/// Check if a v2f32 BUILD_VECTOR provably packs values from non-adjacent
6062/// register pairs (non-coalescable).
6063static bool isNonCoalescableBuildVector(const SDValue &BV) {
6064 if (BV.getOpcode() != ISD::BUILD_VECTOR || BV.getValueType() != MVT::v2f32)
6065 return false;
6066
6067 SDValue Elt0 = BV.getOperand(i: 0);
6068 SDValue Elt1 = BV.getOperand(i: 1);
6069
6070 bool IsExt0 = Elt0.getOpcode() == ISD::EXTRACT_VECTOR_ELT;
6071 bool IsExt1 = Elt1.getOpcode() == ISD::EXTRACT_VECTOR_ELT;
6072
6073 // If neither element is an EXTRACT_VECTOR_ELT they are free-standing
6074 // scalars and the register allocator can still place them side-by-side.
6075 if (!IsExt0 && !IsExt1)
6076 return false;
6077
6078 // If exactly one element is an EXTRACT_VECTOR_ELT, the other is a scalar
6079 // that cannot generally occupy the adjacent register slot.
6080 if (IsExt0 != IsExt1)
6081 return true;
6082
6083 // At this point both sources are extracting from vectors. If they are from
6084 // different vectors, then the BUILD_VECTOR is non-coalescable.
6085 SDValue Src0 = Elt0.getOperand(i: 0);
6086 SDValue Src1 = Elt1.getOperand(i: 0);
6087 if (Src0 != Src1)
6088 return true;
6089
6090 auto *Idx0 = dyn_cast<ConstantSDNode>(Val: Elt0.getOperand(i: 1));
6091 auto *Idx1 = dyn_cast<ConstantSDNode>(Val: Elt1.getOperand(i: 1));
6092 // If both indices are dynamic they will be lowered to
6093 // loads and the vector will be spilled to local memory. The register
6094 // allocator can easily place the results in adjacent registers.
6095 if (!Idx0 && !Idx1)
6096 return false;
6097
6098 // If one index is dynamic and the other is constant, the value from the
6099 // constant load will result in an additional register to pair with the result
6100 // from the dynamic load. We consider this non-coalescable.
6101 if ((Idx0 && !Idx1) || (!Idx0 && Idx1))
6102 return true;
6103
6104 // Both are constant, adjacent pairs are coalescable
6105 return std::abs(i: Idx0->getSExtValue() - Idx1->getSExtValue()) != 1;
6106}
6107
6108/// Return true if FMUL v2f32 node \p N may be scalarized to fold each lane's
6109/// product into a scalar FMA.
6110bool NVPTXTargetLowering::mayFoldFMULIntoFMA(SDNode *N, MachineFunction &MF,
6111 CodeGenOptLevel OptLevel) const {
6112 if (N->getOpcode() != ISD::FMUL || N->getValueType(ResNo: 0) != MVT::v2f32)
6113 return false;
6114 const bool GlobalFMA = allowFMA(MF, OptLevel);
6115 if (!N->getFlags().hasAllowContract() && !GlobalFMA)
6116 return false;
6117
6118 const SDNode *FirstFAdd = nullptr;
6119 unsigned NumScalarFAdd = 0;
6120
6121 // Both lanes must feed unique FADDs
6122 for (SDNode *EE : N->users()) {
6123 if (NumScalarFAdd == 2)
6124 return false;
6125
6126 if (EE->getOpcode() != ISD::EXTRACT_VECTOR_ELT || !EE->hasOneUse() ||
6127 !isa<ConstantSDNode>(Val: EE->getOperand(Num: 1)))
6128 return false;
6129
6130 const SDNode *const FAdd = *EE->users().begin();
6131 if (FAdd->getOpcode() != ISD::FADD ||
6132 (!GlobalFMA && !FAdd->getFlags().hasAllowContract()))
6133 return false;
6134
6135 if (!FirstFAdd)
6136 FirstFAdd = FAdd;
6137 else if (FAdd == FirstFAdd)
6138 return false;
6139
6140 NumScalarFAdd++;
6141 }
6142
6143 return NumScalarFAdd == 2;
6144}
6145
6146/// Scalarize a v2f32 arithmetic node (FADD, FMUL, FSUB, FMA) when at least
6147/// one operand is a BUILD_VECTOR that repacks values from non-adjacent register
6148/// pairs. Without this combine the BUILD_VECTOR forces allocation of a
6149/// temporary 64-bit register, increasing register pressure.
6150///
6151/// Example - before:
6152/// t0: v2f32,v2f32,ch = LoadV2 ...
6153/// t1: f32 = extract_vector_elt t0, 0
6154/// t2: f32 = extract_vector_elt t0:1, 0
6155/// t3: v2f32 = BUILD_VECTOR t1, t2 ;; non-coalescable repack
6156/// t4: v2f32 = fma t_a, t3, t_c
6157///
6158/// After:
6159/// t0: v2f32,v2f32,ch = LoadV2 ...
6160/// t1: f32 = extract_vector_elt t0, 0
6161/// t2: f32 = extract_vector_elt t0:1, 0
6162/// a0: f32 = extract_vector_elt t_a, 0
6163/// a1: f32 = extract_vector_elt t_a, 1
6164/// c0: f32 = extract_vector_elt t_c, 0
6165/// c1: f32 = extract_vector_elt t_c, 1
6166/// r0: f32 = fma a0, t1, c0
6167/// r1: f32 = fma a1, t2, c1
6168/// t4: v2f32 = BUILD_VECTOR r0, r1
6169///
6170/// Also scalarizes an FMUL when all output lanes feed into scalar FADDs
6171/// to enable scalar FMA combining.
6172SDValue NVPTXTargetLowering::performScalarizeV2F32Op(
6173 SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
6174 CodeGenOptLevel OptLevel) const {
6175 EVT VT = N->getValueType(ResNo: 0);
6176 if (VT != MVT::v2f32)
6177 return SDValue();
6178
6179 if (none_of(Range: N->ops(), P: isNonCoalescableBuildVector) &&
6180 !mayFoldFMULIntoFMA(N, MF&: DCI.DAG.getMachineFunction(), OptLevel))
6181 return SDValue();
6182
6183 SelectionDAG &DAG = DCI.DAG;
6184 SDLoc DL(N);
6185 EVT EltVT = VT.getVectorElementType();
6186 unsigned Opc = N->getOpcode();
6187
6188 // For each operand, get the scalar element at the given index: if the operand
6189 // is a BUILD_VECTOR, grab the element directly; otherwise, emit an
6190 // EXTRACT_VECTOR_ELT.
6191 auto GetElement = [&](SDValue Op, unsigned Index) -> SDValue {
6192 if (Op.getOpcode() == ISD::BUILD_VECTOR)
6193 return Op.getOperand(i: Index);
6194 return DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: Op,
6195 N2: DAG.getVectorIdxConstant(Val: Index, DL));
6196 };
6197
6198 // Build scalar operand lists for element 0 and element 1.
6199 SmallVector<SDValue, 3> Ops0, Ops1;
6200 for (const SDValue &Op : N->ops()) {
6201 Ops0.push_back(Elt: GetElement(Op, 0));
6202 Ops1.push_back(Elt: GetElement(Op, 1));
6203 }
6204
6205 SDValue Res0 = DAG.getNode(Opcode: Opc, DL, VT: EltVT, Ops: Ops0, Flags: N->getFlags());
6206 SDValue Res1 = DAG.getNode(Opcode: Opc, DL, VT: EltVT, Ops: Ops1, Flags: N->getFlags());
6207
6208 return DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT, N1: Res0, N2: Res1);
6209}
6210
6211/// Target-specific dag combine xforms for ISD::FADD.
6212SDValue
6213NVPTXTargetLowering::performFADDCombine(SDNode *N,
6214 TargetLowering::DAGCombinerInfo &DCI,
6215 CodeGenOptLevel OptLevel) const {
6216 if (SDValue Result = performScalarizeV2F32Op(N, DCI, OptLevel))
6217 return Result;
6218
6219 SDValue N0 = N->getOperand(Num: 0);
6220 SDValue N1 = N->getOperand(Num: 1);
6221
6222 EVT VT = N0.getValueType();
6223 if (VT.isVector() || !(VT == MVT::f32 || VT == MVT::f64))
6224 return SDValue();
6225
6226 // First try with the default operand order.
6227 if (SDValue Result = performFADDCombineWithOperands(N, N0, N1, DCI, OptLevel))
6228 return Result;
6229
6230 // If that didn't work, try again with the operands commuted.
6231 return performFADDCombineWithOperands(N, N0: N1, N1: N0, DCI, OptLevel);
6232}
6233
6234/// Get 3-input version of a 2-input min/max opcode
6235static unsigned getMinMax3Opcode(unsigned MinMax2Opcode) {
6236 switch (MinMax2Opcode) {
6237 case ISD::FMAXNUM:
6238 case ISD::FMAXIMUMNUM:
6239 return NVPTXISD::FMAXNUM3;
6240 case ISD::FMINNUM:
6241 case ISD::FMINIMUMNUM:
6242 return NVPTXISD::FMINNUM3;
6243 case ISD::FMAXIMUM:
6244 return NVPTXISD::FMAXIMUM3;
6245 case ISD::FMINIMUM:
6246 return NVPTXISD::FMINIMUM3;
6247 default:
6248 llvm_unreachable("Invalid 2-input min/max opcode");
6249 }
6250}
6251
6252/// PerformFMinMaxCombine - Combine (fmaxnum (fmaxnum a, b), c) into
6253/// (fmaxnum3 a, b, c). Also covers other llvm min/max intrinsics.
6254static SDValue PerformFMinMaxCombine(SDNode *N,
6255 TargetLowering::DAGCombinerInfo &DCI,
6256 unsigned PTXVersion, unsigned SmVersion) {
6257
6258 // 3-input min/max requires PTX 8.8+ and SM_100+, and only supports f32s
6259 EVT VT = N->getValueType(ResNo: 0);
6260 if (VT != MVT::f32 || PTXVersion < 88 || SmVersion < 100)
6261 return SDValue();
6262
6263 SDValue Op0 = N->getOperand(Num: 0);
6264 SDValue Op1 = N->getOperand(Num: 1);
6265 unsigned MinMaxOp2 = N->getOpcode();
6266 unsigned MinMaxOp3 = getMinMax3Opcode(MinMax2Opcode: MinMaxOp2);
6267
6268 if (Op0.getOpcode() == MinMaxOp2 && Op0.hasOneUse()) {
6269 // (maxnum (maxnum a, b), c) -> (maxnum3 a, b, c)
6270 SDValue A = Op0.getOperand(i: 0);
6271 SDValue B = Op0.getOperand(i: 1);
6272 SDValue C = Op1;
6273 return DCI.DAG.getNode(Opcode: MinMaxOp3, DL: SDLoc(N), VT, N1: A, N2: B, N3: C, Flags: N->getFlags());
6274 } else if (Op1.getOpcode() == MinMaxOp2 && Op1.hasOneUse()) {
6275 // (maxnum a, (maxnum b, c)) -> (maxnum3 a, b, c)
6276 SDValue A = Op0;
6277 SDValue B = Op1.getOperand(i: 0);
6278 SDValue C = Op1.getOperand(i: 1);
6279 return DCI.DAG.getNode(Opcode: MinMaxOp3, DL: SDLoc(N), VT, N1: A, N2: B, N3: C, Flags: N->getFlags());
6280 }
6281 return SDValue();
6282}
6283
6284static SDValue PerformREMCombine(SDNode *N,
6285 TargetLowering::DAGCombinerInfo &DCI,
6286 CodeGenOptLevel OptLevel) {
6287 assert(N->getOpcode() == ISD::SREM || N->getOpcode() == ISD::UREM);
6288
6289 // Don't do anything at less than -O2.
6290 if (OptLevel < CodeGenOptLevel::Default)
6291 return SDValue();
6292
6293 SelectionDAG &DAG = DCI.DAG;
6294 SDLoc DL(N);
6295 EVT VT = N->getValueType(ResNo: 0);
6296 bool IsSigned = N->getOpcode() == ISD::SREM;
6297 unsigned DivOpc = IsSigned ? ISD::SDIV : ISD::UDIV;
6298
6299 const SDValue &Num = N->getOperand(Num: 0);
6300 const SDValue &Den = N->getOperand(Num: 1);
6301
6302 for (const SDNode *U : Num->users()) {
6303 if (U->getOpcode() == DivOpc && U->getOperand(Num: 0) == Num &&
6304 U->getOperand(Num: 1) == Den) {
6305 // Num % Den -> Num - (Num / Den) * Den
6306 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Num,
6307 N2: DAG.getNode(Opcode: ISD::MUL, DL, VT,
6308 N1: DAG.getNode(Opcode: DivOpc, DL, VT, N1: Num, N2: Den),
6309 N2: Den));
6310 }
6311 }
6312 return SDValue();
6313}
6314
6315// sext (mul.iN nsw x, y) => mul.wide.sN x, y
6316// zext (mul.iN nuw x, y) => mul.wide.uN x, y
6317// sext (shl.iN nsw x, const) => mul.wide.sN x, (1 << const)
6318// zext (shl.iN nuw x, const) => mul.wide.uN x, (1 << const)
6319static SDValue combineSZExtToMulWide(SDNode *N,
6320 TargetLowering::DAGCombinerInfo &DCI,
6321 CodeGenOptLevel OptLevel) {
6322 assert(N->getOpcode() == ISD::SIGN_EXTEND ||
6323 N->getOpcode() == ISD::ZERO_EXTEND);
6324
6325 if (OptLevel == CodeGenOptLevel::None)
6326 return SDValue();
6327
6328 SDValue Op = N->getOperand(Num: 0);
6329 if (!Op.hasOneUse())
6330 return SDValue();
6331
6332 EVT ToVT = N->getValueType(ResNo: 0);
6333 EVT FromVT = Op.getValueType();
6334 if (!((ToVT == MVT::i32 && FromVT == MVT::i16) ||
6335 (ToVT == MVT::i64 && FromVT == MVT::i32)))
6336 return SDValue();
6337
6338 bool IsSigned = N->getOpcode() == ISD::SIGN_EXTEND;
6339 if ((IsSigned && !Op->getFlags().hasNoSignedWrap()) ||
6340 (!IsSigned && !Op->getFlags().hasNoUnsignedWrap()))
6341 return SDValue();
6342
6343 SDLoc DL(N);
6344 SDValue LHS = Op.getOperand(i: 0);
6345 SDValue RHS = Op.getOperand(i: 1);
6346 unsigned MulWideOpcode =
6347 IsSigned ? NVPTXISD::MUL_WIDE_SIGNED : NVPTXISD::MUL_WIDE_UNSIGNED;
6348 if (Op.getOpcode() == ISD::MUL) {
6349 return DCI.DAG.getNode(Opcode: MulWideOpcode, DL, VT: ToVT, N1: LHS, N2: RHS);
6350 } else if (Op.getOpcode() == ISD::SHL && isa<ConstantSDNode>(Val: RHS)) {
6351 const auto ShiftAmt = Op.getConstantOperandVal(i: 1);
6352 const auto MulVal = APInt(FromVT.getSizeInBits(), 1) << ShiftAmt;
6353
6354 // Note that the sext (shl nsw ...) case doesn't work if 1 << const
6355 // overflows to a negative value! The only valid input values in this
6356 // case are 0 and -1 (all other values yield poison because of the nsw),
6357 // and mul.wide.sN would give us the wrong sign for -1. We could use
6358 // mul.wide.uN, but since this is a weird case anyway, we might as well not
6359 // apply this transformation at all.
6360 if (IsSigned && MulVal.isNegative())
6361 return SDValue();
6362
6363 RHS = DCI.DAG.getConstant(Val: MulVal, DL, VT: FromVT);
6364 return DCI.DAG.getNode(Opcode: MulWideOpcode, DL, VT: ToVT, N1: LHS, N2: RHS);
6365 }
6366
6367 return SDValue();
6368}
6369
6370enum OperandSignedness {
6371 Signed = 0,
6372 Unsigned,
6373 Unknown
6374};
6375
6376/// IsMulWideOperandDemotable - Checks if the provided DAG node is an operand
6377/// that can be demoted to \p OptSize bits without loss of information. The
6378/// signedness of the operand, if determinable, is placed in \p S.
6379static bool IsMulWideOperandDemotable(SDValue Op,
6380 unsigned OptSize,
6381 OperandSignedness &S) {
6382 S = Unknown;
6383
6384 if (Op.getOpcode() == ISD::SIGN_EXTEND ||
6385 Op.getOpcode() == ISD::SIGN_EXTEND_INREG) {
6386 EVT OrigVT = Op.getOperand(i: 0).getValueType();
6387 if (OrigVT.getFixedSizeInBits() <= OptSize) {
6388 S = Signed;
6389 return true;
6390 }
6391 } else if (Op.getOpcode() == ISD::ZERO_EXTEND) {
6392 EVT OrigVT = Op.getOperand(i: 0).getValueType();
6393 if (OrigVT.getFixedSizeInBits() <= OptSize) {
6394 S = Unsigned;
6395 return true;
6396 }
6397 }
6398
6399 return false;
6400}
6401
6402/// AreMulWideOperandsDemotable - Checks if the given LHS and RHS operands can
6403/// be demoted to \p OptSize bits without loss of information. If the operands
6404/// contain a constant, it should appear as the RHS operand. The signedness of
6405/// the operands is placed in \p IsSigned.
6406static bool AreMulWideOperandsDemotable(SDValue LHS, SDValue RHS,
6407 unsigned OptSize,
6408 bool &IsSigned) {
6409 OperandSignedness LHSSign;
6410
6411 // The LHS operand must be a demotable op
6412 if (!IsMulWideOperandDemotable(Op: LHS, OptSize, S&: LHSSign))
6413 return false;
6414
6415 // We should have been able to determine the signedness from the LHS
6416 if (LHSSign == Unknown)
6417 return false;
6418
6419 IsSigned = (LHSSign == Signed);
6420
6421 // The RHS can be a demotable op or a constant
6422 if (ConstantSDNode *CI = dyn_cast<ConstantSDNode>(Val&: RHS)) {
6423 const APInt &Val = CI->getAPIntValue();
6424 if (LHSSign == Unsigned) {
6425 return Val.isIntN(N: OptSize);
6426 } else {
6427 return Val.isSignedIntN(N: OptSize);
6428 }
6429 } else {
6430 OperandSignedness RHSSign;
6431 if (!IsMulWideOperandDemotable(Op: RHS, OptSize, S&: RHSSign))
6432 return false;
6433
6434 return LHSSign == RHSSign;
6435 }
6436}
6437
6438/// TryMULWIDECombine - Attempt to replace a multiply of M bits with a multiply
6439/// of M/2 bits that produces an M-bit result (i.e. mul.wide). This transform
6440/// works on both multiply DAG nodes and SHL DAG nodes with a constant shift
6441/// amount.
6442static SDValue TryMULWIDECombine(SDNode *N,
6443 TargetLowering::DAGCombinerInfo &DCI) {
6444 EVT MulType = N->getValueType(ResNo: 0);
6445 if (MulType != MVT::i32 && MulType != MVT::i64) {
6446 return SDValue();
6447 }
6448
6449 SDLoc DL(N);
6450 unsigned OptSize = MulType.getSizeInBits() >> 1;
6451 SDValue LHS = N->getOperand(Num: 0);
6452 SDValue RHS = N->getOperand(Num: 1);
6453
6454 // Canonicalize the multiply so the constant (if any) is on the right
6455 if (N->getOpcode() == ISD::MUL) {
6456 if (isa<ConstantSDNode>(Val: LHS)) {
6457 std::swap(a&: LHS, b&: RHS);
6458 }
6459 }
6460
6461 // If we have a SHL, determine the actual multiply amount
6462 if (N->getOpcode() == ISD::SHL) {
6463 ConstantSDNode *ShlRHS = dyn_cast<ConstantSDNode>(Val&: RHS);
6464 if (!ShlRHS) {
6465 return SDValue();
6466 }
6467
6468 APInt ShiftAmt = ShlRHS->getAPIntValue();
6469 unsigned BitWidth = MulType.getSizeInBits();
6470 if (ShiftAmt.sge(RHS: 0) && ShiftAmt.slt(RHS: BitWidth)) {
6471 APInt MulVal = APInt(BitWidth, 1) << ShiftAmt;
6472 RHS = DCI.DAG.getConstant(Val: MulVal, DL, VT: MulType);
6473 } else {
6474 return SDValue();
6475 }
6476 }
6477
6478 bool Signed;
6479 // Verify that our operands are demotable
6480 if (!AreMulWideOperandsDemotable(LHS, RHS, OptSize, IsSigned&: Signed)) {
6481 return SDValue();
6482 }
6483
6484 EVT DemotedVT;
6485 if (MulType == MVT::i32) {
6486 DemotedVT = MVT::i16;
6487 } else {
6488 DemotedVT = MVT::i32;
6489 }
6490
6491 // Truncate the operands to the correct size. Note that these are just for
6492 // type consistency and will (likely) be eliminated in later phases.
6493 SDValue TruncLHS =
6494 DCI.DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: DemotedVT, Operand: LHS);
6495 SDValue TruncRHS =
6496 DCI.DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: DemotedVT, Operand: RHS);
6497
6498 unsigned Opc;
6499 if (Signed) {
6500 Opc = NVPTXISD::MUL_WIDE_SIGNED;
6501 } else {
6502 Opc = NVPTXISD::MUL_WIDE_UNSIGNED;
6503 }
6504
6505 return DCI.DAG.getNode(Opcode: Opc, DL, VT: MulType, N1: TruncLHS, N2: TruncRHS);
6506}
6507
6508static bool isConstOne(const SDValue &Operand) {
6509 const auto *Const = dyn_cast<ConstantSDNode>(Val: Operand);
6510 return Const && Const->getZExtValue() == 1;
6511}
6512
6513static SDValue matchMADConstOnePattern(SDValue Add) {
6514 if (Add->getOpcode() != ISD::ADD)
6515 return SDValue();
6516
6517 if (isConstOne(Operand: Add->getOperand(Num: 0)))
6518 return Add->getOperand(Num: 1);
6519
6520 if (isConstOne(Operand: Add->getOperand(Num: 1)))
6521 return Add->getOperand(Num: 0);
6522
6523 return SDValue();
6524}
6525
6526static SDValue combineMADConstOne(SDValue X, SDValue Add, EVT VT, SDLoc DL,
6527 TargetLowering::DAGCombinerInfo &DCI) {
6528
6529 if (SDValue Y = matchMADConstOnePattern(Add)) {
6530 SDValue Mul = DCI.DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: X, N2: Y);
6531 return DCI.DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Mul, N2: X);
6532 }
6533
6534 return SDValue();
6535}
6536
6537static SDValue combineMulSelectConstOne(SDValue X, SDValue Select, EVT VT,
6538 SDLoc DL,
6539 TargetLowering::DAGCombinerInfo &DCI) {
6540 if (Select->getOpcode() != ISD::SELECT)
6541 return SDValue();
6542
6543 SDValue Cond = Select->getOperand(Num: 0);
6544
6545 unsigned ConstOpNo;
6546 if (isConstOne(Operand: Select->getOperand(Num: 1)))
6547 ConstOpNo = 1;
6548 else if (isConstOne(Operand: Select->getOperand(Num: 2)))
6549 ConstOpNo = 2;
6550 else
6551 return SDValue();
6552
6553 SDValue Y = Select->getOperand(Num: (ConstOpNo == 1) ? 2 : 1);
6554
6555 // Do not combine if the resulting sequence is not obviously profitable.
6556 if (!matchMADConstOnePattern(Add: Y))
6557 return SDValue();
6558
6559 SDValue NewMul = DCI.DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: X, N2: Y);
6560
6561 return DCI.DAG.getNode(Opcode: ISD::SELECT, DL, VT, N1: Cond,
6562 N2: (ConstOpNo == 1) ? X : NewMul,
6563 N3: (ConstOpNo == 1) ? NewMul : X);
6564}
6565
6566static SDValue
6567PerformMULCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
6568 TargetLowering::DAGCombinerInfo &DCI) {
6569
6570 EVT VT = N0.getValueType();
6571 if (VT.isVector())
6572 return SDValue();
6573
6574 if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
6575 return SDValue();
6576
6577 SDLoc DL(N);
6578
6579 // (mul x, (add y, 1)) -> (add (mul x, y), x)
6580 if (SDValue Res = combineMADConstOne(X: N0, Add: N1, VT, DL, DCI))
6581 return Res;
6582 if (SDValue Res = combineMADConstOne(X: N1, Add: N0, VT, DL, DCI))
6583 return Res;
6584
6585 // (mul x, (select y, 1)) -> (select (mul x, y), x)
6586 if (SDValue Res = combineMulSelectConstOne(X: N0, Select: N1, VT, DL, DCI))
6587 return Res;
6588 if (SDValue Res = combineMulSelectConstOne(X: N1, Select: N0, VT, DL, DCI))
6589 return Res;
6590
6591 return SDValue();
6592}
6593
6594/// PerformMULCombine - Runs PTX-specific DAG combine patterns on MUL nodes.
6595static SDValue PerformMULCombine(SDNode *N,
6596 TargetLowering::DAGCombinerInfo &DCI,
6597 CodeGenOptLevel OptLevel) {
6598 if (OptLevel == CodeGenOptLevel::None)
6599 return SDValue();
6600
6601 if (SDValue Ret = TryMULWIDECombine(N, DCI))
6602 return Ret;
6603
6604 SDValue N0 = N->getOperand(Num: 0);
6605 SDValue N1 = N->getOperand(Num: 1);
6606 return PerformMULCombineWithOperands(N, N0, N1, DCI);
6607}
6608
6609/// PerformSHLCombine - Runs PTX-specific DAG combine patterns on SHL nodes.
6610static SDValue PerformSHLCombine(SDNode *N,
6611 TargetLowering::DAGCombinerInfo &DCI,
6612 CodeGenOptLevel OptLevel) {
6613 if (OptLevel > CodeGenOptLevel::None) {
6614 // Try mul.wide combining at OptLevel > 0
6615 if (SDValue Ret = TryMULWIDECombine(N, DCI))
6616 return Ret;
6617 }
6618
6619 return SDValue();
6620}
6621
6622static SDValue PerformSETCCCombine(SDNode *N,
6623 TargetLowering::DAGCombinerInfo &DCI,
6624 unsigned int SmVersion) {
6625 EVT CCType = N->getValueType(ResNo: 0);
6626 SDValue A = N->getOperand(Num: 0);
6627 SDValue B = N->getOperand(Num: 1);
6628
6629 EVT AType = A.getValueType();
6630 if (!(CCType == MVT::v2i1 && (AType == MVT::v2f16 || AType == MVT::v2bf16)))
6631 return SDValue();
6632
6633 if (A.getValueType() == MVT::v2bf16 && SmVersion < 90)
6634 return SDValue();
6635
6636 SDLoc DL(N);
6637 // setp.f16x2 returns two scalar predicates, which we need to
6638 // convert back to v2i1. The returned result will be scalarized by
6639 // the legalizer, but the comparison will remain a single vector
6640 // instruction.
6641 SDValue CCNode = DCI.DAG.getNode(
6642 Opcode: A.getValueType() == MVT::v2f16 ? NVPTXISD::SETP_F16X2
6643 : NVPTXISD::SETP_BF16X2,
6644 DL, VTList: DCI.DAG.getVTList(VT1: MVT::i1, VT2: MVT::i1), Ops: {A, B, N->getOperand(Num: 2)});
6645 return DCI.DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: CCType, N1: CCNode.getValue(R: 0),
6646 N2: CCNode.getValue(R: 1));
6647}
6648
6649static SDValue PerformEXTRACTCombine(SDNode *N,
6650 TargetLowering::DAGCombinerInfo &DCI) {
6651 SDValue Vector = peekThroughFreeze(V: N->getOperand(Num: 0));
6652 SDLoc DL(N);
6653 EVT VectorVT = Vector.getValueType();
6654 if (Vector->getOpcode() == ISD::LOAD && VectorVT.isSimple() &&
6655 IsPTXVectorType(VT: VectorVT.getSimpleVT()))
6656 return SDValue(); // Native vector loads already combine nicely w/
6657 // extract_vector_elt.
6658 // Don't mess with singletons or packed types (v2*32, v2*16, v4i8 and v8i8),
6659 // we already handle them OK.
6660 if (VectorVT.getVectorNumElements() == 1 ||
6661 NVPTX::isPackedVectorTy(VT: VectorVT) || VectorVT == MVT::v8i8)
6662 return SDValue();
6663
6664 // Don't mess with undef values as sra may be simplified to 0, not undef.
6665 if (Vector->isUndef() || ISD::allOperandsUndef(N: Vector.getNode()))
6666 return SDValue();
6667
6668 uint64_t VectorBits = VectorVT.getSizeInBits();
6669 // We only handle the types we can extract in-register.
6670 if (!(VectorBits == 16 || VectorBits == 32 || VectorBits == 64))
6671 return SDValue();
6672
6673 ConstantSDNode *Index = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
6674 // Index == 0 is handled by generic DAG combiner.
6675 if (!Index || Index->getZExtValue() == 0)
6676 return SDValue();
6677
6678 MVT IVT = MVT::getIntegerVT(BitWidth: VectorBits);
6679 EVT EltVT = VectorVT.getVectorElementType();
6680 EVT EltIVT = EltVT.changeTypeToInteger();
6681 uint64_t EltBits = EltVT.getScalarSizeInBits();
6682
6683 SDValue Result = DCI.DAG.getNode(
6684 Opcode: ISD::TRUNCATE, DL, VT: EltIVT,
6685 Operand: DCI.DAG.getNode(
6686 Opcode: ISD::SRA, DL, VT: IVT, N1: DCI.DAG.getNode(Opcode: ISD::BITCAST, DL, VT: IVT, Operand: Vector),
6687 N2: DCI.DAG.getConstant(Val: Index->getZExtValue() * EltBits, DL, VT: IVT)));
6688
6689 // If element has non-integer type, bitcast it back to the expected type.
6690 if (EltVT != EltIVT)
6691 Result = DCI.DAG.getNode(Opcode: ISD::BITCAST, DL, VT: EltVT, Operand: Result);
6692 // Past legalizer, we may need to extent i8 -> i16 to match the register type.
6693 if (EltVT != N->getValueType(ResNo: 0))
6694 Result = DCI.DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT: N->getValueType(ResNo: 0), Operand: Result);
6695
6696 return Result;
6697}
6698
6699/// Transform patterns like:
6700/// (select (ugt shift_amt, BitWidth-1), 0, (srl/shl x, shift_amt))
6701/// (select (ult shift_amt, BitWidth), (srl/shl x, shift_amt), 0)
6702/// Into:
6703/// (NVPTXISD::SRL_CLAMP x, shift_amt) or (NVPTXISD::SHL_CLAMP x, shift_amt)
6704///
6705/// These patterns arise from code like `s >= 32 ? 0 : x >> s`. In LLVM,
6706/// over-shifting a value results in poison, but PTX shr/shl instructions clamp
6707/// the shift amount to BitWidth, making the guard redundant.
6708///
6709/// Note: We only handle SRL and SHL, not SRA, because arithmetic right shifts
6710/// can produce 0 or -1 when shift >= BitWidth.
6711/// Note: We don't handle uge or ule. These don't appear because of
6712/// canonicalization.
6713static SDValue PerformSELECTShiftCombine(SDNode *N,
6714 TargetLowering::DAGCombinerInfo &DCI) {
6715 if (!DCI.isAfterLegalizeDAG())
6716 return SDValue();
6717
6718 using namespace SDPatternMatch;
6719 unsigned BitWidth = N->getValueType(ResNo: 0).getSizeInBits();
6720 SDValue ShiftAmt, ShiftOp;
6721
6722 // Match logical shifts where the shift amount in the guard matches the shift
6723 // amount in the operation.
6724 auto LogicalShift =
6725 m_AllOf(preds: m_Value(N&: ShiftOp),
6726 preds: m_AnyOf(preds: m_Srl(L: m_Value(), R: m_TruncOrSelf(Op: m_Deferred(V&: ShiftAmt))),
6727 preds: m_Shl(L: m_Value(), R: m_TruncOrSelf(Op: m_Deferred(V&: ShiftAmt)))));
6728
6729 // shift_amt > BitWidth-1 ? 0 : shift_op
6730 bool MatchedUGT =
6731 sd_match(N, P: m_Select(Cond: m_SetCC(LHS: m_Value(N&: ShiftAmt),
6732 RHS: m_SpecificInt(V: APInt(BitWidth, BitWidth - 1)),
6733 CC: m_SpecificCondCode(CC: ISD::SETUGT)),
6734 T: m_Zero(), F: LogicalShift));
6735 // shift_amt < BitWidth ? shift_op : 0
6736 bool MatchedULT =
6737 !MatchedUGT &&
6738 sd_match(N, P: m_Select(Cond: m_SetCC(LHS: m_Value(N&: ShiftAmt),
6739 RHS: m_SpecificInt(V: APInt(BitWidth, BitWidth)),
6740 CC: m_SpecificCondCode(CC: ISD::SETULT)),
6741 T: LogicalShift, F: m_Zero()));
6742
6743 if (!MatchedUGT && !MatchedULT)
6744 return SDValue();
6745
6746 // In LLVM IR, the shift amount and the value-to-be-shifted are the same
6747 // type, whereas in PTX the shift amount is always i32. Therefore when
6748 // shifting types larger than i32, we can only do this transformation if we
6749 // know that the upper bits of the shift amount are known zero.
6750 SDValue ClampAmt = ShiftOp.getOperand(i: 1);
6751 unsigned ClampAmtBits = ClampAmt.getValueSizeInBits();
6752 if (ShiftAmt.getValueSizeInBits() > ClampAmtBits &&
6753 DCI.DAG.computeKnownBits(Op: ShiftAmt).countMaxActiveBits() > ClampAmtBits)
6754 return SDValue();
6755
6756 // Return a clamp shift operation, which has the same semantics as PTX shift.
6757 unsigned ClampOpc = ShiftOp.getOpcode() == ISD::SRL ? NVPTXISD::SRL_CLAMP
6758 : NVPTXISD::SHL_CLAMP;
6759 return DCI.DAG.getNode(Opcode: ClampOpc, DL: SDLoc(N), VT: ShiftOp.getValueType(),
6760 N1: ShiftOp.getOperand(i: 0), N2: ClampAmt);
6761}
6762
6763static SDValue PerformVSELECTCombine(SDNode *N,
6764 TargetLowering::DAGCombinerInfo &DCI) {
6765 SDValue VA = N->getOperand(Num: 1);
6766 EVT VectorVT = VA.getValueType();
6767 if (VectorVT != MVT::v4i8)
6768 return SDValue();
6769
6770 // We need to split vselect into individual per-element operations Because we
6771 // use BFE/BFI instruction for byte extraction/insertion, we do end up with
6772 // 32-bit values, so we may as well do comparison as i32 to avoid conversions
6773 // to/from i16 normally used for i8 values.
6774 SmallVector<SDValue, 4> E;
6775 SDLoc DL(N);
6776 SDValue VCond = N->getOperand(Num: 0);
6777 SDValue VB = N->getOperand(Num: 2);
6778 for (int I = 0; I < 4; ++I) {
6779 SDValue C = DCI.DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i1, N1: VCond,
6780 N2: DCI.DAG.getConstant(Val: I, DL, VT: MVT::i32));
6781 SDValue EA = DCI.DAG.getAnyExtOrTrunc(
6782 Op: DCI.DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i8, N1: VA,
6783 N2: DCI.DAG.getConstant(Val: I, DL, VT: MVT::i32)),
6784 DL, VT: MVT::i32);
6785 SDValue EB = DCI.DAG.getAnyExtOrTrunc(
6786 Op: DCI.DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i8, N1: VB,
6787 N2: DCI.DAG.getConstant(Val: I, DL, VT: MVT::i32)),
6788 DL, VT: MVT::i32);
6789 E.push_back(Elt: DCI.DAG.getAnyExtOrTrunc(
6790 Op: DCI.DAG.getNode(Opcode: ISD::SELECT, DL, VT: MVT::i32, N1: C, N2: EA, N3: EB), DL, VT: MVT::i8));
6791 }
6792 return DCI.DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: MVT::v4i8, Ops: E);
6793}
6794
6795static SDValue
6796PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
6797 auto VT = N->getValueType(ResNo: 0);
6798 if (!DCI.isAfterLegalizeDAG() ||
6799 // only process v2*16 types
6800 !(NVPTX::isPackedVectorTy(VT) && VT.is32BitVector() &&
6801 VT.getVectorNumElements() == 2))
6802 return SDValue();
6803
6804 auto Op0 = N->getOperand(Num: 0);
6805 auto Op1 = N->getOperand(Num: 1);
6806
6807 // Start out by assuming we want to take the lower 2 bytes of each i32
6808 // operand.
6809 uint64_t Op0Bytes = 0x10;
6810 uint64_t Op1Bytes = 0x54;
6811
6812 std::pair<SDValue *, uint64_t *> OpData[2] = {{&Op0, &Op0Bytes},
6813 {&Op1, &Op1Bytes}};
6814
6815 // Check that each operand is an i16, truncated from an i32 operand. We'll
6816 // select individual bytes from those original operands. Optionally, fold in a
6817 // shift right of that original operand.
6818 for (auto &[Op, OpBytes] : OpData) {
6819 // Eat up any bitcast
6820 if (Op->getOpcode() == ISD::BITCAST)
6821 *Op = Op->getOperand(i: 0);
6822
6823 if (!(Op->getValueType() == MVT::i16 && Op->getOpcode() == ISD::TRUNCATE &&
6824 Op->getOperand(i: 0).getValueType() == MVT::i32))
6825 return SDValue();
6826
6827 // If the truncate has multiple uses, this optimization can increase
6828 // register pressure
6829 if (!Op->hasOneUse())
6830 return SDValue();
6831
6832 *Op = Op->getOperand(i: 0);
6833
6834 // Optionally, fold in a shift-right of the original operand and let permute
6835 // pick the two higher bytes of the original value directly.
6836 if (Op->getOpcode() == ISD::SRL && isa<ConstantSDNode>(Val: Op->getOperand(i: 1))) {
6837 if (cast<ConstantSDNode>(Val: Op->getOperand(i: 1))->getZExtValue() == 16) {
6838 // Shift the PRMT byte selector to pick upper bytes from each respective
6839 // value, instead of the lower ones: 0x10 -> 0x32, 0x54 -> 0x76
6840 assert((*OpBytes == 0x10 || *OpBytes == 0x54) &&
6841 "PRMT selector values out of range");
6842 *OpBytes += 0x22;
6843 *Op = Op->getOperand(i: 0);
6844 }
6845 }
6846 }
6847
6848 SDLoc DL(N);
6849 auto &DAG = DCI.DAG;
6850
6851 auto PRMT =
6852 getPRMT(A: DAG.getBitcast(VT: MVT::i32, V: Op0), B: DAG.getBitcast(VT: MVT::i32, V: Op1),
6853 Selector: (Op1Bytes << 8) | Op0Bytes, DL, DAG);
6854 return DAG.getBitcast(VT, V: PRMT);
6855}
6856
6857static SDValue combineADDRSPACECAST(SDNode *N,
6858 TargetLowering::DAGCombinerInfo &DCI) {
6859 auto *ASCN1 = cast<AddrSpaceCastSDNode>(Val: N);
6860
6861 if (auto *ASCN2 = dyn_cast<AddrSpaceCastSDNode>(Val: ASCN1->getOperand(Num: 0))) {
6862 assert(ASCN2->getDestAddressSpace() == ASCN1->getSrcAddressSpace());
6863
6864 // Fold asc[B -> A](asc[A -> B](x)) -> x
6865 if (ASCN1->getDestAddressSpace() == ASCN2->getSrcAddressSpace())
6866 return ASCN2->getOperand(Num: 0);
6867 }
6868
6869 return SDValue();
6870}
6871
6872// Given a constant selector value and a prmt mode, return the selector value
6873// normalized to the generic prmt mode. See the PTX ISA documentation for more
6874// details:
6875// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt
6876static APInt getPRMTSelector(const APInt &Selector, unsigned Mode) {
6877 assert(Selector.getBitWidth() == 32 && "PRMT must have i32 operands");
6878
6879 if (Mode == NVPTX::PTXPrmtMode::NONE)
6880 return Selector;
6881
6882 const unsigned V = Selector.trunc(width: 2).getZExtValue();
6883
6884 const auto GetSelector = [](unsigned S0, unsigned S1, unsigned S2,
6885 unsigned S3) {
6886 return APInt(32, S0 | (S1 << 4) | (S2 << 8) | (S3 << 12));
6887 };
6888
6889 switch (Mode) {
6890 case NVPTX::PTXPrmtMode::F4E:
6891 return GetSelector(V, V + 1, V + 2, V + 3);
6892 case NVPTX::PTXPrmtMode::B4E:
6893 return GetSelector(V, (V - 1) & 7, (V - 2) & 7, (V - 3) & 7);
6894 case NVPTX::PTXPrmtMode::RC8:
6895 return GetSelector(V, V, V, V);
6896 case NVPTX::PTXPrmtMode::ECL:
6897 return GetSelector(V, std::max(a: V, b: 1U), std::max(a: V, b: 2U), 3U);
6898 case NVPTX::PTXPrmtMode::ECR:
6899 return GetSelector(0, std::min(a: V, b: 1U), std::min(a: V, b: 2U), V);
6900 case NVPTX::PTXPrmtMode::RC16: {
6901 unsigned V1 = (V & 1) << 1;
6902 return GetSelector(V1, V1 + 1, V1, V1 + 1);
6903 }
6904 default:
6905 llvm_unreachable("Invalid PRMT mode");
6906 }
6907}
6908
6909static APInt computePRMT(APInt A, APInt B, APInt Selector, unsigned Mode) {
6910 assert(A.getBitWidth() == 32 && B.getBitWidth() == 32 &&
6911 Selector.getBitWidth() == 32 && "PRMT must have i32 operands");
6912 // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
6913 APInt BitField = B.concat(NewLSB: A);
6914 APInt SelectorVal = getPRMTSelector(Selector, Mode);
6915 APInt Result(32, 0);
6916 for (unsigned I : llvm::seq(Size: 4U)) {
6917 APInt Sel = SelectorVal.extractBits(numBits: 4, bitPosition: I * 4);
6918 unsigned Idx = Sel.getLoBits(numBits: 3).getZExtValue();
6919 unsigned Sign = Sel.getHiBits(numBits: 1).getZExtValue();
6920 APInt Byte = BitField.extractBits(numBits: 8, bitPosition: Idx * 8);
6921 if (Sign)
6922 Byte = Byte.ashr(ShiftAmt: 8);
6923 Result.insertBits(SubBits: Byte, bitPosition: I * 8);
6924 }
6925 return Result;
6926}
6927
6928static SDValue combinePRMT(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
6929 CodeGenOptLevel OptLevel) {
6930 if (OptLevel == CodeGenOptLevel::None)
6931 return SDValue();
6932
6933 // Constant fold PRMT
6934 if (isa<ConstantSDNode>(Val: N->getOperand(Num: 0)) &&
6935 isa<ConstantSDNode>(Val: N->getOperand(Num: 1)) &&
6936 isa<ConstantSDNode>(Val: N->getOperand(Num: 2)))
6937 return DCI.DAG.getConstant(Val: computePRMT(A: N->getConstantOperandAPInt(Num: 0),
6938 B: N->getConstantOperandAPInt(Num: 1),
6939 Selector: N->getConstantOperandAPInt(Num: 2),
6940 Mode: N->getConstantOperandVal(Num: 3)),
6941 DL: SDLoc(N), VT: N->getValueType(ResNo: 0));
6942 return SDValue();
6943}
6944
6945// During call lowering we wrap the return values in a ProxyReg node which
6946// depend on the chain value produced by the completed call. This ensures that
6947// the full call is emitted in cases where libcalls are used to legalize
6948// operations. To improve the functioning of other DAG combines we pull all
6949// operations we can through one of these nodes, ensuring that the ProxyReg
6950// directly wraps a load. That is:
6951//
6952// (ProxyReg (zext (load retval0))) => (zext (ProxyReg (load retval0)))
6953//
6954static SDValue sinkProxyReg(SDValue R, SDValue Chain,
6955 TargetLowering::DAGCombinerInfo &DCI) {
6956 switch (R.getOpcode()) {
6957 case ISD::TRUNCATE:
6958 case ISD::ANY_EXTEND:
6959 case ISD::SIGN_EXTEND:
6960 case ISD::ZERO_EXTEND:
6961 case ISD::BITCAST: {
6962 if (SDValue V = sinkProxyReg(R: R.getOperand(i: 0), Chain, DCI))
6963 return DCI.DAG.getNode(Opcode: R.getOpcode(), DL: SDLoc(R), VT: R.getValueType(), Operand: V);
6964 return SDValue();
6965 }
6966 case ISD::SHL:
6967 case ISD::SRL:
6968 case ISD::SRA:
6969 case ISD::OR: {
6970 if (SDValue A = sinkProxyReg(R: R.getOperand(i: 0), Chain, DCI))
6971 if (SDValue B = sinkProxyReg(R: R.getOperand(i: 1), Chain, DCI))
6972 return DCI.DAG.getNode(Opcode: R.getOpcode(), DL: SDLoc(R), VT: R.getValueType(), N1: A, N2: B);
6973 return SDValue();
6974 }
6975 case ISD::Constant:
6976 return R;
6977 case ISD::LOAD:
6978 case NVPTXISD::LoadV2:
6979 case NVPTXISD::LoadV4: {
6980 return DCI.DAG.getNode(Opcode: NVPTXISD::ProxyReg, DL: SDLoc(R), VT: R.getValueType(),
6981 Ops: {Chain, R});
6982 }
6983 case ISD::BUILD_VECTOR: {
6984 if (DCI.isBeforeLegalize())
6985 return SDValue();
6986
6987 SmallVector<SDValue, 16> Ops;
6988 for (auto &Op : R->ops()) {
6989 SDValue V = sinkProxyReg(R: Op, Chain, DCI);
6990 if (!V)
6991 return SDValue();
6992 Ops.push_back(Elt: V);
6993 }
6994 return DCI.DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL: SDLoc(R), VT: R.getValueType(), Ops);
6995 }
6996 case ISD::EXTRACT_VECTOR_ELT: {
6997 if (DCI.isBeforeLegalize())
6998 return SDValue();
6999
7000 if (SDValue V = sinkProxyReg(R: R.getOperand(i: 0), Chain, DCI))
7001 return DCI.DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: SDLoc(R),
7002 VT: R.getValueType(), N1: V, N2: R.getOperand(i: 1));
7003 return SDValue();
7004 }
7005 default:
7006 return SDValue();
7007 }
7008}
7009
7010static unsigned getF16SubOpc(Intrinsic::ID AddIntrinsicID) {
7011 switch (AddIntrinsicID) {
7012 default:
7013 break;
7014 case Intrinsic::nvvm_add_rn_sat_f16:
7015 case Intrinsic::nvvm_add_rn_sat_v2f16:
7016 return NVPTXISD::SUB_RN_SAT;
7017 case Intrinsic::nvvm_add_rn_ftz_sat_f16:
7018 case Intrinsic::nvvm_add_rn_ftz_sat_v2f16:
7019 return NVPTXISD::SUB_RN_FTZ_SAT;
7020 }
7021 llvm_unreachable("Invalid F16 add intrinsic");
7022}
7023
7024static SDValue combineF16AddWithNeg(SDNode *N, SelectionDAG &DAG,
7025 Intrinsic::ID AddIntrinsicID) {
7026 SDValue Op1 = N->getOperand(Num: 1);
7027 SDValue Op2 = N->getOperand(Num: 2);
7028
7029 SDValue SubOp1, SubOp2;
7030
7031 if (Op1.getOpcode() == ISD::FNEG) {
7032 SubOp1 = Op2;
7033 SubOp2 = Op1.getOperand(i: 0);
7034 } else if (Op2.getOpcode() == ISD::FNEG) {
7035 SubOp1 = Op1;
7036 SubOp2 = Op2.getOperand(i: 0);
7037 } else {
7038 return SDValue();
7039 }
7040
7041 SDLoc DL(N);
7042 return DAG.getNode(Opcode: getF16SubOpc(AddIntrinsicID), DL, VT: N->getValueType(ResNo: 0),
7043 N1: SubOp1, N2: SubOp2);
7044}
7045
7046static SDValue combineIntrinsicWOChain(SDNode *N,
7047 TargetLowering::DAGCombinerInfo &DCI,
7048 const NVPTXSubtarget &STI) {
7049 unsigned IID = N->getConstantOperandVal(Num: 0);
7050
7051 switch (IID) {
7052 default:
7053 break;
7054 case Intrinsic::nvvm_add_rn_sat_f16:
7055 case Intrinsic::nvvm_add_rn_ftz_sat_f16:
7056 case Intrinsic::nvvm_add_rn_sat_v2f16:
7057 case Intrinsic::nvvm_add_rn_ftz_sat_v2f16:
7058 return combineF16AddWithNeg(N, DAG&: DCI.DAG, AddIntrinsicID: IID);
7059 }
7060 return SDValue();
7061}
7062
7063static SDValue combineProxyReg(SDNode *N,
7064 TargetLowering::DAGCombinerInfo &DCI) {
7065
7066 SDValue Chain = N->getOperand(Num: 0);
7067 SDValue Reg = N->getOperand(Num: 1);
7068
7069 // If the ProxyReg is not wrapping a load, try to pull the operations through
7070 // the ProxyReg.
7071 if (Reg.getOpcode() != ISD::LOAD) {
7072 if (SDValue V = sinkProxyReg(R: Reg, Chain, DCI))
7073 return V;
7074 }
7075
7076 return SDValue();
7077}
7078
7079SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
7080 DAGCombinerInfo &DCI) const {
7081 CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
7082 switch (N->getOpcode()) {
7083 default:
7084 break;
7085 case ISD::ADD:
7086 return PerformADDCombine(N, DCI, OptLevel);
7087 case ISD::ADDRSPACECAST:
7088 return combineADDRSPACECAST(N, DCI);
7089 case ISD::SIGN_EXTEND:
7090 case ISD::ZERO_EXTEND:
7091 return combineSZExtToMulWide(N, DCI, OptLevel);
7092 case ISD::BUILD_VECTOR:
7093 return PerformBUILD_VECTORCombine(N, DCI);
7094 case ISD::EXTRACT_VECTOR_ELT:
7095 return PerformEXTRACTCombine(N, DCI);
7096 case ISD::FADD:
7097 return performFADDCombine(N, DCI, OptLevel);
7098 case ISD::FMA:
7099 case ISD::FMUL:
7100 case ISD::FSUB:
7101 return performScalarizeV2F32Op(N, DCI, OptLevel);
7102 case ISD::FMAXNUM:
7103 case ISD::FMINNUM:
7104 case ISD::FMAXIMUM:
7105 case ISD::FMINIMUM:
7106 case ISD::FMAXIMUMNUM:
7107 case ISD::FMINIMUMNUM:
7108 return PerformFMinMaxCombine(N, DCI, PTXVersion: STI.getPTXVersion(),
7109 SmVersion: STI.getSmVersion());
7110 case ISD::LOAD:
7111 case NVPTXISD::LoadV2:
7112 case NVPTXISD::LoadV4:
7113 return combineLOAD(N, DCI, STI);
7114 case ISD::MUL:
7115 return PerformMULCombine(N, DCI, OptLevel);
7116 case NVPTXISD::PRMT:
7117 return combinePRMT(N, DCI, OptLevel);
7118 case NVPTXISD::ProxyReg:
7119 return combineProxyReg(N, DCI);
7120 case ISD::SETCC:
7121 return PerformSETCCCombine(N, DCI, SmVersion: STI.getSmVersion());
7122 case ISD::SHL:
7123 return PerformSHLCombine(N, DCI, OptLevel);
7124 case ISD::SREM:
7125 case ISD::UREM:
7126 return PerformREMCombine(N, DCI, OptLevel);
7127 case ISD::STORE:
7128 case NVPTXISD::StoreV2:
7129 case NVPTXISD::StoreV4:
7130 return combineSTORE(N, DCI, STI);
7131 case ISD::SELECT:
7132 return PerformSELECTShiftCombine(N, DCI);
7133 case ISD::VSELECT:
7134 return PerformVSELECTCombine(N, DCI);
7135 case ISD::INTRINSIC_WO_CHAIN:
7136 return combineIntrinsicWOChain(N, DCI, STI);
7137 }
7138 return SDValue();
7139}
7140
7141static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG,
7142 SmallVectorImpl<SDValue> &Results) {
7143 // Handle bitcasting to v2i8 without hitting the default promotion
7144 // strategy which goes through stack memory.
7145 SDValue Op(Node, 0);
7146 EVT ToVT = Op->getValueType(ResNo: 0);
7147 if (ToVT != MVT::v2i8) {
7148 return;
7149 }
7150
7151 // Bitcast to i16 and unpack elements into a vector
7152 SDLoc DL(Node);
7153 SDValue AsInt = DAG.getBitcast(VT: MVT::i16, V: Op->getOperand(Num: 0));
7154 SDValue Vec0 = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MVT::i8, Operand: AsInt);
7155 SDValue Const8 = DAG.getConstant(Val: 8, DL, VT: MVT::i16);
7156 SDValue Vec1 =
7157 DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MVT::i8,
7158 Operand: DAG.getNode(Opcode: ISD::SRL, DL, VT: MVT::i16, Ops: {AsInt, Const8}));
7159 Results.push_back(
7160 Elt: DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: MVT::v2i8, Ops: {Vec0, Vec1}));
7161}
7162
7163static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
7164 SmallVectorImpl<SDValue> &Results) {
7165 SDValue Chain = N->getOperand(Num: 0);
7166 SDValue Intrin = N->getOperand(Num: 1);
7167 SDLoc DL(N);
7168
7169 // Get the intrinsic ID
7170 unsigned IntrinNo = Intrin.getNode()->getAsZExtVal();
7171 switch (IntrinNo) {
7172 default:
7173 return;
7174 case Intrinsic::nvvm_ldu_global_i:
7175 case Intrinsic::nvvm_ldu_global_f:
7176 case Intrinsic::nvvm_ldu_global_p: {
7177 EVT ResVT = N->getValueType(ResNo: 0);
7178
7179 if (ResVT.isVector()) {
7180 // Vector LDG/LDU
7181
7182 unsigned NumElts = ResVT.getVectorNumElements();
7183 EVT EltVT = ResVT.getVectorElementType();
7184
7185 // Since LDU/LDG are target nodes, we cannot rely on DAG type
7186 // legalization.
7187 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
7188 // loaded type to i16 and propagate the "real" type as the memory type.
7189 bool NeedTrunc = false;
7190 if (EltVT.getSizeInBits() < 16) {
7191 EltVT = MVT::i16;
7192 NeedTrunc = true;
7193 }
7194
7195 unsigned Opcode = 0;
7196 SDVTList LdResVTs;
7197
7198 switch (NumElts) {
7199 default:
7200 return;
7201 case 2:
7202 Opcode = NVPTXISD::LDUV2;
7203 LdResVTs = DAG.getVTList(VT1: EltVT, VT2: EltVT, VT3: MVT::Other);
7204 break;
7205 case 4: {
7206 Opcode = NVPTXISD::LDUV4;
7207 EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
7208 LdResVTs = DAG.getVTList(VTs: ListVTs);
7209 break;
7210 }
7211 }
7212
7213 SmallVector<SDValue, 8> OtherOps;
7214
7215 // Copy regular operands
7216
7217 OtherOps.push_back(Elt: Chain); // Chain
7218 // Skip operand 1 (intrinsic ID)
7219 // Others
7220 OtherOps.append(in_start: N->op_begin() + 2, in_end: N->op_end());
7221
7222 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
7223
7224 SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, dl: DL, VTList: LdResVTs, Ops: OtherOps,
7225 MemVT: MemSD->getMemoryVT(),
7226 MMO: MemSD->getMemOperand());
7227
7228 SmallVector<SDValue, 4> ScalarRes;
7229
7230 for (unsigned i = 0; i < NumElts; ++i) {
7231 SDValue Res = NewLD.getValue(R: i);
7232 if (NeedTrunc)
7233 Res =
7234 DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: ResVT.getVectorElementType(), Operand: Res);
7235 ScalarRes.push_back(Elt: Res);
7236 }
7237
7238 SDValue LoadChain = NewLD.getValue(R: NumElts);
7239
7240 SDValue BuildVec =
7241 DAG.getBuildVector(VT: ResVT, DL, Ops: ScalarRes);
7242
7243 Results.push_back(Elt: BuildVec);
7244 Results.push_back(Elt: LoadChain);
7245 } else {
7246 // i8 LDG/LDU
7247 assert(ResVT.isSimple() && ResVT.getSimpleVT().SimpleTy == MVT::i8 &&
7248 "Custom handling of non-i8 ldu/ldg?");
7249
7250 // Just copy all operands as-is
7251 SmallVector<SDValue, 4> Ops(N->ops());
7252
7253 // Force output to i16
7254 SDVTList LdResVTs = DAG.getVTList(VT1: MVT::i16, VT2: MVT::Other);
7255
7256 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
7257
7258 // We make sure the memory type is i8, which will be used during isel
7259 // to select the proper instruction.
7260 SDValue NewLD =
7261 DAG.getMemIntrinsicNode(Opcode: ISD::INTRINSIC_W_CHAIN, dl: DL, VTList: LdResVTs, Ops,
7262 MemVT: MVT::i8, MMO: MemSD->getMemOperand());
7263
7264 Results.push_back(Elt: DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MVT::i8,
7265 Operand: NewLD.getValue(R: 0)));
7266 Results.push_back(Elt: NewLD.getValue(R: 1));
7267 }
7268 return;
7269 }
7270
7271 case Intrinsic::nvvm_tcgen05_ld_16x64b_x4:
7272 case Intrinsic::nvvm_tcgen05_ld_16x64b_x8:
7273 case Intrinsic::nvvm_tcgen05_ld_16x64b_x16:
7274 case Intrinsic::nvvm_tcgen05_ld_16x64b_x32:
7275 case Intrinsic::nvvm_tcgen05_ld_16x64b_x64:
7276 case Intrinsic::nvvm_tcgen05_ld_16x64b_x128:
7277 case Intrinsic::nvvm_tcgen05_ld_32x32b_x4:
7278 case Intrinsic::nvvm_tcgen05_ld_32x32b_x8:
7279 case Intrinsic::nvvm_tcgen05_ld_32x32b_x16:
7280 case Intrinsic::nvvm_tcgen05_ld_32x32b_x32:
7281 case Intrinsic::nvvm_tcgen05_ld_32x32b_x64:
7282 case Intrinsic::nvvm_tcgen05_ld_32x32b_x128:
7283 case Intrinsic::nvvm_tcgen05_ld_16x128b_x2:
7284 case Intrinsic::nvvm_tcgen05_ld_16x128b_x4:
7285 case Intrinsic::nvvm_tcgen05_ld_16x128b_x8:
7286 case Intrinsic::nvvm_tcgen05_ld_16x128b_x16:
7287 case Intrinsic::nvvm_tcgen05_ld_16x128b_x32:
7288 case Intrinsic::nvvm_tcgen05_ld_16x128b_x64:
7289 case Intrinsic::nvvm_tcgen05_ld_16x256b_x1:
7290 case Intrinsic::nvvm_tcgen05_ld_16x256b_x2:
7291 case Intrinsic::nvvm_tcgen05_ld_16x256b_x4:
7292 case Intrinsic::nvvm_tcgen05_ld_16x256b_x8:
7293 case Intrinsic::nvvm_tcgen05_ld_16x256b_x16:
7294 case Intrinsic::nvvm_tcgen05_ld_16x256b_x32:
7295 if (auto Res = lowerTcgen05Ld(N, DAG)) {
7296 Results.push_back(Elt: Res->first);
7297 Results.push_back(Elt: Res->second);
7298 }
7299 return;
7300
7301 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x4:
7302 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x8:
7303 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x16:
7304 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x32:
7305 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x64:
7306 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x128:
7307 if (auto Res = lowerTcgen05Ld(N, DAG, /*HasOffset=*/true)) {
7308 Results.push_back(Elt: Res->first);
7309 Results.push_back(Elt: Res->second);
7310 }
7311 return;
7312
7313 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x8_i32:
7314 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x8_f32:
7315 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x64_i32:
7316 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x64_f32:
7317 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x4_i32:
7318 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x4_f32:
7319 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x32_i32:
7320 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x32_f32:
7321 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x16_i32:
7322 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x16_f32:
7323 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x128_i32:
7324 case Intrinsic::nvvm_tcgen05_ld_red_32x32b_x128_f32:
7325 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x8_i32:
7326 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x8_f32:
7327 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x64_i32:
7328 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x64_f32:
7329 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x4_i32:
7330 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x4_f32:
7331 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x32_i32:
7332 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x32_f32:
7333 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x16_i32:
7334 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x16_f32:
7335 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x128_i32:
7336 case Intrinsic::nvvm_tcgen05_ld_red_16x32bx2_x128_f32:
7337 if (auto Res = lowerTcgen05LdRed(N, DAG)) {
7338 Results.push_back(Elt: std::get<0>(t&: *Res));
7339 Results.push_back(Elt: std::get<1>(t&: *Res));
7340 Results.push_back(Elt: std::get<2>(t&: *Res));
7341 }
7342 return;
7343 }
7344}
7345
7346static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG,
7347 SmallVectorImpl<SDValue> &Results) {
7348 // Change the CopyFromReg to output 2 64-bit results instead of a 128-bit
7349 // result so that it can pass the legalization
7350 SDLoc DL(N);
7351 SDValue Chain = N->getOperand(Num: 0);
7352 SDValue Reg = N->getOperand(Num: 1);
7353 SDValue Glue = N->getOperand(Num: 2);
7354
7355 assert(Reg.getValueType() == MVT::i128 &&
7356 "Custom lowering for CopyFromReg with 128-bit reg only");
7357 SmallVector<EVT, 4> ResultsType = {MVT::i64, MVT::i64, N->getValueType(ResNo: 1),
7358 N->getValueType(ResNo: 2)};
7359 SmallVector<SDValue, 3> NewOps = {Chain, Reg, Glue};
7360
7361 SDValue NewValue = DAG.getNode(Opcode: ISD::CopyFromReg, DL, ResultTys: ResultsType, Ops: NewOps);
7362 SDValue Pair = DAG.getNode(Opcode: ISD::BUILD_PAIR, DL, VT: MVT::i128,
7363 Ops: {NewValue.getValue(R: 0), NewValue.getValue(R: 1)});
7364
7365 Results.push_back(Elt: Pair);
7366 Results.push_back(Elt: NewValue.getValue(R: 2));
7367 Results.push_back(Elt: NewValue.getValue(R: 3));
7368}
7369
7370static void replaceProxyReg(SDNode *N, SelectionDAG &DAG,
7371 const TargetLowering &TLI,
7372 SmallVectorImpl<SDValue> &Results) {
7373 SDValue Chain = N->getOperand(Num: 0);
7374 SDValue Reg = N->getOperand(Num: 1);
7375
7376 MVT VT = TLI.getRegisterType(Context&: *DAG.getContext(), VT: Reg.getValueType());
7377
7378 SDValue NewReg = DAG.getAnyExtOrTrunc(Op: Reg, DL: SDLoc(N), VT);
7379 SDValue NewProxy =
7380 DAG.getNode(Opcode: NVPTXISD::ProxyReg, DL: SDLoc(N), VT, Ops: {Chain, NewReg});
7381 SDValue Res = DAG.getAnyExtOrTrunc(Op: NewProxy, DL: SDLoc(N), VT: N->getValueType(ResNo: 0));
7382
7383 Results.push_back(Elt: Res);
7384}
7385
7386static void replaceAtomicSwap128(SDNode *N, SelectionDAG &DAG,
7387 const NVPTXSubtarget &STI,
7388 SmallVectorImpl<SDValue> &Results) {
7389 assert(N->getValueType(0) == MVT::i128 &&
7390 "Custom lowering for atomic128 only supports i128");
7391
7392 AtomicSDNode *AN = cast<AtomicSDNode>(Val: N);
7393 SDLoc dl(N);
7394
7395 if (!STI.hasAtomSwap128()) {
7396 DAG.getContext()->diagnose(DI: DiagnosticInfoUnsupported(
7397 DAG.getMachineFunction().getFunction(),
7398 "Support for b128 atomics introduced in PTX ISA version 8.3 and "
7399 "requires target sm_90.",
7400 dl.getDebugLoc()));
7401
7402 Results.push_back(Elt: DAG.getUNDEF(VT: MVT::i128));
7403 Results.push_back(Elt: AN->getOperand(Num: 0)); // Chain
7404 return;
7405 }
7406
7407 SmallVector<SDValue, 6> Ops;
7408 Ops.push_back(Elt: AN->getOperand(Num: 0)); // Chain
7409 Ops.push_back(Elt: AN->getOperand(Num: 1)); // Ptr
7410 for (const auto &Op : AN->ops().drop_front(N: 2)) {
7411 // Low part
7412 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::EXTRACT_ELEMENT, DL: dl, VT: MVT::i64, N1: Op,
7413 N2: DAG.getIntPtrConstant(Val: 0, DL: dl)));
7414 // High part
7415 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::EXTRACT_ELEMENT, DL: dl, VT: MVT::i64, N1: Op,
7416 N2: DAG.getIntPtrConstant(Val: 1, DL: dl)));
7417 }
7418 unsigned Opcode = N->getOpcode() == ISD::ATOMIC_SWAP
7419 ? NVPTXISD::ATOMIC_SWAP_B128
7420 : NVPTXISD::ATOMIC_CMP_SWAP_B128;
7421 SDVTList Tys = DAG.getVTList(VT1: MVT::i64, VT2: MVT::i64, VT3: MVT::Other);
7422 SDValue Result = DAG.getMemIntrinsicNode(Opcode, dl, VTList: Tys, Ops, MemVT: MVT::i128,
7423 MMO: AN->getMemOperand());
7424 Results.push_back(Elt: DAG.getNode(Opcode: ISD::BUILD_PAIR, DL: dl, VT: MVT::i128,
7425 Ops: {Result.getValue(R: 0), Result.getValue(R: 1)}));
7426 Results.push_back(Elt: Result.getValue(R: 2));
7427}
7428
7429void NVPTXTargetLowering::ReplaceNodeResults(
7430 SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
7431 switch (N->getOpcode()) {
7432 default:
7433 report_fatal_error(reason: "Unhandled custom legalization");
7434 case ISD::BITCAST:
7435 ReplaceBITCAST(Node: N, DAG, Results);
7436 return;
7437 case ISD::LOAD:
7438 case ISD::MLOAD:
7439 replaceLoadVector(N, DAG, Results, STI);
7440 return;
7441 case ISD::INTRINSIC_W_CHAIN:
7442 ReplaceINTRINSIC_W_CHAIN(N, DAG, Results);
7443 return;
7444 case ISD::CopyFromReg:
7445 ReplaceCopyFromReg_128(N, DAG, Results);
7446 return;
7447 case NVPTXISD::ProxyReg:
7448 replaceProxyReg(N, DAG, TLI: *this, Results);
7449 return;
7450 case ISD::ATOMIC_CMP_SWAP:
7451 case ISD::ATOMIC_SWAP:
7452 replaceAtomicSwap128(N, DAG, STI, Results);
7453 return;
7454 }
7455}
7456
7457NVPTXTargetLowering::AtomicExpansionKind
7458NVPTXTargetLowering::shouldExpandAtomicRMWInIR(const AtomicRMWInst *AI) const {
7459 Type *Ty = AI->getValOperand()->getType();
7460
7461 // Try to lower LLVM atomicrmw fadd to PTX atomic.add. This is complicated
7462 // by the weird FTZ behavior PTX atom.add has:
7463 // - atom.add.f32 on global memory flushes denormals
7464 // - atom.add.f32 on shared memory does not flush denormals
7465 // - atom.add.f16 and atomic.add.bf16 never flush denormals
7466 //
7467 // We lower to atom.add only if the function's FTZ behavior matches that of
7468 // atom.add; otherwise, we lower to a CAS loop. But we always allow
7469 // atomic.add.bf16; even though it never flushes denormals, we never flush
7470 // bf16 denormals when doing regular arithmetic, even when FTZ is enabled.
7471 if (AI->isFloatingPointOperation() &&
7472 AI->getOperation() == AtomicRMWInst::BinOp::FAdd) {
7473 const bool FTZ =
7474 AI->getFunction()->getDenormalMode(FPType: APFloat::IEEEsingle()).Output ==
7475 DenormalMode::PreserveSign;
7476
7477 // AllowFTZAtomics forces atom.add regardless of the FTZ mismatch.
7478 if (Ty->isFloatTy()) {
7479 bool UseNative = AllowFTZAtomics;
7480 switch (AI->getPointerAddressSpace()) {
7481 case llvm::ADDRESS_SPACE_GLOBAL:
7482 UseNative |= FTZ;
7483 break;
7484 case llvm::ADDRESS_SPACE_SHARED:
7485 case llvm::ADDRESS_SPACE_SHARED_CLUSTER:
7486 UseNative |= !FTZ;
7487 break;
7488 }
7489 if (UseNative)
7490 return AtomicExpansionKind::None;
7491 }
7492
7493 if (Ty->isHalfTy() && (!FTZ || AllowFTZAtomics) &&
7494 STI.getSmVersion() >= 70 && STI.getPTXVersion() >= 63)
7495 return AtomicExpansionKind::None;
7496
7497 if (Ty->isBFloatTy() && STI.getSmVersion() >= 90 &&
7498 STI.getPTXVersion() >= 78)
7499 return AtomicExpansionKind::None;
7500
7501 if (Ty->isDoubleTy() && STI.hasAtomAddF64())
7502 return AtomicExpansionKind::None;
7503 }
7504
7505 // PTX's only atomic fp op is `add`; all other ops expand to a CAS loop.
7506 if (AI->isFloatingPointOperation())
7507 return AtomicExpansionKind::CmpXChg;
7508
7509 assert(Ty->isIntegerTy() && "Ty should be integer at this point");
7510 const unsigned BitWidth = cast<IntegerType>(Val: Ty)->getBitWidth();
7511
7512 switch (AI->getOperation()) {
7513 default:
7514 return AtomicExpansionKind::CmpXChg;
7515 case AtomicRMWInst::BinOp::Xchg:
7516 if (BitWidth == 128)
7517 return AtomicExpansionKind::None;
7518 [[fallthrough]];
7519 case AtomicRMWInst::BinOp::And:
7520 case AtomicRMWInst::BinOp::Or:
7521 case AtomicRMWInst::BinOp::Xor:
7522 switch (BitWidth) {
7523 case 8:
7524 case 16:
7525 return AtomicExpansionKind::CmpXChg;
7526 case 32:
7527 return AtomicExpansionKind::None;
7528 case 64:
7529 if (STI.hasAtomBitwise64())
7530 return AtomicExpansionKind::None;
7531 return AtomicExpansionKind::CmpXChg;
7532 case 128:
7533 return AtomicExpansionKind::CmpXChg;
7534 default:
7535 llvm_unreachable("unsupported width encountered");
7536 }
7537 case AtomicRMWInst::BinOp::Add:
7538 case AtomicRMWInst::BinOp::Sub:
7539 case AtomicRMWInst::BinOp::Max:
7540 case AtomicRMWInst::BinOp::Min:
7541 case AtomicRMWInst::BinOp::UMax:
7542 case AtomicRMWInst::BinOp::UMin:
7543 switch (BitWidth) {
7544 case 8:
7545 case 16:
7546 return AtomicExpansionKind::CmpXChg;
7547 case 32:
7548 return AtomicExpansionKind::None;
7549 case 64:
7550 if (STI.hasAtomMinMax64())
7551 return AtomicExpansionKind::None;
7552 return AtomicExpansionKind::CmpXChg;
7553 case 128:
7554 return AtomicExpansionKind::CmpXChg;
7555 default:
7556 llvm_unreachable("unsupported width encountered");
7557 }
7558 case AtomicRMWInst::BinOp::UIncWrap:
7559 case AtomicRMWInst::BinOp::UDecWrap:
7560 switch (BitWidth) {
7561 case 32:
7562 return AtomicExpansionKind::None;
7563 case 8:
7564 case 16:
7565 case 64:
7566 case 128:
7567 return AtomicExpansionKind::CmpXChg;
7568 default:
7569 llvm_unreachable("unsupported width encountered");
7570 }
7571 }
7572
7573 return AtomicExpansionKind::CmpXChg;
7574}
7575
7576bool NVPTXTargetLowering::shouldInsertFencesForAtomic(
7577 const Instruction *I) const {
7578 // This function returns true iff the operation is emulated using a CAS-loop,
7579 // or if it has the memory order seq_cst (which is not natively supported in
7580 // the PTX `atom` instruction).
7581 //
7582 // atomicrmw and cmpxchg instructions not efficiently supported by PTX
7583 // are lowered to CAS emulation loops that preserve their memory order,
7584 // syncscope, and volatile semantics. For PTX, it is more efficient to use
7585 // atom.cas.relaxed.sco instructions within the loop, and fences before and
7586 // after the loop to restore order.
7587 //
7588 // Atomic instructions efficiently supported by PTX are lowered to
7589 // `atom.<op>.<sem>.<scope` instruction with their corresponding memory order
7590 // and scope. Since PTX does not support seq_cst, we emulate it by lowering to
7591 // a fence.sc followed by an atom according to the PTX atomics ABI
7592 // https://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/atomic-abi.html
7593 if (auto *CI = dyn_cast<AtomicCmpXchgInst>(Val: I))
7594 return (cast<IntegerType>(Val: CI->getCompareOperand()->getType())
7595 ->getBitWidth() < STI.getMinCmpXchgSizeInBits()) ||
7596 CI->getMergedOrdering() == AtomicOrdering::SequentiallyConsistent;
7597 if (auto *RI = dyn_cast<AtomicRMWInst>(Val: I))
7598 return shouldExpandAtomicRMWInIR(AI: RI) == AtomicExpansionKind::CmpXChg ||
7599 RI->getOrdering() == AtomicOrdering::SequentiallyConsistent;
7600 return false;
7601}
7602
7603AtomicOrdering NVPTXTargetLowering::atomicOperationOrderAfterFenceSplit(
7604 const Instruction *I) const {
7605 // If the operation is emulated by a CAS-loop, we lower the instruction to
7606 // atom.<op>.relaxed, since AtomicExpandPass will insert fences for enforcing
7607 // the correct memory ordering around the CAS loop.
7608 //
7609 // When the operation is not emulated, but the memory order is seq_cst,
7610 // we must lower to "fence.sc.<scope>; atom.<op>.acquire.<scope>;" to conform
7611 // to the PTX atomics ABI.
7612 // https://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/atomic-abi.html
7613 // For such cases, emitLeadingFence() will separately insert the leading
7614 // "fence.sc.<scope>;". Here, we only set the memory order to acquire.
7615 //
7616 // Otherwise, the operation is not emulated, and the memory order is not
7617 // seq_cst. In this case, the LLVM memory order is natively supported by the
7618 // PTX `atom` instruction, and we just lower to the corresponding
7619 // `atom.<op>.relaxed|acquire|release|acq_rel". For such cases, this function
7620 // will NOT be called.
7621 // prerequisite: shouldInsertFencesForAtomic() should have returned `true` for
7622 // I before its memory order was modified.
7623 if (auto *CI = dyn_cast<AtomicCmpXchgInst>(Val: I);
7624 CI && CI->getMergedOrdering() == AtomicOrdering::SequentiallyConsistent &&
7625 cast<IntegerType>(Val: CI->getCompareOperand()->getType())->getBitWidth() >=
7626 STI.getMinCmpXchgSizeInBits())
7627 return AtomicOrdering::Acquire;
7628 else if (auto *RI = dyn_cast<AtomicRMWInst>(Val: I);
7629 RI && RI->getOrdering() == AtomicOrdering::SequentiallyConsistent &&
7630 shouldExpandAtomicRMWInIR(AI: RI) == AtomicExpansionKind::None)
7631 return AtomicOrdering::Acquire;
7632
7633 return AtomicOrdering::Monotonic;
7634}
7635
7636Instruction *NVPTXTargetLowering::emitLeadingFence(IRBuilderBase &Builder,
7637 Instruction *Inst,
7638 AtomicOrdering Ord) const {
7639 // prerequisite: shouldInsertFencesForAtomic() should have returned `true` for
7640 // `Inst` before its memory order was modified. We cannot enforce this with an
7641 // assert, because AtomicExpandPass will have modified the memory order
7642 // between the initial call to shouldInsertFencesForAtomic() and the call to
7643 // this function.
7644 if (!isa<AtomicCmpXchgInst>(Val: Inst) && !isa<AtomicRMWInst>(Val: Inst))
7645 return TargetLoweringBase::emitLeadingFence(Builder, Inst, Ord);
7646
7647 // Specialize for cmpxchg and atomicrmw
7648 auto SSID = getAtomicSyncScopeID(I: Inst);
7649 assert(SSID.has_value() && "Expected an atomic operation");
7650
7651 if (isReleaseOrStronger(AO: Ord))
7652 return Builder.CreateFence(Ordering: Ord == AtomicOrdering::SequentiallyConsistent
7653 ? AtomicOrdering::SequentiallyConsistent
7654 : AtomicOrdering::Release,
7655 SSID: SSID.value());
7656
7657 return nullptr;
7658}
7659
7660Instruction *NVPTXTargetLowering::emitTrailingFence(IRBuilderBase &Builder,
7661 Instruction *Inst,
7662 AtomicOrdering Ord) const {
7663 // prerequisite: shouldInsertFencesForAtomic() should have returned `true` for
7664 // `Inst` before its memory order was modified. See `emitLeadingFence` for why
7665 // this cannot be enforced with an assert. Specialize for cmpxchg and
7666 // atomicrmw
7667 auto *CI = dyn_cast<AtomicCmpXchgInst>(Val: Inst);
7668 auto *RI = dyn_cast<AtomicRMWInst>(Val: Inst);
7669 if (!CI && !RI)
7670 return TargetLoweringBase::emitTrailingFence(Builder, Inst, Ord);
7671
7672 auto SSID = getAtomicSyncScopeID(I: Inst);
7673 assert(SSID.has_value() && "Expected an atomic operation");
7674
7675 bool IsEmulated =
7676 CI ? cast<IntegerType>(Val: CI->getCompareOperand()->getType())
7677 ->getBitWidth() < STI.getMinCmpXchgSizeInBits()
7678 : shouldExpandAtomicRMWInIR(AI: RI) == AtomicExpansionKind::CmpXChg;
7679
7680 if (isAcquireOrStronger(AO: Ord) && IsEmulated)
7681 return Builder.CreateFence(Ordering: AtomicOrdering::Acquire, SSID: SSID.value());
7682
7683 return nullptr;
7684}
7685
7686// Rather than default to SINT when both UINT and SINT are custom, we only
7687// change the opcode when UINT is not legal and SINT is. UINT is preferred when
7688// both are custom since unsigned CVT instructions can lead to slightly better
7689// SASS code with fewer instructions.
7690unsigned NVPTXTargetLowering::getPreferredFPToIntOpcode(unsigned Op, EVT FromVT,
7691 EVT ToVT) const {
7692 if (isOperationLegal(Op, VT: ToVT))
7693 return Op;
7694 switch (Op) {
7695 case ISD::FP_TO_UINT:
7696 if (isOperationLegal(Op: ISD::FP_TO_SINT, VT: ToVT))
7697 return ISD::FP_TO_SINT;
7698 break;
7699 case ISD::STRICT_FP_TO_UINT:
7700 if (isOperationLegal(Op: ISD::STRICT_FP_TO_SINT, VT: ToVT))
7701 return ISD::STRICT_FP_TO_SINT;
7702 break;
7703 case ISD::VP_FP_TO_UINT:
7704 if (isOperationLegal(Op: ISD::VP_FP_TO_SINT, VT: ToVT))
7705 return ISD::VP_FP_TO_SINT;
7706 break;
7707 default:
7708 break;
7709 }
7710 return Op;
7711}
7712
7713// Pin NVPTXTargetObjectFile's vtables to this file.
7714NVPTXTargetObjectFile::~NVPTXTargetObjectFile() = default;
7715
7716MCSection *NVPTXTargetObjectFile::SelectSectionForGlobal(
7717 const GlobalObject *GO, SectionKind Kind, const TargetMachine &TM) const {
7718 return getDataSection();
7719}
7720
7721static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
7722 const SelectionDAG &DAG, unsigned Depth) {
7723 SDValue A = Op.getOperand(i: 0);
7724 SDValue B = Op.getOperand(i: 1);
7725 ConstantSDNode *Selector = dyn_cast<ConstantSDNode>(Val: Op.getOperand(i: 2));
7726 unsigned Mode = Op.getConstantOperandVal(i: 3);
7727
7728 if (!Selector)
7729 return;
7730
7731 KnownBits AKnown = DAG.computeKnownBits(Op: A, Depth);
7732 KnownBits BKnown = DAG.computeKnownBits(Op: B, Depth);
7733
7734 // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
7735 assert(AKnown.getBitWidth() == 32 && BKnown.getBitWidth() == 32 &&
7736 "PRMT must have i32 operands");
7737 assert(Known.getBitWidth() == 32 && "PRMT must have i32 result");
7738 KnownBits BitField = BKnown.concat(Lo: AKnown);
7739
7740 APInt SelectorVal = getPRMTSelector(Selector: Selector->getAPIntValue(), Mode);
7741 for (unsigned I : llvm::seq(Size: 4)) {
7742 APInt Sel = SelectorVal.extractBits(numBits: 4, bitPosition: I * 4);
7743 unsigned Idx = Sel.getLoBits(numBits: 3).getZExtValue();
7744 unsigned Sign = Sel.getHiBits(numBits: 1).getZExtValue();
7745 KnownBits Byte = BitField.extractBits(NumBits: 8, BitPosition: Idx * 8);
7746 if (Sign)
7747 Byte = KnownBits::ashr(LHS: Byte, RHS: KnownBits::makeConstant(C: APInt(8, 7)));
7748 Known.insertBits(SubBits: Byte, BitPosition: I * 8);
7749 }
7750}
7751
7752static void computeKnownBitsForLoadV(const SDValue Op, KnownBits &Known) {
7753 MemSDNode *LD = cast<MemSDNode>(Val: Op);
7754
7755 // We can't do anything without knowing the sign bit.
7756 auto ExtType = LD->getConstantOperandVal(Num: LD->getNumOperands() - 1);
7757 if (ExtType == ISD::SEXTLOAD)
7758 return;
7759
7760 // ExtLoading to vector types is weird and may not work well with known bits.
7761 auto DestVT = LD->getValueType(ResNo: 0);
7762 if (DestVT.isVector())
7763 return;
7764
7765 assert(Known.getBitWidth() == DestVT.getSizeInBits());
7766 auto ElementBitWidth = NVPTXDAGToDAGISel::getFromTypeWidthForLoad(Mem: LD);
7767 Known.Zero.setHighBits(Known.getBitWidth() - ElementBitWidth);
7768}
7769
7770void NVPTXTargetLowering::computeKnownBitsForTargetNode(
7771 const SDValue Op, KnownBits &Known, const APInt &DemandedElts,
7772 const SelectionDAG &DAG, unsigned Depth) const {
7773 Known.resetAll();
7774
7775 switch (Op.getOpcode()) {
7776 case NVPTXISD::PRMT:
7777 computeKnownBitsForPRMT(Op, Known, DAG, Depth);
7778 break;
7779 case NVPTXISD::LoadV2:
7780 case NVPTXISD::LoadV4:
7781 case NVPTXISD::LoadV8:
7782 computeKnownBitsForLoadV(Op, Known);
7783 break;
7784 default:
7785 break;
7786 }
7787}
7788
7789static std::pair<APInt, APInt> getPRMTDemandedBits(const APInt &SelectorVal,
7790 const APInt &DemandedBits) {
7791 APInt DemandedLHS = APInt(32, 0);
7792 APInt DemandedRHS = APInt(32, 0);
7793
7794 for (unsigned I : llvm::seq(Size: 4)) {
7795 if (DemandedBits.extractBits(numBits: 8, bitPosition: I * 8).isZero())
7796 continue;
7797
7798 APInt Sel = SelectorVal.extractBits(numBits: 4, bitPosition: I * 4);
7799 unsigned Idx = Sel.getLoBits(numBits: 3).getZExtValue();
7800 unsigned Sign = Sel.getHiBits(numBits: 1).getZExtValue();
7801
7802 APInt &Src = Idx < 4 ? DemandedLHS : DemandedRHS;
7803 unsigned ByteStart = (Idx % 4) * 8;
7804 if (Sign)
7805 Src.setBit(ByteStart + 7);
7806 else
7807 Src.setBits(loBit: ByteStart, hiBit: ByteStart + 8);
7808 }
7809
7810 return {DemandedLHS, DemandedRHS};
7811}
7812
7813// Replace undef with 0 as this is easier for other optimizations such as
7814// known bits.
7815static SDValue canonicalizePRMTInput(SDValue Op, SelectionDAG &DAG) {
7816 if (!Op)
7817 return SDValue();
7818 if (Op.isUndef())
7819 return DAG.getConstant(Val: 0, DL: SDLoc(), VT: MVT::i32);
7820 return Op;
7821}
7822
7823static SDValue simplifyDemandedBitsForPRMT(SDValue PRMT,
7824 const APInt &DemandedBits,
7825 SelectionDAG &DAG,
7826 const TargetLowering &TLI,
7827 unsigned Depth) {
7828 assert(PRMT.getOpcode() == NVPTXISD::PRMT);
7829 SDValue Op0 = PRMT.getOperand(i: 0);
7830 SDValue Op1 = PRMT.getOperand(i: 1);
7831 auto *SelectorConst = dyn_cast<ConstantSDNode>(Val: PRMT.getOperand(i: 2));
7832 if (!SelectorConst)
7833 return SDValue();
7834
7835 unsigned Mode = PRMT.getConstantOperandVal(i: 3);
7836 const APInt Selector = getPRMTSelector(Selector: SelectorConst->getAPIntValue(), Mode);
7837
7838 // Try to simplify the PRMT to one of the inputs if the used bytes are all
7839 // from the same input in the correct order.
7840 const unsigned LeadingBytes = DemandedBits.countLeadingZeros() / 8;
7841 const unsigned SelBits = (4 - LeadingBytes) * 4;
7842 if (Selector.getLoBits(numBits: SelBits) == APInt(32, 0x3210).getLoBits(numBits: SelBits))
7843 return Op0;
7844 if (Selector.getLoBits(numBits: SelBits) == APInt(32, 0x7654).getLoBits(numBits: SelBits))
7845 return Op1;
7846
7847 auto [DemandedLHS, DemandedRHS] = getPRMTDemandedBits(SelectorVal: Selector, DemandedBits);
7848
7849 // Attempt to avoid multi-use ops if we don't need anything from them.
7850 SDValue DemandedOp0 =
7851 TLI.SimplifyMultipleUseDemandedBits(Op: Op0, DemandedBits: DemandedLHS, DAG, Depth: Depth + 1);
7852 SDValue DemandedOp1 =
7853 TLI.SimplifyMultipleUseDemandedBits(Op: Op1, DemandedBits: DemandedRHS, DAG, Depth: Depth + 1);
7854
7855 DemandedOp0 = canonicalizePRMTInput(Op: DemandedOp0, DAG);
7856 DemandedOp1 = canonicalizePRMTInput(Op: DemandedOp1, DAG);
7857 if ((DemandedOp0 && DemandedOp0 != Op0) ||
7858 (DemandedOp1 && DemandedOp1 != Op1)) {
7859 Op0 = DemandedOp0 ? DemandedOp0 : Op0;
7860 Op1 = DemandedOp1 ? DemandedOp1 : Op1;
7861 return getPRMT(A: Op0, B: Op1, Selector: Selector.getZExtValue(), DL: SDLoc(PRMT), DAG);
7862 }
7863
7864 return SDValue();
7865}
7866
7867bool NVPTXTargetLowering::SimplifyDemandedBitsForTargetNode(
7868 SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
7869 KnownBits &Known, TargetLoweringOpt &TLO, unsigned Depth) const {
7870 Known.resetAll();
7871
7872 switch (Op.getOpcode()) {
7873 case NVPTXISD::PRMT:
7874 if (SDValue Result = simplifyDemandedBitsForPRMT(PRMT: Op, DemandedBits, DAG&: TLO.DAG,
7875 TLI: *this, Depth)) {
7876 TLO.CombineTo(O: Op, N: Result);
7877 return true;
7878 }
7879 break;
7880 default:
7881 break;
7882 }
7883
7884 computeKnownBitsForTargetNode(Op, Known, DemandedElts, DAG: TLO.DAG, Depth);
7885 return false;
7886}
7887