1//===-- NVPTXISelLowering.cpp - NVPTX DAG Lowering Implementation ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the interfaces that NVPTX uses to lower LLVM code into a
10// selection DAG.
11//
12//===----------------------------------------------------------------------===//
13
14#include "NVPTXISelLowering.h"
15#include "MCTargetDesc/NVPTXBaseInfo.h"
16#include "NVPTX.h"
17#include "NVPTXSubtarget.h"
18#include "NVPTXTargetMachine.h"
19#include "NVPTXTargetObjectFile.h"
20#include "NVPTXUtilities.h"
21#include "llvm/ADT/APFloat.h"
22#include "llvm/ADT/APInt.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/SmallVector.h"
25#include "llvm/ADT/StringRef.h"
26#include "llvm/CodeGen/Analysis.h"
27#include "llvm/CodeGen/ISDOpcodes.h"
28#include "llvm/CodeGen/MachineFunction.h"
29#include "llvm/CodeGen/MachineJumpTableInfo.h"
30#include "llvm/CodeGen/MachineMemOperand.h"
31#include "llvm/CodeGen/Register.h"
32#include "llvm/CodeGen/SelectionDAG.h"
33#include "llvm/CodeGen/SelectionDAGNodes.h"
34#include "llvm/CodeGen/TargetCallingConv.h"
35#include "llvm/CodeGen/TargetLowering.h"
36#include "llvm/CodeGen/ValueTypes.h"
37#include "llvm/CodeGenTypes/MachineValueType.h"
38#include "llvm/IR/Argument.h"
39#include "llvm/IR/Attributes.h"
40#include "llvm/IR/Constants.h"
41#include "llvm/IR/DataLayout.h"
42#include "llvm/IR/DerivedTypes.h"
43#include "llvm/IR/DiagnosticInfo.h"
44#include "llvm/IR/FPEnv.h"
45#include "llvm/IR/Function.h"
46#include "llvm/IR/GlobalValue.h"
47#include "llvm/IR/IRBuilder.h"
48#include "llvm/IR/Instruction.h"
49#include "llvm/IR/Instructions.h"
50#include "llvm/IR/IntrinsicsNVPTX.h"
51#include "llvm/IR/Module.h"
52#include "llvm/IR/Type.h"
53#include "llvm/IR/Value.h"
54#include "llvm/Support/Alignment.h"
55#include "llvm/Support/AtomicOrdering.h"
56#include "llvm/Support/Casting.h"
57#include "llvm/Support/CodeGen.h"
58#include "llvm/Support/CommandLine.h"
59#include "llvm/Support/ErrorHandling.h"
60#include "llvm/Support/NVPTXAddrSpace.h"
61#include "llvm/Support/raw_ostream.h"
62#include "llvm/Target/TargetMachine.h"
63#include "llvm/Target/TargetOptions.h"
64#include <algorithm>
65#include <cassert>
66#include <cmath>
67#include <cstdint>
68#include <iterator>
69#include <optional>
70#include <string>
71#include <tuple>
72#include <utility>
73#include <vector>
74
75#define DEBUG_TYPE "nvptx-lower"
76
77using namespace llvm;
78
79static cl::opt<bool> sched4reg(
80 "nvptx-sched4reg",
81 cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(Val: false));
82
83static cl::opt<unsigned> FMAContractLevelOpt(
84 "nvptx-fma-level", cl::Hidden,
85 cl::desc("NVPTX Specific: FMA contraction (0: don't do it"
86 " 1: do it 2: do it aggressively"),
87 cl::init(Val: 2));
88
89static cl::opt<NVPTX::DivPrecisionLevel> UsePrecDivF32(
90 "nvptx-prec-divf32", cl::Hidden,
91 cl::desc(
92 "NVPTX Specific: Override the precision of the lowering for f32 fdiv"),
93 cl::values(
94 clEnumValN(NVPTX::DivPrecisionLevel::Approx, "0", "Use div.approx"),
95 clEnumValN(NVPTX::DivPrecisionLevel::Full, "1", "Use div.full"),
96 clEnumValN(NVPTX::DivPrecisionLevel::IEEE754, "2",
97 "Use IEEE Compliant F32 div.rnd if available (default)"),
98 clEnumValN(NVPTX::DivPrecisionLevel::IEEE754_NoFTZ, "3",
99 "Use IEEE Compliant F32 div.rnd if available, no FTZ")),
100 cl::init(Val: NVPTX::DivPrecisionLevel::IEEE754));
101
102static cl::opt<bool> UsePrecSqrtF32(
103 "nvptx-prec-sqrtf32", cl::Hidden,
104 cl::desc("NVPTX Specific: 0 use sqrt.approx, 1 use sqrt.rn."),
105 cl::init(Val: true));
106
107/// Whereas CUDA's implementation (see libdevice) uses ex2.approx for exp2(), it
108/// does NOT use lg2.approx for log2, so this is disabled by default.
109static cl::opt<bool> UseApproxLog2F32(
110 "nvptx-approx-log2f32",
111 cl::desc("NVPTX Specific: whether to use lg2.approx for log2"),
112 cl::init(Val: false));
113
114static cl::opt<bool> ForceMinByValParamAlign(
115 "nvptx-force-min-byval-param-align", cl::Hidden,
116 cl::desc("NVPTX Specific: force 4-byte minimal alignment for byval"
117 " params of device functions."),
118 cl::init(Val: false));
119
120NVPTX::DivPrecisionLevel
121NVPTXTargetLowering::getDivF32Level(const MachineFunction &MF,
122 const SDNode &N) const {
123 // If nvptx-prec-div32=N is used on the command-line, always honor it
124 if (UsePrecDivF32.getNumOccurrences() > 0)
125 return UsePrecDivF32;
126
127 // Otherwise, use div.approx if fast math is enabled
128 if (allowUnsafeFPMath(MF))
129 return NVPTX::DivPrecisionLevel::Approx;
130
131 const SDNodeFlags Flags = N.getFlags();
132 if (Flags.hasApproximateFuncs())
133 return NVPTX::DivPrecisionLevel::Approx;
134
135 return NVPTX::DivPrecisionLevel::IEEE754;
136}
137
138bool NVPTXTargetLowering::usePrecSqrtF32(const MachineFunction &MF,
139 const SDNode *N) const {
140 // If nvptx-prec-sqrtf32 is used on the command-line, always honor it
141 if (UsePrecSqrtF32.getNumOccurrences() > 0)
142 return UsePrecSqrtF32;
143
144 // Otherwise, use sqrt.approx if fast math is enabled
145 if (allowUnsafeFPMath(MF))
146 return false;
147
148 if (N) {
149 const SDNodeFlags Flags = N->getFlags();
150 if (Flags.hasApproximateFuncs())
151 return false;
152 }
153
154 return true;
155}
156
157bool NVPTXTargetLowering::useF32FTZ(const MachineFunction &MF) const {
158 return MF.getDenormalMode(FPType: APFloat::IEEEsingle()).Output ==
159 DenormalMode::PreserveSign;
160}
161
162static bool IsPTXVectorType(MVT VT) {
163 switch (VT.SimpleTy) {
164 default:
165 return false;
166 case MVT::v2i1:
167 case MVT::v4i1:
168 case MVT::v2i8:
169 case MVT::v4i8:
170 case MVT::v8i8: // <2 x i8x4>
171 case MVT::v16i8: // <4 x i8x4>
172 case MVT::v2i16:
173 case MVT::v4i16:
174 case MVT::v8i16: // <4 x i16x2>
175 case MVT::v2i32:
176 case MVT::v4i32:
177 case MVT::v2i64:
178 case MVT::v2f16:
179 case MVT::v4f16:
180 case MVT::v8f16: // <4 x f16x2>
181 case MVT::v2bf16:
182 case MVT::v4bf16:
183 case MVT::v8bf16: // <4 x bf16x2>
184 case MVT::v2f32:
185 case MVT::v4f32:
186 case MVT::v2f64:
187 case MVT::v4i64:
188 case MVT::v4f64:
189 case MVT::v8i32:
190 case MVT::v8f32:
191 case MVT::v16f16: // <8 x f16x2>
192 case MVT::v16bf16: // <8 x bf16x2>
193 case MVT::v16i16: // <8 x i16x2>
194 case MVT::v32i8: // <8 x i8x4>
195 return true;
196 }
197}
198
199static bool Is16bitsType(MVT VT) {
200 return (VT.SimpleTy == MVT::f16 || VT.SimpleTy == MVT::bf16 ||
201 VT.SimpleTy == MVT::i16);
202}
203
204// When legalizing vector loads/stores, this function is called, which does two
205// things:
206// 1. Determines Whether the vector is something we want to custom lower,
207// std::nullopt is returned if we do not want to custom lower it.
208// 2. If we do want to handle it, returns two parameters:
209// - unsigned int NumElts - The number of elements in the final vector
210// - EVT EltVT - The type of the elements in the final vector
211static std::optional<std::pair<unsigned int, MVT>>
212getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
213 if (!VectorEVT.isSimple())
214 return std::nullopt;
215 const MVT VectorVT = VectorEVT.getSimpleVT();
216
217 if (!VectorVT.isVector()) {
218 if (VectorVT == MVT::i128 || VectorVT == MVT::f128)
219 return {{2, MVT::i64}};
220 return std::nullopt;
221 }
222
223 const MVT EltVT = VectorVT.getVectorElementType();
224 const unsigned NumElts = VectorVT.getVectorNumElements();
225
226 // We only handle "native" vector sizes for now, e.g. <4 x double> is not
227 // legal. We can (and should) split that into 2 stores of <2 x double> here
228 // but I'm leaving that as a TODO for now.
229 switch (VectorVT.SimpleTy) {
230 default:
231 return std::nullopt;
232 case MVT::v4i64:
233 case MVT::v4f64:
234 case MVT::v8i32:
235 case MVT::v8f32:
236 // This is a "native" vector type iff the address space is global
237 // and the target supports 256-bit loads/stores
238 if (!CanLowerTo256Bit)
239 return std::nullopt;
240 LLVM_FALLTHROUGH;
241 case MVT::v2i8:
242 case MVT::v2i32:
243 case MVT::v2i64:
244 case MVT::v2f32:
245 case MVT::v2f64:
246 case MVT::v4i32:
247 case MVT::v4f32:
248 // This is a "native" vector type
249 return std::pair(NumElts, EltVT);
250 case MVT::v16f16: // <8 x f16x2>
251 case MVT::v16bf16: // <8 x bf16x2>
252 case MVT::v16i16: // <8 x i16x2>
253 case MVT::v32i8: // <8 x i8x4>
254 // This can be upsized into a "native" vector type iff the address space is
255 // global and the target supports 256-bit loads/stores.
256 if (!CanLowerTo256Bit)
257 return std::nullopt;
258 LLVM_FALLTHROUGH;
259 case MVT::v2i16: // <1 x i16x2>
260 case MVT::v2f16: // <1 x f16x2>
261 case MVT::v2bf16: // <1 x bf16x2>
262 case MVT::v4i8: // <1 x i8x4>
263 case MVT::v4i16: // <2 x i16x2>
264 case MVT::v4f16: // <2 x f16x2>
265 case MVT::v4bf16: // <2 x bf16x2>
266 case MVT::v8i8: // <2 x i8x4>
267 case MVT::v8f16: // <4 x f16x2>
268 case MVT::v8bf16: // <4 x bf16x2>
269 case MVT::v8i16: // <4 x i16x2>
270 case MVT::v16i8: // <4 x i8x4>
271 // This can be upsized into a "native" vector type.
272 // Despite vectors like v8i8, v16i8, v8i16 being within the bit-limit for
273 // total load/store size, PTX syntax only supports v2/v4. Thus, we can't use
274 // vectorized loads/stores with the actual element type for i8/i16 as that
275 // would require v8/v16 variants that do not exist.
276 // In order to load/store such vectors efficiently, here in Type
277 // Legalization, we split the vector into word-sized chunks (v2x16/v4i8).
278 // Later, we will lower to PTX as vectors of b32.
279
280 // Number of elements to pack in one word.
281 const unsigned NPerWord = 32 / EltVT.getSizeInBits();
282
283 return std::pair(NumElts / NPerWord, MVT::getVectorVT(VT: EltVT, NumElements: NPerWord));
284 }
285
286 llvm_unreachable("All cases in switch should return.");
287}
288
289/// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
290/// EVTs that compose it. Unlike ComputeValueVTs, this will break apart vectors
291/// into their primitive components.
292/// NOTE: This is a band-aid for code that expects ComputeValueVTs to return the
293/// same number of types as the Ins/Outs arrays in LowerFormalArguments,
294/// LowerCall, and LowerReturn.
295static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
296 Type *Ty, SmallVectorImpl<EVT> &ValueVTs,
297 SmallVectorImpl<uint64_t> *Offsets = nullptr,
298 uint64_t StartingOffset = 0) {
299 SmallVector<EVT, 16> TempVTs;
300 SmallVector<uint64_t, 16> TempOffsets;
301
302 // Special case for i128 - decompose to (i64, i64)
303 if (Ty->isIntegerTy(Bitwidth: 128) || Ty->isFP128Ty()) {
304 ValueVTs.append(IL: {MVT::i64, MVT::i64});
305
306 if (Offsets)
307 Offsets->append(IL: {StartingOffset + 0, StartingOffset + 8});
308
309 return;
310 }
311
312 // Given a struct type, recursively traverse the elements with custom ComputePTXValueVTs.
313 if (StructType *STy = dyn_cast<StructType>(Val: Ty)) {
314 auto const *SL = DL.getStructLayout(Ty: STy);
315 auto ElementNum = 0;
316 for(auto *EI : STy->elements()) {
317 ComputePTXValueVTs(TLI, DL, Ty: EI, ValueVTs, Offsets,
318 StartingOffset: StartingOffset + SL->getElementOffset(Idx: ElementNum));
319 ++ElementNum;
320 }
321 return;
322 }
323
324 // Given an array type, recursively traverse the elements with custom ComputePTXValueVTs.
325 if (ArrayType *ATy = dyn_cast<ArrayType>(Val: Ty)) {
326 Type *EltTy = ATy->getElementType();
327 uint64_t EltSize = DL.getTypeAllocSize(Ty: EltTy);
328 for (int I : llvm::seq<int>(Size: ATy->getNumElements()))
329 ComputePTXValueVTs(TLI, DL, Ty: EltTy, ValueVTs, Offsets, StartingOffset: StartingOffset + I * EltSize);
330 return;
331 }
332
333 ComputeValueVTs(TLI, DL, Ty, ValueVTs&: TempVTs, FixedOffsets: &TempOffsets, StartingOffset);
334 for (unsigned i = 0, e = TempVTs.size(); i != e; ++i) {
335 EVT VT = TempVTs[i];
336 uint64_t Off = TempOffsets[i];
337 // Split vectors into individual elements, except for v2f16, which
338 // we will pass as a single scalar.
339 if (VT.isVector()) {
340 unsigned NumElts = VT.getVectorNumElements();
341 EVT EltVT = VT.getVectorElementType();
342 // We require power-of-2 sized vectors because
343 // TargetLoweringBase::getVectorTypeBreakdown() which is invoked in
344 // ComputePTXValueVTs() cannot currently break down non-power-of-2 sized
345 // vectors.
346 if ((Is16bitsType(VT: EltVT.getSimpleVT())) && NumElts % 2 == 0 &&
347 isPowerOf2_32(Value: NumElts)) {
348 // Vectors with an even number of f16 elements will be passed to
349 // us as an array of v2f16/v2bf16 elements. We must match this so we
350 // stay in sync with Ins/Outs.
351 switch (EltVT.getSimpleVT().SimpleTy) {
352 case MVT::f16:
353 EltVT = MVT::v2f16;
354 break;
355 case MVT::bf16:
356 EltVT = MVT::v2bf16;
357 break;
358 case MVT::i16:
359 EltVT = MVT::v2i16;
360 break;
361 default:
362 llvm_unreachable("Unexpected type");
363 }
364 NumElts /= 2;
365 } else if (EltVT.getSimpleVT() == MVT::i8 &&
366 ((NumElts % 4 == 0 && isPowerOf2_32(Value: NumElts)) ||
367 NumElts == 3)) {
368 // v*i8 are formally lowered as v4i8
369 EltVT = MVT::v4i8;
370 NumElts = (NumElts + 3) / 4;
371 } else if (EltVT.getSimpleVT() == MVT::i8 && NumElts == 2) {
372 // v2i8 is promoted to v2i16
373 NumElts = 1;
374 EltVT = MVT::v2i8;
375 }
376 for (unsigned j = 0; j != NumElts; ++j) {
377 ValueVTs.push_back(Elt: EltVT);
378 if (Offsets)
379 Offsets->push_back(Elt: Off + j * EltVT.getStoreSize());
380 }
381 } else {
382 ValueVTs.push_back(Elt: VT);
383 if (Offsets)
384 Offsets->push_back(Elt: Off);
385 }
386 }
387}
388
389/// PromoteScalarIntegerPTX
390/// Used to make sure the arguments/returns are suitable for passing
391/// and promote them to a larger size if they're not.
392///
393/// The promoted type is placed in \p PromoteVT if the function returns true.
394static EVT promoteScalarIntegerPTX(const EVT VT) {
395 if (VT.isScalarInteger()) {
396 switch (PowerOf2Ceil(A: VT.getFixedSizeInBits())) {
397 default:
398 llvm_unreachable(
399 "Promotion is not suitable for scalars of size larger than 64-bits");
400 case 1:
401 return MVT::i1;
402 case 2:
403 case 4:
404 case 8:
405 return MVT::i8;
406 case 16:
407 return MVT::i16;
408 case 32:
409 return MVT::i32;
410 case 64:
411 return MVT::i64;
412 }
413 }
414 return VT;
415}
416
417// Check whether we can merge loads/stores of some of the pieces of a
418// flattened function parameter or return value into a single vector
419// load/store.
420//
421// The flattened parameter is represented as a list of EVTs and
422// offsets, and the whole structure is aligned to ParamAlignment. This
423// function determines whether we can load/store pieces of the
424// parameter starting at index Idx using a single vectorized op of
425// size AccessSize. If so, it returns the number of param pieces
426// covered by the vector op. Otherwise, it returns 1.
427static unsigned CanMergeParamLoadStoresStartingAt(
428 unsigned Idx, uint32_t AccessSize, const SmallVectorImpl<EVT> &ValueVTs,
429 const SmallVectorImpl<uint64_t> &Offsets, Align ParamAlignment) {
430
431 // Can't vectorize if param alignment is not sufficient.
432 if (ParamAlignment < AccessSize)
433 return 1;
434 // Can't vectorize if offset is not aligned.
435 if (Offsets[Idx] & (AccessSize - 1))
436 return 1;
437
438 EVT EltVT = ValueVTs[Idx];
439 unsigned EltSize = EltVT.getStoreSize();
440
441 // Element is too large to vectorize.
442 if (EltSize >= AccessSize)
443 return 1;
444
445 unsigned NumElts = AccessSize / EltSize;
446 // Can't vectorize if AccessBytes if not a multiple of EltSize.
447 if (AccessSize != EltSize * NumElts)
448 return 1;
449
450 // We don't have enough elements to vectorize.
451 if (Idx + NumElts > ValueVTs.size())
452 return 1;
453
454 // PTX ISA can only deal with 2- and 4-element vector ops.
455 if (NumElts != 4 && NumElts != 2)
456 return 1;
457
458 for (unsigned j = Idx + 1; j < Idx + NumElts; ++j) {
459 // Types do not match.
460 if (ValueVTs[j] != EltVT)
461 return 1;
462
463 // Elements are not contiguous.
464 if (Offsets[j] - Offsets[j - 1] != EltSize)
465 return 1;
466 }
467 // OK. We can vectorize ValueVTs[i..i+NumElts)
468 return NumElts;
469}
470
471// Computes whether and how we can vectorize the loads/stores of a
472// flattened function parameter or return value.
473//
474// The flattened parameter is represented as the list of ValueVTs and
475// Offsets, and is aligned to ParamAlignment bytes. We return a vector
476// of the same size as ValueVTs indicating how each piece should be
477// loaded/stored (i.e. as a scalar, or as part of a vector
478// load/store).
479static SmallVector<unsigned, 16>
480VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
481 const SmallVectorImpl<uint64_t> &Offsets,
482 Align ParamAlignment, bool IsVAArg = false) {
483 // Set vector size to match ValueVTs and mark all elements as
484 // scalars by default.
485
486 if (IsVAArg)
487 return SmallVector<unsigned>(ValueVTs.size(), 1);
488
489 SmallVector<unsigned, 16> VectorInfo;
490
491 const auto GetNumElts = [&](unsigned I) -> unsigned {
492 for (const unsigned AccessSize : {16, 8, 4, 2}) {
493 const unsigned NumElts = CanMergeParamLoadStoresStartingAt(
494 Idx: I, AccessSize, ValueVTs, Offsets, ParamAlignment);
495 assert((NumElts == 1 || NumElts == 2 || NumElts == 4) &&
496 "Unexpected vectorization size");
497 if (NumElts != 1)
498 return NumElts;
499 }
500 return 1;
501 };
502
503 // Check what we can vectorize using 128/64/32-bit accesses.
504 for (unsigned I = 0, E = ValueVTs.size(); I != E;) {
505 const unsigned NumElts = GetNumElts(I);
506 VectorInfo.push_back(Elt: NumElts);
507 I += NumElts;
508 }
509 assert(std::accumulate(VectorInfo.begin(), VectorInfo.end(), 0u) ==
510 ValueVTs.size());
511 return VectorInfo;
512}
513
514// NVPTXTargetLowering Constructor.
515NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
516 const NVPTXSubtarget &STI)
517 : TargetLowering(TM), nvTM(&TM), STI(STI), GlobalUniqueCallSite(0) {
518 // always lower memset, memcpy, and memmove intrinsics to load/store
519 // instructions, rather
520 // then generating calls to memset, mempcy or memmove.
521 MaxStoresPerMemset = MaxStoresPerMemsetOptSize = (unsigned)0xFFFFFFFF;
522 MaxStoresPerMemcpy = MaxStoresPerMemcpyOptSize = (unsigned) 0xFFFFFFFF;
523 MaxStoresPerMemmove = MaxStoresPerMemmoveOptSize = (unsigned) 0xFFFFFFFF;
524
525 setBooleanContents(ZeroOrNegativeOneBooleanContent);
526 setBooleanVectorContents(ZeroOrNegativeOneBooleanContent);
527
528 // Jump is Expensive. Don't create extra control flow for 'and', 'or'
529 // condition branches.
530 setJumpIsExpensive(true);
531
532 // Wide divides are _very_ slow. Try to reduce the width of the divide if
533 // possible.
534 addBypassSlowDiv(SlowBitWidth: 64, FastBitWidth: 32);
535
536 // By default, use the Source scheduling
537 if (sched4reg)
538 setSchedulingPreference(Sched::RegPressure);
539 else
540 setSchedulingPreference(Sched::Source);
541
542 auto setFP16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
543 LegalizeAction NoF16Action) {
544 bool IsOpSupported = STI.allowFP16Math();
545 switch (Op) {
546 // Several FP16 instructions are available on sm_80 only.
547 case ISD::FMINNUM:
548 case ISD::FMAXNUM:
549 case ISD::FMAXNUM_IEEE:
550 case ISD::FMINNUM_IEEE:
551 case ISD::FMAXIMUM:
552 case ISD::FMINIMUM:
553 IsOpSupported &= STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
554 break;
555 case ISD::FEXP2:
556 IsOpSupported &= STI.getSmVersion() >= 75 && STI.getPTXVersion() >= 70;
557 break;
558 }
559 setOperationAction(Op, VT, Action: IsOpSupported ? Action : NoF16Action);
560 };
561
562 auto setBF16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
563 LegalizeAction NoBF16Action) {
564 bool IsOpSupported = STI.hasNativeBF16Support(Opcode: Op);
565 setOperationAction(
566 Op, VT, Action: IsOpSupported ? Action : NoBF16Action);
567 };
568
569 auto setI16x2OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
570 LegalizeAction NoI16x2Action) {
571 bool IsOpSupported = false;
572 // instructions are available on sm_90 only
573 switch (Op) {
574 case ISD::ADD:
575 case ISD::SMAX:
576 case ISD::SMIN:
577 case ISD::UMIN:
578 case ISD::UMAX:
579 IsOpSupported = STI.getSmVersion() >= 90 && STI.getPTXVersion() >= 80;
580 break;
581 }
582 setOperationAction(Op, VT, Action: IsOpSupported ? Action : NoI16x2Action);
583 };
584
585 addRegisterClass(VT: MVT::i1, RC: &NVPTX::B1RegClass);
586 addRegisterClass(VT: MVT::i16, RC: &NVPTX::B16RegClass);
587 addRegisterClass(VT: MVT::v2i16, RC: &NVPTX::B32RegClass);
588 addRegisterClass(VT: MVT::v4i8, RC: &NVPTX::B32RegClass);
589 addRegisterClass(VT: MVT::i32, RC: &NVPTX::B32RegClass);
590 addRegisterClass(VT: MVT::i64, RC: &NVPTX::B64RegClass);
591 addRegisterClass(VT: MVT::f32, RC: &NVPTX::B32RegClass);
592 addRegisterClass(VT: MVT::f64, RC: &NVPTX::B64RegClass);
593 addRegisterClass(VT: MVT::f16, RC: &NVPTX::B16RegClass);
594 addRegisterClass(VT: MVT::v2f16, RC: &NVPTX::B32RegClass);
595 addRegisterClass(VT: MVT::bf16, RC: &NVPTX::B16RegClass);
596 addRegisterClass(VT: MVT::v2bf16, RC: &NVPTX::B32RegClass);
597
598 // Conversion to/from FP16/FP16x2 is always legal.
599 setOperationAction(Op: ISD::BUILD_VECTOR, VT: MVT::v2f16, Action: Custom);
600 setOperationAction(Op: ISD::EXTRACT_VECTOR_ELT, VT: MVT::v2f16, Action: Custom);
601 setOperationAction(Op: ISD::INSERT_VECTOR_ELT, VT: MVT::v2f16, Action: Expand);
602 setOperationAction(Op: ISD::VECTOR_SHUFFLE, VT: MVT::v2f16, Action: Expand);
603
604 setOperationAction(Op: ISD::READCYCLECOUNTER, VT: MVT::i64, Action: Legal);
605 if (STI.getSmVersion() >= 30 && STI.getPTXVersion() > 31)
606 setOperationAction(Op: ISD::READSTEADYCOUNTER, VT: MVT::i64, Action: Legal);
607
608 setFP16OperationAction(ISD::SETCC, MVT::f16, Legal, Promote);
609 setFP16OperationAction(ISD::SETCC, MVT::v2f16, Legal, Expand);
610
611 // Conversion to/from BFP16/BFP16x2 is always legal.
612 setOperationAction(Op: ISD::BUILD_VECTOR, VT: MVT::v2bf16, Action: Custom);
613 setOperationAction(Op: ISD::EXTRACT_VECTOR_ELT, VT: MVT::v2bf16, Action: Custom);
614 setOperationAction(Op: ISD::INSERT_VECTOR_ELT, VT: MVT::v2bf16, Action: Expand);
615 setOperationAction(Op: ISD::VECTOR_SHUFFLE, VT: MVT::v2bf16, Action: Expand);
616
617 setBF16OperationAction(ISD::SETCC, MVT::v2bf16, Legal, Expand);
618 setBF16OperationAction(ISD::SETCC, MVT::bf16, Legal, Promote);
619 if (getOperationAction(Op: ISD::SETCC, VT: MVT::bf16) == Promote)
620 AddPromotedToType(Opc: ISD::SETCC, OrigVT: MVT::bf16, DestVT: MVT::f32);
621
622 // Conversion to/from i16/i16x2 is always legal.
623 setOperationAction(Op: ISD::BUILD_VECTOR, VT: MVT::v2i16, Action: Custom);
624 setOperationAction(Op: ISD::EXTRACT_VECTOR_ELT, VT: MVT::v2i16, Action: Custom);
625 setOperationAction(Op: ISD::INSERT_VECTOR_ELT, VT: MVT::v2i16, Action: Expand);
626 setOperationAction(Op: ISD::VECTOR_SHUFFLE, VT: MVT::v2i16, Action: Expand);
627
628 setOperationAction(Op: ISD::BUILD_VECTOR, VT: MVT::v4i8, Action: Custom);
629 setOperationAction(Op: ISD::EXTRACT_VECTOR_ELT, VT: MVT::v4i8, Action: Custom);
630 setOperationAction(Op: ISD::INSERT_VECTOR_ELT, VT: MVT::v4i8, Action: Custom);
631 setOperationAction(Op: ISD::VECTOR_SHUFFLE, VT: MVT::v4i8, Action: Custom);
632
633 // Custom conversions to/from v2i8.
634 setOperationAction(Op: ISD::BITCAST, VT: MVT::v2i8, Action: Custom);
635
636 // Only logical ops can be done on v4i8 directly, others must be done
637 // elementwise.
638 setOperationAction(
639 Ops: {ISD::ABS, ISD::ADD, ISD::ADDC, ISD::ADDE,
640 ISD::BITREVERSE, ISD::CTLZ, ISD::CTPOP, ISD::CTTZ,
641 ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::FSHL, ISD::FSHR,
642 ISD::MUL, ISD::MULHS, ISD::MULHU, ISD::PARITY,
643 ISD::ROTL, ISD::ROTR, ISD::SADDO, ISD::SADDO_CARRY,
644 ISD::SADDSAT, ISD::SDIV, ISD::SDIVREM, ISD::SELECT_CC,
645 ISD::SETCC, ISD::SHL, ISD::SINT_TO_FP, ISD::SMAX,
646 ISD::SMIN, ISD::SMULO, ISD::SMUL_LOHI, ISD::SRA,
647 ISD::SREM, ISD::SRL, ISD::SSHLSAT, ISD::SSUBO,
648 ISD::SSUBO_CARRY, ISD::SSUBSAT, ISD::SUB, ISD::SUBC,
649 ISD::SUBE, ISD::UADDO, ISD::UADDO_CARRY, ISD::UADDSAT,
650 ISD::UDIV, ISD::UDIVREM, ISD::UINT_TO_FP, ISD::UMAX,
651 ISD::UMIN, ISD::UMULO, ISD::UMUL_LOHI, ISD::UREM,
652 ISD::USHLSAT, ISD::USUBO, ISD::USUBO_CARRY, ISD::VSELECT,
653 ISD::USUBSAT},
654 VT: MVT::v4i8, Action: Expand);
655
656 // Operations not directly supported by NVPTX.
657 for (MVT VT : {MVT::bf16, MVT::f16, MVT::v2bf16, MVT::v2f16, MVT::f32,
658 MVT::f64, MVT::i1, MVT::i8, MVT::i16, MVT::v2i16, MVT::v4i8,
659 MVT::i32, MVT::i64}) {
660 setOperationAction(Op: ISD::SELECT_CC, VT, Action: Expand);
661 setOperationAction(Op: ISD::BR_CC, VT, Action: Expand);
662 }
663
664 // Some SIGN_EXTEND_INREG can be done using cvt instruction.
665 // For others we will expand to a SHL/SRA pair.
666 setOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: MVT::i64, Action: Legal);
667 setOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: MVT::i32, Action: Legal);
668 setOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: MVT::i16, Action: Legal);
669 setOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: MVT::i8 , Action: Legal);
670 setOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: MVT::i1, Action: Expand);
671 setOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: MVT::v2i16, Action: Expand);
672
673 setOperationAction(Op: ISD::SHL_PARTS, VT: MVT::i32 , Action: Custom);
674 setOperationAction(Op: ISD::SRA_PARTS, VT: MVT::i32 , Action: Custom);
675 setOperationAction(Op: ISD::SRL_PARTS, VT: MVT::i32 , Action: Custom);
676 setOperationAction(Op: ISD::SHL_PARTS, VT: MVT::i64 , Action: Custom);
677 setOperationAction(Op: ISD::SRA_PARTS, VT: MVT::i64 , Action: Custom);
678 setOperationAction(Op: ISD::SRL_PARTS, VT: MVT::i64 , Action: Custom);
679
680 setOperationAction(Op: ISD::BITREVERSE, VT: MVT::i32, Action: Legal);
681 setOperationAction(Op: ISD::BITREVERSE, VT: MVT::i64, Action: Legal);
682
683 setOperationAction(Ops: {ISD::ROTL, ISD::ROTR},
684 VTs: {MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64},
685 Action: Expand);
686
687 if (STI.hasHWROT32()) {
688 setOperationAction(Ops: {ISD::FSHL, ISD::FSHR}, VT: MVT::i32, Action: Legal);
689 setOperationAction(Ops: {ISD::ROTL, ISD::ROTR, ISD::FSHL, ISD::FSHR}, VT: MVT::i64,
690 Action: Custom);
691 }
692
693 setOperationAction(Op: ISD::BSWAP, VT: MVT::i16, Action: Expand);
694
695 setOperationAction(Op: ISD::BR_JT, VT: MVT::Other, Action: Custom);
696 setOperationAction(Op: ISD::BRIND, VT: MVT::Other, Action: Expand);
697
698 // We want to legalize constant related memmove and memcopy
699 // intrinsics.
700 setOperationAction(Op: ISD::INTRINSIC_W_CHAIN, VT: MVT::Other, Action: Custom);
701
702 // Turn FP extload into load/fpextend
703 setLoadExtAction(ExtType: ISD::EXTLOAD, ValVT: MVT::f32, MemVT: MVT::f16, Action: Expand);
704 setLoadExtAction(ExtType: ISD::EXTLOAD, ValVT: MVT::f64, MemVT: MVT::f16, Action: Expand);
705 setLoadExtAction(ExtType: ISD::EXTLOAD, ValVT: MVT::f32, MemVT: MVT::bf16, Action: Expand);
706 setLoadExtAction(ExtType: ISD::EXTLOAD, ValVT: MVT::f64, MemVT: MVT::bf16, Action: Expand);
707 setLoadExtAction(ExtType: ISD::EXTLOAD, ValVT: MVT::f64, MemVT: MVT::f32, Action: Expand);
708 setLoadExtAction(ExtType: ISD::EXTLOAD, ValVT: MVT::v2f32, MemVT: MVT::v2f16, Action: Expand);
709 setLoadExtAction(ExtType: ISD::EXTLOAD, ValVT: MVT::v2f64, MemVT: MVT::v2f16, Action: Expand);
710 setLoadExtAction(ExtType: ISD::EXTLOAD, ValVT: MVT::v2f32, MemVT: MVT::v2bf16, Action: Expand);
711 setLoadExtAction(ExtType: ISD::EXTLOAD, ValVT: MVT::v2f64, MemVT: MVT::v2bf16, Action: Expand);
712 setLoadExtAction(ExtType: ISD::EXTLOAD, ValVT: MVT::v2f64, MemVT: MVT::v2f32, Action: Expand);
713 setLoadExtAction(ExtType: ISD::EXTLOAD, ValVT: MVT::v4f32, MemVT: MVT::v4f16, Action: Expand);
714 setLoadExtAction(ExtType: ISD::EXTLOAD, ValVT: MVT::v4f64, MemVT: MVT::v4f16, Action: Expand);
715 setLoadExtAction(ExtType: ISD::EXTLOAD, ValVT: MVT::v4f32, MemVT: MVT::v4bf16, Action: Expand);
716 setLoadExtAction(ExtType: ISD::EXTLOAD, ValVT: MVT::v4f64, MemVT: MVT::v4bf16, Action: Expand);
717 setLoadExtAction(ExtType: ISD::EXTLOAD, ValVT: MVT::v4f64, MemVT: MVT::v4f32, Action: Expand);
718 setLoadExtAction(ExtType: ISD::EXTLOAD, ValVT: MVT::v8f32, MemVT: MVT::v8f16, Action: Expand);
719 setLoadExtAction(ExtType: ISD::EXTLOAD, ValVT: MVT::v8f64, MemVT: MVT::v8f16, Action: Expand);
720 setLoadExtAction(ExtType: ISD::EXTLOAD, ValVT: MVT::v8f32, MemVT: MVT::v8bf16, Action: Expand);
721 setLoadExtAction(ExtType: ISD::EXTLOAD, ValVT: MVT::v8f64, MemVT: MVT::v8bf16, Action: Expand);
722 // Turn FP truncstore into trunc + store.
723 // FIXME: vector types should also be expanded
724 setTruncStoreAction(ValVT: MVT::f32, MemVT: MVT::f16, Action: Expand);
725 setTruncStoreAction(ValVT: MVT::f64, MemVT: MVT::f16, Action: Expand);
726 setTruncStoreAction(ValVT: MVT::f32, MemVT: MVT::bf16, Action: Expand);
727 setTruncStoreAction(ValVT: MVT::f64, MemVT: MVT::bf16, Action: Expand);
728 setTruncStoreAction(ValVT: MVT::f64, MemVT: MVT::f32, Action: Expand);
729
730 // PTX does not support load / store predicate registers
731 setOperationAction(Op: ISD::LOAD, VT: MVT::i1, Action: Custom);
732 setOperationAction(Op: ISD::STORE, VT: MVT::i1, Action: Custom);
733
734 for (MVT VT : MVT::integer_valuetypes()) {
735 setLoadExtAction(ExtType: ISD::SEXTLOAD, ValVT: VT, MemVT: MVT::i1, Action: Promote);
736 setLoadExtAction(ExtType: ISD::ZEXTLOAD, ValVT: VT, MemVT: MVT::i1, Action: Promote);
737 setLoadExtAction(ExtType: ISD::EXTLOAD, ValVT: VT, MemVT: MVT::i1, Action: Promote);
738 setTruncStoreAction(ValVT: VT, MemVT: MVT::i1, Action: Expand);
739 }
740
741 setCondCodeAction(CCs: {ISD::SETNE, ISD::SETEQ, ISD::SETUGE, ISD::SETULE,
742 ISD::SETUGT, ISD::SETULT, ISD::SETGT, ISD::SETLT,
743 ISD::SETGE, ISD::SETLE},
744 VT: MVT::i1, Action: Expand);
745
746 // expand extload of vector of integers.
747 setLoadExtAction(ExtTypes: {ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, ValVT: MVT::v2i16,
748 MemVT: MVT::v2i8, Action: Expand);
749 setTruncStoreAction(ValVT: MVT::v2i16, MemVT: MVT::v2i8, Action: Expand);
750
751 // This is legal in NVPTX
752 setOperationAction(Op: ISD::ConstantFP, VT: MVT::f64, Action: Legal);
753 setOperationAction(Op: ISD::ConstantFP, VT: MVT::f32, Action: Legal);
754 setOperationAction(Op: ISD::ConstantFP, VT: MVT::f16, Action: Legal);
755 setOperationAction(Op: ISD::ConstantFP, VT: MVT::bf16, Action: Legal);
756
757 setOperationAction(Ops: ISD::DYNAMIC_STACKALLOC, VTs: {MVT::i32, MVT::i64}, Action: Custom);
758 setOperationAction(Ops: {ISD::STACKRESTORE, ISD::STACKSAVE}, VT: MVT::Other, Action: Custom);
759
760 // TRAP can be lowered to PTX trap
761 setOperationAction(Op: ISD::TRAP, VT: MVT::Other, Action: Legal);
762 // DEBUGTRAP can be lowered to PTX brkpt
763 setOperationAction(Op: ISD::DEBUGTRAP, VT: MVT::Other, Action: Legal);
764
765 // Register custom handling for vector loads/stores
766 for (MVT VT : MVT::fixedlen_vector_valuetypes())
767 if (IsPTXVectorType(VT))
768 setOperationAction(Ops: {ISD::LOAD, ISD::STORE, ISD::INTRINSIC_W_CHAIN}, VT,
769 Action: Custom);
770
771 setOperationAction(Ops: {ISD::LOAD, ISD::STORE, ISD::INTRINSIC_W_CHAIN},
772 VTs: {MVT::i128, MVT::f128}, Action: Custom);
773
774 // Support varargs.
775 setOperationAction(Op: ISD::VASTART, VT: MVT::Other, Action: Custom);
776 setOperationAction(Op: ISD::VAARG, VT: MVT::Other, Action: Custom);
777 setOperationAction(Op: ISD::VACOPY, VT: MVT::Other, Action: Expand);
778 setOperationAction(Op: ISD::VAEND, VT: MVT::Other, Action: Expand);
779
780 // Custom handling for i8 intrinsics
781 setOperationAction(Op: ISD::INTRINSIC_W_CHAIN, VT: MVT::i8, Action: Custom);
782
783 setOperationAction(Ops: {ISD::ABS, ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX},
784 VTs: {MVT::i16, MVT::i32, MVT::i64}, Action: Legal);
785
786 setOperationAction(Ops: {ISD::CTPOP, ISD::CTLZ, ISD::CTLZ_ZERO_UNDEF}, VT: MVT::i16,
787 Action: Promote);
788 setOperationAction(Ops: {ISD::CTPOP, ISD::CTLZ}, VT: MVT::i32, Action: Legal);
789 setOperationAction(Ops: {ISD::CTPOP, ISD::CTLZ}, VT: MVT::i64, Action: Custom);
790
791 setI16x2OperationAction(ISD::ABS, MVT::v2i16, Legal, Custom);
792 setI16x2OperationAction(ISD::SMIN, MVT::v2i16, Legal, Custom);
793 setI16x2OperationAction(ISD::SMAX, MVT::v2i16, Legal, Custom);
794 setI16x2OperationAction(ISD::UMIN, MVT::v2i16, Legal, Custom);
795 setI16x2OperationAction(ISD::UMAX, MVT::v2i16, Legal, Custom);
796 setI16x2OperationAction(ISD::CTPOP, MVT::v2i16, Legal, Expand);
797 setI16x2OperationAction(ISD::CTLZ, MVT::v2i16, Legal, Expand);
798
799 setI16x2OperationAction(ISD::ADD, MVT::v2i16, Legal, Custom);
800 setI16x2OperationAction(ISD::SUB, MVT::v2i16, Legal, Custom);
801 setI16x2OperationAction(ISD::MUL, MVT::v2i16, Legal, Custom);
802 setI16x2OperationAction(ISD::SHL, MVT::v2i16, Legal, Custom);
803 setI16x2OperationAction(ISD::SREM, MVT::v2i16, Legal, Custom);
804 setI16x2OperationAction(ISD::UREM, MVT::v2i16, Legal, Custom);
805
806 // Other arithmetic and logic ops are unsupported.
807 setOperationAction(Ops: {ISD::SDIV, ISD::UDIV, ISD::SRA, ISD::SRL, ISD::MULHS,
808 ISD::MULHU, ISD::FP_TO_SINT, ISD::FP_TO_UINT,
809 ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::SETCC},
810 VT: MVT::v2i16, Action: Expand);
811
812 setOperationAction(Op: ISD::ADDC, VT: MVT::i32, Action: Legal);
813 setOperationAction(Op: ISD::ADDE, VT: MVT::i32, Action: Legal);
814 setOperationAction(Op: ISD::SUBC, VT: MVT::i32, Action: Legal);
815 setOperationAction(Op: ISD::SUBE, VT: MVT::i32, Action: Legal);
816 if (STI.getPTXVersion() >= 43) {
817 setOperationAction(Op: ISD::ADDC, VT: MVT::i64, Action: Legal);
818 setOperationAction(Op: ISD::ADDE, VT: MVT::i64, Action: Legal);
819 setOperationAction(Op: ISD::SUBC, VT: MVT::i64, Action: Legal);
820 setOperationAction(Op: ISD::SUBE, VT: MVT::i64, Action: Legal);
821 }
822
823 setOperationAction(Op: ISD::CTTZ, VT: MVT::i16, Action: Expand);
824 setOperationAction(Op: ISD::CTTZ, VT: MVT::v2i16, Action: Expand);
825 setOperationAction(Op: ISD::CTTZ, VT: MVT::i32, Action: Expand);
826 setOperationAction(Op: ISD::CTTZ, VT: MVT::i64, Action: Expand);
827
828 // PTX does not directly support SELP of i1, so promote to i32 first
829 setOperationAction(Op: ISD::SELECT, VT: MVT::i1, Action: Custom);
830
831 // PTX cannot multiply two i64s in a single instruction.
832 setOperationAction(Op: ISD::SMUL_LOHI, VT: MVT::i64, Action: Expand);
833 setOperationAction(Op: ISD::UMUL_LOHI, VT: MVT::i64, Action: Expand);
834
835 // We have some custom DAG combine patterns for these nodes
836 setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
837 ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
838 ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD,
839 ISD::STORE});
840
841 // setcc for f16x2 and bf16x2 needs special handling to prevent
842 // legalizer's attempt to scalarize it due to v2i1 not being legal.
843 if (STI.allowFP16Math() || STI.hasBF16Math())
844 setTargetDAGCombine(ISD::SETCC);
845
846 // Promote fp16 arithmetic if fp16 hardware isn't available or the
847 // user passed --nvptx-no-fp16-math. The flag is useful because,
848 // although sm_53+ GPUs have some sort of FP16 support in
849 // hardware, only sm_53 and sm_60 have full implementation. Others
850 // only have token amount of hardware and are likely to run faster
851 // by using fp32 units instead.
852 for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) {
853 setFP16OperationAction(Op, MVT::f16, Legal, Promote);
854 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
855 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
856 // bf16 must be promoted to f32.
857 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
858 if (getOperationAction(Op, VT: MVT::bf16) == Promote)
859 AddPromotedToType(Opc: Op, OrigVT: MVT::bf16, DestVT: MVT::f32);
860 }
861
862 // On SM80, we select add/mul/sub as fma to avoid promotion to float
863 for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB}) {
864 for (const auto &VT : {MVT::bf16, MVT::v2bf16}) {
865 if (!STI.hasNativeBF16Support(Opcode: Op) && STI.hasNativeBF16Support(Opcode: ISD::FMA)) {
866 setOperationAction(Op, VT, Action: Custom);
867 }
868 }
869 }
870
871 // f16/f16x2 neg was introduced in PTX 60, SM_53.
872 const bool IsFP16FP16x2NegAvailable = STI.getSmVersion() >= 53 &&
873 STI.getPTXVersion() >= 60 &&
874 STI.allowFP16Math();
875 for (const auto &VT : {MVT::f16, MVT::v2f16})
876 setOperationAction(Op: ISD::FNEG, VT,
877 Action: IsFP16FP16x2NegAvailable ? Legal : Expand);
878
879 setBF16OperationAction(ISD::FNEG, MVT::bf16, Legal, Expand);
880 setBF16OperationAction(ISD::FNEG, MVT::v2bf16, Legal, Expand);
881 // (would be) Library functions.
882
883 // These map to conversion instructions for scalar FP types.
884 for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
885 ISD::FROUNDEVEN, ISD::FTRUNC}) {
886 setOperationAction(Op, VT: MVT::f16, Action: Legal);
887 setOperationAction(Op, VT: MVT::f32, Action: Legal);
888 setOperationAction(Op, VT: MVT::f64, Action: Legal);
889 setOperationAction(Op, VT: MVT::v2f16, Action: Expand);
890 setOperationAction(Op, VT: MVT::v2bf16, Action: Expand);
891 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
892 if (getOperationAction(Op, VT: MVT::bf16) == Promote)
893 AddPromotedToType(Opc: Op, OrigVT: MVT::bf16, DestVT: MVT::f32);
894 }
895
896 if (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71) {
897 setOperationAction(Op: ISD::BF16_TO_FP, VT: MVT::f32, Action: Expand);
898 }
899 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
900 for (MVT VT : {MVT::bf16, MVT::f32, MVT::f64}) {
901 setOperationAction(Op: ISD::FP_EXTEND, VT, Action: Custom);
902 setOperationAction(Op: ISD::FP_ROUND, VT, Action: Custom);
903 }
904 }
905
906 // sm_80 only has conversions between f32 and bf16. Custom lower all other
907 // bf16 conversions.
908 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
909 for (MVT VT : {MVT::i1, MVT::i16, MVT::i32, MVT::i64}) {
910 setOperationAction(
911 Ops: {ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT},
912 VT, Action: Custom);
913 }
914 setOperationAction(
915 Ops: {ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT},
916 VT: MVT::bf16, Action: Custom);
917 }
918
919 setOperationAction(Op: ISD::FROUND, VT: MVT::f16, Action: Promote);
920 setOperationAction(Op: ISD::FROUND, VT: MVT::v2f16, Action: Expand);
921 setOperationAction(Op: ISD::FROUND, VT: MVT::v2bf16, Action: Expand);
922 setOperationAction(Op: ISD::FROUND, VT: MVT::f32, Action: Custom);
923 setOperationAction(Op: ISD::FROUND, VT: MVT::f64, Action: Custom);
924 setOperationAction(Op: ISD::FROUND, VT: MVT::bf16, Action: Promote);
925 AddPromotedToType(Opc: ISD::FROUND, OrigVT: MVT::bf16, DestVT: MVT::f32);
926
927 // 'Expand' implements FCOPYSIGN without calling an external library.
928 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::f16, Action: Expand);
929 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::v2f16, Action: Expand);
930 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::bf16, Action: Expand);
931 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::v2bf16, Action: Expand);
932 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::f32, Action: Custom);
933 setOperationAction(Op: ISD::FCOPYSIGN, VT: MVT::f64, Action: Custom);
934
935 // These map to corresponding instructions for f32/f64. f16 must be
936 // promoted to f32. v2f16 is expanded to f16, which is then promoted
937 // to f32.
938 for (const auto &Op :
939 {ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS}) {
940 setOperationAction(Op, VT: MVT::f16, Action: Promote);
941 setOperationAction(Op, VT: MVT::f32, Action: Legal);
942 setOperationAction(Op, VT: MVT::f64, Action: Legal);
943 setOperationAction(Op, VT: MVT::v2f16, Action: Expand);
944 setOperationAction(Op, VT: MVT::v2bf16, Action: Expand);
945 setOperationAction(Op, VT: MVT::bf16, Action: Promote);
946 AddPromotedToType(Opc: Op, OrigVT: MVT::bf16, DestVT: MVT::f32);
947 }
948 setOperationAction(Ops: ISD::FREM, VTs: {MVT::f32, MVT::f64}, Action: Custom);
949
950 setOperationAction(Ops: ISD::FABS, VTs: {MVT::f32, MVT::f64}, Action: Legal);
951 if (STI.getPTXVersion() >= 65) {
952 setFP16OperationAction(ISD::FABS, MVT::f16, Legal, Promote);
953 setFP16OperationAction(ISD::FABS, MVT::v2f16, Legal, Expand);
954 } else {
955 setOperationAction(Op: ISD::FABS, VT: MVT::f16, Action: Promote);
956 setOperationAction(Op: ISD::FABS, VT: MVT::v2f16, Action: Expand);
957 }
958 setBF16OperationAction(ISD::FABS, MVT::v2bf16, Legal, Expand);
959 setBF16OperationAction(ISD::FABS, MVT::bf16, Legal, Promote);
960 if (getOperationAction(Op: ISD::FABS, VT: MVT::bf16) == Promote)
961 AddPromotedToType(Opc: ISD::FABS, OrigVT: MVT::bf16, DestVT: MVT::f32);
962
963 for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) {
964 setOperationAction(Op, VT: MVT::f32, Action: Legal);
965 setOperationAction(Op, VT: MVT::f64, Action: Legal);
966 setFP16OperationAction(Op, MVT::f16, Legal, Promote);
967 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
968 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
969 setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
970 if (getOperationAction(Op, VT: MVT::bf16) == Promote)
971 AddPromotedToType(Opc: Op, OrigVT: MVT::bf16, DestVT: MVT::f32);
972 }
973 bool SupportsF32MinMaxNaN =
974 STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
975 for (const auto &Op : {ISD::FMINIMUM, ISD::FMAXIMUM}) {
976 setOperationAction(Op, VT: MVT::f32, Action: SupportsF32MinMaxNaN ? Legal : Expand);
977 setFP16OperationAction(Op, MVT::f16, Legal, Expand);
978 setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
979 setBF16OperationAction(Op, MVT::bf16, Legal, Expand);
980 setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
981 }
982
983 // Custom lowering for inline asm with 128-bit operands
984 setOperationAction(Op: ISD::CopyToReg, VT: MVT::i128, Action: Custom);
985 setOperationAction(Op: ISD::CopyFromReg, VT: MVT::i128, Action: Custom);
986
987 // FEXP2 support:
988 // - f32
989 // - f16/f16x2 (sm_70+, PTX 7.0+)
990 // - bf16/bf16x2 (sm_90+, PTX 7.8+)
991 // When f16/bf16 types aren't supported, they are promoted/expanded to f32.
992 setOperationAction(Op: ISD::FEXP2, VT: MVT::f32, Action: Legal);
993 setFP16OperationAction(ISD::FEXP2, MVT::f16, Legal, Promote);
994 setFP16OperationAction(ISD::FEXP2, MVT::v2f16, Legal, Expand);
995 setBF16OperationAction(ISD::FEXP2, MVT::bf16, Legal, Promote);
996 setBF16OperationAction(ISD::FEXP2, MVT::v2bf16, Legal, Expand);
997
998 // FLOG2 supports f32 only
999 // f16/bf16 types aren't supported, but they are promoted/expanded to f32.
1000 if (UseApproxLog2F32) {
1001 setOperationAction(Op: ISD::FLOG2, VT: MVT::f32, Action: Legal);
1002 setOperationPromotedToType(Opc: ISD::FLOG2, OrigVT: MVT::f16, DestVT: MVT::f32);
1003 setOperationPromotedToType(Opc: ISD::FLOG2, OrigVT: MVT::bf16, DestVT: MVT::f32);
1004 setOperationAction(Ops: ISD::FLOG2, VTs: {MVT::v2f16, MVT::v2bf16}, Action: Expand);
1005 }
1006
1007 setOperationAction(Ops: ISD::ADDRSPACECAST, VTs: {MVT::i32, MVT::i64}, Action: Custom);
1008
1009 setOperationAction(Ops: ISD::ATOMIC_LOAD_SUB, VTs: {MVT::i32, MVT::i64}, Action: Expand);
1010 // No FPOW or FREM in PTX.
1011
1012 // Now deduce the information based on the above mentioned
1013 // actions
1014 computeRegisterProperties(TRI: STI.getRegisterInfo());
1015
1016 // PTX support for 16-bit CAS is emulated. Only use 32+
1017 setMinCmpXchgSizeInBits(STI.getMinCmpXchgSizeInBits());
1018 setMaxAtomicSizeInBitsSupported(64);
1019 setMaxDivRemBitWidthSupported(64);
1020
1021 // Custom lowering for tcgen05.ld vector operands
1022 setOperationAction(Ops: ISD::INTRINSIC_W_CHAIN,
1023 VTs: {MVT::v2i32, MVT::v4i32, MVT::v8i32, MVT::v16i32,
1024 MVT::v32i32, MVT::v64i32, MVT::v128i32},
1025 Action: Custom);
1026
1027 // Custom lowering for tcgen05.st vector operands
1028 setOperationAction(Ops: ISD::INTRINSIC_VOID,
1029 VTs: {MVT::v2i32, MVT::v4i32, MVT::v8i32, MVT::v16i32,
1030 MVT::v32i32, MVT::v64i32, MVT::v128i32},
1031 Action: Custom);
1032
1033 setOperationAction(Op: ISD::INTRINSIC_WO_CHAIN, VT: MVT::Other, Action: Custom);
1034 // Enable custom lowering for the i128 bit operand with clusterlaunchcontrol
1035 setOperationAction(Op: ISD::INTRINSIC_WO_CHAIN, VT: MVT::i128, Action: Custom);
1036}
1037
1038const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
1039
1040#define MAKE_CASE(V) \
1041 case V: \
1042 return #V;
1043
1044 switch ((NVPTXISD::NodeType)Opcode) {
1045 case NVPTXISD::FIRST_NUMBER:
1046 break;
1047
1048 MAKE_CASE(NVPTXISD::RET_GLUE)
1049 MAKE_CASE(NVPTXISD::DeclareArrayParam)
1050 MAKE_CASE(NVPTXISD::DeclareScalarParam)
1051 MAKE_CASE(NVPTXISD::CALL)
1052 MAKE_CASE(NVPTXISD::LoadParam)
1053 MAKE_CASE(NVPTXISD::LoadParamV2)
1054 MAKE_CASE(NVPTXISD::LoadParamV4)
1055 MAKE_CASE(NVPTXISD::StoreParam)
1056 MAKE_CASE(NVPTXISD::StoreParamV2)
1057 MAKE_CASE(NVPTXISD::StoreParamV4)
1058 MAKE_CASE(NVPTXISD::MoveParam)
1059 MAKE_CASE(NVPTXISD::UNPACK_VECTOR)
1060 MAKE_CASE(NVPTXISD::BUILD_VECTOR)
1061 MAKE_CASE(NVPTXISD::CallPrototype)
1062 MAKE_CASE(NVPTXISD::ProxyReg)
1063 MAKE_CASE(NVPTXISD::LoadV2)
1064 MAKE_CASE(NVPTXISD::LoadV4)
1065 MAKE_CASE(NVPTXISD::LoadV8)
1066 MAKE_CASE(NVPTXISD::LDUV2)
1067 MAKE_CASE(NVPTXISD::LDUV4)
1068 MAKE_CASE(NVPTXISD::StoreV2)
1069 MAKE_CASE(NVPTXISD::StoreV4)
1070 MAKE_CASE(NVPTXISD::StoreV8)
1071 MAKE_CASE(NVPTXISD::FSHL_CLAMP)
1072 MAKE_CASE(NVPTXISD::FSHR_CLAMP)
1073 MAKE_CASE(NVPTXISD::BFE)
1074 MAKE_CASE(NVPTXISD::BFI)
1075 MAKE_CASE(NVPTXISD::PRMT)
1076 MAKE_CASE(NVPTXISD::FCOPYSIGN)
1077 MAKE_CASE(NVPTXISD::DYNAMIC_STACKALLOC)
1078 MAKE_CASE(NVPTXISD::STACKRESTORE)
1079 MAKE_CASE(NVPTXISD::STACKSAVE)
1080 MAKE_CASE(NVPTXISD::SETP_F16X2)
1081 MAKE_CASE(NVPTXISD::SETP_BF16X2)
1082 MAKE_CASE(NVPTXISD::MUL_WIDE_SIGNED)
1083 MAKE_CASE(NVPTXISD::MUL_WIDE_UNSIGNED)
1084 MAKE_CASE(NVPTXISD::BrxEnd)
1085 MAKE_CASE(NVPTXISD::BrxItem)
1086 MAKE_CASE(NVPTXISD::BrxStart)
1087 MAKE_CASE(NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_IS_CANCELED)
1088 MAKE_CASE(NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_X)
1089 MAKE_CASE(NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Y)
1090 MAKE_CASE(NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Z)
1091 }
1092 return nullptr;
1093
1094#undef MAKE_CASE
1095}
1096
1097TargetLoweringBase::LegalizeTypeAction
1098NVPTXTargetLowering::getPreferredVectorAction(MVT VT) const {
1099 if (!VT.isScalableVector() && VT.getVectorNumElements() != 1 &&
1100 VT.getScalarType() == MVT::i1)
1101 return TypeSplitVector;
1102 return TargetLoweringBase::getPreferredVectorAction(VT);
1103}
1104
1105SDValue NVPTXTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG,
1106 int Enabled, int &ExtraSteps,
1107 bool &UseOneConst,
1108 bool Reciprocal) const {
1109 if (!(Enabled == ReciprocalEstimate::Enabled ||
1110 (Enabled == ReciprocalEstimate::Unspecified &&
1111 !usePrecSqrtF32(MF: DAG.getMachineFunction()))))
1112 return SDValue();
1113
1114 if (ExtraSteps == ReciprocalEstimate::Unspecified)
1115 ExtraSteps = 0;
1116
1117 SDLoc DL(Operand);
1118 EVT VT = Operand.getValueType();
1119 bool Ftz = useF32FTZ(MF: DAG.getMachineFunction());
1120
1121 auto MakeIntrinsicCall = [&](Intrinsic::ID IID) {
1122 return DAG.getNode(Opcode: ISD::INTRINSIC_WO_CHAIN, DL, VT,
1123 N1: DAG.getConstant(Val: IID, DL, VT: MVT::i32), N2: Operand);
1124 };
1125
1126 // The sqrt and rsqrt refinement processes assume we always start out with an
1127 // approximation of the rsqrt. Therefore, if we're going to do any refinement
1128 // (i.e. ExtraSteps > 0), we must return an rsqrt. But if we're *not* doing
1129 // any refinement, we must return a regular sqrt.
1130 if (Reciprocal || ExtraSteps > 0) {
1131 if (VT == MVT::f32)
1132 return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_rsqrt_approx_ftz_f
1133 : Intrinsic::nvvm_rsqrt_approx_f);
1134 else if (VT == MVT::f64)
1135 return MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d);
1136 else
1137 return SDValue();
1138 } else {
1139 if (VT == MVT::f32)
1140 return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_sqrt_approx_ftz_f
1141 : Intrinsic::nvvm_sqrt_approx_f);
1142 else {
1143 // There's no sqrt.approx.f64 instruction, so we emit
1144 // reciprocal(rsqrt(x)). This is faster than
1145 // select(x == 0, 0, x * rsqrt(x)). (In fact, it's faster than plain
1146 // x * rsqrt(x).)
1147 return DAG.getNode(
1148 Opcode: ISD::INTRINSIC_WO_CHAIN, DL, VT,
1149 N1: DAG.getConstant(Val: Intrinsic::nvvm_rcp_approx_ftz_d, DL, VT: MVT::i32),
1150 N2: MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d));
1151 }
1152 }
1153}
1154
1155std::string NVPTXTargetLowering::getPrototype(
1156 const DataLayout &DL, Type *RetTy, const ArgListTy &Args,
1157 const SmallVectorImpl<ISD::OutputArg> &Outs,
1158 std::optional<unsigned> FirstVAArg, const CallBase &CB,
1159 unsigned UniqueCallSite) const {
1160 auto PtrVT = getPointerTy(DL);
1161
1162 std::string Prototype;
1163 raw_string_ostream O(Prototype);
1164 O << "prototype_" << UniqueCallSite << " : .callprototype ";
1165
1166 if (RetTy->isVoidTy()) {
1167 O << "()";
1168 } else {
1169 O << "(";
1170 if (shouldPassAsArray(Ty: RetTy)) {
1171 const Align RetAlign = getArgumentAlignment(CB: &CB, Ty: RetTy, Idx: 0, DL);
1172 O << ".param .align " << RetAlign.value() << " .b8 _["
1173 << DL.getTypeAllocSize(Ty: RetTy) << "]";
1174 } else if (RetTy->isFloatingPointTy() || RetTy->isIntegerTy()) {
1175 unsigned size = 0;
1176 if (auto *ITy = dyn_cast<IntegerType>(Val: RetTy)) {
1177 size = ITy->getBitWidth();
1178 } else {
1179 assert(RetTy->isFloatingPointTy() &&
1180 "Floating point type expected here");
1181 size = RetTy->getPrimitiveSizeInBits();
1182 }
1183 // PTX ABI requires all scalar return values to be at least 32
1184 // bits in size. fp16 normally uses .b16 as its storage type in
1185 // PTX, so its size must be adjusted here, too.
1186 size = promoteScalarArgumentSize(size);
1187
1188 O << ".param .b" << size << " _";
1189 } else if (isa<PointerType>(Val: RetTy)) {
1190 O << ".param .b" << PtrVT.getSizeInBits() << " _";
1191 } else {
1192 llvm_unreachable("Unknown return type");
1193 }
1194 O << ") ";
1195 }
1196 O << "_ (";
1197
1198 bool first = true;
1199
1200 const unsigned NumArgs = FirstVAArg.value_or(u: Args.size());
1201 auto AllOuts = ArrayRef(Outs);
1202 for (const unsigned I : llvm::seq(Size: NumArgs)) {
1203 const auto ArgOuts =
1204 AllOuts.take_while(Pred: [I](auto O) { return O.OrigArgIndex == I; });
1205 AllOuts = AllOuts.drop_front(N: ArgOuts.size());
1206
1207 Type *Ty = Args[I].Ty;
1208 if (!first) {
1209 O << ", ";
1210 }
1211 first = false;
1212
1213 if (ArgOuts[0].Flags.isByVal()) {
1214 // Indirect calls need strict ABI alignment so we disable optimizations by
1215 // not providing a function to optimize.
1216 Type *ETy = Args[I].IndirectType;
1217 Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
1218 Align ParamByValAlign =
1219 getFunctionByValParamAlign(/*F=*/nullptr, ArgTy: ETy, InitialAlign, DL);
1220
1221 O << ".param .align " << ParamByValAlign.value() << " .b8 _["
1222 << ArgOuts[0].Flags.getByValSize() << "]";
1223 } else {
1224 if (shouldPassAsArray(Ty)) {
1225 Align ParamAlign =
1226 getArgumentAlignment(CB: &CB, Ty, Idx: I + AttributeList::FirstArgIndex, DL);
1227 O << ".param .align " << ParamAlign.value() << " .b8 _["
1228 << DL.getTypeAllocSize(Ty) << "]";
1229 continue;
1230 }
1231 // i8 types in IR will be i16 types in SDAG
1232 assert((getValueType(DL, Ty) == ArgOuts[0].VT ||
1233 (getValueType(DL, Ty) == MVT::i8 && ArgOuts[0].VT == MVT::i16)) &&
1234 "type mismatch between callee prototype and arguments");
1235 // scalar type
1236 unsigned sz = 0;
1237 if (auto *ITy = dyn_cast<IntegerType>(Val: Ty)) {
1238 sz = promoteScalarArgumentSize(size: ITy->getBitWidth());
1239 } else if (isa<PointerType>(Val: Ty)) {
1240 sz = PtrVT.getSizeInBits();
1241 } else {
1242 sz = Ty->getPrimitiveSizeInBits();
1243 }
1244 O << ".param .b" << sz << " _";
1245 }
1246 }
1247
1248 if (FirstVAArg)
1249 O << (first ? "" : ",") << " .param .align "
1250 << STI.getMaxRequiredAlignment() << " .b8 _[]";
1251 O << ")";
1252 if (shouldEmitPTXNoReturn(V: &CB, TM: *nvTM))
1253 O << " .noreturn";
1254 O << ";";
1255
1256 return Prototype;
1257}
1258
1259Align NVPTXTargetLowering::getFunctionArgumentAlignment(
1260 const Function *F, Type *Ty, unsigned Idx, const DataLayout &DL) const {
1261 return getAlign(F: *F, Index: Idx).value_or(u: getFunctionParamOptimizedAlign(F, ArgTy: Ty, DL));
1262}
1263
1264Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty,
1265 unsigned Idx,
1266 const DataLayout &DL) const {
1267 if (!CB) {
1268 // CallSite is zero, fallback to ABI type alignment
1269 return DL.getABITypeAlign(Ty);
1270 }
1271
1272 const Function *DirectCallee = CB->getCalledFunction();
1273
1274 if (!DirectCallee) {
1275 // We don't have a direct function symbol, but that may be because of
1276 // constant cast instructions in the call.
1277
1278 // With bitcast'd call targets, the instruction will be the call
1279 if (const auto *CI = dyn_cast<CallInst>(Val: CB)) {
1280 // Check if we have call alignment metadata
1281 if (MaybeAlign StackAlign = getAlign(*CI, Idx))
1282 return StackAlign.value();
1283 }
1284 DirectCallee = getMaybeBitcastedCallee(CB);
1285 }
1286
1287 // Check for function alignment information if we found that the
1288 // ultimate target is a Function
1289 if (DirectCallee)
1290 return getFunctionArgumentAlignment(F: DirectCallee, Ty, Idx, DL);
1291
1292 // Call is indirect, fall back to the ABI type alignment
1293 return DL.getABITypeAlign(Ty);
1294}
1295
1296static bool adjustElementType(EVT &ElementType) {
1297 switch (ElementType.getSimpleVT().SimpleTy) {
1298 default:
1299 return false;
1300 case MVT::f16:
1301 case MVT::bf16:
1302 ElementType = MVT::i16;
1303 return true;
1304 case MVT::f32:
1305 case MVT::v2f16:
1306 case MVT::v2bf16:
1307 ElementType = MVT::i32;
1308 return true;
1309 case MVT::f64:
1310 ElementType = MVT::i64;
1311 return true;
1312 }
1313}
1314
1315// Use byte-store when the param address of the argument value is unaligned.
1316// This may happen when the return value is a field of a packed structure.
1317//
1318// This is called in LowerCall() when passing the param values.
1319static SDValue LowerUnalignedStoreParam(SelectionDAG &DAG, SDValue Chain,
1320 uint64_t Offset, EVT ElementType,
1321 SDValue StVal, SDValue &InGlue,
1322 unsigned ArgID, const SDLoc &dl) {
1323 // Bit logic only works on integer types
1324 if (adjustElementType(ElementType))
1325 StVal = DAG.getNode(Opcode: ISD::BITCAST, DL: dl, VT: ElementType, Operand: StVal);
1326
1327 // Store each byte
1328 SDVTList StoreVTs = DAG.getVTList(VT1: MVT::Other, VT2: MVT::Glue);
1329 for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
1330 // Shift the byte to the last byte position
1331 SDValue ShiftVal = DAG.getNode(Opcode: ISD::SRL, DL: dl, VT: ElementType, N1: StVal,
1332 N2: DAG.getConstant(Val: i * 8, DL: dl, VT: MVT::i32));
1333 SDValue StoreOperands[] = {Chain, DAG.getConstant(Val: ArgID, DL: dl, VT: MVT::i32),
1334 DAG.getConstant(Val: Offset + i, DL: dl, VT: MVT::i32),
1335 ShiftVal, InGlue};
1336 // Trunc store only the last byte by using
1337 // st.param.b8
1338 // The register type can be larger than b8.
1339 Chain = DAG.getMemIntrinsicNode(
1340 Opcode: NVPTXISD::StoreParam, dl, VTList: StoreVTs, Ops: StoreOperands, MemVT: MVT::i8,
1341 PtrInfo: MachinePointerInfo(), Alignment: Align(1), Flags: MachineMemOperand::MOStore);
1342 InGlue = Chain.getValue(R: 1);
1343 }
1344 return Chain;
1345}
1346
1347// Use byte-load when the param adress of the returned value is unaligned.
1348// This may happen when the returned value is a field of a packed structure.
1349static SDValue
1350LowerUnalignedLoadRetParam(SelectionDAG &DAG, SDValue &Chain, uint64_t Offset,
1351 EVT ElementType, SDValue &InGlue,
1352 SmallVectorImpl<SDValue> &TempProxyRegOps,
1353 const SDLoc &dl) {
1354 // Bit logic only works on integer types
1355 EVT MergedType = ElementType;
1356 adjustElementType(ElementType&: MergedType);
1357
1358 // Load each byte and construct the whole value. Initial value to 0
1359 SDValue RetVal = DAG.getConstant(Val: 0, DL: dl, VT: MergedType);
1360 // LoadParamMemI8 loads into i16 register only
1361 SDVTList LoadVTs = DAG.getVTList(VT1: MVT::i16, VT2: MVT::Other, VT3: MVT::Glue);
1362 for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
1363 SDValue LoadOperands[] = {Chain, DAG.getConstant(Val: 1, DL: dl, VT: MVT::i32),
1364 DAG.getConstant(Val: Offset + i, DL: dl, VT: MVT::i32),
1365 InGlue};
1366 // This will be selected to LoadParamMemI8
1367 SDValue LdVal =
1368 DAG.getMemIntrinsicNode(Opcode: NVPTXISD::LoadParam, dl, VTList: LoadVTs, Ops: LoadOperands,
1369 MemVT: MVT::i8, PtrInfo: MachinePointerInfo(), Alignment: Align(1));
1370 SDValue TmpLdVal = LdVal.getValue(R: 0);
1371 Chain = LdVal.getValue(R: 1);
1372 InGlue = LdVal.getValue(R: 2);
1373
1374 TmpLdVal = DAG.getNode(Opcode: NVPTXISD::ProxyReg, DL: dl,
1375 VT: TmpLdVal.getSimpleValueType(), Operand: TmpLdVal);
1376 TempProxyRegOps.push_back(Elt: TmpLdVal);
1377
1378 SDValue CMask = DAG.getConstant(Val: 255, DL: dl, VT: MergedType);
1379 SDValue CShift = DAG.getConstant(Val: i * 8, DL: dl, VT: MVT::i32);
1380 // Need to extend the i16 register to the whole width.
1381 TmpLdVal = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL: dl, VT: MergedType, Operand: TmpLdVal);
1382 // Mask off the high bits. Leave only the lower 8bits.
1383 // Do this because we are using loadparam.b8.
1384 TmpLdVal = DAG.getNode(Opcode: ISD::AND, DL: dl, VT: MergedType, N1: TmpLdVal, N2: CMask);
1385 // Shift and merge
1386 TmpLdVal = DAG.getNode(Opcode: ISD::SHL, DL: dl, VT: MergedType, N1: TmpLdVal, N2: CShift);
1387 RetVal = DAG.getNode(Opcode: ISD::OR, DL: dl, VT: MergedType, N1: RetVal, N2: TmpLdVal);
1388 }
1389 if (ElementType != MergedType)
1390 RetVal = DAG.getNode(Opcode: ISD::BITCAST, DL: dl, VT: ElementType, Operand: RetVal);
1391
1392 return RetVal;
1393}
1394
1395static bool shouldConvertToIndirectCall(const CallBase *CB,
1396 const GlobalAddressSDNode *Func) {
1397 if (!Func)
1398 return false;
1399 if (auto *CalleeFunc = dyn_cast<Function>(Val: Func->getGlobal()))
1400 return CB->getFunctionType() != CalleeFunc->getFunctionType();
1401 return false;
1402}
1403
1404static MachinePointerInfo refinePtrAS(SDValue &Ptr, SelectionDAG &DAG,
1405 const DataLayout &DL,
1406 const TargetLowering &TL) {
1407 if (Ptr->getOpcode() == ISD::FrameIndex) {
1408 auto Ty = TL.getPointerTy(DL, AS: ADDRESS_SPACE_LOCAL);
1409 Ptr = DAG.getAddrSpaceCast(dl: SDLoc(), VT: Ty, Ptr, SrcAS: ADDRESS_SPACE_GENERIC,
1410 DestAS: ADDRESS_SPACE_LOCAL);
1411
1412 return MachinePointerInfo(ADDRESS_SPACE_LOCAL);
1413 }
1414
1415 // Peel of an addrspacecast to generic and load directly from the specific
1416 // address space.
1417 if (Ptr->getOpcode() == ISD::ADDRSPACECAST) {
1418 const auto *ASC = cast<AddrSpaceCastSDNode>(Val&: Ptr);
1419 if (ASC->getDestAddressSpace() == ADDRESS_SPACE_GENERIC) {
1420 Ptr = ASC->getOperand(Num: 0);
1421 return MachinePointerInfo(ASC->getSrcAddressSpace());
1422 }
1423 }
1424
1425 return MachinePointerInfo();
1426}
1427
1428static ISD::NodeType getExtOpcode(const ISD::ArgFlagsTy &Flags) {
1429 if (Flags.isSExt())
1430 return ISD::SIGN_EXTEND;
1431 if (Flags.isZExt())
1432 return ISD::ZERO_EXTEND;
1433 return ISD::ANY_EXTEND;
1434}
1435
1436static SDValue correctParamType(SDValue V, EVT ExpectedVT,
1437 ISD::ArgFlagsTy Flags, SelectionDAG &DAG,
1438 SDLoc dl) {
1439 const EVT ActualVT = V.getValueType();
1440 assert((ActualVT == ExpectedVT ||
1441 (ExpectedVT.isInteger() && ActualVT.isInteger())) &&
1442 "Non-integer argument type size mismatch");
1443 if (ExpectedVT.bitsGT(VT: ActualVT))
1444 return DAG.getNode(Opcode: getExtOpcode(Flags), DL: dl, VT: ExpectedVT, Operand: V);
1445 if (ExpectedVT.bitsLT(VT: ActualVT))
1446 return DAG.getNode(Opcode: ISD::TRUNCATE, DL: dl, VT: ExpectedVT, Operand: V);
1447
1448 return V;
1449}
1450
1451SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
1452 SmallVectorImpl<SDValue> &InVals) const {
1453
1454 if (CLI.IsVarArg && (STI.getPTXVersion() < 60 || STI.getSmVersion() < 30))
1455 report_fatal_error(
1456 reason: "Support for variadic functions (unsized array parameter) introduced "
1457 "in PTX ISA version 6.0 and requires target sm_30.");
1458
1459 SelectionDAG &DAG = CLI.DAG;
1460 SDLoc dl = CLI.DL;
1461 SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
1462 SDValue Chain = CLI.Chain;
1463 SDValue Callee = CLI.Callee;
1464 bool &isTailCall = CLI.IsTailCall;
1465 ArgListTy &Args = CLI.getArgs();
1466 Type *RetTy = CLI.RetTy;
1467 const CallBase *CB = CLI.CB;
1468 const DataLayout &DL = DAG.getDataLayout();
1469
1470 const auto GetI32 = [&](const unsigned I) {
1471 return DAG.getConstant(Val: I, DL: dl, VT: MVT::i32);
1472 };
1473
1474 // Variadic arguments.
1475 //
1476 // Normally, for each argument, we declare a param scalar or a param
1477 // byte array in the .param space, and store the argument value to that
1478 // param scalar or array starting at offset 0.
1479 //
1480 // In the case of the first variadic argument, we declare a vararg byte array
1481 // with size 0. The exact size of this array isn't known at this point, so
1482 // it'll be patched later. All the variadic arguments will be stored to this
1483 // array at a certain offset (which gets tracked by 'VAOffset'). The offset is
1484 // initially set to 0, so it can be used for non-variadic arguments (which use
1485 // 0 offset) to simplify the code.
1486 //
1487 // After all vararg is processed, 'VAOffset' holds the size of the
1488 // vararg byte array.
1489
1490 SDValue VADeclareParam; // vararg byte array
1491 const unsigned FirstVAArg = CLI.NumFixedArgs; // position of first variadic
1492 unsigned VAOffset = 0; // current offset in the param array
1493
1494 const unsigned UniqueCallSite = GlobalUniqueCallSite++;
1495 SDValue TempChain = Chain;
1496 Chain = DAG.getCALLSEQ_START(Chain, InSize: UniqueCallSite, OutSize: 0, DL: dl);
1497 SDValue InGlue = Chain.getValue(R: 1);
1498
1499 // Args.size() and Outs.size() need not match.
1500 // Outs.size() will be larger
1501 // * if there is an aggregate argument with multiple fields (each field
1502 // showing up separately in Outs)
1503 // * if there is a vector argument with more than typical vector-length
1504 // elements (generally if more than 4) where each vector element is
1505 // individually present in Outs.
1506 // So a different index should be used for indexing into Outs/OutVals.
1507 // See similar issue in LowerFormalArguments.
1508 auto AllOuts = ArrayRef(CLI.Outs);
1509 auto AllOutVals = ArrayRef(CLI.OutVals);
1510 assert(AllOuts.size() == AllOutVals.size() &&
1511 "Outs and OutVals must be the same size");
1512 // Declare the .params or .reg need to pass values
1513 // to the function
1514 for (const auto E : llvm::enumerate(First&: Args)) {
1515 const auto ArgI = E.index();
1516 const auto Arg = E.value();
1517 const auto ArgOuts =
1518 AllOuts.take_while(Pred: [&](auto O) { return O.OrigArgIndex == ArgI; });
1519 const auto ArgOutVals = AllOutVals.take_front(N: ArgOuts.size());
1520 AllOuts = AllOuts.drop_front(N: ArgOuts.size());
1521 AllOutVals = AllOutVals.drop_front(N: ArgOuts.size());
1522
1523 const bool IsVAArg = (ArgI >= FirstVAArg);
1524 const bool IsByVal = Arg.IsByVal;
1525
1526 const SDValue ParamSymbol =
1527 getCallParamSymbol(DAG, I: IsVAArg ? FirstVAArg : ArgI, T: MVT::i32);
1528
1529 SmallVector<EVT, 16> VTs;
1530 SmallVector<uint64_t, 16> Offsets;
1531
1532 assert((!IsByVal || Arg.IndirectType) &&
1533 "byval arg must have indirect type");
1534 Type *ETy = (IsByVal ? Arg.IndirectType : Arg.Ty);
1535 ComputePTXValueVTs(TLI: *this, DL, Ty: ETy, ValueVTs&: VTs, Offsets: &Offsets, StartingOffset: IsByVal ? 0 : VAOffset);
1536 assert(VTs.size() == Offsets.size() && "Size mismatch");
1537 assert((IsByVal || VTs.size() == ArgOuts.size()) && "Size mismatch");
1538
1539 const Align ArgAlign = [&]() {
1540 if (IsByVal) {
1541 // The ByValAlign in the Outs[OIdx].Flags is always set at this point,
1542 // so we don't need to worry whether it's naturally aligned or not.
1543 // See TargetLowering::LowerCallTo().
1544 const Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
1545 const Align ByValAlign = getFunctionByValParamAlign(
1546 F: CB->getCalledFunction(), ArgTy: ETy, InitialAlign, DL);
1547 if (IsVAArg)
1548 VAOffset = alignTo(Size: VAOffset, A: ByValAlign);
1549 return ByValAlign;
1550 }
1551 return getArgumentAlignment(CB, Ty: Arg.Ty, Idx: ArgI + 1, DL);
1552 }();
1553
1554 const unsigned TypeSize = DL.getTypeAllocSize(Ty: ETy);
1555 assert((!IsByVal || TypeSize == ArgOuts[0].Flags.getByValSize()) &&
1556 "type size mismatch");
1557
1558 const std::optional<SDValue> ArgDeclare = [&]() -> std::optional<SDValue> {
1559 if (IsVAArg) {
1560 if (ArgI == FirstVAArg) {
1561 VADeclareParam = DAG.getNode(
1562 Opcode: NVPTXISD::DeclareArrayParam, DL: dl, ResultTys: {MVT::Other, MVT::Glue},
1563 Ops: {Chain, ParamSymbol, GetI32(STI.getMaxRequiredAlignment()),
1564 GetI32(0), InGlue});
1565 return VADeclareParam;
1566 }
1567 return std::nullopt;
1568 }
1569 if (IsByVal || shouldPassAsArray(Ty: Arg.Ty)) {
1570 // declare .param .align <align> .b8 .param<n>[<size>];
1571 return DAG.getNode(Opcode: NVPTXISD::DeclareArrayParam, DL: dl,
1572 ResultTys: {MVT::Other, MVT::Glue},
1573 Ops: {Chain, ParamSymbol, GetI32(ArgAlign.value()),
1574 GetI32(TypeSize), InGlue});
1575 }
1576 assert(ArgOuts.size() == 1 && "We must pass only one value as non-array");
1577 // declare .param .b<size> .param<n>;
1578
1579 // PTX ABI requires integral types to be at least 32 bits in
1580 // size. FP16 is loaded/stored using i16, so it's handled
1581 // here as well.
1582 const unsigned PromotedSize =
1583 (ArgOuts[0].VT.isInteger() || ArgOuts[0].VT.isFloatingPoint())
1584 ? promoteScalarArgumentSize(size: TypeSize * 8)
1585 : TypeSize * 8;
1586
1587 return DAG.getNode(Opcode: NVPTXISD::DeclareScalarParam, DL: dl,
1588 ResultTys: {MVT::Other, MVT::Glue},
1589 Ops: {Chain, ParamSymbol, GetI32(PromotedSize), InGlue});
1590 }();
1591 if (ArgDeclare) {
1592 Chain = ArgDeclare->getValue(R: 0);
1593 InGlue = ArgDeclare->getValue(R: 1);
1594 }
1595
1596 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
1597 // than 32-bits are sign extended or zero extended, depending on
1598 // whether they are signed or unsigned types. This case applies
1599 // only to scalar parameters and not to aggregate values.
1600 const bool ExtendIntegerParam =
1601 Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty: Arg.Ty) < 32;
1602
1603 const auto GetStoredValue = [&](const unsigned I, EVT EltVT,
1604 const Align PartAlign) {
1605 SDValue StVal;
1606 if (IsByVal) {
1607 SDValue Ptr = ArgOutVals[0];
1608 auto MPI = refinePtrAS(Ptr, DAG, DL, TL: *this);
1609 SDValue SrcAddr =
1610 DAG.getObjectPtrOffset(SL: dl, Ptr, Offset: TypeSize::getFixed(ExactSize: Offsets[I]));
1611
1612 StVal = DAG.getLoad(VT: EltVT, dl, Chain: TempChain, Ptr: SrcAddr, PtrInfo: MPI, Alignment: PartAlign);
1613 } else {
1614 StVal = ArgOutVals[I];
1615
1616 auto PromotedVT = promoteScalarIntegerPTX(VT: StVal.getValueType());
1617 if (PromotedVT != StVal.getValueType()) {
1618 StVal = DAG.getNode(Opcode: getExtOpcode(Flags: ArgOuts[I].Flags), DL: dl, VT: PromotedVT,
1619 Operand: StVal);
1620 }
1621 }
1622
1623 if (ExtendIntegerParam) {
1624 assert(VTs.size() == 1 && "Scalar can't have multiple parts.");
1625 // zext/sext to i32
1626 StVal =
1627 DAG.getNode(Opcode: getExtOpcode(Flags: ArgOuts[I].Flags), DL: dl, VT: MVT::i32, Operand: StVal);
1628 } else if (EltVT.getSizeInBits() < 16) {
1629 // Use 16-bit registers for small stores as it's the
1630 // smallest general purpose register size supported by NVPTX.
1631 StVal = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL: dl, VT: MVT::i16, Operand: StVal);
1632 }
1633 return StVal;
1634 };
1635
1636 const auto VectorInfo =
1637 VectorizePTXValueVTs(ValueVTs: VTs, Offsets, ParamAlignment: ArgAlign, IsVAArg);
1638
1639 unsigned J = 0;
1640 for (const unsigned NumElts : VectorInfo) {
1641 const int CurOffset = Offsets[J];
1642 EVT EltVT = promoteScalarIntegerPTX(VT: VTs[J]);
1643 const Align PartAlign = commonAlignment(A: ArgAlign, Offset: CurOffset);
1644
1645 // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
1646 // scalar store. In such cases, fall back to byte stores.
1647 if (NumElts == 1 && !IsVAArg && PartAlign < DAG.getEVTAlign(MemoryVT: EltVT)) {
1648
1649 SDValue StVal = GetStoredValue(J, EltVT, PartAlign);
1650 Chain = LowerUnalignedStoreParam(DAG, Chain,
1651 Offset: CurOffset + (IsByVal ? VAOffset : 0),
1652 ElementType: EltVT, StVal, InGlue, ArgID: ArgI, dl);
1653
1654 // LowerUnalignedStoreParam took care of inserting the necessary nodes
1655 // into the SDAG, so just move on to the next element.
1656 J++;
1657 continue;
1658 }
1659
1660 if (IsVAArg && !IsByVal)
1661 // Align each part of the variadic argument to their type.
1662 VAOffset = alignTo(Size: VAOffset, A: DAG.getEVTAlign(MemoryVT: EltVT));
1663
1664 assert((IsVAArg || VAOffset == 0) &&
1665 "VAOffset must be 0 for non-VA args");
1666 SmallVector<SDValue, 6> StoreOperands{
1667 Chain, GetI32(IsVAArg ? FirstVAArg : ArgI),
1668 GetI32(VAOffset + ((IsVAArg && !IsByVal) ? 0 : CurOffset))};
1669
1670 // Record the values to store.
1671 for (const unsigned K : llvm::seq(Size: NumElts))
1672 StoreOperands.push_back(Elt: GetStoredValue(J + K, EltVT, PartAlign));
1673 StoreOperands.push_back(Elt: InGlue);
1674
1675 NVPTXISD::NodeType Op;
1676 switch (NumElts) {
1677 case 1:
1678 Op = NVPTXISD::StoreParam;
1679 break;
1680 case 2:
1681 Op = NVPTXISD::StoreParamV2;
1682 break;
1683 case 4:
1684 Op = NVPTXISD::StoreParamV4;
1685 break;
1686 default:
1687 llvm_unreachable("Invalid vector info.");
1688 }
1689 // Adjust type of the store op if we've extended the scalar
1690 // return value.
1691 EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
1692
1693 Chain = DAG.getMemIntrinsicNode(
1694 Opcode: Op, dl, VTList: DAG.getVTList(VT1: MVT::Other, VT2: MVT::Glue), Ops: StoreOperands,
1695 MemVT: TheStoreType, PtrInfo: MachinePointerInfo(), Alignment: PartAlign,
1696 Flags: MachineMemOperand::MOStore);
1697 InGlue = Chain.getValue(R: 1);
1698
1699 // TODO: We may need to support vector types that can be passed
1700 // as scalars in variadic arguments.
1701 if (IsVAArg && !IsByVal) {
1702 assert(NumElts == 1 &&
1703 "Vectorization is expected to be disabled for variadics.");
1704 VAOffset +=
1705 DL.getTypeAllocSize(Ty: TheStoreType.getTypeForEVT(Context&: *DAG.getContext()));
1706 }
1707
1708 J += NumElts;
1709 }
1710 if (IsVAArg && IsByVal)
1711 VAOffset += TypeSize;
1712 }
1713
1714 GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Val: Callee.getNode());
1715
1716 // Handle Result
1717 if (!Ins.empty()) {
1718 const SDValue RetDeclare = [&]() {
1719 const SDValue RetSymbol = DAG.getExternalSymbol(Sym: "retval0", VT: MVT::i32);
1720 const unsigned ResultSize = DL.getTypeAllocSizeInBits(Ty: RetTy);
1721 if (shouldPassAsArray(Ty: RetTy)) {
1722 const Align RetAlign = getArgumentAlignment(CB, Ty: RetTy, Idx: 0, DL);
1723 return DAG.getNode(Opcode: NVPTXISD::DeclareArrayParam, DL: dl,
1724 ResultTys: {MVT::Other, MVT::Glue},
1725 Ops: {Chain, RetSymbol, GetI32(RetAlign.value()),
1726 GetI32(ResultSize / 8), InGlue});
1727 }
1728 const auto PromotedResultSize = promoteScalarArgumentSize(size: ResultSize);
1729 return DAG.getNode(
1730 Opcode: NVPTXISD::DeclareScalarParam, DL: dl, ResultTys: {MVT::Other, MVT::Glue},
1731 Ops: {Chain, RetSymbol, GetI32(PromotedResultSize), InGlue});
1732 }();
1733 Chain = RetDeclare.getValue(R: 0);
1734 InGlue = RetDeclare.getValue(R: 1);
1735 }
1736
1737 const bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs);
1738 // Set the size of the vararg param byte array if the callee is a variadic
1739 // function and the variadic part is not empty.
1740 if (HasVAArgs) {
1741 SDValue DeclareParamOps[] = {VADeclareParam.getOperand(i: 0),
1742 VADeclareParam.getOperand(i: 1),
1743 VADeclareParam.getOperand(i: 2), GetI32(VAOffset),
1744 VADeclareParam.getOperand(i: 4)};
1745 DAG.MorphNodeTo(N: VADeclareParam.getNode(), Opc: VADeclareParam.getOpcode(),
1746 VTs: VADeclareParam->getVTList(), Ops: DeclareParamOps);
1747 }
1748
1749 // If the type of the callsite does not match that of the function, convert
1750 // the callsite to an indirect call.
1751 const bool ConvertToIndirectCall = shouldConvertToIndirectCall(CB, Func);
1752
1753 // Both indirect calls and libcalls have nullptr Func. In order to distinguish
1754 // between them we must rely on the call site value which is valid for
1755 // indirect calls but is always null for libcalls.
1756 const bool IsIndirectCall = (!Func && CB) || ConvertToIndirectCall;
1757
1758 if (isa<ExternalSymbolSDNode>(Val: Callee)) {
1759 Function* CalleeFunc = nullptr;
1760
1761 // Try to find the callee in the current module.
1762 Callee = DAG.getSymbolFunctionGlobalAddress(Op: Callee, TargetFunction: &CalleeFunc);
1763 assert(CalleeFunc != nullptr && "Libcall callee must be set.");
1764
1765 // Set the "libcall callee" attribute to indicate that the function
1766 // must always have a declaration.
1767 CalleeFunc->addFnAttr(Kind: "nvptx-libcall-callee", Val: "true");
1768 }
1769
1770 if (IsIndirectCall) {
1771 // This is indirect function call case : PTX requires a prototype of the
1772 // form
1773 // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
1774 // to be emitted, and the label has to used as the last arg of call
1775 // instruction.
1776 // The prototype is embedded in a string and put as the operand for a
1777 // CallPrototype SDNode which will print out to the value of the string.
1778 std::string Proto =
1779 getPrototype(DL, RetTy, Args, Outs: CLI.Outs,
1780 FirstVAArg: HasVAArgs ? std::optional(FirstVAArg) : std::nullopt, CB: *CB,
1781 UniqueCallSite);
1782 const char *ProtoStr = nvTM->getStrPool().save(S: Proto).data();
1783 Chain = DAG.getNode(
1784 Opcode: NVPTXISD::CallPrototype, DL: dl, ResultTys: {MVT::Other, MVT::Glue},
1785 Ops: {Chain, DAG.getTargetExternalSymbol(Sym: ProtoStr, VT: MVT::i32), InGlue});
1786 InGlue = Chain.getValue(R: 1);
1787 }
1788
1789 if (ConvertToIndirectCall) {
1790 // Copy the function ptr to a ptx register and use the register to call the
1791 // function.
1792 const MVT DestVT = Callee.getValueType().getSimpleVT();
1793 MachineRegisterInfo &MRI = DAG.getMachineFunction().getRegInfo();
1794 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
1795 Register DestReg = MRI.createVirtualRegister(RegClass: TLI.getRegClassFor(VT: DestVT));
1796 auto RegCopy = DAG.getCopyToReg(Chain: DAG.getEntryNode(), dl, Reg: DestReg, N: Callee);
1797 Callee = DAG.getCopyFromReg(Chain: RegCopy, dl, Reg: DestReg, VT: DestVT);
1798 }
1799
1800 const unsigned Proto = IsIndirectCall ? UniqueCallSite : 0;
1801 const unsigned NumArgs =
1802 std::min<unsigned>(a: CLI.NumFixedArgs + 1, b: Args.size());
1803 /// CALL(Chain, IsConvergent, IsIndirectCall/IsUniform, NumReturns,
1804 /// NumParams, Callee, Proto, InGlue)
1805 Chain = DAG.getNode(Opcode: NVPTXISD::CALL, DL: dl, ResultTys: {MVT::Other, MVT::Glue},
1806 Ops: {Chain, GetI32(CLI.IsConvergent), GetI32(IsIndirectCall),
1807 GetI32(Ins.empty() ? 0 : 1), GetI32(NumArgs), Callee,
1808 GetI32(Proto), InGlue});
1809 InGlue = Chain.getValue(R: 1);
1810
1811 SmallVector<SDValue, 16> ProxyRegOps;
1812 // An item of the vector is filled if the element does not need a ProxyReg
1813 // operation on it and should be added to InVals as is. ProxyRegOps and
1814 // ProxyRegTruncates contain empty/none items at the same index.
1815 SmallVector<SDValue, 16> RetElts;
1816 // A temporary ProxyReg operations inserted in `LowerUnalignedLoadRetParam()`
1817 // to use the values of `LoadParam`s and to be replaced later then
1818 // `CALLSEQ_END` is added.
1819 SmallVector<SDValue, 16> TempProxyRegOps;
1820
1821 // Generate loads from param memory/moves from registers for result
1822 if (!Ins.empty()) {
1823 SmallVector<EVT, 16> VTs;
1824 SmallVector<uint64_t, 16> Offsets;
1825 ComputePTXValueVTs(TLI: *this, DL, Ty: RetTy, ValueVTs&: VTs, Offsets: &Offsets, StartingOffset: 0);
1826 assert(VTs.size() == Ins.size() && "Bad value decomposition");
1827
1828 const Align RetAlign = getArgumentAlignment(CB, Ty: RetTy, Idx: 0, DL);
1829
1830 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
1831 // 32-bits are sign extended or zero extended, depending on whether
1832 // they are signed or unsigned types.
1833 const bool ExtendIntegerRetVal =
1834 RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty: RetTy) < 32;
1835
1836 const auto VectorInfo = VectorizePTXValueVTs(ValueVTs: VTs, Offsets, ParamAlignment: RetAlign);
1837 unsigned I = 0;
1838 for (const unsigned VectorizedSize : VectorInfo) {
1839 EVT TheLoadType = promoteScalarIntegerPTX(VT: VTs[I]);
1840 EVT EltType = Ins[I].VT;
1841 const Align EltAlign = commonAlignment(A: RetAlign, Offset: Offsets[I]);
1842
1843 if (TheLoadType != VTs[I])
1844 EltType = TheLoadType;
1845
1846 if (ExtendIntegerRetVal) {
1847 TheLoadType = MVT::i32;
1848 EltType = MVT::i32;
1849 } else if (TheLoadType.getSizeInBits() < 16) {
1850 EltType = MVT::i16;
1851 }
1852
1853 // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
1854 // scalar load. In such cases, fall back to byte loads.
1855 if (VectorizedSize == 1 && RetTy->isAggregateType() &&
1856 EltAlign < DAG.getEVTAlign(MemoryVT: TheLoadType)) {
1857 SDValue Ret = LowerUnalignedLoadRetParam(
1858 DAG, Chain, Offset: Offsets[I], ElementType: TheLoadType, InGlue, TempProxyRegOps, dl);
1859 ProxyRegOps.push_back(Elt: SDValue());
1860 RetElts.resize(N: I);
1861 RetElts.push_back(Elt: Ret);
1862
1863 I++;
1864 continue;
1865 }
1866
1867 SmallVector<EVT, 6> LoadVTs(VectorizedSize, EltType);
1868 LoadVTs.append(IL: {MVT::Other, MVT::Glue});
1869
1870 NVPTXISD::NodeType Op;
1871 switch (VectorizedSize) {
1872 case 1:
1873 Op = NVPTXISD::LoadParam;
1874 break;
1875 case 2:
1876 Op = NVPTXISD::LoadParamV2;
1877 break;
1878 case 4:
1879 Op = NVPTXISD::LoadParamV4;
1880 break;
1881 default:
1882 llvm_unreachable("Invalid vector info.");
1883 }
1884
1885 SDValue LoadOperands[] = {Chain, GetI32(1), GetI32(Offsets[I]), InGlue};
1886 SDValue RetVal = DAG.getMemIntrinsicNode(
1887 Opcode: Op, dl, VTList: DAG.getVTList(VTs: LoadVTs), Ops: LoadOperands, MemVT: TheLoadType,
1888 PtrInfo: MachinePointerInfo(), Alignment: EltAlign, Flags: MachineMemOperand::MOLoad);
1889
1890 for (const unsigned J : llvm::seq(Size: VectorizedSize)) {
1891 ProxyRegOps.push_back(Elt: RetVal.getValue(R: J));
1892 }
1893
1894 Chain = RetVal.getValue(R: VectorizedSize);
1895 InGlue = RetVal.getValue(R: VectorizedSize + 1);
1896
1897 I += VectorizedSize;
1898 }
1899 }
1900
1901 Chain =
1902 DAG.getCALLSEQ_END(Chain, Size1: UniqueCallSite, Size2: UniqueCallSite + 1, Glue: InGlue, DL: dl);
1903 InGlue = Chain.getValue(R: 1);
1904
1905 // Append ProxyReg instructions to the chain to make sure that `callseq_end`
1906 // will not get lost. Otherwise, during libcalls expansion, the nodes can become
1907 // dangling.
1908 for (const unsigned I : llvm::seq(Size: ProxyRegOps.size())) {
1909 if (I < RetElts.size() && RetElts[I]) {
1910 InVals.push_back(Elt: RetElts[I]);
1911 continue;
1912 }
1913
1914 SDValue Ret =
1915 DAG.getNode(Opcode: NVPTXISD::ProxyReg, DL: dl, VT: ProxyRegOps[I].getSimpleValueType(),
1916 Ops: {Chain, ProxyRegOps[I]});
1917
1918 const EVT ExpectedVT = Ins[I].VT;
1919 if (!Ret.getValueType().bitsEq(VT: ExpectedVT)) {
1920 Ret = DAG.getNode(Opcode: ISD::TRUNCATE, DL: dl, VT: ExpectedVT, Operand: Ret);
1921 }
1922 InVals.push_back(Elt: Ret);
1923 }
1924
1925 for (SDValue &T : TempProxyRegOps) {
1926 SDValue Repl = DAG.getNode(Opcode: NVPTXISD::ProxyReg, DL: dl, VT: T.getSimpleValueType(),
1927 Ops: {Chain, T.getOperand(i: 0)});
1928 DAG.ReplaceAllUsesWith(From: T, To: Repl);
1929 DAG.RemoveDeadNode(N: T.getNode());
1930 }
1931
1932 // set isTailCall to false for now, until we figure out how to express
1933 // tail call optimization in PTX
1934 isTailCall = false;
1935 return Chain;
1936}
1937
1938SDValue NVPTXTargetLowering::LowerDYNAMIC_STACKALLOC(SDValue Op,
1939 SelectionDAG &DAG) const {
1940
1941 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
1942 const Function &Fn = DAG.getMachineFunction().getFunction();
1943
1944 DAG.getContext()->diagnose(DI: DiagnosticInfoUnsupported(
1945 Fn,
1946 "Support for dynamic alloca introduced in PTX ISA version 7.3 and "
1947 "requires target sm_52.",
1948 SDLoc(Op).getDebugLoc()));
1949 auto Ops = {DAG.getConstant(Val: 0, DL: SDLoc(), VT: Op.getValueType()),
1950 Op.getOperand(i: 0)};
1951 return DAG.getMergeValues(Ops, dl: SDLoc());
1952 }
1953
1954 SDLoc DL(Op.getNode());
1955 SDValue Chain = Op.getOperand(i: 0);
1956 SDValue Size = Op.getOperand(i: 1);
1957 uint64_t Align = Op.getConstantOperandVal(i: 2);
1958
1959 // The alignment on a ISD::DYNAMIC_STACKALLOC node may be 0 to indicate that
1960 // the default stack alignment should be used.
1961 if (Align == 0)
1962 Align = DAG.getSubtarget().getFrameLowering()->getStackAlign().value();
1963
1964 // The size for ptx alloca instruction is 64-bit for m64 and 32-bit for m32.
1965 const MVT LocalVT = getPointerTy(DL: DAG.getDataLayout(), AS: ADDRESS_SPACE_LOCAL);
1966
1967 SDValue Alloc =
1968 DAG.getNode(Opcode: NVPTXISD::DYNAMIC_STACKALLOC, DL, ResultTys: {LocalVT, MVT::Other},
1969 Ops: {Chain, DAG.getZExtOrTrunc(Op: Size, DL, VT: LocalVT),
1970 DAG.getTargetConstant(Val: Align, DL, VT: MVT::i32)});
1971
1972 SDValue ASC = DAG.getAddrSpaceCast(
1973 dl: DL, VT: Op.getValueType(), Ptr: Alloc, SrcAS: ADDRESS_SPACE_LOCAL, DestAS: ADDRESS_SPACE_GENERIC);
1974
1975 return DAG.getMergeValues(Ops: {ASC, SDValue(Alloc.getNode(), 1)}, dl: DL);
1976}
1977
1978SDValue NVPTXTargetLowering::LowerSTACKRESTORE(SDValue Op,
1979 SelectionDAG &DAG) const {
1980 SDLoc DL(Op.getNode());
1981 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
1982 const Function &Fn = DAG.getMachineFunction().getFunction();
1983
1984 DAG.getContext()->diagnose(DI: DiagnosticInfoUnsupported(
1985 Fn,
1986 "Support for stackrestore requires PTX ISA version >= 7.3 and target "
1987 ">= sm_52.",
1988 DL.getDebugLoc()));
1989 return Op.getOperand(i: 0);
1990 }
1991
1992 const MVT LocalVT = getPointerTy(DL: DAG.getDataLayout(), AS: ADDRESS_SPACE_LOCAL);
1993 SDValue Chain = Op.getOperand(i: 0);
1994 SDValue Ptr = Op.getOperand(i: 1);
1995 SDValue ASC = DAG.getAddrSpaceCast(dl: DL, VT: LocalVT, Ptr, SrcAS: ADDRESS_SPACE_GENERIC,
1996 DestAS: ADDRESS_SPACE_LOCAL);
1997 return DAG.getNode(Opcode: NVPTXISD::STACKRESTORE, DL, VT: MVT::Other, Ops: {Chain, ASC});
1998}
1999
2000SDValue NVPTXTargetLowering::LowerSTACKSAVE(SDValue Op,
2001 SelectionDAG &DAG) const {
2002 SDLoc DL(Op.getNode());
2003 if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
2004 const Function &Fn = DAG.getMachineFunction().getFunction();
2005
2006 DAG.getContext()->diagnose(DI: DiagnosticInfoUnsupported(
2007 Fn,
2008 "Support for stacksave requires PTX ISA version >= 7.3 and target >= "
2009 "sm_52.",
2010 DL.getDebugLoc()));
2011 auto Ops = {DAG.getConstant(Val: 0, DL, VT: Op.getValueType()), Op.getOperand(i: 0)};
2012 return DAG.getMergeValues(Ops, dl: DL);
2013 }
2014
2015 const MVT LocalVT = getPointerTy(DL: DAG.getDataLayout(), AS: ADDRESS_SPACE_LOCAL);
2016 SDValue Chain = Op.getOperand(i: 0);
2017 SDValue SS =
2018 DAG.getNode(Opcode: NVPTXISD::STACKSAVE, DL, ResultTys: {LocalVT, MVT::Other}, Ops: Chain);
2019 SDValue ASC = DAG.getAddrSpaceCast(
2020 dl: DL, VT: Op.getValueType(), Ptr: SS, SrcAS: ADDRESS_SPACE_LOCAL, DestAS: ADDRESS_SPACE_GENERIC);
2021 return DAG.getMergeValues(Ops: {ASC, SDValue(SS.getNode(), 1)}, dl: DL);
2022}
2023
2024// By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
2025// (see LegalizeDAG.cpp). This is slow and uses local memory.
2026// We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
2027SDValue
2028NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
2029 SDNode *Node = Op.getNode();
2030 SDLoc dl(Node);
2031 SmallVector<SDValue, 8> Ops;
2032 unsigned NumOperands = Node->getNumOperands();
2033 for (unsigned i = 0; i < NumOperands; ++i) {
2034 SDValue SubOp = Node->getOperand(Num: i);
2035 EVT VVT = SubOp.getNode()->getValueType(ResNo: 0);
2036 EVT EltVT = VVT.getVectorElementType();
2037 unsigned NumSubElem = VVT.getVectorNumElements();
2038 for (unsigned j = 0; j < NumSubElem; ++j) {
2039 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: dl, VT: EltVT, N1: SubOp,
2040 N2: DAG.getIntPtrConstant(Val: j, DL: dl)));
2041 }
2042 }
2043 return DAG.getBuildVector(VT: Node->getValueType(ResNo: 0), DL: dl, Ops);
2044}
2045
2046SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
2047 // Handle bitcasting from v2i8 without hitting the default promotion
2048 // strategy which goes through stack memory.
2049 EVT FromVT = Op->getOperand(Num: 0)->getValueType(ResNo: 0);
2050 if (FromVT != MVT::v2i8) {
2051 return Op;
2052 }
2053
2054 // Pack vector elements into i16 and bitcast to final type
2055 SDLoc DL(Op);
2056 SDValue Vec0 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i8,
2057 N1: Op->getOperand(Num: 0), N2: DAG.getIntPtrConstant(Val: 0, DL));
2058 SDValue Vec1 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i8,
2059 N1: Op->getOperand(Num: 0), N2: DAG.getIntPtrConstant(Val: 1, DL));
2060 SDValue Extend0 = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: MVT::i16, Operand: Vec0);
2061 SDValue Extend1 = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: MVT::i16, Operand: Vec1);
2062 SDValue Const8 = DAG.getConstant(Val: 8, DL, VT: MVT::i16);
2063 SDValue AsInt = DAG.getNode(
2064 Opcode: ISD::OR, DL, VT: MVT::i16,
2065 Ops: {Extend0, DAG.getNode(Opcode: ISD::SHL, DL, VT: MVT::i16, Ops: {Extend1, Const8})});
2066 EVT ToVT = Op->getValueType(ResNo: 0);
2067 return DAG.getBitcast(VT: ToVT, V: AsInt);
2068}
2069
2070// We can init constant f16x2/v2i16/v4i8 with a single .b32 move. Normally it
2071// would get lowered as two constant loads and vector-packing move.
2072// Instead we want just a constant move:
2073// mov.b32 %r2, 0x40003C00
2074SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
2075 SelectionDAG &DAG) const {
2076 EVT VT = Op->getValueType(ResNo: 0);
2077 if (!(Isv2x16VT(VT) || VT == MVT::v4i8))
2078 return Op;
2079 SDLoc DL(Op);
2080
2081 if (!llvm::all_of(Range: Op->ops(), P: [](SDValue Operand) {
2082 return Operand->isUndef() || isa<ConstantSDNode>(Val: Operand) ||
2083 isa<ConstantFPSDNode>(Val: Operand);
2084 })) {
2085 if (VT != MVT::v4i8)
2086 return Op;
2087 // Lower non-const v4i8 vector as byte-wise constructed i32, which allows us
2088 // to optimize calculation of constant parts.
2089 auto GetPRMT = [&](const SDValue Left, const SDValue Right, bool Cast,
2090 uint64_t SelectionValue) -> SDValue {
2091 SDValue L = Left;
2092 SDValue R = Right;
2093 if (Cast) {
2094 L = DAG.getAnyExtOrTrunc(Op: L, DL, VT: MVT::i32);
2095 R = DAG.getAnyExtOrTrunc(Op: R, DL, VT: MVT::i32);
2096 }
2097 return DAG.getNode(
2098 Opcode: NVPTXISD::PRMT, DL, VT: MVT::v4i8,
2099 Ops: {L, R, DAG.getConstant(Val: SelectionValue, DL, VT: MVT::i32),
2100 DAG.getConstant(Val: NVPTX::PTXPrmtMode::NONE, DL, VT: MVT::i32)});
2101 };
2102 auto PRMT__10 = GetPRMT(Op->getOperand(Num: 0), Op->getOperand(Num: 1), true, 0x3340);
2103 auto PRMT__32 = GetPRMT(Op->getOperand(Num: 2), Op->getOperand(Num: 3), true, 0x3340);
2104 auto PRMT3210 = GetPRMT(PRMT__10, PRMT__32, false, 0x5410);
2105 return DAG.getNode(Opcode: ISD::BITCAST, DL, VT, Operand: PRMT3210);
2106 }
2107
2108 // Get value or the Nth operand as an APInt(32). Undef values treated as 0.
2109 auto GetOperand = [](SDValue Op, int N) -> APInt {
2110 const SDValue &Operand = Op->getOperand(Num: N);
2111 EVT VT = Op->getValueType(ResNo: 0);
2112 if (Operand->isUndef())
2113 return APInt(32, 0);
2114 APInt Value;
2115 if (VT == MVT::v2f16 || VT == MVT::v2bf16)
2116 Value = cast<ConstantFPSDNode>(Val: Operand)->getValueAPF().bitcastToAPInt();
2117 else if (VT == MVT::v2i16 || VT == MVT::v4i8)
2118 Value = Operand->getAsAPIntVal();
2119 else
2120 llvm_unreachable("Unsupported type");
2121 // i8 values are carried around as i16, so we need to zero out upper bits,
2122 // so they do not get in the way of combining individual byte values
2123 if (VT == MVT::v4i8)
2124 Value = Value.trunc(width: 8);
2125 return Value.zext(width: 32);
2126 };
2127 APInt Value;
2128 if (Isv2x16VT(VT)) {
2129 Value = GetOperand(Op, 0) | GetOperand(Op, 1).shl(shiftAmt: 16);
2130 } else if (VT == MVT::v4i8) {
2131 Value = GetOperand(Op, 0) | GetOperand(Op, 1).shl(shiftAmt: 8) |
2132 GetOperand(Op, 2).shl(shiftAmt: 16) | GetOperand(Op, 3).shl(shiftAmt: 24);
2133 } else {
2134 llvm_unreachable("Unsupported type");
2135 }
2136 SDValue Const = DAG.getConstant(Val: Value, DL, VT: MVT::i32);
2137 return DAG.getNode(Opcode: ISD::BITCAST, DL, VT: Op->getValueType(ResNo: 0), Operand: Const);
2138}
2139
2140SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
2141 SelectionDAG &DAG) const {
2142 SDValue Index = Op->getOperand(Num: 1);
2143 SDValue Vector = Op->getOperand(Num: 0);
2144 SDLoc DL(Op);
2145 EVT VectorVT = Vector.getValueType();
2146
2147 if (VectorVT == MVT::v4i8) {
2148 SDValue BFE =
2149 DAG.getNode(Opcode: NVPTXISD::BFE, DL, VT: MVT::i32,
2150 Ops: {Vector,
2151 DAG.getNode(Opcode: ISD::MUL, DL, VT: MVT::i32,
2152 N1: DAG.getZExtOrTrunc(Op: Index, DL, VT: MVT::i32),
2153 N2: DAG.getConstant(Val: 8, DL, VT: MVT::i32)),
2154 DAG.getConstant(Val: 8, DL, VT: MVT::i32)});
2155 return DAG.getAnyExtOrTrunc(Op: BFE, DL, VT: Op->getValueType(ResNo: 0));
2156 }
2157
2158 // Constant index will be matched by tablegen.
2159 if (isa<ConstantSDNode>(Val: Index.getNode()))
2160 return Op;
2161
2162 // Extract individual elements and select one of them.
2163 assert(Isv2x16VT(VectorVT) && "Unexpected vector type.");
2164 EVT EltVT = VectorVT.getVectorElementType();
2165
2166 SDLoc dl(Op.getNode());
2167 SDValue E0 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: dl, VT: EltVT, N1: Vector,
2168 N2: DAG.getIntPtrConstant(Val: 0, DL: dl));
2169 SDValue E1 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: dl, VT: EltVT, N1: Vector,
2170 N2: DAG.getIntPtrConstant(Val: 1, DL: dl));
2171 return DAG.getSelectCC(DL: dl, LHS: Index, RHS: DAG.getIntPtrConstant(Val: 0, DL: dl), True: E0, False: E1,
2172 Cond: ISD::CondCode::SETEQ);
2173}
2174
2175SDValue NVPTXTargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op,
2176 SelectionDAG &DAG) const {
2177 SDValue Vector = Op->getOperand(Num: 0);
2178 EVT VectorVT = Vector.getValueType();
2179
2180 if (VectorVT != MVT::v4i8)
2181 return Op;
2182 SDLoc DL(Op);
2183 SDValue Value = Op->getOperand(Num: 1);
2184 if (Value->isUndef())
2185 return Vector;
2186
2187 SDValue Index = Op->getOperand(Num: 2);
2188
2189 SDValue BFI =
2190 DAG.getNode(Opcode: NVPTXISD::BFI, DL, VT: MVT::i32,
2191 Ops: {DAG.getZExtOrTrunc(Op: Value, DL, VT: MVT::i32), Vector,
2192 DAG.getNode(Opcode: ISD::MUL, DL, VT: MVT::i32,
2193 N1: DAG.getZExtOrTrunc(Op: Index, DL, VT: MVT::i32),
2194 N2: DAG.getConstant(Val: 8, DL, VT: MVT::i32)),
2195 DAG.getConstant(Val: 8, DL, VT: MVT::i32)});
2196 return DAG.getNode(Opcode: ISD::BITCAST, DL, VT: Op->getValueType(ResNo: 0), Operand: BFI);
2197}
2198
2199SDValue NVPTXTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
2200 SelectionDAG &DAG) const {
2201 SDValue V1 = Op.getOperand(i: 0);
2202 EVT VectorVT = V1.getValueType();
2203 if (VectorVT != MVT::v4i8 || Op.getValueType() != MVT::v4i8)
2204 return Op;
2205
2206 // Lower shuffle to PRMT instruction.
2207 const ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Val: Op.getNode());
2208 SDValue V2 = Op.getOperand(i: 1);
2209 uint32_t Selector = 0;
2210 for (auto I : llvm::enumerate(First: SVN->getMask())) {
2211 if (I.value() != -1) // -1 is a placeholder for undef.
2212 Selector |= (I.value() << (I.index() * 4));
2213 }
2214
2215 SDLoc DL(Op);
2216 return DAG.getNode(Opcode: NVPTXISD::PRMT, DL, VT: MVT::v4i8, N1: V1, N2: V2,
2217 N3: DAG.getConstant(Val: Selector, DL, VT: MVT::i32),
2218 N4: DAG.getConstant(Val: NVPTX::PTXPrmtMode::NONE, DL, VT: MVT::i32));
2219}
2220/// LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which
2221/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
2222/// amount, or
2223/// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift
2224/// amount.
2225SDValue NVPTXTargetLowering::LowerShiftRightParts(SDValue Op,
2226 SelectionDAG &DAG) const {
2227 assert(Op.getNumOperands() == 3 && "Not a double-shift!");
2228 assert(Op.getOpcode() == ISD::SRA_PARTS || Op.getOpcode() == ISD::SRL_PARTS);
2229
2230 EVT VT = Op.getValueType();
2231 unsigned VTBits = VT.getSizeInBits();
2232 SDLoc dl(Op);
2233 SDValue ShOpLo = Op.getOperand(i: 0);
2234 SDValue ShOpHi = Op.getOperand(i: 1);
2235 SDValue ShAmt = Op.getOperand(i: 2);
2236 unsigned Opc = (Op.getOpcode() == ISD::SRA_PARTS) ? ISD::SRA : ISD::SRL;
2237
2238 if (VTBits == 32 && STI.getSmVersion() >= 35) {
2239 // For 32bit and sm35, we can use the funnel shift 'shf' instruction.
2240 // {dHi, dLo} = {aHi, aLo} >> Amt
2241 // dHi = aHi >> Amt
2242 // dLo = shf.r.clamp aLo, aHi, Amt
2243
2244 SDValue Hi = DAG.getNode(Opcode: Opc, DL: dl, VT, N1: ShOpHi, N2: ShAmt);
2245 SDValue Lo =
2246 DAG.getNode(Opcode: NVPTXISD::FSHR_CLAMP, DL: dl, VT, N1: ShOpHi, N2: ShOpLo, N3: ShAmt);
2247
2248 SDValue Ops[2] = { Lo, Hi };
2249 return DAG.getMergeValues(Ops, dl);
2250 }
2251 else {
2252 // {dHi, dLo} = {aHi, aLo} >> Amt
2253 // - if (Amt>=size) then
2254 // dLo = aHi >> (Amt-size)
2255 // dHi = aHi >> Amt (this is either all 0 or all 1)
2256 // else
2257 // dLo = (aLo >>logic Amt) | (aHi << (size-Amt))
2258 // dHi = aHi >> Amt
2259
2260 SDValue RevShAmt = DAG.getNode(Opcode: ISD::SUB, DL: dl, VT: MVT::i32,
2261 N1: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32),
2262 N2: ShAmt);
2263 SDValue Tmp1 = DAG.getNode(Opcode: ISD::SRL, DL: dl, VT, N1: ShOpLo, N2: ShAmt);
2264 SDValue ExtraShAmt = DAG.getNode(Opcode: ISD::SUB, DL: dl, VT: MVT::i32, N1: ShAmt,
2265 N2: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32));
2266 SDValue Tmp2 = DAG.getNode(Opcode: ISD::SHL, DL: dl, VT, N1: ShOpHi, N2: RevShAmt);
2267 SDValue FalseVal = DAG.getNode(Opcode: ISD::OR, DL: dl, VT, N1: Tmp1, N2: Tmp2);
2268 SDValue TrueVal = DAG.getNode(Opcode: Opc, DL: dl, VT, N1: ShOpHi, N2: ExtraShAmt);
2269
2270 SDValue Cmp = DAG.getSetCC(DL: dl, VT: MVT::i1, LHS: ShAmt,
2271 RHS: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32),
2272 Cond: ISD::SETGE);
2273 SDValue Hi = DAG.getNode(Opcode: Opc, DL: dl, VT, N1: ShOpHi, N2: ShAmt);
2274 SDValue Lo = DAG.getNode(Opcode: ISD::SELECT, DL: dl, VT, N1: Cmp, N2: TrueVal, N3: FalseVal);
2275
2276 SDValue Ops[2] = { Lo, Hi };
2277 return DAG.getMergeValues(Ops, dl);
2278 }
2279}
2280
2281/// LowerShiftLeftParts - Lower SHL_PARTS, which
2282/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
2283/// amount, or
2284/// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift
2285/// amount.
2286SDValue NVPTXTargetLowering::LowerShiftLeftParts(SDValue Op,
2287 SelectionDAG &DAG) const {
2288 assert(Op.getNumOperands() == 3 && "Not a double-shift!");
2289 assert(Op.getOpcode() == ISD::SHL_PARTS);
2290
2291 EVT VT = Op.getValueType();
2292 unsigned VTBits = VT.getSizeInBits();
2293 SDLoc dl(Op);
2294 SDValue ShOpLo = Op.getOperand(i: 0);
2295 SDValue ShOpHi = Op.getOperand(i: 1);
2296 SDValue ShAmt = Op.getOperand(i: 2);
2297
2298 if (VTBits == 32 && STI.getSmVersion() >= 35) {
2299 // For 32bit and sm35, we can use the funnel shift 'shf' instruction.
2300 // {dHi, dLo} = {aHi, aLo} << Amt
2301 // dHi = shf.l.clamp aLo, aHi, Amt
2302 // dLo = aLo << Amt
2303
2304 SDValue Hi =
2305 DAG.getNode(Opcode: NVPTXISD::FSHL_CLAMP, DL: dl, VT, N1: ShOpHi, N2: ShOpLo, N3: ShAmt);
2306 SDValue Lo = DAG.getNode(Opcode: ISD::SHL, DL: dl, VT, N1: ShOpLo, N2: ShAmt);
2307
2308 SDValue Ops[2] = { Lo, Hi };
2309 return DAG.getMergeValues(Ops, dl);
2310 }
2311 else {
2312 // {dHi, dLo} = {aHi, aLo} << Amt
2313 // - if (Amt>=size) then
2314 // dLo = aLo << Amt (all 0)
2315 // dLo = aLo << (Amt-size)
2316 // else
2317 // dLo = aLo << Amt
2318 // dHi = (aHi << Amt) | (aLo >> (size-Amt))
2319
2320 SDValue RevShAmt = DAG.getNode(Opcode: ISD::SUB, DL: dl, VT: MVT::i32,
2321 N1: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32),
2322 N2: ShAmt);
2323 SDValue Tmp1 = DAG.getNode(Opcode: ISD::SHL, DL: dl, VT, N1: ShOpHi, N2: ShAmt);
2324 SDValue ExtraShAmt = DAG.getNode(Opcode: ISD::SUB, DL: dl, VT: MVT::i32, N1: ShAmt,
2325 N2: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32));
2326 SDValue Tmp2 = DAG.getNode(Opcode: ISD::SRL, DL: dl, VT, N1: ShOpLo, N2: RevShAmt);
2327 SDValue FalseVal = DAG.getNode(Opcode: ISD::OR, DL: dl, VT, N1: Tmp1, N2: Tmp2);
2328 SDValue TrueVal = DAG.getNode(Opcode: ISD::SHL, DL: dl, VT, N1: ShOpLo, N2: ExtraShAmt);
2329
2330 SDValue Cmp = DAG.getSetCC(DL: dl, VT: MVT::i1, LHS: ShAmt,
2331 RHS: DAG.getConstant(Val: VTBits, DL: dl, VT: MVT::i32),
2332 Cond: ISD::SETGE);
2333 SDValue Lo = DAG.getNode(Opcode: ISD::SHL, DL: dl, VT, N1: ShOpLo, N2: ShAmt);
2334 SDValue Hi = DAG.getNode(Opcode: ISD::SELECT, DL: dl, VT, N1: Cmp, N2: TrueVal, N3: FalseVal);
2335
2336 SDValue Ops[2] = { Lo, Hi };
2337 return DAG.getMergeValues(Ops, dl);
2338 }
2339}
2340
2341/// If the types match, convert the generic copysign to the NVPTXISD version,
2342/// otherwise bail ensuring that mismatched cases are properly expaned.
2343SDValue NVPTXTargetLowering::LowerFCOPYSIGN(SDValue Op,
2344 SelectionDAG &DAG) const {
2345 EVT VT = Op.getValueType();
2346 SDLoc DL(Op);
2347
2348 SDValue In1 = Op.getOperand(i: 0);
2349 SDValue In2 = Op.getOperand(i: 1);
2350 EVT SrcVT = In2.getValueType();
2351
2352 if (!SrcVT.bitsEq(VT))
2353 return SDValue();
2354
2355 return DAG.getNode(Opcode: NVPTXISD::FCOPYSIGN, DL, VT, N1: In1, N2: In2);
2356}
2357
2358SDValue NVPTXTargetLowering::LowerFROUND(SDValue Op, SelectionDAG &DAG) const {
2359 EVT VT = Op.getValueType();
2360
2361 if (VT == MVT::f32)
2362 return LowerFROUND32(Op, DAG);
2363
2364 if (VT == MVT::f64)
2365 return LowerFROUND64(Op, DAG);
2366
2367 llvm_unreachable("unhandled type");
2368}
2369
2370// This is the the rounding method used in CUDA libdevice in C like code:
2371// float roundf(float A)
2372// {
2373// float RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f));
2374// RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
2375// return abs(A) < 0.5 ? (float)(int)A : RoundedA;
2376// }
2377SDValue NVPTXTargetLowering::LowerFROUND32(SDValue Op,
2378 SelectionDAG &DAG) const {
2379 SDLoc SL(Op);
2380 SDValue A = Op.getOperand(i: 0);
2381 EVT VT = Op.getValueType();
2382
2383 SDValue AbsA = DAG.getNode(Opcode: ISD::FABS, DL: SL, VT, Operand: A);
2384
2385 // RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f))
2386 SDValue Bitcast = DAG.getNode(Opcode: ISD::BITCAST, DL: SL, VT: MVT::i32, Operand: A);
2387 const unsigned SignBitMask = 0x80000000;
2388 SDValue Sign = DAG.getNode(Opcode: ISD::AND, DL: SL, VT: MVT::i32, N1: Bitcast,
2389 N2: DAG.getConstant(Val: SignBitMask, DL: SL, VT: MVT::i32));
2390 const unsigned PointFiveInBits = 0x3F000000;
2391 SDValue PointFiveWithSignRaw =
2392 DAG.getNode(Opcode: ISD::OR, DL: SL, VT: MVT::i32, N1: Sign,
2393 N2: DAG.getConstant(Val: PointFiveInBits, DL: SL, VT: MVT::i32));
2394 SDValue PointFiveWithSign =
2395 DAG.getNode(Opcode: ISD::BITCAST, DL: SL, VT, Operand: PointFiveWithSignRaw);
2396 SDValue AdjustedA = DAG.getNode(Opcode: ISD::FADD, DL: SL, VT, N1: A, N2: PointFiveWithSign);
2397 SDValue RoundedA = DAG.getNode(Opcode: ISD::FTRUNC, DL: SL, VT, Operand: AdjustedA);
2398
2399 // RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
2400 EVT SetCCVT = getSetCCResultType(DL: DAG.getDataLayout(), Ctx&: *DAG.getContext(), VT);
2401 SDValue IsLarge =
2402 DAG.getSetCC(DL: SL, VT: SetCCVT, LHS: AbsA, RHS: DAG.getConstantFP(Val: pow(x: 2.0, y: 23.0), DL: SL, VT),
2403 Cond: ISD::SETOGT);
2404 RoundedA = DAG.getNode(Opcode: ISD::SELECT, DL: SL, VT, N1: IsLarge, N2: A, N3: RoundedA);
2405
2406 // return abs(A) < 0.5 ? (float)(int)A : RoundedA;
2407 SDValue IsSmall =DAG.getSetCC(DL: SL, VT: SetCCVT, LHS: AbsA,
2408 RHS: DAG.getConstantFP(Val: 0.5, DL: SL, VT), Cond: ISD::SETOLT);
2409 SDValue RoundedAForSmallA = DAG.getNode(Opcode: ISD::FTRUNC, DL: SL, VT, Operand: A);
2410 return DAG.getNode(Opcode: ISD::SELECT, DL: SL, VT, N1: IsSmall, N2: RoundedAForSmallA, N3: RoundedA);
2411}
2412
2413// The implementation of round(double) is similar to that of round(float) in
2414// that they both separate the value range into three regions and use a method
2415// specific to the region to round the values. However, round(double) first
2416// calculates the round of the absolute value and then adds the sign back while
2417// round(float) directly rounds the value with sign.
2418SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
2419 SelectionDAG &DAG) const {
2420 SDLoc SL(Op);
2421 SDValue A = Op.getOperand(i: 0);
2422 EVT VT = Op.getValueType();
2423
2424 SDValue AbsA = DAG.getNode(Opcode: ISD::FABS, DL: SL, VT, Operand: A);
2425
2426 // double RoundedA = (double) (int) (abs(A) + 0.5f);
2427 SDValue AdjustedA = DAG.getNode(Opcode: ISD::FADD, DL: SL, VT, N1: AbsA,
2428 N2: DAG.getConstantFP(Val: 0.5, DL: SL, VT));
2429 SDValue RoundedA = DAG.getNode(Opcode: ISD::FTRUNC, DL: SL, VT, Operand: AdjustedA);
2430
2431 // RoundedA = abs(A) < 0.5 ? (double)0 : RoundedA;
2432 EVT SetCCVT = getSetCCResultType(DL: DAG.getDataLayout(), Ctx&: *DAG.getContext(), VT);
2433 SDValue IsSmall =DAG.getSetCC(DL: SL, VT: SetCCVT, LHS: AbsA,
2434 RHS: DAG.getConstantFP(Val: 0.5, DL: SL, VT), Cond: ISD::SETOLT);
2435 RoundedA = DAG.getNode(Opcode: ISD::SELECT, DL: SL, VT, N1: IsSmall,
2436 N2: DAG.getConstantFP(Val: 0, DL: SL, VT),
2437 N3: RoundedA);
2438
2439 // Add sign to rounded_A
2440 RoundedA = DAG.getNode(Opcode: ISD::FCOPYSIGN, DL: SL, VT, N1: RoundedA, N2: A);
2441 DAG.getNode(Opcode: ISD::FTRUNC, DL: SL, VT, Operand: A);
2442
2443 // RoundedA = abs(A) > 0x1.0p52 ? A : RoundedA;
2444 SDValue IsLarge =
2445 DAG.getSetCC(DL: SL, VT: SetCCVT, LHS: AbsA, RHS: DAG.getConstantFP(Val: pow(x: 2.0, y: 52.0), DL: SL, VT),
2446 Cond: ISD::SETOGT);
2447 return DAG.getNode(Opcode: ISD::SELECT, DL: SL, VT, N1: IsLarge, N2: A, N3: RoundedA);
2448}
2449
2450static SDValue PromoteBinOpToF32(SDNode *N, SelectionDAG &DAG) {
2451 EVT VT = N->getValueType(ResNo: 0);
2452 EVT NVT = MVT::f32;
2453 if (VT.isVector()) {
2454 NVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: NVT, EC: VT.getVectorElementCount());
2455 }
2456 SDLoc DL(N);
2457 SDValue Tmp0 = DAG.getFPExtendOrRound(Op: N->getOperand(Num: 0), DL, VT: NVT);
2458 SDValue Tmp1 = DAG.getFPExtendOrRound(Op: N->getOperand(Num: 1), DL, VT: NVT);
2459 SDValue Res = DAG.getNode(Opcode: N->getOpcode(), DL, VT: NVT, N1: Tmp0, N2: Tmp1, Flags: N->getFlags());
2460 return DAG.getFPExtendOrRound(Op: Res, DL, VT);
2461}
2462
2463SDValue NVPTXTargetLowering::PromoteBinOpIfF32FTZ(SDValue Op,
2464 SelectionDAG &DAG) const {
2465 if (useF32FTZ(MF: DAG.getMachineFunction())) {
2466 return PromoteBinOpToF32(N: Op.getNode(), DAG);
2467 }
2468 return Op;
2469}
2470
2471SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
2472 SelectionDAG &DAG) const {
2473 assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
2474
2475 if (Op.getValueType() == MVT::bf16) {
2476 SDLoc Loc(Op);
2477 return DAG.getNode(
2478 Opcode: ISD::FP_ROUND, DL: Loc, VT: MVT::bf16,
2479 N1: DAG.getNode(Opcode: Op.getOpcode(), DL: Loc, VT: MVT::f32, Operand: Op.getOperand(i: 0)),
2480 N2: DAG.getIntPtrConstant(Val: 0, DL: Loc, /*isTarget=*/true));
2481 }
2482
2483 // Everything else is considered legal.
2484 return Op;
2485}
2486
2487SDValue NVPTXTargetLowering::LowerFP_TO_INT(SDValue Op,
2488 SelectionDAG &DAG) const {
2489 assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
2490
2491 if (Op.getOperand(i: 0).getValueType() == MVT::bf16) {
2492 SDLoc Loc(Op);
2493 return DAG.getNode(
2494 Opcode: Op.getOpcode(), DL: Loc, VT: Op.getValueType(),
2495 Operand: DAG.getNode(Opcode: ISD::FP_EXTEND, DL: Loc, VT: MVT::f32, Operand: Op.getOperand(i: 0)));
2496 }
2497
2498 // Everything else is considered legal.
2499 return Op;
2500}
2501
2502SDValue NVPTXTargetLowering::LowerFP_ROUND(SDValue Op,
2503 SelectionDAG &DAG) const {
2504 EVT NarrowVT = Op.getValueType();
2505 SDValue Wide = Op.getOperand(i: 0);
2506 EVT WideVT = Wide.getValueType();
2507 if (NarrowVT.getScalarType() == MVT::bf16) {
2508 const TargetLowering *TLI = STI.getTargetLowering();
2509 if (STI.getSmVersion() < 80 || STI.getPTXVersion() < 70) {
2510 return TLI->expandFP_ROUND(Node: Op.getNode(), DAG);
2511 }
2512 if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
2513 // This combination was the first to support f32 -> bf16.
2514 if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70) {
2515 if (WideVT.getScalarType() == MVT::f32) {
2516 return Op;
2517 }
2518 if (WideVT.getScalarType() == MVT::f64) {
2519 SDLoc Loc(Op);
2520 // Round-inexact-to-odd f64 to f32, then do the final rounding using
2521 // the hardware f32 -> bf16 instruction.
2522 SDValue rod = TLI->expandRoundInexactToOdd(
2523 ResultVT: WideVT.isVector() ? WideVT.changeVectorElementType(EltVT: MVT::f32)
2524 : MVT::f32,
2525 Op: Wide, DL: Loc, DAG);
2526 return DAG.getFPExtendOrRound(Op: rod, DL: Loc, VT: NarrowVT);
2527 }
2528 }
2529 return TLI->expandFP_ROUND(Node: Op.getNode(), DAG);
2530 }
2531 }
2532
2533 // Everything else is considered legal.
2534 return Op;
2535}
2536
2537SDValue NVPTXTargetLowering::LowerFP_EXTEND(SDValue Op,
2538 SelectionDAG &DAG) const {
2539 SDValue Narrow = Op.getOperand(i: 0);
2540 EVT NarrowVT = Narrow.getValueType();
2541 EVT WideVT = Op.getValueType();
2542 if (NarrowVT.getScalarType() == MVT::bf16) {
2543 if (WideVT.getScalarType() == MVT::f32 &&
2544 (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71)) {
2545 SDLoc Loc(Op);
2546 return DAG.getNode(Opcode: ISD::BF16_TO_FP, DL: Loc, VT: WideVT, Operand: Narrow);
2547 }
2548 if (WideVT.getScalarType() == MVT::f64 &&
2549 (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78)) {
2550 EVT F32 = NarrowVT.isVector() ? NarrowVT.changeVectorElementType(EltVT: MVT::f32)
2551 : MVT::f32;
2552 SDLoc Loc(Op);
2553 if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 71) {
2554 Op = DAG.getNode(Opcode: ISD::FP_EXTEND, DL: Loc, VT: F32, Operand: Narrow);
2555 } else {
2556 Op = DAG.getNode(Opcode: ISD::BF16_TO_FP, DL: Loc, VT: F32, Operand: Narrow);
2557 }
2558 return DAG.getNode(Opcode: ISD::FP_EXTEND, DL: Loc, VT: WideVT, Operand: Op);
2559 }
2560 }
2561
2562 // Everything else is considered legal.
2563 return Op;
2564}
2565
2566static SDValue LowerVectorArith(SDValue Op, SelectionDAG &DAG) {
2567 SDLoc DL(Op);
2568 if (Op.getValueType() != MVT::v2i16)
2569 return Op;
2570 EVT EltVT = Op.getValueType().getVectorElementType();
2571 SmallVector<SDValue> VecElements;
2572 for (int I = 0, E = Op.getValueType().getVectorNumElements(); I < E; I++) {
2573 SmallVector<SDValue> ScalarArgs;
2574 llvm::transform(Range: Op->ops(), d_first: std::back_inserter(x&: ScalarArgs),
2575 F: [&](const SDUse &O) {
2576 return DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT,
2577 N1: O.get(), N2: DAG.getIntPtrConstant(Val: I, DL));
2578 });
2579 VecElements.push_back(Elt: DAG.getNode(Opcode: Op.getOpcode(), DL, VT: EltVT, Ops: ScalarArgs));
2580 }
2581 SDValue V =
2582 DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: Op.getValueType(), Ops: VecElements);
2583 return V;
2584}
2585
2586static SDValue LowerTcgen05St(SDValue Op, SelectionDAG &DAG) {
2587 SDNode *N = Op.getNode();
2588 SDLoc DL(N);
2589 SmallVector<SDValue, 32> Ops;
2590
2591 // split the vector argument
2592 for (size_t I = 0; I < N->getNumOperands(); I++) {
2593 SDValue Val = N->getOperand(Num: I);
2594 EVT ValVT = Val.getValueType();
2595 if (ValVT.isVector()) {
2596 EVT EltVT = ValVT.getVectorElementType();
2597 for (unsigned J = 0, NElts = ValVT.getVectorNumElements(); J < NElts; J++)
2598 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: Val,
2599 N2: DAG.getIntPtrConstant(Val: J, DL)));
2600 } else
2601 Ops.push_back(Elt: Val);
2602 }
2603
2604 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
2605 SDValue Tcgen05StNode =
2606 DAG.getMemIntrinsicNode(Opcode: ISD::INTRINSIC_VOID, dl: DL, VTList: N->getVTList(), Ops,
2607 MemVT: MemSD->getMemoryVT(), MMO: MemSD->getMemOperand());
2608
2609 return Tcgen05StNode;
2610}
2611
2612static SDValue LowerIntrinsicVoid(SDValue Op, SelectionDAG &DAG) {
2613 SDNode *N = Op.getNode();
2614 SDValue Intrin = N->getOperand(Num: 1);
2615
2616 // Get the intrinsic ID
2617 unsigned IntrinNo = cast<ConstantSDNode>(Val: Intrin.getNode())->getZExtValue();
2618 switch (IntrinNo) {
2619 default:
2620 break;
2621 case Intrinsic::nvvm_tcgen05_st_16x64b_x1:
2622 case Intrinsic::nvvm_tcgen05_st_16x64b_x2:
2623 case Intrinsic::nvvm_tcgen05_st_16x64b_x4:
2624 case Intrinsic::nvvm_tcgen05_st_16x64b_x8:
2625 case Intrinsic::nvvm_tcgen05_st_16x64b_x16:
2626 case Intrinsic::nvvm_tcgen05_st_16x64b_x32:
2627 case Intrinsic::nvvm_tcgen05_st_16x64b_x128:
2628 case Intrinsic::nvvm_tcgen05_st_16x128b_x1:
2629 case Intrinsic::nvvm_tcgen05_st_16x128b_x2:
2630 case Intrinsic::nvvm_tcgen05_st_16x128b_x4:
2631 case Intrinsic::nvvm_tcgen05_st_16x128b_x8:
2632 case Intrinsic::nvvm_tcgen05_st_16x128b_x16:
2633 case Intrinsic::nvvm_tcgen05_st_16x128b_x32:
2634 case Intrinsic::nvvm_tcgen05_st_16x128b_x64:
2635 case Intrinsic::nvvm_tcgen05_st_16x256b_x1:
2636 case Intrinsic::nvvm_tcgen05_st_16x256b_x2:
2637 case Intrinsic::nvvm_tcgen05_st_16x256b_x4:
2638 case Intrinsic::nvvm_tcgen05_st_16x256b_x8:
2639 case Intrinsic::nvvm_tcgen05_st_16x256b_x16:
2640 case Intrinsic::nvvm_tcgen05_st_16x256b_x32:
2641 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x1:
2642 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x2:
2643 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x4:
2644 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x8:
2645 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x16:
2646 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x32:
2647 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x64:
2648 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x128:
2649 case Intrinsic::nvvm_tcgen05_st_32x32b_x1:
2650 case Intrinsic::nvvm_tcgen05_st_32x32b_x2:
2651 case Intrinsic::nvvm_tcgen05_st_32x32b_x4:
2652 case Intrinsic::nvvm_tcgen05_st_32x32b_x8:
2653 case Intrinsic::nvvm_tcgen05_st_32x32b_x16:
2654 case Intrinsic::nvvm_tcgen05_st_32x32b_x32:
2655 case Intrinsic::nvvm_tcgen05_st_16x64b_x64:
2656 case Intrinsic::nvvm_tcgen05_st_32x32b_x64:
2657 case Intrinsic::nvvm_tcgen05_st_32x32b_x128:
2658 return LowerTcgen05St(Op, DAG);
2659 }
2660 return Op;
2661}
2662
2663static SDValue LowerClusterLaunchControlQueryCancel(SDValue Op,
2664 SelectionDAG &DAG) {
2665
2666 SDNode *N = Op.getNode();
2667 if (N->getOperand(Num: 1).getValueType() != MVT::i128) {
2668 // return, if the operand is already lowered
2669 return SDValue();
2670 }
2671
2672 unsigned IID =
2673 cast<ConstantSDNode>(Val: N->getOperand(Num: 0).getNode())->getZExtValue();
2674 auto Opcode = [&]() {
2675 switch (IID) {
2676 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled:
2677 return NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_IS_CANCELED;
2678 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x:
2679 return NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_X;
2680 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y:
2681 return NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Y;
2682 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z:
2683 return NVPTXISD::CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Z;
2684 default:
2685 llvm_unreachable("unsupported/unhandled intrinsic");
2686 }
2687 }();
2688
2689 SDLoc DL(N);
2690 SDValue TryCancelResponse = N->getOperand(Num: 1);
2691 SDValue Cast = DAG.getNode(Opcode: ISD::BITCAST, DL, VT: MVT::v2i64, Operand: TryCancelResponse);
2692 SDValue TryCancelResponse0 =
2693 DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i64, N1: Cast,
2694 N2: DAG.getIntPtrConstant(Val: 0, DL));
2695 SDValue TryCancelResponse1 =
2696 DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i64, N1: Cast,
2697 N2: DAG.getIntPtrConstant(Val: 1, DL));
2698
2699 return DAG.getNode(Opcode, DL, VTList: N->getVTList(),
2700 Ops: {TryCancelResponse0, TryCancelResponse1});
2701}
2702
2703static SDValue lowerIntrinsicWOChain(SDValue Op, SelectionDAG &DAG) {
2704 switch (Op->getConstantOperandVal(Num: 0)) {
2705 default:
2706 return Op;
2707 case Intrinsic::nvvm_internal_addrspace_wrap:
2708 return Op.getOperand(i: 1);
2709 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled:
2710 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x:
2711 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y:
2712 case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z:
2713 return LowerClusterLaunchControlQueryCancel(Op, DAG);
2714 }
2715}
2716
2717// In PTX 64-bit CTLZ and CTPOP are supported, but they return a 32-bit value.
2718// Lower these into a node returning the correct type which is zero-extended
2719// back to the correct size.
2720static SDValue lowerCTLZCTPOP(SDValue Op, SelectionDAG &DAG) {
2721 SDValue V = Op->getOperand(Num: 0);
2722 assert(V.getValueType() == MVT::i64 &&
2723 "Unexpected CTLZ/CTPOP type to legalize");
2724
2725 SDLoc DL(Op);
2726 SDValue CT = DAG.getNode(Opcode: Op->getOpcode(), DL, VT: MVT::i32, Operand: V);
2727 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: MVT::i64, Operand: CT, Flags: SDNodeFlags::NonNeg);
2728}
2729
2730static SDValue expandFSH64(SDValue A, SDValue B, SDValue ShiftAmount, SDLoc DL,
2731 unsigned Opcode, SelectionDAG &DAG) {
2732 assert(A.getValueType() == MVT::i64 && B.getValueType() == MVT::i64);
2733
2734 const auto *AmtConst = dyn_cast<ConstantSDNode>(Val&: ShiftAmount);
2735 if (!AmtConst)
2736 return SDValue();
2737 const auto Amt = AmtConst->getZExtValue() & 63;
2738
2739 SDValue UnpackA =
2740 DAG.getNode(Opcode: NVPTXISD::UNPACK_VECTOR, DL, ResultTys: {MVT::i32, MVT::i32}, Ops: A);
2741 SDValue UnpackB =
2742 DAG.getNode(Opcode: NVPTXISD::UNPACK_VECTOR, DL, ResultTys: {MVT::i32, MVT::i32}, Ops: B);
2743
2744 // Arch is Little endiain: 0 = low bits, 1 = high bits
2745 SDValue ALo = UnpackA.getValue(R: 0);
2746 SDValue AHi = UnpackA.getValue(R: 1);
2747 SDValue BLo = UnpackB.getValue(R: 0);
2748 SDValue BHi = UnpackB.getValue(R: 1);
2749
2750 // The bitfeild consists of { AHi : ALo : BHi : BLo }
2751 //
2752 // * FSHL, Amt < 32 - The window will contain { AHi : ALo : BHi }
2753 // * FSHL, Amt >= 32 - The window will contain { ALo : BHi : BLo }
2754 // * FSHR, Amt < 32 - The window will contain { ALo : BHi : BLo }
2755 // * FSHR, Amt >= 32 - The window will contain { AHi : ALo : BHi }
2756 //
2757 // Note that Amt = 0 and Amt = 32 are special cases where 32-bit funnel shifts
2758 // are not needed at all. Amt = 0 is a no-op producing either A or B depending
2759 // on the direction. Amt = 32 can be implemented by a packing and unpacking
2760 // move to select and arrange the 32bit values. For simplicity, these cases
2761 // are not handled here explicitly and instead we rely on DAGCombiner to
2762 // remove the no-op funnel shifts we insert.
2763 auto [High, Mid, Low] = ((Opcode == ISD::FSHL) == (Amt < 32))
2764 ? std::make_tuple(args&: AHi, args&: ALo, args&: BHi)
2765 : std::make_tuple(args&: ALo, args&: BHi, args&: BLo);
2766
2767 SDValue NewAmt = DAG.getConstant(Val: Amt & 31, DL, VT: MVT::i32);
2768 SDValue RHi = DAG.getNode(Opcode, DL, VT: MVT::i32, Ops: {High, Mid, NewAmt});
2769 SDValue RLo = DAG.getNode(Opcode, DL, VT: MVT::i32, Ops: {Mid, Low, NewAmt});
2770
2771 return DAG.getNode(Opcode: NVPTXISD::BUILD_VECTOR, DL, VT: MVT::i64, Ops: {RLo, RHi});
2772}
2773
2774static SDValue lowerFSH(SDValue Op, SelectionDAG &DAG) {
2775 return expandFSH64(A: Op->getOperand(Num: 0), B: Op->getOperand(Num: 1), ShiftAmount: Op->getOperand(Num: 2),
2776 DL: SDLoc(Op), Opcode: Op->getOpcode(), DAG);
2777}
2778
2779static SDValue lowerROT(SDValue Op, SelectionDAG &DAG) {
2780 unsigned Opcode = Op->getOpcode() == ISD::ROTL ? ISD::FSHL : ISD::FSHR;
2781 return expandFSH64(A: Op->getOperand(Num: 0), B: Op->getOperand(Num: 0), ShiftAmount: Op->getOperand(Num: 1),
2782 DL: SDLoc(Op), Opcode, DAG);
2783}
2784
2785static SDValue lowerFREM(SDValue Op, SelectionDAG &DAG,
2786 bool AllowUnsafeFPMath) {
2787 // Lower (frem x, y) into (sub x, (mul (ftrunc (div x, y)) y)),
2788 // i.e. "poor man's fmod()". When y is infinite, x is returned. This matches
2789 // the semantics of LLVM's frem.
2790 SDLoc DL(Op);
2791 SDValue X = Op->getOperand(Num: 0);
2792 SDValue Y = Op->getOperand(Num: 1);
2793 EVT Ty = Op.getValueType();
2794 SDNodeFlags Flags = Op->getFlags();
2795
2796 SDValue Div = DAG.getNode(Opcode: ISD::FDIV, DL, VT: Ty, N1: X, N2: Y, Flags);
2797 SDValue Trunc = DAG.getNode(Opcode: ISD::FTRUNC, DL, VT: Ty, Operand: Div, Flags);
2798 SDValue Mul = DAG.getNode(Opcode: ISD::FMUL, DL, VT: Ty, N1: Trunc, N2: Y,
2799 Flags: Flags | SDNodeFlags::AllowContract);
2800 SDValue Sub = DAG.getNode(Opcode: ISD::FSUB, DL, VT: Ty, N1: X, N2: Mul,
2801 Flags: Flags | SDNodeFlags::AllowContract);
2802
2803 if (AllowUnsafeFPMath || Flags.hasNoInfs())
2804 return Sub;
2805
2806 // If Y is infinite, return X
2807 SDValue AbsY = DAG.getNode(Opcode: ISD::FABS, DL, VT: Ty, Operand: Y);
2808 SDValue Inf =
2809 DAG.getConstantFP(Val: APFloat::getInf(Sem: Ty.getFltSemantics()), DL, VT: Ty);
2810 SDValue IsInf = DAG.getSetCC(DL, VT: MVT::i1, LHS: AbsY, RHS: Inf, Cond: ISD::SETEQ);
2811 return DAG.getSelect(DL, VT: Ty, Cond: IsInf, LHS: X, RHS: Sub);
2812}
2813
2814static SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG) {
2815 assert(Op.getValueType() == MVT::i1 && "Custom lowering enabled only for i1");
2816
2817 SDValue Cond = Op->getOperand(Num: 0);
2818 SDValue TrueVal = Op->getOperand(Num: 1);
2819 SDValue FalseVal = Op->getOperand(Num: 2);
2820 SDLoc DL(Op);
2821
2822 // If both operands are truncated, we push the select through the truncates.
2823 if (TrueVal.getOpcode() == ISD::TRUNCATE &&
2824 FalseVal.getOpcode() == ISD::TRUNCATE) {
2825 TrueVal = TrueVal.getOperand(i: 0);
2826 FalseVal = FalseVal.getOperand(i: 0);
2827
2828 EVT VT = TrueVal.getSimpleValueType().bitsLE(VT: FalseVal.getSimpleValueType())
2829 ? TrueVal.getValueType()
2830 : FalseVal.getValueType();
2831 TrueVal = DAG.getAnyExtOrTrunc(Op: TrueVal, DL, VT);
2832 FalseVal = DAG.getAnyExtOrTrunc(Op: FalseVal, DL, VT);
2833 SDValue Select = DAG.getSelect(DL, VT, Cond, LHS: TrueVal, RHS: FalseVal);
2834 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MVT::i1, Operand: Select);
2835 }
2836
2837 // Otherwise, expand the select into a series of logical operations. These
2838 // often can be folded into other operations either by us or ptxas.
2839 TrueVal = DAG.getFreeze(V: TrueVal);
2840 FalseVal = DAG.getFreeze(V: FalseVal);
2841 SDValue And1 = DAG.getNode(Opcode: ISD::AND, DL, VT: MVT::i1, N1: Cond, N2: TrueVal);
2842 SDValue NotCond = DAG.getNOT(DL, Val: Cond, VT: MVT::i1);
2843 SDValue And2 = DAG.getNode(Opcode: ISD::AND, DL, VT: MVT::i1, N1: NotCond, N2: FalseVal);
2844 SDValue Or = DAG.getNode(Opcode: ISD::OR, DL, VT: MVT::i1, N1: And1, N2: And2);
2845 return Or;
2846}
2847
2848SDValue
2849NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
2850 switch (Op.getOpcode()) {
2851 case ISD::RETURNADDR:
2852 return SDValue();
2853 case ISD::FRAMEADDR:
2854 return SDValue();
2855 case ISD::ADDRSPACECAST:
2856 return LowerADDRSPACECAST(Op, DAG);
2857 case ISD::INTRINSIC_W_CHAIN:
2858 return Op;
2859 case ISD::INTRINSIC_WO_CHAIN:
2860 return lowerIntrinsicWOChain(Op, DAG);
2861 case ISD::INTRINSIC_VOID:
2862 return LowerIntrinsicVoid(Op, DAG);
2863 case ISD::BUILD_VECTOR:
2864 return LowerBUILD_VECTOR(Op, DAG);
2865 case ISD::BITCAST:
2866 return LowerBITCAST(Op, DAG);
2867 case ISD::EXTRACT_SUBVECTOR:
2868 return Op;
2869 case ISD::EXTRACT_VECTOR_ELT:
2870 return LowerEXTRACT_VECTOR_ELT(Op, DAG);
2871 case ISD::INSERT_VECTOR_ELT:
2872 return LowerINSERT_VECTOR_ELT(Op, DAG);
2873 case ISD::VECTOR_SHUFFLE:
2874 return LowerVECTOR_SHUFFLE(Op, DAG);
2875 case ISD::CONCAT_VECTORS:
2876 return LowerCONCAT_VECTORS(Op, DAG);
2877 case ISD::STORE:
2878 return LowerSTORE(Op, DAG);
2879 case ISD::LOAD:
2880 return LowerLOAD(Op, DAG);
2881 case ISD::SHL_PARTS:
2882 return LowerShiftLeftParts(Op, DAG);
2883 case ISD::SRA_PARTS:
2884 case ISD::SRL_PARTS:
2885 return LowerShiftRightParts(Op, DAG);
2886 case ISD::SELECT:
2887 return lowerSELECT(Op, DAG);
2888 case ISD::FROUND:
2889 return LowerFROUND(Op, DAG);
2890 case ISD::FCOPYSIGN:
2891 return LowerFCOPYSIGN(Op, DAG);
2892 case ISD::SINT_TO_FP:
2893 case ISD::UINT_TO_FP:
2894 return LowerINT_TO_FP(Op, DAG);
2895 case ISD::FP_TO_SINT:
2896 case ISD::FP_TO_UINT:
2897 return LowerFP_TO_INT(Op, DAG);
2898 case ISD::FP_ROUND:
2899 return LowerFP_ROUND(Op, DAG);
2900 case ISD::FP_EXTEND:
2901 return LowerFP_EXTEND(Op, DAG);
2902 case ISD::BR_JT:
2903 return LowerBR_JT(Op, DAG);
2904 case ISD::VAARG:
2905 return LowerVAARG(Op, DAG);
2906 case ISD::VASTART:
2907 return LowerVASTART(Op, DAG);
2908 case ISD::FSHL:
2909 case ISD::FSHR:
2910 return lowerFSH(Op, DAG);
2911 case ISD::ROTL:
2912 case ISD::ROTR:
2913 return lowerROT(Op, DAG);
2914 case ISD::ABS:
2915 case ISD::SMIN:
2916 case ISD::SMAX:
2917 case ISD::UMIN:
2918 case ISD::UMAX:
2919 case ISD::ADD:
2920 case ISD::SUB:
2921 case ISD::MUL:
2922 case ISD::SHL:
2923 case ISD::SREM:
2924 case ISD::UREM:
2925 return LowerVectorArith(Op, DAG);
2926 case ISD::DYNAMIC_STACKALLOC:
2927 return LowerDYNAMIC_STACKALLOC(Op, DAG);
2928 case ISD::STACKRESTORE:
2929 return LowerSTACKRESTORE(Op, DAG);
2930 case ISD::STACKSAVE:
2931 return LowerSTACKSAVE(Op, DAG);
2932 case ISD::CopyToReg:
2933 return LowerCopyToReg_128(Op, DAG);
2934 case ISD::FADD:
2935 case ISD::FSUB:
2936 case ISD::FMUL:
2937 // Used only for bf16 on SM80, where we select fma for non-ftz operation
2938 return PromoteBinOpIfF32FTZ(Op, DAG);
2939 case ISD::CTPOP:
2940 case ISD::CTLZ:
2941 return lowerCTLZCTPOP(Op, DAG);
2942 case ISD::FREM:
2943 return lowerFREM(Op, DAG, AllowUnsafeFPMath: allowUnsafeFPMath(MF: DAG.getMachineFunction()));
2944
2945 default:
2946 llvm_unreachable("Custom lowering not defined for operation");
2947 }
2948}
2949
2950SDValue NVPTXTargetLowering::LowerBR_JT(SDValue Op, SelectionDAG &DAG) const {
2951 SDLoc DL(Op);
2952 SDValue Chain = Op.getOperand(i: 0);
2953 const auto *JT = cast<JumpTableSDNode>(Val: Op.getOperand(i: 1));
2954 SDValue Index = Op.getOperand(i: 2);
2955
2956 unsigned JId = JT->getIndex();
2957 MachineJumpTableInfo *MJTI = DAG.getMachineFunction().getJumpTableInfo();
2958 ArrayRef<MachineBasicBlock *> MBBs = MJTI->getJumpTables()[JId].MBBs;
2959
2960 SDValue IdV = DAG.getConstant(Val: JId, DL, VT: MVT::i32);
2961
2962 // Generate BrxStart node
2963 SDVTList VTs = DAG.getVTList(VT1: MVT::Other, VT2: MVT::Glue);
2964 Chain = DAG.getNode(Opcode: NVPTXISD::BrxStart, DL, VTList: VTs, N1: Chain, N2: IdV);
2965
2966 // Generate BrxItem nodes
2967 assert(!MBBs.empty());
2968 for (MachineBasicBlock *MBB : MBBs.drop_back())
2969 Chain = DAG.getNode(Opcode: NVPTXISD::BrxItem, DL, VTList: VTs, N1: Chain.getValue(R: 0),
2970 N2: DAG.getBasicBlock(MBB), N3: Chain.getValue(R: 1));
2971
2972 // Generate BrxEnd nodes
2973 SDValue EndOps[] = {Chain.getValue(R: 0), DAG.getBasicBlock(MBB: MBBs.back()), Index,
2974 IdV, Chain.getValue(R: 1)};
2975 SDValue BrxEnd = DAG.getNode(Opcode: NVPTXISD::BrxEnd, DL, VTList: VTs, Ops: EndOps);
2976
2977 return BrxEnd;
2978}
2979
2980// This will prevent AsmPrinter from trying to print the jump tables itself.
2981unsigned NVPTXTargetLowering::getJumpTableEncoding() const {
2982 return MachineJumpTableInfo::EK_Inline;
2983}
2984
2985SDValue NVPTXTargetLowering::LowerADDRSPACECAST(SDValue Op,
2986 SelectionDAG &DAG) const {
2987 AddrSpaceCastSDNode *N = cast<AddrSpaceCastSDNode>(Val: Op.getNode());
2988 unsigned SrcAS = N->getSrcAddressSpace();
2989 unsigned DestAS = N->getDestAddressSpace();
2990 if (SrcAS != llvm::ADDRESS_SPACE_GENERIC &&
2991 DestAS != llvm::ADDRESS_SPACE_GENERIC) {
2992 // Shared and SharedCluster can be converted to each other through generic
2993 // space
2994 if ((SrcAS == llvm::ADDRESS_SPACE_SHARED &&
2995 DestAS == llvm::ADDRESS_SPACE_SHARED_CLUSTER) ||
2996 (SrcAS == llvm::ADDRESS_SPACE_SHARED_CLUSTER &&
2997 DestAS == llvm::ADDRESS_SPACE_SHARED)) {
2998 SDLoc DL(Op.getNode());
2999 const MVT GenerictVT =
3000 getPointerTy(DL: DAG.getDataLayout(), AS: ADDRESS_SPACE_GENERIC);
3001 SDValue GenericConversion = DAG.getAddrSpaceCast(
3002 dl: DL, VT: GenerictVT, Ptr: Op.getOperand(i: 0), SrcAS, DestAS: ADDRESS_SPACE_GENERIC);
3003 SDValue SharedClusterConversion =
3004 DAG.getAddrSpaceCast(dl: DL, VT: Op.getValueType(), Ptr: GenericConversion,
3005 SrcAS: ADDRESS_SPACE_GENERIC, DestAS);
3006 return SharedClusterConversion;
3007 }
3008
3009 return DAG.getUNDEF(VT: Op.getValueType());
3010 }
3011
3012 return Op;
3013}
3014
3015// This function is almost a copy of SelectionDAG::expandVAArg().
3016// The only diff is that this one produces loads from local address space.
3017SDValue NVPTXTargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {
3018 const TargetLowering *TLI = STI.getTargetLowering();
3019 SDLoc DL(Op);
3020
3021 SDNode *Node = Op.getNode();
3022 const Value *V = cast<SrcValueSDNode>(Val: Node->getOperand(Num: 2))->getValue();
3023 EVT VT = Node->getValueType(ResNo: 0);
3024 auto *Ty = VT.getTypeForEVT(Context&: *DAG.getContext());
3025 SDValue Tmp1 = Node->getOperand(Num: 0);
3026 SDValue Tmp2 = Node->getOperand(Num: 1);
3027 const MaybeAlign MA(Node->getConstantOperandVal(Num: 3));
3028
3029 SDValue VAListLoad = DAG.getLoad(VT: TLI->getPointerTy(DL: DAG.getDataLayout()), dl: DL,
3030 Chain: Tmp1, Ptr: Tmp2, PtrInfo: MachinePointerInfo(V));
3031 SDValue VAList = VAListLoad;
3032
3033 if (MA && *MA > TLI->getMinStackArgumentAlignment()) {
3034 VAList = DAG.getNode(
3035 Opcode: ISD::ADD, DL, VT: VAList.getValueType(), N1: VAList,
3036 N2: DAG.getConstant(Val: MA->value() - 1, DL, VT: VAList.getValueType()));
3037
3038 VAList = DAG.getNode(Opcode: ISD::AND, DL, VT: VAList.getValueType(), N1: VAList,
3039 N2: DAG.getSignedConstant(Val: -(int64_t)MA->value(), DL,
3040 VT: VAList.getValueType()));
3041 }
3042
3043 // Increment the pointer, VAList, to the next vaarg
3044 Tmp1 = DAG.getNode(Opcode: ISD::ADD, DL, VT: VAList.getValueType(), N1: VAList,
3045 N2: DAG.getConstant(Val: DAG.getDataLayout().getTypeAllocSize(Ty),
3046 DL, VT: VAList.getValueType()));
3047
3048 // Store the incremented VAList to the legalized pointer
3049 Tmp1 = DAG.getStore(Chain: VAListLoad.getValue(R: 1), dl: DL, Val: Tmp1, Ptr: Tmp2,
3050 PtrInfo: MachinePointerInfo(V));
3051
3052 const Value *SrcV = Constant::getNullValue(
3053 Ty: PointerType::get(C&: *DAG.getContext(), AddressSpace: ADDRESS_SPACE_LOCAL));
3054
3055 // Load the actual argument out of the pointer VAList
3056 return DAG.getLoad(VT, dl: DL, Chain: Tmp1, Ptr: VAList, PtrInfo: MachinePointerInfo(SrcV));
3057}
3058
3059SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
3060 const TargetLowering *TLI = STI.getTargetLowering();
3061 SDLoc DL(Op);
3062 EVT PtrVT = TLI->getPointerTy(DL: DAG.getDataLayout());
3063
3064 // Store the address of unsized array <function>_vararg[] in the ap object.
3065 SDValue VAReg = getParamSymbol(DAG, /* vararg */ I: -1, T: PtrVT);
3066
3067 const Value *SV = cast<SrcValueSDNode>(Val: Op.getOperand(i: 2))->getValue();
3068 return DAG.getStore(Chain: Op.getOperand(i: 0), dl: DL, Val: VAReg, Ptr: Op.getOperand(i: 1),
3069 PtrInfo: MachinePointerInfo(SV));
3070}
3071
3072SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
3073 if (Op.getValueType() == MVT::i1)
3074 return LowerLOADi1(Op, DAG);
3075
3076 // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
3077 // unaligned loads and have to handle it here.
3078 EVT VT = Op.getValueType();
3079 if (Isv2x16VT(VT) || VT == MVT::v4i8) {
3080 LoadSDNode *Load = cast<LoadSDNode>(Val&: Op);
3081 EVT MemVT = Load->getMemoryVT();
3082 if (!allowsMemoryAccessForAlignment(Context&: *DAG.getContext(), DL: DAG.getDataLayout(),
3083 VT: MemVT, MMO: *Load->getMemOperand())) {
3084 SDValue Ops[2];
3085 std::tie(args&: Ops[0], args&: Ops[1]) = expandUnalignedLoad(LD: Load, DAG);
3086 return DAG.getMergeValues(Ops, dl: SDLoc(Op));
3087 }
3088 }
3089
3090 return SDValue();
3091}
3092
3093// v = ld i1* addr
3094// =>
3095// v1 = ld i8* addr (-> i16)
3096// v = trunc i16 to i1
3097SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const {
3098 SDNode *Node = Op.getNode();
3099 LoadSDNode *LD = cast<LoadSDNode>(Val: Node);
3100 SDLoc dl(Node);
3101 assert(LD->getExtensionType() == ISD::NON_EXTLOAD);
3102 assert(Node->getValueType(0) == MVT::i1 &&
3103 "Custom lowering for i1 load only");
3104 SDValue newLD = DAG.getExtLoad(ExtType: ISD::ZEXTLOAD, dl, VT: MVT::i16, Chain: LD->getChain(),
3105 Ptr: LD->getBasePtr(), PtrInfo: LD->getPointerInfo(),
3106 MemVT: MVT::i8, Alignment: LD->getAlign(),
3107 MMOFlags: LD->getMemOperand()->getFlags());
3108 SDValue result = DAG.getNode(Opcode: ISD::TRUNCATE, DL: dl, VT: MVT::i1, Operand: newLD);
3109 // The legalizer (the caller) is expecting two values from the legalized
3110 // load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
3111 // in LegalizeDAG.cpp which also uses MergeValues.
3112 SDValue Ops[] = { result, LD->getChain() };
3113 return DAG.getMergeValues(Ops, dl);
3114}
3115
3116SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
3117 StoreSDNode *Store = cast<StoreSDNode>(Val&: Op);
3118 EVT VT = Store->getMemoryVT();
3119
3120 if (VT == MVT::i1)
3121 return LowerSTOREi1(Op, DAG);
3122
3123 // v2f16 is legal, so we can't rely on legalizer to handle unaligned
3124 // stores and have to handle it here.
3125 if ((Isv2x16VT(VT) || VT == MVT::v4i8) &&
3126 !allowsMemoryAccessForAlignment(Context&: *DAG.getContext(), DL: DAG.getDataLayout(),
3127 VT, MMO: *Store->getMemOperand()))
3128 return expandUnalignedStore(ST: Store, DAG);
3129
3130 // v2f16, v2bf16 and v2i16 don't need special handling.
3131 if (Isv2x16VT(VT) || VT == MVT::v4i8)
3132 return SDValue();
3133
3134 return LowerSTOREVector(Op, DAG);
3135}
3136
3137SDValue
3138NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
3139 MemSDNode *N = cast<MemSDNode>(Val: Op.getNode());
3140 SDValue Val = N->getOperand(Num: 1);
3141 SDLoc DL(N);
3142 const EVT ValVT = Val.getValueType();
3143 const EVT MemVT = N->getMemoryVT();
3144
3145 // If we're truncating as part of the store, avoid lowering to a StoreV node.
3146 // TODO: consider relaxing this restriction.
3147 if (ValVT != MemVT)
3148 return SDValue();
3149
3150 const auto NumEltsAndEltVT = getVectorLoweringShape(
3151 VectorEVT: ValVT, CanLowerTo256Bit: STI.has256BitVectorLoadStore(AS: N->getAddressSpace()));
3152 if (!NumEltsAndEltVT)
3153 return SDValue();
3154 const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
3155
3156 const DataLayout &TD = DAG.getDataLayout();
3157
3158 Align Alignment = N->getAlign();
3159 Align PrefAlign = TD.getPrefTypeAlign(Ty: ValVT.getTypeForEVT(Context&: *DAG.getContext()));
3160 if (Alignment < PrefAlign) {
3161 // This store is not sufficiently aligned, so bail out and let this vector
3162 // store be scalarized. Note that we may still be able to emit smaller
3163 // vector stores. For example, if we are storing a <4 x float> with an
3164 // alignment of 8, this check will fail but the legalizer will try again
3165 // with 2 x <2 x float>, which will succeed with an alignment of 8.
3166 return SDValue();
3167 }
3168
3169 unsigned Opcode;
3170 switch (NumElts) {
3171 default:
3172 return SDValue();
3173 case 2:
3174 Opcode = NVPTXISD::StoreV2;
3175 break;
3176 case 4:
3177 Opcode = NVPTXISD::StoreV4;
3178 break;
3179 case 8:
3180 Opcode = NVPTXISD::StoreV8;
3181 break;
3182 }
3183
3184 SmallVector<SDValue, 8> Ops;
3185
3186 // First is the chain
3187 Ops.push_back(Elt: N->getOperand(Num: 0));
3188
3189 // Then the split values
3190 if (EltVT.isVector()) {
3191 assert(EVT(EltVT.getVectorElementType()) == ValVT.getVectorElementType());
3192 assert(NumElts * EltVT.getVectorNumElements() ==
3193 ValVT.getVectorNumElements());
3194 // Combine individual elements into v2[i,f,bf]16/v4i8 subvectors to be
3195 // stored as b32s
3196 const unsigned NumEltsPerSubVector = EltVT.getVectorNumElements();
3197 for (const unsigned I : llvm::seq(Size: NumElts)) {
3198 SmallVector<SDValue, 4> SubVectorElts;
3199 DAG.ExtractVectorElements(Op: Val, Args&: SubVectorElts, Start: I * NumEltsPerSubVector,
3200 Count: NumEltsPerSubVector);
3201 Ops.push_back(Elt: DAG.getBuildVector(VT: EltVT, DL, Ops: SubVectorElts));
3202 }
3203 } else {
3204 SDValue V = DAG.getBitcast(VT: MVT::getVectorVT(VT: EltVT, NumElements: NumElts), V: Val);
3205 for (const unsigned I : llvm::seq(Size: NumElts)) {
3206 SDValue ExtVal = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: V,
3207 N2: DAG.getIntPtrConstant(Val: I, DL));
3208
3209 // Since StoreV2 is a target node, we cannot rely on DAG type
3210 // legalization. Therefore, we must ensure the type is legal. For i1 and
3211 // i8, we set the stored type to i16 and propagate the "real" type as the
3212 // memory type.
3213 if (EltVT.getSizeInBits() < 16)
3214 ExtVal = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT: MVT::i16, Operand: ExtVal);
3215 Ops.push_back(Elt: ExtVal);
3216 }
3217 }
3218
3219 // Then any remaining arguments
3220 Ops.append(in_start: N->op_begin() + 2, in_end: N->op_end());
3221
3222 SDValue NewSt =
3223 DAG.getMemIntrinsicNode(Opcode, dl: DL, VTList: DAG.getVTList(VT: MVT::Other), Ops,
3224 MemVT: N->getMemoryVT(), MMO: N->getMemOperand());
3225
3226 // return DCI.CombineTo(N, NewSt, true);
3227 return NewSt;
3228}
3229
3230// st i1 v, addr
3231// =>
3232// v1 = zxt v to i16
3233// st.u8 i16, addr
3234SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
3235 SDNode *Node = Op.getNode();
3236 SDLoc dl(Node);
3237 StoreSDNode *ST = cast<StoreSDNode>(Val: Node);
3238 SDValue Tmp1 = ST->getChain();
3239 SDValue Tmp2 = ST->getBasePtr();
3240 SDValue Tmp3 = ST->getValue();
3241 assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only");
3242 Tmp3 = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL: dl, VT: MVT::i16, Operand: Tmp3);
3243 SDValue Result =
3244 DAG.getTruncStore(Chain: Tmp1, dl, Val: Tmp3, Ptr: Tmp2, PtrInfo: ST->getPointerInfo(), SVT: MVT::i8,
3245 Alignment: ST->getAlign(), MMOFlags: ST->getMemOperand()->getFlags());
3246 return Result;
3247}
3248
3249SDValue NVPTXTargetLowering::LowerCopyToReg_128(SDValue Op,
3250 SelectionDAG &DAG) const {
3251 // Change the CopyToReg to take in two 64-bit operands instead of a 128-bit
3252 // operand so that it can pass the legalization.
3253
3254 assert(Op.getOperand(1).getValueType() == MVT::i128 &&
3255 "Custom lowering for 128-bit CopyToReg only");
3256
3257 SDNode *Node = Op.getNode();
3258 SDLoc DL(Node);
3259
3260 SDValue Cast = DAG.getBitcast(VT: MVT::v2i64, V: Op->getOperand(Num: 2));
3261 SDValue Lo = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i64, N1: Cast,
3262 N2: DAG.getIntPtrConstant(Val: 0, DL));
3263 SDValue Hi = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i64, N1: Cast,
3264 N2: DAG.getIntPtrConstant(Val: 1, DL));
3265
3266 SmallVector<SDValue, 5> NewOps(Op->getNumOperands() + 1);
3267 SmallVector<EVT, 3> ResultsType(Node->values());
3268
3269 NewOps[0] = Op->getOperand(Num: 0); // Chain
3270 NewOps[1] = Op->getOperand(Num: 1); // Dst Reg
3271 NewOps[2] = Lo; // Lower 64-bit
3272 NewOps[3] = Hi; // Higher 64-bit
3273 if (Op.getNumOperands() == 4)
3274 NewOps[4] = Op->getOperand(Num: 3); // Glue if exists
3275
3276 return DAG.getNode(Opcode: ISD::CopyToReg, DL, ResultTys: ResultsType, Ops: NewOps);
3277}
3278
3279unsigned NVPTXTargetLowering::getNumRegisters(
3280 LLVMContext &Context, EVT VT,
3281 std::optional<MVT> RegisterVT = std::nullopt) const {
3282 if (VT == MVT::i128 && RegisterVT == MVT::i128)
3283 return 1;
3284 return TargetLoweringBase::getNumRegisters(Context, VT, RegisterVT);
3285}
3286
3287bool NVPTXTargetLowering::splitValueIntoRegisterParts(
3288 SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
3289 unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
3290 if (Val.getValueType() == MVT::i128 && NumParts == 1) {
3291 Parts[0] = Val;
3292 return true;
3293 }
3294 return false;
3295}
3296
3297// This creates target external symbol for a function parameter.
3298// Name of the symbol is composed from its index and the function name.
3299// Negative index corresponds to special parameter (unsized array) used for
3300// passing variable arguments.
3301SDValue NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int I,
3302 EVT T) const {
3303 StringRef SavedStr = nvTM->getStrPool().save(
3304 S: getParamName(F: &DAG.getMachineFunction().getFunction(), Idx: I));
3305 return DAG.getExternalSymbol(Sym: SavedStr.data(), VT: T);
3306}
3307
3308SDValue NVPTXTargetLowering::getCallParamSymbol(SelectionDAG &DAG, int I,
3309 EVT T) const {
3310 const StringRef SavedStr = nvTM->getStrPool().save(S: "param" + Twine(I));
3311 return DAG.getExternalSymbol(Sym: SavedStr.data(), VT: T);
3312}
3313
3314SDValue NVPTXTargetLowering::LowerFormalArguments(
3315 SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
3316 const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &dl,
3317 SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
3318 MachineFunction &MF = DAG.getMachineFunction();
3319 const DataLayout &DL = DAG.getDataLayout();
3320 auto PtrVT = getPointerTy(DL: DAG.getDataLayout());
3321
3322 const Function *F = &MF.getFunction();
3323
3324 SDValue Root = DAG.getRoot();
3325 SmallVector<SDValue, 16> OutChains;
3326
3327 // argTypes.size() (or theArgs.size()) and Ins.size() need not match.
3328 // Ins.size() will be larger
3329 // * if there is an aggregate argument with multiple fields (each field
3330 // showing up separately in Ins)
3331 // * if there is a vector argument with more than typical vector-length
3332 // elements (generally if more than 4) where each vector element is
3333 // individually present in Ins.
3334 // So a different index should be used for indexing into Ins.
3335 // See similar issue in LowerCall.
3336
3337 auto AllIns = ArrayRef(Ins);
3338 for (const auto &Arg : F->args()) {
3339 const auto ArgIns = AllIns.take_while(
3340 Pred: [&](auto I) { return I.OrigArgIndex == Arg.getArgNo(); });
3341 AllIns = AllIns.drop_front(N: ArgIns.size());
3342
3343 Type *Ty = Arg.getType();
3344
3345 if (ArgIns.empty())
3346 report_fatal_error(reason: "Empty parameter types are not supported");
3347
3348 if (Arg.use_empty()) {
3349 // argument is dead
3350 for (const auto &In : ArgIns) {
3351 assert(!In.Used && "Arg.use_empty() is true but Arg is used?");
3352 InVals.push_back(Elt: DAG.getUNDEF(VT: In.VT));
3353 }
3354 continue;
3355 }
3356
3357 SDValue ArgSymbol = getParamSymbol(DAG, I: Arg.getArgNo(), T: PtrVT);
3358
3359 // In the following cases, assign a node order of "i+1"
3360 // to newly created nodes. The SDNodes for params have to
3361 // appear in the same order as their order of appearance
3362 // in the original function. "i+1" holds that order.
3363 if (Arg.hasByValAttr()) {
3364 // Param has ByVal attribute
3365 // Return MoveParam(param symbol).
3366 // Ideally, the param symbol can be returned directly,
3367 // but when SDNode builder decides to use it in a CopyToReg(),
3368 // machine instruction fails because TargetExternalSymbol
3369 // (not lowered) is target dependent, and CopyToReg assumes
3370 // the source is lowered.
3371 assert(ArgIns.size() == 1 && "ByVal argument must be a pointer");
3372 const auto &ByvalIn = ArgIns[0];
3373 assert(getValueType(DL, Ty) == ByvalIn.VT &&
3374 "Ins type did not match function type");
3375 assert(ByvalIn.VT == PtrVT && "ByVal argument must be a pointer");
3376
3377 SDValue P;
3378 if (isKernelFunction(F: *F)) {
3379 P = ArgSymbol;
3380 P.getNode()->setIROrder(Arg.getArgNo() + 1);
3381 } else {
3382 P = DAG.getNode(Opcode: NVPTXISD::MoveParam, DL: dl, VT: ByvalIn.VT, Operand: ArgSymbol);
3383 P.getNode()->setIROrder(Arg.getArgNo() + 1);
3384 P = DAG.getAddrSpaceCast(dl, VT: ByvalIn.VT, Ptr: P, SrcAS: ADDRESS_SPACE_LOCAL,
3385 DestAS: ADDRESS_SPACE_GENERIC);
3386 }
3387 InVals.push_back(Elt: P);
3388 } else {
3389 SmallVector<EVT, 16> VTs;
3390 SmallVector<uint64_t, 16> Offsets;
3391 ComputePTXValueVTs(TLI: *this, DL, Ty, ValueVTs&: VTs, Offsets: &Offsets, StartingOffset: 0);
3392 assert(VTs.size() == ArgIns.size() && "Size mismatch");
3393 assert(VTs.size() == Offsets.size() && "Size mismatch");
3394
3395 const Align ArgAlign = getFunctionArgumentAlignment(
3396 F, Ty, Idx: Arg.getArgNo() + AttributeList::FirstArgIndex, DL);
3397
3398 const auto VectorInfo = VectorizePTXValueVTs(ValueVTs: VTs, Offsets, ParamAlignment: ArgAlign);
3399 unsigned I = 0;
3400 for (const unsigned NumElts : VectorInfo) {
3401 // i1 is loaded/stored as i8
3402 const EVT LoadVT = VTs[I] == MVT::i1 ? MVT::i8 : VTs[I];
3403 // If the element is a packed type (ex. v2f16, v4i8, etc) holding
3404 // multiple elements.
3405 const unsigned PackingAmt =
3406 LoadVT.isVector() ? LoadVT.getVectorNumElements() : 1;
3407
3408 const EVT VecVT =
3409 NumElts == 1
3410 ? LoadVT
3411 : EVT::getVectorVT(Context&: F->getContext(), VT: LoadVT.getScalarType(),
3412 NumElements: NumElts * PackingAmt);
3413
3414 SDValue VecAddr = DAG.getObjectPtrOffset(
3415 SL: dl, Ptr: ArgSymbol, Offset: TypeSize::getFixed(ExactSize: Offsets[I]));
3416
3417 const MaybeAlign PartAlign = commonAlignment(A: ArgAlign, Offset: Offsets[I]);
3418 SDValue P =
3419 DAG.getLoad(VT: VecVT, dl, Chain: Root, Ptr: VecAddr,
3420 PtrInfo: MachinePointerInfo(ADDRESS_SPACE_PARAM), Alignment: PartAlign,
3421 MMOFlags: MachineMemOperand::MODereferenceable |
3422 MachineMemOperand::MOInvariant);
3423 if (P.getNode())
3424 P.getNode()->setIROrder(Arg.getArgNo() + 1);
3425 for (const unsigned J : llvm::seq(Size: NumElts)) {
3426 SDValue Elt =
3427 NumElts == 1
3428 ? P
3429 : DAG.getNode(Opcode: LoadVT.isVector() ? ISD::EXTRACT_SUBVECTOR
3430 : ISD::EXTRACT_VECTOR_ELT,
3431 DL: dl, VT: LoadVT, N1: P,
3432 N2: DAG.getVectorIdxConstant(Val: J * PackingAmt, DL: dl));
3433
3434 Elt = correctParamType(V: Elt, ExpectedVT: ArgIns[I + J].VT, Flags: ArgIns[I + J].Flags,
3435 DAG, dl);
3436 InVals.push_back(Elt);
3437 }
3438 I += NumElts;
3439 }
3440 }
3441 }
3442
3443 if (!OutChains.empty())
3444 DAG.setRoot(DAG.getTokenFactor(DL: dl, Vals&: OutChains));
3445
3446 return Chain;
3447}
3448
3449SDValue
3450NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
3451 bool isVarArg,
3452 const SmallVectorImpl<ISD::OutputArg> &Outs,
3453 const SmallVectorImpl<SDValue> &OutVals,
3454 const SDLoc &dl, SelectionDAG &DAG) const {
3455 const MachineFunction &MF = DAG.getMachineFunction();
3456 const Function &F = MF.getFunction();
3457 Type *RetTy = MF.getFunction().getReturnType();
3458
3459 if (RetTy->isVoidTy()) {
3460 assert(OutVals.empty() && Outs.empty() && "Return value expected for void");
3461 return DAG.getNode(Opcode: NVPTXISD::RET_GLUE, DL: dl, VT: MVT::Other, Operand: Chain);
3462 }
3463
3464 const DataLayout &DL = DAG.getDataLayout();
3465 SmallVector<EVT, 16> VTs;
3466 SmallVector<uint64_t, 16> Offsets;
3467 ComputePTXValueVTs(TLI: *this, DL, Ty: RetTy, ValueVTs&: VTs, Offsets: &Offsets);
3468 assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
3469
3470 // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
3471 // 32-bits are sign extended or zero extended, depending on whether
3472 // they are signed or unsigned types.
3473 const bool ExtendIntegerRetVal =
3474 RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty: RetTy) < 32;
3475
3476 const auto GetRetVal = [&](unsigned I) -> SDValue {
3477 SDValue RetVal = OutVals[I];
3478 assert(promoteScalarIntegerPTX(RetVal.getValueType()) ==
3479 RetVal.getValueType() &&
3480 "OutVal type should always be legal");
3481
3482 const EVT VTI = promoteScalarIntegerPTX(VT: VTs[I]);
3483 const EVT StoreVT =
3484 ExtendIntegerRetVal ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
3485 return correctParamType(V: RetVal, ExpectedVT: StoreVT, Flags: Outs[I].Flags, DAG, dl);
3486 };
3487
3488 const auto RetAlign = getFunctionParamOptimizedAlign(F: &F, ArgTy: RetTy, DL);
3489 const auto VectorInfo = VectorizePTXValueVTs(ValueVTs: VTs, Offsets, ParamAlignment: RetAlign);
3490 unsigned I = 0;
3491 for (const unsigned NumElts : VectorInfo) {
3492 const MaybeAlign CurrentAlign = ExtendIntegerRetVal
3493 ? MaybeAlign(std::nullopt)
3494 : commonAlignment(A: RetAlign, Offset: Offsets[I]);
3495
3496 SDValue Val;
3497 if (NumElts == 1) {
3498 Val = GetRetVal(I);
3499 } else {
3500 SmallVector<SDValue, 4> StoreVals;
3501 for (const unsigned J : llvm::seq(Size: NumElts)) {
3502 SDValue ValJ = GetRetVal(I + J);
3503 if (ValJ.getValueType().isVector())
3504 DAG.ExtractVectorElements(Op: ValJ, Args&: StoreVals);
3505 else
3506 StoreVals.push_back(Elt: ValJ);
3507 }
3508
3509 EVT VT = EVT::getVectorVT(Context&: F.getContext(), VT: StoreVals[0].getValueType(),
3510 NumElements: StoreVals.size());
3511 Val = DAG.getBuildVector(VT, DL: dl, Ops: StoreVals);
3512 }
3513
3514 const SDValue RetSymbol = DAG.getExternalSymbol(Sym: "func_retval0", VT: MVT::i32);
3515 SDValue Ptr =
3516 DAG.getObjectPtrOffset(SL: dl, Ptr: RetSymbol, Offset: TypeSize::getFixed(ExactSize: Offsets[I]));
3517
3518 Chain = DAG.getStore(Chain, dl, Val, Ptr,
3519 PtrInfo: MachinePointerInfo(ADDRESS_SPACE_PARAM), Alignment: CurrentAlign);
3520
3521 I += NumElts;
3522 }
3523
3524 return DAG.getNode(Opcode: NVPTXISD::RET_GLUE, DL: dl, VT: MVT::Other, Operand: Chain);
3525}
3526
3527void NVPTXTargetLowering::LowerAsmOperandForConstraint(
3528 SDValue Op, StringRef Constraint, std::vector<SDValue> &Ops,
3529 SelectionDAG &DAG) const {
3530 if (Constraint.size() > 1)
3531 return;
3532 TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
3533}
3534
3535// llvm.ptx.memcpy.const and llvm.ptx.memmove.const need to be modeled as
3536// TgtMemIntrinsic
3537// because we need the information that is only available in the "Value" type
3538// of destination
3539// pointer. In particular, the address space information.
3540bool NVPTXTargetLowering::getTgtMemIntrinsic(
3541 IntrinsicInfo &Info, const CallInst &I,
3542 MachineFunction &MF, unsigned Intrinsic) const {
3543 switch (Intrinsic) {
3544 default:
3545 return false;
3546 case Intrinsic::nvvm_match_all_sync_i32p:
3547 case Intrinsic::nvvm_match_all_sync_i64p:
3548 Info.opc = ISD::INTRINSIC_W_CHAIN;
3549 // memVT is bogus. These intrinsics have IntrInaccessibleMemOnly attribute
3550 // in order to model data exchange with other threads, but perform no real
3551 // memory accesses.
3552 Info.memVT = MVT::i1;
3553
3554 // Our result depends on both our and other thread's arguments.
3555 Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOStore;
3556 return true;
3557 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col:
3558 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row:
3559 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col_stride:
3560 case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row_stride:
3561 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col:
3562 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row:
3563 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col_stride:
3564 case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride:
3565 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col:
3566 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row:
3567 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col_stride:
3568 case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row_stride:
3569 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col:
3570 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row:
3571 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col_stride:
3572 case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row_stride:
3573 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col:
3574 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row:
3575 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col_stride:
3576 case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row_stride:
3577 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col:
3578 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row:
3579 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col_stride:
3580 case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row_stride: {
3581 Info.opc = ISD::INTRINSIC_W_CHAIN;
3582 Info.memVT = MVT::v8f16;
3583 Info.ptrVal = I.getArgOperand(i: 0);
3584 Info.offset = 0;
3585 Info.flags = MachineMemOperand::MOLoad;
3586 Info.align = Align(16);
3587 return true;
3588 }
3589 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col:
3590 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col_stride:
3591 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col_stride:
3592 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col:
3593 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row:
3594 case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row_stride:
3595 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row_stride:
3596 case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row:
3597 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col:
3598 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col_stride:
3599 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row:
3600 case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row_stride:
3601 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col:
3602 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col_stride:
3603 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col_stride:
3604 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col:
3605 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row:
3606 case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row_stride:
3607 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row_stride:
3608 case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row:
3609 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col:
3610 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col_stride:
3611 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row:
3612 case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row_stride: {
3613 Info.opc = ISD::INTRINSIC_W_CHAIN;
3614 Info.memVT = MVT::v2i32;
3615 Info.ptrVal = I.getArgOperand(i: 0);
3616 Info.offset = 0;
3617 Info.flags = MachineMemOperand::MOLoad;
3618 Info.align = Align(8);
3619 return true;
3620 }
3621
3622 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col:
3623 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col_stride:
3624 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col_stride:
3625 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col:
3626 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row:
3627 case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row_stride:
3628 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row_stride:
3629 case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row:
3630 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col:
3631 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col_stride:
3632 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row:
3633 case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row_stride:
3634 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col:
3635 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col_stride:
3636 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row:
3637 case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row_stride:
3638
3639 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col:
3640 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col_stride:
3641 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col_stride:
3642 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col:
3643 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row:
3644 case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row_stride:
3645 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row_stride:
3646 case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row:
3647 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col:
3648 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col_stride:
3649 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row:
3650 case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row_stride:
3651 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col:
3652 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col_stride:
3653 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row:
3654 case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride:
3655 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16:
3656 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16:
3657 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8:
3658 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64:
3659 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32:
3660 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64:
3661 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32: {
3662 Info.opc = ISD::INTRINSIC_W_CHAIN;
3663 Info.memVT = MVT::v4i32;
3664 Info.ptrVal = I.getArgOperand(i: 0);
3665 Info.offset = 0;
3666 Info.flags = MachineMemOperand::MOLoad;
3667 Info.align = Align(16);
3668 return true;
3669 }
3670
3671 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col:
3672 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col_stride:
3673 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col_stride:
3674 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col:
3675 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row:
3676 case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row_stride:
3677 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row_stride:
3678 case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row:
3679
3680 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col:
3681 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col_stride:
3682 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col_stride:
3683 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col:
3684 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row:
3685 case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row_stride:
3686 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row_stride:
3687 case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row:
3688 case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row:
3689 case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row_stride:
3690 case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col:
3691 case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col_stride:
3692 case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row:
3693 case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row_stride:
3694 case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row_stride:
3695 case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row:
3696 case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col:
3697 case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col_stride:
3698 case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col_stride:
3699 case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col:
3700 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16:
3701 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16:
3702 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64:
3703 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32: {
3704 Info.opc = ISD::INTRINSIC_W_CHAIN;
3705 Info.memVT = MVT::i32;
3706 Info.ptrVal = I.getArgOperand(i: 0);
3707 Info.offset = 0;
3708 Info.flags = MachineMemOperand::MOLoad;
3709 Info.align = Align(4);
3710 return true;
3711 }
3712
3713 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col:
3714 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row:
3715 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col_stride:
3716 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride:
3717 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col:
3718 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row:
3719 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col_stride:
3720 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row_stride:
3721 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col:
3722 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row:
3723 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col_stride:
3724 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row_stride: {
3725 Info.opc = ISD::INTRINSIC_W_CHAIN;
3726 Info.memVT = MVT::v4f16;
3727 Info.ptrVal = I.getArgOperand(i: 0);
3728 Info.offset = 0;
3729 Info.flags = MachineMemOperand::MOLoad;
3730 Info.align = Align(16);
3731 return true;
3732 }
3733
3734 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col:
3735 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row:
3736 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col_stride:
3737 case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride:
3738 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col:
3739 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row:
3740 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col_stride:
3741 case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row_stride:
3742 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col:
3743 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row:
3744 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col_stride:
3745 case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride:
3746 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col:
3747 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row:
3748 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col_stride:
3749 case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row_stride: {
3750 Info.opc = ISD::INTRINSIC_W_CHAIN;
3751 Info.memVT = MVT::v8f32;
3752 Info.ptrVal = I.getArgOperand(i: 0);
3753 Info.offset = 0;
3754 Info.flags = MachineMemOperand::MOLoad;
3755 Info.align = Align(16);
3756 return true;
3757 }
3758
3759 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col:
3760 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col_stride:
3761 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row:
3762 case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row_stride:
3763
3764 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col:
3765 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col_stride:
3766 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row:
3767 case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row_stride:
3768
3769 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col:
3770 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col_stride:
3771 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row:
3772 case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row_stride:
3773 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col:
3774 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col_stride:
3775 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row:
3776 case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row_stride:
3777 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col:
3778 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col_stride:
3779 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row:
3780 case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row_stride: {
3781 Info.opc = ISD::INTRINSIC_W_CHAIN;
3782 Info.memVT = MVT::v8i32;
3783 Info.ptrVal = I.getArgOperand(i: 0);
3784 Info.offset = 0;
3785 Info.flags = MachineMemOperand::MOLoad;
3786 Info.align = Align(16);
3787 return true;
3788 }
3789
3790 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col:
3791 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col_stride:
3792 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row:
3793 case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row_stride:
3794 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col:
3795 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col_stride:
3796 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row:
3797 case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride:
3798 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16:
3799 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16:
3800 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8:
3801 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64:
3802 case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32:
3803 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64:
3804 case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32: {
3805 Info.opc = ISD::INTRINSIC_W_CHAIN;
3806 Info.memVT = MVT::v2i32;
3807 Info.ptrVal = I.getArgOperand(i: 0);
3808 Info.offset = 0;
3809 Info.flags = MachineMemOperand::MOLoad;
3810 Info.align = Align(8);
3811 return true;
3812 }
3813
3814 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col:
3815 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col_stride:
3816 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row:
3817 case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row_stride:
3818
3819 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col:
3820 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col_stride:
3821 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row:
3822 case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row_stride: {
3823 Info.opc = ISD::INTRINSIC_W_CHAIN;
3824 Info.memVT = MVT::f64;
3825 Info.ptrVal = I.getArgOperand(i: 0);
3826 Info.offset = 0;
3827 Info.flags = MachineMemOperand::MOLoad;
3828 Info.align = Align(8);
3829 return true;
3830 }
3831
3832 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col:
3833 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col_stride:
3834 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row:
3835 case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row_stride: {
3836 Info.opc = ISD::INTRINSIC_W_CHAIN;
3837 Info.memVT = MVT::v2f64;
3838 Info.ptrVal = I.getArgOperand(i: 0);
3839 Info.offset = 0;
3840 Info.flags = MachineMemOperand::MOLoad;
3841 Info.align = Align(16);
3842 return true;
3843 }
3844
3845 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col:
3846 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row:
3847 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride:
3848 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride:
3849 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col:
3850 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row:
3851 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col_stride:
3852 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row_stride:
3853 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col:
3854 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row:
3855 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col_stride:
3856 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row_stride: {
3857 Info.opc = ISD::INTRINSIC_VOID;
3858 Info.memVT = MVT::v4f16;
3859 Info.ptrVal = I.getArgOperand(i: 0);
3860 Info.offset = 0;
3861 Info.flags = MachineMemOperand::MOStore;
3862 Info.align = Align(16);
3863 return true;
3864 }
3865
3866 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col:
3867 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row:
3868 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col_stride:
3869 case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride:
3870 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col:
3871 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row:
3872 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col_stride:
3873 case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row_stride:
3874 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col:
3875 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row:
3876 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col_stride:
3877 case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride:
3878 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col:
3879 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row:
3880 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col_stride:
3881 case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row_stride: {
3882 Info.opc = ISD::INTRINSIC_VOID;
3883 Info.memVT = MVT::v8f32;
3884 Info.ptrVal = I.getArgOperand(i: 0);
3885 Info.offset = 0;
3886 Info.flags = MachineMemOperand::MOStore;
3887 Info.align = Align(16);
3888 return true;
3889 }
3890
3891 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col:
3892 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col_stride:
3893 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row:
3894 case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row_stride:
3895 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col:
3896 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col_stride:
3897 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row:
3898 case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row_stride:
3899 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col:
3900 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col_stride:
3901 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row:
3902 case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row_stride: {
3903 Info.opc = ISD::INTRINSIC_VOID;
3904 Info.memVT = MVT::v8i32;
3905 Info.ptrVal = I.getArgOperand(i: 0);
3906 Info.offset = 0;
3907 Info.flags = MachineMemOperand::MOStore;
3908 Info.align = Align(16);
3909 return true;
3910 }
3911
3912 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col:
3913 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col_stride:
3914 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row:
3915 case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row_stride:
3916 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col:
3917 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride:
3918 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row:
3919 case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: {
3920 Info.opc = ISD::INTRINSIC_VOID;
3921 Info.memVT = MVT::v2i32;
3922 Info.ptrVal = I.getArgOperand(i: 0);
3923 Info.offset = 0;
3924 Info.flags = MachineMemOperand::MOStore;
3925 Info.align = Align(8);
3926 return true;
3927 }
3928
3929 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col:
3930 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col_stride:
3931 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row:
3932 case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row_stride: {
3933 Info.opc = ISD::INTRINSIC_VOID;
3934 Info.memVT = MVT::v2f64;
3935 Info.ptrVal = I.getArgOperand(i: 0);
3936 Info.offset = 0;
3937 Info.flags = MachineMemOperand::MOStore;
3938 Info.align = Align(16);
3939 return true;
3940 }
3941
3942 case Intrinsic::nvvm_atomic_add_gen_f_cta:
3943 case Intrinsic::nvvm_atomic_add_gen_f_sys:
3944 case Intrinsic::nvvm_atomic_add_gen_i_cta:
3945 case Intrinsic::nvvm_atomic_add_gen_i_sys:
3946 case Intrinsic::nvvm_atomic_and_gen_i_cta:
3947 case Intrinsic::nvvm_atomic_and_gen_i_sys:
3948 case Intrinsic::nvvm_atomic_cas_gen_i_cta:
3949 case Intrinsic::nvvm_atomic_cas_gen_i_sys:
3950 case Intrinsic::nvvm_atomic_dec_gen_i_cta:
3951 case Intrinsic::nvvm_atomic_dec_gen_i_sys:
3952 case Intrinsic::nvvm_atomic_inc_gen_i_cta:
3953 case Intrinsic::nvvm_atomic_inc_gen_i_sys:
3954 case Intrinsic::nvvm_atomic_max_gen_i_cta:
3955 case Intrinsic::nvvm_atomic_max_gen_i_sys:
3956 case Intrinsic::nvvm_atomic_min_gen_i_cta:
3957 case Intrinsic::nvvm_atomic_min_gen_i_sys:
3958 case Intrinsic::nvvm_atomic_or_gen_i_cta:
3959 case Intrinsic::nvvm_atomic_or_gen_i_sys:
3960 case Intrinsic::nvvm_atomic_exch_gen_i_cta:
3961 case Intrinsic::nvvm_atomic_exch_gen_i_sys:
3962 case Intrinsic::nvvm_atomic_xor_gen_i_cta:
3963 case Intrinsic::nvvm_atomic_xor_gen_i_sys: {
3964 auto &DL = I.getDataLayout();
3965 Info.opc = ISD::INTRINSIC_W_CHAIN;
3966 Info.memVT = getValueType(DL, Ty: I.getType());
3967 Info.ptrVal = I.getArgOperand(i: 0);
3968 Info.offset = 0;
3969 Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOStore;
3970 Info.align.reset();
3971 return true;
3972 }
3973
3974 case Intrinsic::nvvm_ldu_global_i:
3975 case Intrinsic::nvvm_ldu_global_f:
3976 case Intrinsic::nvvm_ldu_global_p: {
3977 auto &DL = I.getDataLayout();
3978 Info.opc = ISD::INTRINSIC_W_CHAIN;
3979 if (Intrinsic == Intrinsic::nvvm_ldu_global_i)
3980 Info.memVT = getValueType(DL, Ty: I.getType());
3981 else if(Intrinsic == Intrinsic::nvvm_ldu_global_p)
3982 Info.memVT = getPointerTy(DL);
3983 else
3984 Info.memVT = getValueType(DL, Ty: I.getType());
3985 Info.ptrVal = I.getArgOperand(i: 0);
3986 Info.offset = 0;
3987 Info.flags = MachineMemOperand::MOLoad;
3988 Info.align = cast<ConstantInt>(Val: I.getArgOperand(i: 1))->getMaybeAlignValue();
3989
3990 return true;
3991 }
3992 case Intrinsic::nvvm_tex_1d_v4f32_s32:
3993 case Intrinsic::nvvm_tex_1d_v4f32_f32:
3994 case Intrinsic::nvvm_tex_1d_level_v4f32_f32:
3995 case Intrinsic::nvvm_tex_1d_grad_v4f32_f32:
3996 case Intrinsic::nvvm_tex_1d_array_v4f32_s32:
3997 case Intrinsic::nvvm_tex_1d_array_v4f32_f32:
3998 case Intrinsic::nvvm_tex_1d_array_level_v4f32_f32:
3999 case Intrinsic::nvvm_tex_1d_array_grad_v4f32_f32:
4000 case Intrinsic::nvvm_tex_2d_v4f32_s32:
4001 case Intrinsic::nvvm_tex_2d_v4f32_f32:
4002 case Intrinsic::nvvm_tex_2d_level_v4f32_f32:
4003 case Intrinsic::nvvm_tex_2d_grad_v4f32_f32:
4004 case Intrinsic::nvvm_tex_2d_array_v4f32_s32:
4005 case Intrinsic::nvvm_tex_2d_array_v4f32_f32:
4006 case Intrinsic::nvvm_tex_2d_array_level_v4f32_f32:
4007 case Intrinsic::nvvm_tex_2d_array_grad_v4f32_f32:
4008 case Intrinsic::nvvm_tex_3d_v4f32_s32:
4009 case Intrinsic::nvvm_tex_3d_v4f32_f32:
4010 case Intrinsic::nvvm_tex_3d_level_v4f32_f32:
4011 case Intrinsic::nvvm_tex_3d_grad_v4f32_f32:
4012 case Intrinsic::nvvm_tex_cube_v4f32_f32:
4013 case Intrinsic::nvvm_tex_cube_level_v4f32_f32:
4014 case Intrinsic::nvvm_tex_cube_array_v4f32_f32:
4015 case Intrinsic::nvvm_tex_cube_array_level_v4f32_f32:
4016 case Intrinsic::nvvm_tld4_r_2d_v4f32_f32:
4017 case Intrinsic::nvvm_tld4_g_2d_v4f32_f32:
4018 case Intrinsic::nvvm_tld4_b_2d_v4f32_f32:
4019 case Intrinsic::nvvm_tld4_a_2d_v4f32_f32:
4020 case Intrinsic::nvvm_tex_unified_1d_v4f32_s32:
4021 case Intrinsic::nvvm_tex_unified_1d_v4f32_f32:
4022 case Intrinsic::nvvm_tex_unified_1d_level_v4f32_f32:
4023 case Intrinsic::nvvm_tex_unified_1d_grad_v4f32_f32:
4024 case Intrinsic::nvvm_tex_unified_1d_array_v4f32_s32:
4025 case Intrinsic::nvvm_tex_unified_1d_array_v4f32_f32:
4026 case Intrinsic::nvvm_tex_unified_1d_array_level_v4f32_f32:
4027 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4f32_f32:
4028 case Intrinsic::nvvm_tex_unified_2d_v4f32_s32:
4029 case Intrinsic::nvvm_tex_unified_2d_v4f32_f32:
4030 case Intrinsic::nvvm_tex_unified_2d_level_v4f32_f32:
4031 case Intrinsic::nvvm_tex_unified_2d_grad_v4f32_f32:
4032 case Intrinsic::nvvm_tex_unified_2d_array_v4f32_s32:
4033 case Intrinsic::nvvm_tex_unified_2d_array_v4f32_f32:
4034 case Intrinsic::nvvm_tex_unified_2d_array_level_v4f32_f32:
4035 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4f32_f32:
4036 case Intrinsic::nvvm_tex_unified_3d_v4f32_s32:
4037 case Intrinsic::nvvm_tex_unified_3d_v4f32_f32:
4038 case Intrinsic::nvvm_tex_unified_3d_level_v4f32_f32:
4039 case Intrinsic::nvvm_tex_unified_3d_grad_v4f32_f32:
4040 case Intrinsic::nvvm_tex_unified_cube_v4f32_f32:
4041 case Intrinsic::nvvm_tex_unified_cube_level_v4f32_f32:
4042 case Intrinsic::nvvm_tex_unified_cube_array_v4f32_f32:
4043 case Intrinsic::nvvm_tex_unified_cube_array_level_v4f32_f32:
4044 case Intrinsic::nvvm_tex_unified_cube_grad_v4f32_f32:
4045 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4f32_f32:
4046 case Intrinsic::nvvm_tld4_unified_r_2d_v4f32_f32:
4047 case Intrinsic::nvvm_tld4_unified_g_2d_v4f32_f32:
4048 case Intrinsic::nvvm_tld4_unified_b_2d_v4f32_f32:
4049 case Intrinsic::nvvm_tld4_unified_a_2d_v4f32_f32:
4050 Info.opc = ISD::INTRINSIC_W_CHAIN;
4051 Info.memVT = MVT::v4f32;
4052 Info.ptrVal = nullptr;
4053 Info.offset = 0;
4054 Info.flags = MachineMemOperand::MOLoad;
4055 Info.align = Align(16);
4056 return true;
4057
4058 case Intrinsic::nvvm_tex_1d_v4s32_s32:
4059 case Intrinsic::nvvm_tex_1d_v4s32_f32:
4060 case Intrinsic::nvvm_tex_1d_level_v4s32_f32:
4061 case Intrinsic::nvvm_tex_1d_grad_v4s32_f32:
4062 case Intrinsic::nvvm_tex_1d_array_v4s32_s32:
4063 case Intrinsic::nvvm_tex_1d_array_v4s32_f32:
4064 case Intrinsic::nvvm_tex_1d_array_level_v4s32_f32:
4065 case Intrinsic::nvvm_tex_1d_array_grad_v4s32_f32:
4066 case Intrinsic::nvvm_tex_2d_v4s32_s32:
4067 case Intrinsic::nvvm_tex_2d_v4s32_f32:
4068 case Intrinsic::nvvm_tex_2d_level_v4s32_f32:
4069 case Intrinsic::nvvm_tex_2d_grad_v4s32_f32:
4070 case Intrinsic::nvvm_tex_2d_array_v4s32_s32:
4071 case Intrinsic::nvvm_tex_2d_array_v4s32_f32:
4072 case Intrinsic::nvvm_tex_2d_array_level_v4s32_f32:
4073 case Intrinsic::nvvm_tex_2d_array_grad_v4s32_f32:
4074 case Intrinsic::nvvm_tex_3d_v4s32_s32:
4075 case Intrinsic::nvvm_tex_3d_v4s32_f32:
4076 case Intrinsic::nvvm_tex_3d_level_v4s32_f32:
4077 case Intrinsic::nvvm_tex_3d_grad_v4s32_f32:
4078 case Intrinsic::nvvm_tex_cube_v4s32_f32:
4079 case Intrinsic::nvvm_tex_cube_level_v4s32_f32:
4080 case Intrinsic::nvvm_tex_cube_array_v4s32_f32:
4081 case Intrinsic::nvvm_tex_cube_array_level_v4s32_f32:
4082 case Intrinsic::nvvm_tex_cube_v4u32_f32:
4083 case Intrinsic::nvvm_tex_cube_level_v4u32_f32:
4084 case Intrinsic::nvvm_tex_cube_array_v4u32_f32:
4085 case Intrinsic::nvvm_tex_cube_array_level_v4u32_f32:
4086 case Intrinsic::nvvm_tex_1d_v4u32_s32:
4087 case Intrinsic::nvvm_tex_1d_v4u32_f32:
4088 case Intrinsic::nvvm_tex_1d_level_v4u32_f32:
4089 case Intrinsic::nvvm_tex_1d_grad_v4u32_f32:
4090 case Intrinsic::nvvm_tex_1d_array_v4u32_s32:
4091 case Intrinsic::nvvm_tex_1d_array_v4u32_f32:
4092 case Intrinsic::nvvm_tex_1d_array_level_v4u32_f32:
4093 case Intrinsic::nvvm_tex_1d_array_grad_v4u32_f32:
4094 case Intrinsic::nvvm_tex_2d_v4u32_s32:
4095 case Intrinsic::nvvm_tex_2d_v4u32_f32:
4096 case Intrinsic::nvvm_tex_2d_level_v4u32_f32:
4097 case Intrinsic::nvvm_tex_2d_grad_v4u32_f32:
4098 case Intrinsic::nvvm_tex_2d_array_v4u32_s32:
4099 case Intrinsic::nvvm_tex_2d_array_v4u32_f32:
4100 case Intrinsic::nvvm_tex_2d_array_level_v4u32_f32:
4101 case Intrinsic::nvvm_tex_2d_array_grad_v4u32_f32:
4102 case Intrinsic::nvvm_tex_3d_v4u32_s32:
4103 case Intrinsic::nvvm_tex_3d_v4u32_f32:
4104 case Intrinsic::nvvm_tex_3d_level_v4u32_f32:
4105 case Intrinsic::nvvm_tex_3d_grad_v4u32_f32:
4106 case Intrinsic::nvvm_tld4_r_2d_v4s32_f32:
4107 case Intrinsic::nvvm_tld4_g_2d_v4s32_f32:
4108 case Intrinsic::nvvm_tld4_b_2d_v4s32_f32:
4109 case Intrinsic::nvvm_tld4_a_2d_v4s32_f32:
4110 case Intrinsic::nvvm_tld4_r_2d_v4u32_f32:
4111 case Intrinsic::nvvm_tld4_g_2d_v4u32_f32:
4112 case Intrinsic::nvvm_tld4_b_2d_v4u32_f32:
4113 case Intrinsic::nvvm_tld4_a_2d_v4u32_f32:
4114 case Intrinsic::nvvm_tex_unified_1d_v4s32_s32:
4115 case Intrinsic::nvvm_tex_unified_1d_v4s32_f32:
4116 case Intrinsic::nvvm_tex_unified_1d_level_v4s32_f32:
4117 case Intrinsic::nvvm_tex_unified_1d_grad_v4s32_f32:
4118 case Intrinsic::nvvm_tex_unified_1d_array_v4s32_s32:
4119 case Intrinsic::nvvm_tex_unified_1d_array_v4s32_f32:
4120 case Intrinsic::nvvm_tex_unified_1d_array_level_v4s32_f32:
4121 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4s32_f32:
4122 case Intrinsic::nvvm_tex_unified_2d_v4s32_s32:
4123 case Intrinsic::nvvm_tex_unified_2d_v4s32_f32:
4124 case Intrinsic::nvvm_tex_unified_2d_level_v4s32_f32:
4125 case Intrinsic::nvvm_tex_unified_2d_grad_v4s32_f32:
4126 case Intrinsic::nvvm_tex_unified_2d_array_v4s32_s32:
4127 case Intrinsic::nvvm_tex_unified_2d_array_v4s32_f32:
4128 case Intrinsic::nvvm_tex_unified_2d_array_level_v4s32_f32:
4129 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4s32_f32:
4130 case Intrinsic::nvvm_tex_unified_3d_v4s32_s32:
4131 case Intrinsic::nvvm_tex_unified_3d_v4s32_f32:
4132 case Intrinsic::nvvm_tex_unified_3d_level_v4s32_f32:
4133 case Intrinsic::nvvm_tex_unified_3d_grad_v4s32_f32:
4134 case Intrinsic::nvvm_tex_unified_1d_v4u32_s32:
4135 case Intrinsic::nvvm_tex_unified_1d_v4u32_f32:
4136 case Intrinsic::nvvm_tex_unified_1d_level_v4u32_f32:
4137 case Intrinsic::nvvm_tex_unified_1d_grad_v4u32_f32:
4138 case Intrinsic::nvvm_tex_unified_1d_array_v4u32_s32:
4139 case Intrinsic::nvvm_tex_unified_1d_array_v4u32_f32:
4140 case Intrinsic::nvvm_tex_unified_1d_array_level_v4u32_f32:
4141 case Intrinsic::nvvm_tex_unified_1d_array_grad_v4u32_f32:
4142 case Intrinsic::nvvm_tex_unified_2d_v4u32_s32:
4143 case Intrinsic::nvvm_tex_unified_2d_v4u32_f32:
4144 case Intrinsic::nvvm_tex_unified_2d_level_v4u32_f32:
4145 case Intrinsic::nvvm_tex_unified_2d_grad_v4u32_f32:
4146 case Intrinsic::nvvm_tex_unified_2d_array_v4u32_s32:
4147 case Intrinsic::nvvm_tex_unified_2d_array_v4u32_f32:
4148 case Intrinsic::nvvm_tex_unified_2d_array_level_v4u32_f32:
4149 case Intrinsic::nvvm_tex_unified_2d_array_grad_v4u32_f32:
4150 case Intrinsic::nvvm_tex_unified_3d_v4u32_s32:
4151 case Intrinsic::nvvm_tex_unified_3d_v4u32_f32:
4152 case Intrinsic::nvvm_tex_unified_3d_level_v4u32_f32:
4153 case Intrinsic::nvvm_tex_unified_3d_grad_v4u32_f32:
4154 case Intrinsic::nvvm_tex_unified_cube_v4s32_f32:
4155 case Intrinsic::nvvm_tex_unified_cube_level_v4s32_f32:
4156 case Intrinsic::nvvm_tex_unified_cube_array_v4s32_f32:
4157 case Intrinsic::nvvm_tex_unified_cube_array_level_v4s32_f32:
4158 case Intrinsic::nvvm_tex_unified_cube_v4u32_f32:
4159 case Intrinsic::nvvm_tex_unified_cube_level_v4u32_f32:
4160 case Intrinsic::nvvm_tex_unified_cube_array_v4u32_f32:
4161 case Intrinsic::nvvm_tex_unified_cube_array_level_v4u32_f32:
4162 case Intrinsic::nvvm_tex_unified_cube_grad_v4s32_f32:
4163 case Intrinsic::nvvm_tex_unified_cube_grad_v4u32_f32:
4164 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4s32_f32:
4165 case Intrinsic::nvvm_tex_unified_cube_array_grad_v4u32_f32:
4166 case Intrinsic::nvvm_tld4_unified_r_2d_v4s32_f32:
4167 case Intrinsic::nvvm_tld4_unified_g_2d_v4s32_f32:
4168 case Intrinsic::nvvm_tld4_unified_b_2d_v4s32_f32:
4169 case Intrinsic::nvvm_tld4_unified_a_2d_v4s32_f32:
4170 case Intrinsic::nvvm_tld4_unified_r_2d_v4u32_f32:
4171 case Intrinsic::nvvm_tld4_unified_g_2d_v4u32_f32:
4172 case Intrinsic::nvvm_tld4_unified_b_2d_v4u32_f32:
4173 case Intrinsic::nvvm_tld4_unified_a_2d_v4u32_f32:
4174 Info.opc = ISD::INTRINSIC_W_CHAIN;
4175 Info.memVT = MVT::v4i32;
4176 Info.ptrVal = nullptr;
4177 Info.offset = 0;
4178 Info.flags = MachineMemOperand::MOLoad;
4179 Info.align = Align(16);
4180 return true;
4181
4182 case Intrinsic::nvvm_suld_1d_i8_clamp:
4183 case Intrinsic::nvvm_suld_1d_v2i8_clamp:
4184 case Intrinsic::nvvm_suld_1d_v4i8_clamp:
4185 case Intrinsic::nvvm_suld_1d_array_i8_clamp:
4186 case Intrinsic::nvvm_suld_1d_array_v2i8_clamp:
4187 case Intrinsic::nvvm_suld_1d_array_v4i8_clamp:
4188 case Intrinsic::nvvm_suld_2d_i8_clamp:
4189 case Intrinsic::nvvm_suld_2d_v2i8_clamp:
4190 case Intrinsic::nvvm_suld_2d_v4i8_clamp:
4191 case Intrinsic::nvvm_suld_2d_array_i8_clamp:
4192 case Intrinsic::nvvm_suld_2d_array_v2i8_clamp:
4193 case Intrinsic::nvvm_suld_2d_array_v4i8_clamp:
4194 case Intrinsic::nvvm_suld_3d_i8_clamp:
4195 case Intrinsic::nvvm_suld_3d_v2i8_clamp:
4196 case Intrinsic::nvvm_suld_3d_v4i8_clamp:
4197 case Intrinsic::nvvm_suld_1d_i8_trap:
4198 case Intrinsic::nvvm_suld_1d_v2i8_trap:
4199 case Intrinsic::nvvm_suld_1d_v4i8_trap:
4200 case Intrinsic::nvvm_suld_1d_array_i8_trap:
4201 case Intrinsic::nvvm_suld_1d_array_v2i8_trap:
4202 case Intrinsic::nvvm_suld_1d_array_v4i8_trap:
4203 case Intrinsic::nvvm_suld_2d_i8_trap:
4204 case Intrinsic::nvvm_suld_2d_v2i8_trap:
4205 case Intrinsic::nvvm_suld_2d_v4i8_trap:
4206 case Intrinsic::nvvm_suld_2d_array_i8_trap:
4207 case Intrinsic::nvvm_suld_2d_array_v2i8_trap:
4208 case Intrinsic::nvvm_suld_2d_array_v4i8_trap:
4209 case Intrinsic::nvvm_suld_3d_i8_trap:
4210 case Intrinsic::nvvm_suld_3d_v2i8_trap:
4211 case Intrinsic::nvvm_suld_3d_v4i8_trap:
4212 case Intrinsic::nvvm_suld_1d_i8_zero:
4213 case Intrinsic::nvvm_suld_1d_v2i8_zero:
4214 case Intrinsic::nvvm_suld_1d_v4i8_zero:
4215 case Intrinsic::nvvm_suld_1d_array_i8_zero:
4216 case Intrinsic::nvvm_suld_1d_array_v2i8_zero:
4217 case Intrinsic::nvvm_suld_1d_array_v4i8_zero:
4218 case Intrinsic::nvvm_suld_2d_i8_zero:
4219 case Intrinsic::nvvm_suld_2d_v2i8_zero:
4220 case Intrinsic::nvvm_suld_2d_v4i8_zero:
4221 case Intrinsic::nvvm_suld_2d_array_i8_zero:
4222 case Intrinsic::nvvm_suld_2d_array_v2i8_zero:
4223 case Intrinsic::nvvm_suld_2d_array_v4i8_zero:
4224 case Intrinsic::nvvm_suld_3d_i8_zero:
4225 case Intrinsic::nvvm_suld_3d_v2i8_zero:
4226 case Intrinsic::nvvm_suld_3d_v4i8_zero:
4227 Info.opc = ISD::INTRINSIC_W_CHAIN;
4228 Info.memVT = MVT::i8;
4229 Info.ptrVal = nullptr;
4230 Info.offset = 0;
4231 Info.flags = MachineMemOperand::MOLoad;
4232 Info.align = Align(16);
4233 return true;
4234
4235 case Intrinsic::nvvm_suld_1d_i16_clamp:
4236 case Intrinsic::nvvm_suld_1d_v2i16_clamp:
4237 case Intrinsic::nvvm_suld_1d_v4i16_clamp:
4238 case Intrinsic::nvvm_suld_1d_array_i16_clamp:
4239 case Intrinsic::nvvm_suld_1d_array_v2i16_clamp:
4240 case Intrinsic::nvvm_suld_1d_array_v4i16_clamp:
4241 case Intrinsic::nvvm_suld_2d_i16_clamp:
4242 case Intrinsic::nvvm_suld_2d_v2i16_clamp:
4243 case Intrinsic::nvvm_suld_2d_v4i16_clamp:
4244 case Intrinsic::nvvm_suld_2d_array_i16_clamp:
4245 case Intrinsic::nvvm_suld_2d_array_v2i16_clamp:
4246 case Intrinsic::nvvm_suld_2d_array_v4i16_clamp:
4247 case Intrinsic::nvvm_suld_3d_i16_clamp:
4248 case Intrinsic::nvvm_suld_3d_v2i16_clamp:
4249 case Intrinsic::nvvm_suld_3d_v4i16_clamp:
4250 case Intrinsic::nvvm_suld_1d_i16_trap:
4251 case Intrinsic::nvvm_suld_1d_v2i16_trap:
4252 case Intrinsic::nvvm_suld_1d_v4i16_trap:
4253 case Intrinsic::nvvm_suld_1d_array_i16_trap:
4254 case Intrinsic::nvvm_suld_1d_array_v2i16_trap:
4255 case Intrinsic::nvvm_suld_1d_array_v4i16_trap:
4256 case Intrinsic::nvvm_suld_2d_i16_trap:
4257 case Intrinsic::nvvm_suld_2d_v2i16_trap:
4258 case Intrinsic::nvvm_suld_2d_v4i16_trap:
4259 case Intrinsic::nvvm_suld_2d_array_i16_trap:
4260 case Intrinsic::nvvm_suld_2d_array_v2i16_trap:
4261 case Intrinsic::nvvm_suld_2d_array_v4i16_trap:
4262 case Intrinsic::nvvm_suld_3d_i16_trap:
4263 case Intrinsic::nvvm_suld_3d_v2i16_trap:
4264 case Intrinsic::nvvm_suld_3d_v4i16_trap:
4265 case Intrinsic::nvvm_suld_1d_i16_zero:
4266 case Intrinsic::nvvm_suld_1d_v2i16_zero:
4267 case Intrinsic::nvvm_suld_1d_v4i16_zero:
4268 case Intrinsic::nvvm_suld_1d_array_i16_zero:
4269 case Intrinsic::nvvm_suld_1d_array_v2i16_zero:
4270 case Intrinsic::nvvm_suld_1d_array_v4i16_zero:
4271 case Intrinsic::nvvm_suld_2d_i16_zero:
4272 case Intrinsic::nvvm_suld_2d_v2i16_zero:
4273 case Intrinsic::nvvm_suld_2d_v4i16_zero:
4274 case Intrinsic::nvvm_suld_2d_array_i16_zero:
4275 case Intrinsic::nvvm_suld_2d_array_v2i16_zero:
4276 case Intrinsic::nvvm_suld_2d_array_v4i16_zero:
4277 case Intrinsic::nvvm_suld_3d_i16_zero:
4278 case Intrinsic::nvvm_suld_3d_v2i16_zero:
4279 case Intrinsic::nvvm_suld_3d_v4i16_zero:
4280 Info.opc = ISD::INTRINSIC_W_CHAIN;
4281 Info.memVT = MVT::i16;
4282 Info.ptrVal = nullptr;
4283 Info.offset = 0;
4284 Info.flags = MachineMemOperand::MOLoad;
4285 Info.align = Align(16);
4286 return true;
4287
4288 case Intrinsic::nvvm_suld_1d_i32_clamp:
4289 case Intrinsic::nvvm_suld_1d_v2i32_clamp:
4290 case Intrinsic::nvvm_suld_1d_v4i32_clamp:
4291 case Intrinsic::nvvm_suld_1d_array_i32_clamp:
4292 case Intrinsic::nvvm_suld_1d_array_v2i32_clamp:
4293 case Intrinsic::nvvm_suld_1d_array_v4i32_clamp:
4294 case Intrinsic::nvvm_suld_2d_i32_clamp:
4295 case Intrinsic::nvvm_suld_2d_v2i32_clamp:
4296 case Intrinsic::nvvm_suld_2d_v4i32_clamp:
4297 case Intrinsic::nvvm_suld_2d_array_i32_clamp:
4298 case Intrinsic::nvvm_suld_2d_array_v2i32_clamp:
4299 case Intrinsic::nvvm_suld_2d_array_v4i32_clamp:
4300 case Intrinsic::nvvm_suld_3d_i32_clamp:
4301 case Intrinsic::nvvm_suld_3d_v2i32_clamp:
4302 case Intrinsic::nvvm_suld_3d_v4i32_clamp:
4303 case Intrinsic::nvvm_suld_1d_i32_trap:
4304 case Intrinsic::nvvm_suld_1d_v2i32_trap:
4305 case Intrinsic::nvvm_suld_1d_v4i32_trap:
4306 case Intrinsic::nvvm_suld_1d_array_i32_trap:
4307 case Intrinsic::nvvm_suld_1d_array_v2i32_trap:
4308 case Intrinsic::nvvm_suld_1d_array_v4i32_trap:
4309 case Intrinsic::nvvm_suld_2d_i32_trap:
4310 case Intrinsic::nvvm_suld_2d_v2i32_trap:
4311 case Intrinsic::nvvm_suld_2d_v4i32_trap:
4312 case Intrinsic::nvvm_suld_2d_array_i32_trap:
4313 case Intrinsic::nvvm_suld_2d_array_v2i32_trap:
4314 case Intrinsic::nvvm_suld_2d_array_v4i32_trap:
4315 case Intrinsic::nvvm_suld_3d_i32_trap:
4316 case Intrinsic::nvvm_suld_3d_v2i32_trap:
4317 case Intrinsic::nvvm_suld_3d_v4i32_trap:
4318 case Intrinsic::nvvm_suld_1d_i32_zero:
4319 case Intrinsic::nvvm_suld_1d_v2i32_zero:
4320 case Intrinsic::nvvm_suld_1d_v4i32_zero:
4321 case Intrinsic::nvvm_suld_1d_array_i32_zero:
4322 case Intrinsic::nvvm_suld_1d_array_v2i32_zero:
4323 case Intrinsic::nvvm_suld_1d_array_v4i32_zero:
4324 case Intrinsic::nvvm_suld_2d_i32_zero:
4325 case Intrinsic::nvvm_suld_2d_v2i32_zero:
4326 case Intrinsic::nvvm_suld_2d_v4i32_zero:
4327 case Intrinsic::nvvm_suld_2d_array_i32_zero:
4328 case Intrinsic::nvvm_suld_2d_array_v2i32_zero:
4329 case Intrinsic::nvvm_suld_2d_array_v4i32_zero:
4330 case Intrinsic::nvvm_suld_3d_i32_zero:
4331 case Intrinsic::nvvm_suld_3d_v2i32_zero:
4332 case Intrinsic::nvvm_suld_3d_v4i32_zero:
4333 Info.opc = ISD::INTRINSIC_W_CHAIN;
4334 Info.memVT = MVT::i32;
4335 Info.ptrVal = nullptr;
4336 Info.offset = 0;
4337 Info.flags = MachineMemOperand::MOLoad;
4338 Info.align = Align(16);
4339 return true;
4340
4341 case Intrinsic::nvvm_suld_1d_i64_clamp:
4342 case Intrinsic::nvvm_suld_1d_v2i64_clamp:
4343 case Intrinsic::nvvm_suld_1d_array_i64_clamp:
4344 case Intrinsic::nvvm_suld_1d_array_v2i64_clamp:
4345 case Intrinsic::nvvm_suld_2d_i64_clamp:
4346 case Intrinsic::nvvm_suld_2d_v2i64_clamp:
4347 case Intrinsic::nvvm_suld_2d_array_i64_clamp:
4348 case Intrinsic::nvvm_suld_2d_array_v2i64_clamp:
4349 case Intrinsic::nvvm_suld_3d_i64_clamp:
4350 case Intrinsic::nvvm_suld_3d_v2i64_clamp:
4351 case Intrinsic::nvvm_suld_1d_i64_trap:
4352 case Intrinsic::nvvm_suld_1d_v2i64_trap:
4353 case Intrinsic::nvvm_suld_1d_array_i64_trap:
4354 case Intrinsic::nvvm_suld_1d_array_v2i64_trap:
4355 case Intrinsic::nvvm_suld_2d_i64_trap:
4356 case Intrinsic::nvvm_suld_2d_v2i64_trap:
4357 case Intrinsic::nvvm_suld_2d_array_i64_trap:
4358 case Intrinsic::nvvm_suld_2d_array_v2i64_trap:
4359 case Intrinsic::nvvm_suld_3d_i64_trap:
4360 case Intrinsic::nvvm_suld_3d_v2i64_trap:
4361 case Intrinsic::nvvm_suld_1d_i64_zero:
4362 case Intrinsic::nvvm_suld_1d_v2i64_zero:
4363 case Intrinsic::nvvm_suld_1d_array_i64_zero:
4364 case Intrinsic::nvvm_suld_1d_array_v2i64_zero:
4365 case Intrinsic::nvvm_suld_2d_i64_zero:
4366 case Intrinsic::nvvm_suld_2d_v2i64_zero:
4367 case Intrinsic::nvvm_suld_2d_array_i64_zero:
4368 case Intrinsic::nvvm_suld_2d_array_v2i64_zero:
4369 case Intrinsic::nvvm_suld_3d_i64_zero:
4370 case Intrinsic::nvvm_suld_3d_v2i64_zero:
4371 Info.opc = ISD::INTRINSIC_W_CHAIN;
4372 Info.memVT = MVT::i64;
4373 Info.ptrVal = nullptr;
4374 Info.offset = 0;
4375 Info.flags = MachineMemOperand::MOLoad;
4376 Info.align = Align(16);
4377 return true;
4378
4379 case Intrinsic::nvvm_tcgen05_ld_16x64b_x1:
4380 case Intrinsic::nvvm_tcgen05_ld_32x32b_x1:
4381 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x1: {
4382 Info.opc = ISD::INTRINSIC_W_CHAIN;
4383 Info.memVT = MVT::v1i32;
4384 Info.ptrVal = I.getArgOperand(i: 0);
4385 Info.offset = 0;
4386 Info.flags = MachineMemOperand::MOLoad;
4387 Info.align.reset();
4388 return true;
4389 }
4390
4391 case Intrinsic::nvvm_tcgen05_ld_16x64b_x2:
4392 case Intrinsic::nvvm_tcgen05_ld_16x128b_x1:
4393 case Intrinsic::nvvm_tcgen05_ld_32x32b_x2:
4394 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x2: {
4395 Info.opc = ISD::INTRINSIC_W_CHAIN;
4396 Info.memVT = MVT::v2i32;
4397 Info.ptrVal = I.getArgOperand(i: 0);
4398 Info.offset = 0;
4399 Info.flags = MachineMemOperand::MOLoad;
4400 Info.align.reset();
4401 return true;
4402 }
4403
4404 case Intrinsic::nvvm_tcgen05_ld_16x64b_x4:
4405 case Intrinsic::nvvm_tcgen05_ld_16x128b_x2:
4406 case Intrinsic::nvvm_tcgen05_ld_32x32b_x4:
4407 case Intrinsic::nvvm_tcgen05_ld_16x256b_x1:
4408 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x4: {
4409 Info.opc = ISD::INTRINSIC_W_CHAIN;
4410 Info.memVT = MVT::v4i32;
4411 Info.ptrVal = I.getArgOperand(i: 0);
4412 Info.offset = 0;
4413 Info.flags = MachineMemOperand::MOLoad;
4414 Info.align.reset();
4415 return true;
4416 }
4417
4418 case Intrinsic::nvvm_tcgen05_ld_16x64b_x8:
4419 case Intrinsic::nvvm_tcgen05_ld_16x128b_x4:
4420 case Intrinsic::nvvm_tcgen05_ld_16x256b_x2:
4421 case Intrinsic::nvvm_tcgen05_ld_32x32b_x8:
4422 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x8: {
4423 Info.opc = ISD::INTRINSIC_W_CHAIN;
4424 Info.memVT = MVT::v8i32;
4425 Info.ptrVal = I.getArgOperand(i: 0);
4426 Info.offset = 0;
4427 Info.flags = MachineMemOperand::MOLoad;
4428 Info.align.reset();
4429 return true;
4430 }
4431
4432 case Intrinsic::nvvm_tcgen05_ld_16x64b_x16:
4433 case Intrinsic::nvvm_tcgen05_ld_16x128b_x8:
4434 case Intrinsic::nvvm_tcgen05_ld_16x256b_x4:
4435 case Intrinsic::nvvm_tcgen05_ld_32x32b_x16:
4436 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x16: {
4437 Info.opc = ISD::INTRINSIC_W_CHAIN;
4438 Info.memVT = MVT::v16i32;
4439 Info.ptrVal = I.getArgOperand(i: 0);
4440 Info.offset = 0;
4441 Info.flags = MachineMemOperand::MOLoad;
4442 Info.align.reset();
4443 return true;
4444 }
4445
4446 case Intrinsic::nvvm_tcgen05_ld_16x64b_x32:
4447 case Intrinsic::nvvm_tcgen05_ld_16x128b_x16:
4448 case Intrinsic::nvvm_tcgen05_ld_16x256b_x8:
4449 case Intrinsic::nvvm_tcgen05_ld_32x32b_x32:
4450 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x32: {
4451 Info.opc = ISD::INTRINSIC_W_CHAIN;
4452 Info.memVT = MVT::v32i32;
4453 Info.ptrVal = I.getArgOperand(i: 0);
4454 Info.offset = 0;
4455 Info.flags = MachineMemOperand::MOLoad;
4456 Info.align.reset();
4457 return true;
4458 }
4459
4460 case Intrinsic::nvvm_tcgen05_ld_16x64b_x64:
4461 case Intrinsic::nvvm_tcgen05_ld_16x128b_x32:
4462 case Intrinsic::nvvm_tcgen05_ld_16x256b_x16:
4463 case Intrinsic::nvvm_tcgen05_ld_32x32b_x64:
4464 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x64: {
4465 Info.opc = ISD::INTRINSIC_W_CHAIN;
4466 Info.memVT = MVT::v64i32;
4467 Info.ptrVal = I.getArgOperand(i: 0);
4468 Info.offset = 0;
4469 Info.flags = MachineMemOperand::MOLoad;
4470 Info.align.reset();
4471 return true;
4472 }
4473
4474 case Intrinsic::nvvm_tcgen05_ld_16x64b_x128:
4475 case Intrinsic::nvvm_tcgen05_ld_16x128b_x64:
4476 case Intrinsic::nvvm_tcgen05_ld_16x256b_x32:
4477 case Intrinsic::nvvm_tcgen05_ld_32x32b_x128:
4478 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x128: {
4479 Info.opc = ISD::INTRINSIC_W_CHAIN;
4480 Info.memVT = MVT::v128i32;
4481 Info.ptrVal = I.getArgOperand(i: 0);
4482 Info.offset = 0;
4483 Info.flags = MachineMemOperand::MOLoad;
4484 Info.align.reset();
4485 return true;
4486 }
4487
4488 case Intrinsic::nvvm_tcgen05_st_16x64b_x1:
4489 case Intrinsic::nvvm_tcgen05_st_32x32b_x1:
4490 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x1: {
4491 Info.opc = ISD::INTRINSIC_VOID;
4492 Info.memVT = MVT::i32;
4493 Info.ptrVal = I.getArgOperand(i: 0);
4494 Info.offset = 0;
4495 Info.flags = MachineMemOperand::MOStore;
4496 Info.align.reset();
4497 return true;
4498 }
4499
4500 case Intrinsic::nvvm_tcgen05_st_16x64b_x2:
4501 case Intrinsic::nvvm_tcgen05_st_16x128b_x1:
4502 case Intrinsic::nvvm_tcgen05_st_32x32b_x2:
4503 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x2: {
4504 Info.opc = ISD::INTRINSIC_VOID;
4505 Info.memVT = MVT::v2i32;
4506 Info.ptrVal = I.getArgOperand(i: 0);
4507 Info.offset = 0;
4508 Info.flags = MachineMemOperand::MOStore;
4509 Info.align.reset();
4510 return true;
4511 }
4512
4513 case Intrinsic::nvvm_tcgen05_st_16x64b_x4:
4514 case Intrinsic::nvvm_tcgen05_st_16x128b_x2:
4515 case Intrinsic::nvvm_tcgen05_st_16x256b_x1:
4516 case Intrinsic::nvvm_tcgen05_st_32x32b_x4:
4517 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x4: {
4518 Info.opc = ISD::INTRINSIC_VOID;
4519 Info.memVT = MVT::v4i32;
4520 Info.ptrVal = I.getArgOperand(i: 0);
4521 Info.offset = 0;
4522 Info.flags = MachineMemOperand::MOStore;
4523 Info.align.reset();
4524 return true;
4525 }
4526
4527 case Intrinsic::nvvm_tcgen05_st_16x64b_x8:
4528 case Intrinsic::nvvm_tcgen05_st_16x128b_x4:
4529 case Intrinsic::nvvm_tcgen05_st_16x256b_x2:
4530 case Intrinsic::nvvm_tcgen05_st_32x32b_x8:
4531 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x8: {
4532 Info.opc = ISD::INTRINSIC_VOID;
4533 Info.memVT = MVT::v8i32;
4534 Info.ptrVal = I.getArgOperand(i: 0);
4535 Info.offset = 0;
4536 Info.flags = MachineMemOperand::MOStore;
4537 Info.align.reset();
4538 return true;
4539 }
4540
4541 case Intrinsic::nvvm_tcgen05_st_16x64b_x16:
4542 case Intrinsic::nvvm_tcgen05_st_16x128b_x8:
4543 case Intrinsic::nvvm_tcgen05_st_16x256b_x4:
4544 case Intrinsic::nvvm_tcgen05_st_32x32b_x16:
4545 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x16: {
4546 Info.opc = ISD::INTRINSIC_VOID;
4547 Info.memVT = MVT::v16i32;
4548 Info.ptrVal = I.getArgOperand(i: 0);
4549 Info.offset = 0;
4550 Info.flags = MachineMemOperand::MOStore;
4551 Info.align.reset();
4552 return true;
4553 }
4554
4555 case Intrinsic::nvvm_tcgen05_st_16x64b_x32:
4556 case Intrinsic::nvvm_tcgen05_st_16x128b_x16:
4557 case Intrinsic::nvvm_tcgen05_st_16x256b_x8:
4558 case Intrinsic::nvvm_tcgen05_st_32x32b_x32:
4559 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x32: {
4560 Info.opc = ISD::INTRINSIC_VOID;
4561 Info.memVT = MVT::v32i32;
4562 Info.ptrVal = I.getArgOperand(i: 0);
4563 Info.offset = 0;
4564 Info.flags = MachineMemOperand::MOStore;
4565 Info.align.reset();
4566 return true;
4567 }
4568
4569 case Intrinsic::nvvm_tcgen05_st_16x64b_x64:
4570 case Intrinsic::nvvm_tcgen05_st_16x128b_x32:
4571 case Intrinsic::nvvm_tcgen05_st_16x256b_x16:
4572 case Intrinsic::nvvm_tcgen05_st_32x32b_x64:
4573 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x64: {
4574 Info.opc = ISD::INTRINSIC_VOID;
4575 Info.memVT = MVT::v64i32;
4576 Info.ptrVal = I.getArgOperand(i: 0);
4577 Info.offset = 0;
4578 Info.flags = MachineMemOperand::MOStore;
4579 Info.align.reset();
4580 return true;
4581 }
4582
4583 case Intrinsic::nvvm_tcgen05_st_16x64b_x128:
4584 case Intrinsic::nvvm_tcgen05_st_16x128b_x64:
4585 case Intrinsic::nvvm_tcgen05_st_16x256b_x32:
4586 case Intrinsic::nvvm_tcgen05_st_32x32b_x128:
4587 case Intrinsic::nvvm_tcgen05_st_16x32bx2_x128: {
4588 Info.opc = ISD::INTRINSIC_VOID;
4589 Info.memVT = MVT::v128i32;
4590 Info.ptrVal = I.getArgOperand(i: 0);
4591 Info.offset = 0;
4592 Info.flags = MachineMemOperand::MOStore;
4593 Info.align.reset();
4594 return true;
4595 }
4596 }
4597 return false;
4598}
4599
4600/// getFunctionParamOptimizedAlign - since function arguments are passed via
4601/// .param space, we may want to increase their alignment in a way that
4602/// ensures that we can effectively vectorize their loads & stores. We can
4603/// increase alignment only if the function has internal or has private
4604/// linkage as for other linkage types callers may already rely on default
4605/// alignment. To allow using 128-bit vectorized loads/stores, this function
4606/// ensures that alignment is 16 or greater.
4607Align NVPTXTargetLowering::getFunctionParamOptimizedAlign(
4608 const Function *F, Type *ArgTy, const DataLayout &DL) const {
4609 // Capping the alignment to 128 bytes as that is the maximum alignment
4610 // supported by PTX.
4611 const Align ABITypeAlign = std::min(a: Align(128), b: DL.getABITypeAlign(Ty: ArgTy));
4612
4613 // If a function has linkage different from internal or private, we
4614 // must use default ABI alignment as external users rely on it. Same
4615 // for a function that may be called from a function pointer.
4616 if (!F || !F->hasLocalLinkage() ||
4617 F->hasAddressTaken(/*Users=*/nullptr,
4618 /*IgnoreCallbackUses=*/false,
4619 /*IgnoreAssumeLikeCalls=*/true,
4620 /*IgnoreLLVMUsed=*/IngoreLLVMUsed: true))
4621 return ABITypeAlign;
4622
4623 assert(!isKernelFunction(*F) && "Expect kernels to have non-local linkage");
4624 return std::max(a: Align(16), b: ABITypeAlign);
4625}
4626
4627/// Helper for computing alignment of a device function byval parameter.
4628Align NVPTXTargetLowering::getFunctionByValParamAlign(
4629 const Function *F, Type *ArgTy, Align InitialAlign,
4630 const DataLayout &DL) const {
4631 Align ArgAlign = InitialAlign;
4632 // Try to increase alignment to enhance vectorization options.
4633 if (F)
4634 ArgAlign = std::max(a: ArgAlign, b: getFunctionParamOptimizedAlign(F, ArgTy, DL));
4635
4636 // Old ptx versions have a bug. When PTX code takes address of
4637 // byval parameter with alignment < 4, ptxas generates code to
4638 // spill argument into memory. Alas on sm_50+ ptxas generates
4639 // SASS code that fails with misaligned access. To work around
4640 // the problem, make sure that we align byval parameters by at
4641 // least 4. This bug seems to be fixed at least starting from
4642 // ptxas > 9.0.
4643 // TODO: remove this after verifying the bug is not reproduced
4644 // on non-deprecated ptxas versions.
4645 if (ForceMinByValParamAlign)
4646 ArgAlign = std::max(a: ArgAlign, b: Align(4));
4647
4648 return ArgAlign;
4649}
4650
4651// Helper for getting a function parameter name. Name is composed from
4652// its index and the function name. Negative index corresponds to special
4653// parameter (unsized array) used for passing variable arguments.
4654std::string NVPTXTargetLowering::getParamName(const Function *F,
4655 int Idx) const {
4656 std::string ParamName;
4657 raw_string_ostream ParamStr(ParamName);
4658
4659 ParamStr << getTargetMachine().getSymbol(GV: F)->getName();
4660 if (Idx < 0)
4661 ParamStr << "_vararg";
4662 else
4663 ParamStr << "_param_" << Idx;
4664
4665 return ParamName;
4666}
4667
4668/// isLegalAddressingMode - Return true if the addressing mode represented
4669/// by AM is legal for this target, for a load/store of the specified type.
4670/// Used to guide target specific optimizations, like loop strength reduction
4671/// (LoopStrengthReduce.cpp) and memory optimization for address mode
4672/// (CodeGenPrepare.cpp)
4673bool NVPTXTargetLowering::isLegalAddressingMode(const DataLayout &DL,
4674 const AddrMode &AM, Type *Ty,
4675 unsigned AS, Instruction *I) const {
4676 // AddrMode - This represents an addressing mode of:
4677 // BaseGV + BaseOffs + BaseReg + Scale*ScaleReg
4678 //
4679 // The legal address modes are
4680 // - [avar]
4681 // - [areg]
4682 // - [areg+immoff]
4683 // - [immAddr]
4684
4685 // immoff must fit in a signed 32-bit int
4686 if (!APInt(64, AM.BaseOffs).isSignedIntN(N: 32))
4687 return false;
4688
4689 if (AM.BaseGV)
4690 return !AM.BaseOffs && !AM.HasBaseReg && !AM.Scale;
4691
4692 switch (AM.Scale) {
4693 case 0: // "r", "r+i" or "i" is allowed
4694 break;
4695 case 1:
4696 if (AM.HasBaseReg) // "r+r+i" or "r+r" is not allowed.
4697 return false;
4698 // Otherwise we have r+i.
4699 break;
4700 default:
4701 // No scale > 1 is allowed
4702 return false;
4703 }
4704 return true;
4705}
4706
4707//===----------------------------------------------------------------------===//
4708// NVPTX Inline Assembly Support
4709//===----------------------------------------------------------------------===//
4710
4711/// getConstraintType - Given a constraint letter, return the type of
4712/// constraint it is for this target.
4713NVPTXTargetLowering::ConstraintType
4714NVPTXTargetLowering::getConstraintType(StringRef Constraint) const {
4715 if (Constraint.size() == 1) {
4716 switch (Constraint[0]) {
4717 default:
4718 break;
4719 case 'b':
4720 case 'r':
4721 case 'h':
4722 case 'c':
4723 case 'l':
4724 case 'f':
4725 case 'd':
4726 case 'q':
4727 case '0':
4728 case 'N':
4729 return C_RegisterClass;
4730 }
4731 }
4732 return TargetLowering::getConstraintType(Constraint);
4733}
4734
4735std::pair<unsigned, const TargetRegisterClass *>
4736NVPTXTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
4737 StringRef Constraint,
4738 MVT VT) const {
4739 if (Constraint.size() == 1) {
4740 switch (Constraint[0]) {
4741 case 'b':
4742 return std::make_pair(x: 0U, y: &NVPTX::B1RegClass);
4743 case 'c':
4744 case 'h':
4745 return std::make_pair(x: 0U, y: &NVPTX::B16RegClass);
4746 case 'r':
4747 case 'f':
4748 return std::make_pair(x: 0U, y: &NVPTX::B32RegClass);
4749 case 'l':
4750 case 'N':
4751 case 'd':
4752 return std::make_pair(x: 0U, y: &NVPTX::B64RegClass);
4753 case 'q': {
4754 if (STI.getSmVersion() < 70)
4755 report_fatal_error(reason: "Inline asm with 128 bit operands is only "
4756 "supported for sm_70 and higher!");
4757 return std::make_pair(x: 0U, y: &NVPTX::B128RegClass);
4758 }
4759 }
4760 }
4761 return TargetLowering::getRegForInlineAsmConstraint(TRI, Constraint, VT);
4762}
4763
4764//===----------------------------------------------------------------------===//
4765// NVPTX DAG Combining
4766//===----------------------------------------------------------------------===//
4767
4768bool NVPTXTargetLowering::allowFMA(MachineFunction &MF,
4769 CodeGenOptLevel OptLevel) const {
4770 // Always honor command-line argument
4771 if (FMAContractLevelOpt.getNumOccurrences() > 0)
4772 return FMAContractLevelOpt > 0;
4773
4774 // Do not contract if we're not optimizing the code.
4775 if (OptLevel == CodeGenOptLevel::None)
4776 return false;
4777
4778 // Honor TargetOptions flags that explicitly say fusion is okay.
4779 if (MF.getTarget().Options.AllowFPOpFusion == FPOpFusion::Fast)
4780 return true;
4781
4782 return allowUnsafeFPMath(MF);
4783}
4784
4785bool NVPTXTargetLowering::allowUnsafeFPMath(const MachineFunction &MF) const {
4786 // Honor TargetOptions flags that explicitly say unsafe math is okay.
4787 if (MF.getTarget().Options.UnsafeFPMath)
4788 return true;
4789
4790 // Allow unsafe math if unsafe-fp-math attribute explicitly says so.
4791 const Function &F = MF.getFunction();
4792 return F.getFnAttribute(Kind: "unsafe-fp-math").getValueAsBool();
4793}
4794
4795static bool isConstZero(const SDValue &Operand) {
4796 const auto *Const = dyn_cast<ConstantSDNode>(Val: Operand);
4797 return Const && Const->getZExtValue() == 0;
4798}
4799
4800/// PerformADDCombineWithOperands - Try DAG combinations for an ADD with
4801/// operands N0 and N1. This is a helper for PerformADDCombine that is
4802/// called with the default operands, and if that fails, with commuted
4803/// operands.
4804static SDValue
4805PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
4806 TargetLowering::DAGCombinerInfo &DCI) {
4807 EVT VT = N0.getValueType();
4808
4809 // Since integer multiply-add costs the same as integer multiply
4810 // but is more costly than integer add, do the fusion only when
4811 // the mul is only used in the add.
4812 // TODO: this may not be true for later architectures, consider relaxing this
4813 if (!N0.getNode()->hasOneUse())
4814 return SDValue();
4815
4816 // fold (add (select cond, 0, (mul a, b)), c)
4817 // -> (select cond, c, (add (mul a, b), c))
4818 //
4819 if (N0.getOpcode() == ISD::SELECT) {
4820 unsigned ZeroOpNum;
4821 if (isConstZero(Operand: N0->getOperand(Num: 1)))
4822 ZeroOpNum = 1;
4823 else if (isConstZero(Operand: N0->getOperand(Num: 2)))
4824 ZeroOpNum = 2;
4825 else
4826 return SDValue();
4827
4828 SDValue M = N0->getOperand(Num: (ZeroOpNum == 1) ? 2 : 1);
4829 if (M->getOpcode() != ISD::MUL || !M.getNode()->hasOneUse())
4830 return SDValue();
4831
4832 SDLoc DL(N);
4833 SDValue Mul =
4834 DCI.DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: M->getOperand(Num: 0), N2: M->getOperand(Num: 1));
4835 SDValue MAD = DCI.DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Mul, N2: N1);
4836 return DCI.DAG.getSelect(DL: SDLoc(N), VT, Cond: N0->getOperand(Num: 0),
4837 LHS: ((ZeroOpNum == 1) ? N1 : MAD),
4838 RHS: ((ZeroOpNum == 1) ? MAD : N1));
4839 }
4840
4841 return SDValue();
4842}
4843
4844static SDValue
4845PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
4846 TargetLowering::DAGCombinerInfo &DCI,
4847 CodeGenOptLevel OptLevel) {
4848 EVT VT = N0.getValueType();
4849 if (N0.getOpcode() == ISD::FMUL) {
4850 const auto *TLI = static_cast<const NVPTXTargetLowering *>(
4851 &DCI.DAG.getTargetLoweringInfo());
4852 if (!(TLI->allowFMA(MF&: DCI.DAG.getMachineFunction(), OptLevel) ||
4853 (N->getFlags().hasAllowContract() &&
4854 N0->getFlags().hasAllowContract())))
4855 return SDValue();
4856
4857 // For floating point:
4858 // Do the fusion only when the mul has less than 5 uses and all
4859 // are add.
4860 // The heuristic is that if a use is not an add, then that use
4861 // cannot be fused into fma, therefore mul is still needed anyway.
4862 // If there are more than 4 uses, even if they are all add, fusing
4863 // them will increase register pressue.
4864 //
4865 int numUses = 0;
4866 int nonAddCount = 0;
4867 for (const SDNode *User : N0.getNode()->users()) {
4868 numUses++;
4869 if (User->getOpcode() != ISD::FADD)
4870 ++nonAddCount;
4871 if (numUses >= 5)
4872 return SDValue();
4873 }
4874 if (nonAddCount) {
4875 int orderNo = N->getIROrder();
4876 int orderNo2 = N0.getNode()->getIROrder();
4877 // simple heuristics here for considering potential register
4878 // pressure, the logics here is that the differnce are used
4879 // to measure the distance between def and use, the longer distance
4880 // more likely cause register pressure.
4881 if (orderNo - orderNo2 < 500)
4882 return SDValue();
4883
4884 // Now, check if at least one of the FMUL's operands is live beyond the
4885 // node N, which guarantees that the FMA will not increase register
4886 // pressure at node N.
4887 bool opIsLive = false;
4888 const SDNode *left = N0.getOperand(i: 0).getNode();
4889 const SDNode *right = N0.getOperand(i: 1).getNode();
4890
4891 if (isa<ConstantSDNode>(Val: left) || isa<ConstantSDNode>(Val: right))
4892 opIsLive = true;
4893
4894 if (!opIsLive)
4895 for (const SDNode *User : left->users()) {
4896 int orderNo3 = User->getIROrder();
4897 if (orderNo3 > orderNo) {
4898 opIsLive = true;
4899 break;
4900 }
4901 }
4902
4903 if (!opIsLive)
4904 for (const SDNode *User : right->users()) {
4905 int orderNo3 = User->getIROrder();
4906 if (orderNo3 > orderNo) {
4907 opIsLive = true;
4908 break;
4909 }
4910 }
4911
4912 if (!opIsLive)
4913 return SDValue();
4914 }
4915
4916 return DCI.DAG.getNode(Opcode: ISD::FMA, DL: SDLoc(N), VT, N1: N0.getOperand(i: 0),
4917 N2: N0.getOperand(i: 1), N3: N1);
4918 }
4919
4920 return SDValue();
4921}
4922
4923/// Fold extractelts into a load by increasing the number of return values.
4924///
4925/// ex:
4926/// L: v2f16,ch = load <p>
4927/// a: f16 = extractelt L:0, 0
4928/// b: f16 = extractelt L:0, 1
4929/// use(a, b)
4930///
4931/// ...is turned into...
4932/// L: f16,f16,ch = LoadV2 <p>
4933/// use(L:0, L:1)
4934static SDValue
4935combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
4936 // Don't run this optimization before the legalizer
4937 if (!DCI.isAfterLegalizeDAG())
4938 return SDValue();
4939
4940 EVT ElemVT = N->getValueType(ResNo: 0);
4941 if (!Isv2x16VT(VT: ElemVT))
4942 return SDValue();
4943
4944 // Check whether all outputs are either used by an extractelt or are
4945 // glue/chain nodes
4946 if (!all_of(Range: N->uses(), P: [&](SDUse &U) {
4947 // Skip glue, chain nodes
4948 if (U.getValueType() == MVT::Glue || U.getValueType() == MVT::Other)
4949 return true;
4950 if (U.getUser()->getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
4951 if (N->getOpcode() != ISD::LOAD)
4952 return true;
4953 // Since this is an ISD::LOAD, check all extractelts are used. If
4954 // any are not used, we don't want to defeat another optimization that
4955 // will narrow the load.
4956 //
4957 // For example:
4958 //
4959 // L: v2f16,ch = load <p>
4960 // e0: f16 = extractelt L:0, 0
4961 // e1: f16 = extractelt L:0, 1 <-- unused
4962 // store e0
4963 //
4964 // Can be optimized by DAGCombiner to:
4965 //
4966 // L: f16,ch = load <p>
4967 // store L:0
4968 return !U.getUser()->use_empty();
4969 }
4970
4971 // Otherwise, this use prevents us from splitting a value.
4972 return false;
4973 }))
4974 return SDValue();
4975
4976 auto *LD = cast<MemSDNode>(Val: N);
4977 EVT MemVT = LD->getMemoryVT();
4978 SDLoc DL(LD);
4979
4980 // the new opcode after we double the number of operands
4981 NVPTXISD::NodeType Opcode;
4982 SmallVector<SDValue> Operands(LD->ops());
4983 unsigned OldNumOutputs; // non-glue, non-chain outputs
4984 switch (LD->getOpcode()) {
4985 case ISD::LOAD:
4986 OldNumOutputs = 1;
4987 // Any packed type is legal, so the legalizer will not have lowered
4988 // ISD::LOAD -> NVPTXISD::Load (unless it's under-aligned). We have to do it
4989 // here.
4990 Opcode = NVPTXISD::LoadV2;
4991 Operands.push_back(Elt: DCI.DAG.getIntPtrConstant(
4992 Val: cast<LoadSDNode>(Val: LD)->getExtensionType(), DL));
4993 break;
4994 case NVPTXISD::LoadParamV2:
4995 OldNumOutputs = 2;
4996 Opcode = NVPTXISD::LoadParamV4;
4997 break;
4998 case NVPTXISD::LoadV2:
4999 OldNumOutputs = 2;
5000 Opcode = NVPTXISD::LoadV4;
5001 break;
5002 case NVPTXISD::LoadV4:
5003 case NVPTXISD::LoadV8:
5004 // PTX doesn't support the next doubling of outputs
5005 return SDValue();
5006 }
5007
5008 // the non-glue, non-chain outputs in the new load
5009 const unsigned NewNumOutputs = OldNumOutputs * 2;
5010 SmallVector<EVT> NewVTs(NewNumOutputs, ElemVT.getVectorElementType());
5011 // add remaining chain and glue values
5012 NewVTs.append(in_start: LD->value_begin() + OldNumOutputs, in_end: LD->value_end());
5013
5014 // Create the new load
5015 SDValue NewLoad =
5016 DCI.DAG.getMemIntrinsicNode(Opcode, dl: DL, VTList: DCI.DAG.getVTList(VTs: NewVTs),
5017 Ops: Operands, MemVT, MMO: LD->getMemOperand());
5018
5019 // Now we use a combination of BUILD_VECTORs and a MERGE_VALUES node to keep
5020 // the outputs the same. These nodes will be optimized away in later
5021 // DAGCombiner iterations.
5022 SmallVector<SDValue> Results;
5023 for (unsigned I : seq(Size: OldNumOutputs))
5024 Results.push_back(Elt: DCI.DAG.getBuildVector(
5025 VT: ElemVT, DL, Ops: {NewLoad.getValue(R: I * 2), NewLoad.getValue(R: I * 2 + 1)}));
5026 // Add remaining chain and glue nodes
5027 for (unsigned I : seq(Size: NewLoad->getNumValues() - NewNumOutputs))
5028 Results.push_back(Elt: NewLoad.getValue(R: NewNumOutputs + I));
5029
5030 return DCI.DAG.getMergeValues(Ops: Results, dl: DL);
5031}
5032
5033/// Fold a packing mov into a store.
5034///
5035/// ex:
5036/// v: v2f16 = BUILD_VECTOR a:f16, b:f16
5037/// StoreRetval v
5038///
5039/// ...is turned into...
5040///
5041/// StoreRetvalV2 a:f16, b:f16
5042static SDValue combinePackingMovIntoStore(SDNode *N,
5043 TargetLowering::DAGCombinerInfo &DCI,
5044 unsigned Front, unsigned Back) {
5045 // We want to run this as late as possible since other optimizations may
5046 // eliminate the BUILD_VECTORs.
5047 if (!DCI.isAfterLegalizeDAG())
5048 return SDValue();
5049
5050 // Get the type of the operands being stored.
5051 EVT ElementVT = N->getOperand(Num: Front).getValueType();
5052
5053 if (!Isv2x16VT(VT: ElementVT))
5054 return SDValue();
5055
5056 auto *ST = cast<MemSDNode>(Val: N);
5057 EVT MemVT = ElementVT.getVectorElementType();
5058
5059 // The new opcode after we double the number of operands.
5060 NVPTXISD::NodeType Opcode;
5061 switch (N->getOpcode()) {
5062 case ISD::STORE:
5063 // Any packed type is legal, so the legalizer will not have lowered
5064 // ISD::STORE -> NVPTXISD::Store (unless it's under-aligned). We have to do
5065 // it here.
5066 MemVT = ST->getMemoryVT();
5067 Opcode = NVPTXISD::StoreV2;
5068 break;
5069 case NVPTXISD::StoreParam:
5070 Opcode = NVPTXISD::StoreParamV2;
5071 break;
5072 case NVPTXISD::StoreParamV2:
5073 Opcode = NVPTXISD::StoreParamV4;
5074 break;
5075 case NVPTXISD::StoreV2:
5076 MemVT = ST->getMemoryVT();
5077 Opcode = NVPTXISD::StoreV4;
5078 break;
5079 case NVPTXISD::StoreV4:
5080 case NVPTXISD::StoreParamV4:
5081 case NVPTXISD::StoreV8:
5082 // PTX doesn't support the next doubling of operands
5083 return SDValue();
5084 default:
5085 llvm_unreachable("Unhandled store opcode");
5086 }
5087
5088 // Scan the operands and if they're all BUILD_VECTORs, we'll have gathered
5089 // their elements.
5090 SmallVector<SDValue, 4> Operands(N->ops().take_front(N: Front));
5091 for (SDValue BV : N->ops().drop_front(N: Front).drop_back(N: Back)) {
5092 if (BV.getOpcode() != ISD::BUILD_VECTOR)
5093 return SDValue();
5094
5095 // If the operand has multiple uses, this optimization can increase register
5096 // pressure.
5097 if (!BV.hasOneUse())
5098 return SDValue();
5099
5100 // DAGCombiner visits nodes bottom-up. Check the BUILD_VECTOR operands for
5101 // any signs they may be folded by some other pattern or rule.
5102 for (SDValue Op : BV->ops()) {
5103 // Peek through bitcasts
5104 if (Op.getOpcode() == ISD::BITCAST)
5105 Op = Op.getOperand(i: 0);
5106
5107 // This may be folded into a PRMT.
5108 if (Op.getValueType() == MVT::i16 && Op.getOpcode() == ISD::TRUNCATE &&
5109 Op->getOperand(Num: 0).getValueType() == MVT::i32)
5110 return SDValue();
5111
5112 // This may be folded into cvt.bf16x2
5113 if (Op.getOpcode() == ISD::FP_ROUND)
5114 return SDValue();
5115 }
5116 Operands.append(IL: {BV.getOperand(i: 0), BV.getOperand(i: 1)});
5117 }
5118 Operands.append(in_start: N->op_end() - Back, in_end: N->op_end());
5119
5120 // Now we replace the store
5121 return DCI.DAG.getMemIntrinsicNode(Opcode, dl: SDLoc(N), VTList: N->getVTList(), Ops: Operands,
5122 MemVT, MMO: ST->getMemOperand());
5123}
5124
5125static SDValue PerformStoreCombineHelper(SDNode *N,
5126 TargetLowering::DAGCombinerInfo &DCI,
5127 unsigned Front, unsigned Back) {
5128 if (all_of(Range: N->ops().drop_front(N: Front).drop_back(N: Back),
5129 P: [](const SDUse &U) { return U.get()->isUndef(); }))
5130 // Operand 0 is the previous value in the chain. Cannot return EntryToken
5131 // as the previous value will become unused and eliminated later.
5132 return N->getOperand(Num: 0);
5133
5134 return combinePackingMovIntoStore(N, DCI, Front, Back);
5135}
5136
5137static SDValue PerformStoreCombine(SDNode *N,
5138 TargetLowering::DAGCombinerInfo &DCI) {
5139 return combinePackingMovIntoStore(N, DCI, Front: 1, Back: 2);
5140}
5141
5142static SDValue PerformStoreParamCombine(SDNode *N,
5143 TargetLowering::DAGCombinerInfo &DCI) {
5144 // Operands from the 3rd to the 2nd last one are the values to be stored.
5145 // {Chain, ArgID, Offset, Val, Glue}
5146 return PerformStoreCombineHelper(N, DCI, Front: 3, Back: 1);
5147}
5148
5149/// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
5150///
5151static SDValue PerformADDCombine(SDNode *N,
5152 TargetLowering::DAGCombinerInfo &DCI,
5153 CodeGenOptLevel OptLevel) {
5154 if (OptLevel == CodeGenOptLevel::None)
5155 return SDValue();
5156
5157 SDValue N0 = N->getOperand(Num: 0);
5158 SDValue N1 = N->getOperand(Num: 1);
5159
5160 // Skip non-integer, non-scalar case
5161 EVT VT = N0.getValueType();
5162 if (VT.isVector() || VT != MVT::i32)
5163 return SDValue();
5164
5165 // First try with the default operand order.
5166 if (SDValue Result = PerformADDCombineWithOperands(N, N0, N1, DCI))
5167 return Result;
5168
5169 // If that didn't work, try again with the operands commuted.
5170 return PerformADDCombineWithOperands(N, N0: N1, N1: N0, DCI);
5171}
5172
5173/// PerformFADDCombine - Target-specific dag combine xforms for ISD::FADD.
5174///
5175static SDValue PerformFADDCombine(SDNode *N,
5176 TargetLowering::DAGCombinerInfo &DCI,
5177 CodeGenOptLevel OptLevel) {
5178 SDValue N0 = N->getOperand(Num: 0);
5179 SDValue N1 = N->getOperand(Num: 1);
5180
5181 EVT VT = N0.getValueType();
5182 if (VT.isVector() || !(VT == MVT::f32 || VT == MVT::f64))
5183 return SDValue();
5184
5185 // First try with the default operand order.
5186 if (SDValue Result = PerformFADDCombineWithOperands(N, N0, N1, DCI, OptLevel))
5187 return Result;
5188
5189 // If that didn't work, try again with the operands commuted.
5190 return PerformFADDCombineWithOperands(N, N0: N1, N1: N0, DCI, OptLevel);
5191}
5192
5193static SDValue PerformANDCombine(SDNode *N,
5194 TargetLowering::DAGCombinerInfo &DCI) {
5195 // The type legalizer turns a vector load of i8 values into a zextload to i16
5196 // registers, optionally ANY_EXTENDs it (if target type is integer),
5197 // and ANDs off the high 8 bits. Since we turn this load into a
5198 // target-specific DAG node, the DAG combiner fails to eliminate these AND
5199 // nodes. Do that here.
5200 SDValue Val = N->getOperand(Num: 0);
5201 SDValue Mask = N->getOperand(Num: 1);
5202
5203 if (isa<ConstantSDNode>(Val)) {
5204 std::swap(a&: Val, b&: Mask);
5205 }
5206
5207 SDValue AExt;
5208
5209 // Convert BFE-> truncate i16 -> and 255
5210 // To just BFE-> truncate i16, as the value already has all the bits in the
5211 // right places.
5212 if (Val.getOpcode() == ISD::TRUNCATE) {
5213 SDValue BFE = Val.getOperand(i: 0);
5214 if (BFE.getOpcode() != NVPTXISD::BFE)
5215 return SDValue();
5216
5217 ConstantSDNode *BFEBits = dyn_cast<ConstantSDNode>(Val: BFE.getOperand(i: 0));
5218 if (!BFEBits)
5219 return SDValue();
5220 uint64_t BFEBitsVal = BFEBits->getZExtValue();
5221
5222 ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Val&: Mask);
5223 if (!MaskCnst) {
5224 // Not an AND with a constant
5225 return SDValue();
5226 }
5227 uint64_t MaskVal = MaskCnst->getZExtValue();
5228
5229 if (MaskVal != (uint64_t(1) << BFEBitsVal) - 1)
5230 return SDValue();
5231 // If we get here, the AND is unnecessary. Just replace it with the trunc
5232 DCI.CombineTo(N, Res: Val, AddTo: false);
5233 }
5234 // Generally, we will see zextload -> IMOV16rr -> ANY_EXTEND -> and
5235 if (Val.getOpcode() == ISD::ANY_EXTEND) {
5236 AExt = Val;
5237 Val = Val->getOperand(Num: 0);
5238 }
5239
5240 if (Val->getOpcode() == NVPTXISD::LoadV2 ||
5241 Val->getOpcode() == NVPTXISD::LoadV4) {
5242 ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Val&: Mask);
5243 if (!MaskCnst) {
5244 // Not an AND with a constant
5245 return SDValue();
5246 }
5247
5248 uint64_t MaskVal = MaskCnst->getZExtValue();
5249 if (MaskVal != 0xff) {
5250 // Not an AND that chops off top 8 bits
5251 return SDValue();
5252 }
5253
5254 MemSDNode *Mem = dyn_cast<MemSDNode>(Val);
5255 if (!Mem) {
5256 // Not a MemSDNode?!?
5257 return SDValue();
5258 }
5259
5260 EVT MemVT = Mem->getMemoryVT();
5261 if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8) {
5262 // We only handle the i8 case
5263 return SDValue();
5264 }
5265
5266 unsigned ExtType = Val->getConstantOperandVal(Num: Val->getNumOperands() - 1);
5267 if (ExtType == ISD::SEXTLOAD) {
5268 // If for some reason the load is a sextload, the and is needed to zero
5269 // out the high 8 bits
5270 return SDValue();
5271 }
5272
5273 bool AddTo = false;
5274 if (AExt.getNode() != nullptr) {
5275 // Re-insert the ext as a zext.
5276 Val = DCI.DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL: SDLoc(N),
5277 VT: AExt.getValueType(), Operand: Val);
5278 AddTo = true;
5279 }
5280
5281 // If we get here, the AND is unnecessary. Just replace it with the load
5282 DCI.CombineTo(N, Res: Val, AddTo);
5283 }
5284
5285 return SDValue();
5286}
5287
5288static SDValue PerformREMCombine(SDNode *N,
5289 TargetLowering::DAGCombinerInfo &DCI,
5290 CodeGenOptLevel OptLevel) {
5291 assert(N->getOpcode() == ISD::SREM || N->getOpcode() == ISD::UREM);
5292
5293 // Don't do anything at less than -O2.
5294 if (OptLevel < CodeGenOptLevel::Default)
5295 return SDValue();
5296
5297 SelectionDAG &DAG = DCI.DAG;
5298 SDLoc DL(N);
5299 EVT VT = N->getValueType(ResNo: 0);
5300 bool IsSigned = N->getOpcode() == ISD::SREM;
5301 unsigned DivOpc = IsSigned ? ISD::SDIV : ISD::UDIV;
5302
5303 const SDValue &Num = N->getOperand(Num: 0);
5304 const SDValue &Den = N->getOperand(Num: 1);
5305
5306 for (const SDNode *U : Num->users()) {
5307 if (U->getOpcode() == DivOpc && U->getOperand(Num: 0) == Num &&
5308 U->getOperand(Num: 1) == Den) {
5309 // Num % Den -> Num - (Num / Den) * Den
5310 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Num,
5311 N2: DAG.getNode(Opcode: ISD::MUL, DL, VT,
5312 N1: DAG.getNode(Opcode: DivOpc, DL, VT, N1: Num, N2: Den),
5313 N2: Den));
5314 }
5315 }
5316 return SDValue();
5317}
5318
5319enum OperandSignedness {
5320 Signed = 0,
5321 Unsigned,
5322 Unknown
5323};
5324
5325/// IsMulWideOperandDemotable - Checks if the provided DAG node is an operand
5326/// that can be demoted to \p OptSize bits without loss of information. The
5327/// signedness of the operand, if determinable, is placed in \p S.
5328static bool IsMulWideOperandDemotable(SDValue Op,
5329 unsigned OptSize,
5330 OperandSignedness &S) {
5331 S = Unknown;
5332
5333 if (Op.getOpcode() == ISD::SIGN_EXTEND ||
5334 Op.getOpcode() == ISD::SIGN_EXTEND_INREG) {
5335 EVT OrigVT = Op.getOperand(i: 0).getValueType();
5336 if (OrigVT.getFixedSizeInBits() <= OptSize) {
5337 S = Signed;
5338 return true;
5339 }
5340 } else if (Op.getOpcode() == ISD::ZERO_EXTEND) {
5341 EVT OrigVT = Op.getOperand(i: 0).getValueType();
5342 if (OrigVT.getFixedSizeInBits() <= OptSize) {
5343 S = Unsigned;
5344 return true;
5345 }
5346 }
5347
5348 return false;
5349}
5350
5351/// AreMulWideOperandsDemotable - Checks if the given LHS and RHS operands can
5352/// be demoted to \p OptSize bits without loss of information. If the operands
5353/// contain a constant, it should appear as the RHS operand. The signedness of
5354/// the operands is placed in \p IsSigned.
5355static bool AreMulWideOperandsDemotable(SDValue LHS, SDValue RHS,
5356 unsigned OptSize,
5357 bool &IsSigned) {
5358 OperandSignedness LHSSign;
5359
5360 // The LHS operand must be a demotable op
5361 if (!IsMulWideOperandDemotable(Op: LHS, OptSize, S&: LHSSign))
5362 return false;
5363
5364 // We should have been able to determine the signedness from the LHS
5365 if (LHSSign == Unknown)
5366 return false;
5367
5368 IsSigned = (LHSSign == Signed);
5369
5370 // The RHS can be a demotable op or a constant
5371 if (ConstantSDNode *CI = dyn_cast<ConstantSDNode>(Val&: RHS)) {
5372 const APInt &Val = CI->getAPIntValue();
5373 if (LHSSign == Unsigned) {
5374 return Val.isIntN(N: OptSize);
5375 } else {
5376 return Val.isSignedIntN(N: OptSize);
5377 }
5378 } else {
5379 OperandSignedness RHSSign;
5380 if (!IsMulWideOperandDemotable(Op: RHS, OptSize, S&: RHSSign))
5381 return false;
5382
5383 return LHSSign == RHSSign;
5384 }
5385}
5386
5387/// TryMULWIDECombine - Attempt to replace a multiply of M bits with a multiply
5388/// of M/2 bits that produces an M-bit result (i.e. mul.wide). This transform
5389/// works on both multiply DAG nodes and SHL DAG nodes with a constant shift
5390/// amount.
5391static SDValue TryMULWIDECombine(SDNode *N,
5392 TargetLowering::DAGCombinerInfo &DCI) {
5393 EVT MulType = N->getValueType(ResNo: 0);
5394 if (MulType != MVT::i32 && MulType != MVT::i64) {
5395 return SDValue();
5396 }
5397
5398 SDLoc DL(N);
5399 unsigned OptSize = MulType.getSizeInBits() >> 1;
5400 SDValue LHS = N->getOperand(Num: 0);
5401 SDValue RHS = N->getOperand(Num: 1);
5402
5403 // Canonicalize the multiply so the constant (if any) is on the right
5404 if (N->getOpcode() == ISD::MUL) {
5405 if (isa<ConstantSDNode>(Val: LHS)) {
5406 std::swap(a&: LHS, b&: RHS);
5407 }
5408 }
5409
5410 // If we have a SHL, determine the actual multiply amount
5411 if (N->getOpcode() == ISD::SHL) {
5412 ConstantSDNode *ShlRHS = dyn_cast<ConstantSDNode>(Val&: RHS);
5413 if (!ShlRHS) {
5414 return SDValue();
5415 }
5416
5417 APInt ShiftAmt = ShlRHS->getAPIntValue();
5418 unsigned BitWidth = MulType.getSizeInBits();
5419 if (ShiftAmt.sge(RHS: 0) && ShiftAmt.slt(RHS: BitWidth)) {
5420 APInt MulVal = APInt(BitWidth, 1) << ShiftAmt;
5421 RHS = DCI.DAG.getConstant(Val: MulVal, DL, VT: MulType);
5422 } else {
5423 return SDValue();
5424 }
5425 }
5426
5427 bool Signed;
5428 // Verify that our operands are demotable
5429 if (!AreMulWideOperandsDemotable(LHS, RHS, OptSize, IsSigned&: Signed)) {
5430 return SDValue();
5431 }
5432
5433 EVT DemotedVT;
5434 if (MulType == MVT::i32) {
5435 DemotedVT = MVT::i16;
5436 } else {
5437 DemotedVT = MVT::i32;
5438 }
5439
5440 // Truncate the operands to the correct size. Note that these are just for
5441 // type consistency and will (likely) be eliminated in later phases.
5442 SDValue TruncLHS =
5443 DCI.DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: DemotedVT, Operand: LHS);
5444 SDValue TruncRHS =
5445 DCI.DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: DemotedVT, Operand: RHS);
5446
5447 unsigned Opc;
5448 if (Signed) {
5449 Opc = NVPTXISD::MUL_WIDE_SIGNED;
5450 } else {
5451 Opc = NVPTXISD::MUL_WIDE_UNSIGNED;
5452 }
5453
5454 return DCI.DAG.getNode(Opcode: Opc, DL, VT: MulType, N1: TruncLHS, N2: TruncRHS);
5455}
5456
5457static bool isConstOne(const SDValue &Operand) {
5458 const auto *Const = dyn_cast<ConstantSDNode>(Val: Operand);
5459 return Const && Const->getZExtValue() == 1;
5460}
5461
5462static SDValue matchMADConstOnePattern(SDValue Add) {
5463 if (Add->getOpcode() != ISD::ADD)
5464 return SDValue();
5465
5466 if (isConstOne(Operand: Add->getOperand(Num: 0)))
5467 return Add->getOperand(Num: 1);
5468
5469 if (isConstOne(Operand: Add->getOperand(Num: 1)))
5470 return Add->getOperand(Num: 0);
5471
5472 return SDValue();
5473}
5474
5475static SDValue combineMADConstOne(SDValue X, SDValue Add, EVT VT, SDLoc DL,
5476 TargetLowering::DAGCombinerInfo &DCI) {
5477
5478 if (SDValue Y = matchMADConstOnePattern(Add)) {
5479 SDValue Mul = DCI.DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: X, N2: Y);
5480 return DCI.DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Mul, N2: X);
5481 }
5482
5483 return SDValue();
5484}
5485
5486static SDValue combineMulSelectConstOne(SDValue X, SDValue Select, EVT VT,
5487 SDLoc DL,
5488 TargetLowering::DAGCombinerInfo &DCI) {
5489 if (Select->getOpcode() != ISD::SELECT)
5490 return SDValue();
5491
5492 SDValue Cond = Select->getOperand(Num: 0);
5493
5494 unsigned ConstOpNo;
5495 if (isConstOne(Operand: Select->getOperand(Num: 1)))
5496 ConstOpNo = 1;
5497 else if (isConstOne(Operand: Select->getOperand(Num: 2)))
5498 ConstOpNo = 2;
5499 else
5500 return SDValue();
5501
5502 SDValue Y = Select->getOperand(Num: (ConstOpNo == 1) ? 2 : 1);
5503
5504 // Do not combine if the resulting sequence is not obviously profitable.
5505 if (!matchMADConstOnePattern(Add: Y))
5506 return SDValue();
5507
5508 SDValue NewMul = DCI.DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: X, N2: Y);
5509
5510 return DCI.DAG.getNode(Opcode: ISD::SELECT, DL, VT, N1: Cond,
5511 N2: (ConstOpNo == 1) ? X : NewMul,
5512 N3: (ConstOpNo == 1) ? NewMul : X);
5513}
5514
5515static SDValue
5516PerformMULCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
5517 TargetLowering::DAGCombinerInfo &DCI) {
5518
5519 EVT VT = N0.getValueType();
5520 if (VT.isVector())
5521 return SDValue();
5522
5523 if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
5524 return SDValue();
5525
5526 SDLoc DL(N);
5527
5528 // (mul x, (add y, 1)) -> (add (mul x, y), x)
5529 if (SDValue Res = combineMADConstOne(X: N0, Add: N1, VT, DL, DCI))
5530 return Res;
5531 if (SDValue Res = combineMADConstOne(X: N1, Add: N0, VT, DL, DCI))
5532 return Res;
5533
5534 // (mul x, (select y, 1)) -> (select (mul x, y), x)
5535 if (SDValue Res = combineMulSelectConstOne(X: N0, Select: N1, VT, DL, DCI))
5536 return Res;
5537 if (SDValue Res = combineMulSelectConstOne(X: N1, Select: N0, VT, DL, DCI))
5538 return Res;
5539
5540 return SDValue();
5541}
5542
5543/// PerformMULCombine - Runs PTX-specific DAG combine patterns on MUL nodes.
5544static SDValue PerformMULCombine(SDNode *N,
5545 TargetLowering::DAGCombinerInfo &DCI,
5546 CodeGenOptLevel OptLevel) {
5547 if (OptLevel == CodeGenOptLevel::None)
5548 return SDValue();
5549
5550 if (SDValue Ret = TryMULWIDECombine(N, DCI))
5551 return Ret;
5552
5553 SDValue N0 = N->getOperand(Num: 0);
5554 SDValue N1 = N->getOperand(Num: 1);
5555 return PerformMULCombineWithOperands(N, N0, N1, DCI);
5556}
5557
5558/// PerformSHLCombine - Runs PTX-specific DAG combine patterns on SHL nodes.
5559static SDValue PerformSHLCombine(SDNode *N,
5560 TargetLowering::DAGCombinerInfo &DCI,
5561 CodeGenOptLevel OptLevel) {
5562 if (OptLevel > CodeGenOptLevel::None) {
5563 // Try mul.wide combining at OptLevel > 0
5564 if (SDValue Ret = TryMULWIDECombine(N, DCI))
5565 return Ret;
5566 }
5567
5568 return SDValue();
5569}
5570
5571static SDValue PerformSETCCCombine(SDNode *N,
5572 TargetLowering::DAGCombinerInfo &DCI,
5573 unsigned int SmVersion) {
5574 EVT CCType = N->getValueType(ResNo: 0);
5575 SDValue A = N->getOperand(Num: 0);
5576 SDValue B = N->getOperand(Num: 1);
5577
5578 EVT AType = A.getValueType();
5579 if (!(CCType == MVT::v2i1 && (AType == MVT::v2f16 || AType == MVT::v2bf16)))
5580 return SDValue();
5581
5582 if (A.getValueType() == MVT::v2bf16 && SmVersion < 90)
5583 return SDValue();
5584
5585 SDLoc DL(N);
5586 // setp.f16x2 returns two scalar predicates, which we need to
5587 // convert back to v2i1. The returned result will be scalarized by
5588 // the legalizer, but the comparison will remain a single vector
5589 // instruction.
5590 SDValue CCNode = DCI.DAG.getNode(
5591 Opcode: A.getValueType() == MVT::v2f16 ? NVPTXISD::SETP_F16X2
5592 : NVPTXISD::SETP_BF16X2,
5593 DL, VTList: DCI.DAG.getVTList(VT1: MVT::i1, VT2: MVT::i1), Ops: {A, B, N->getOperand(Num: 2)});
5594 return DCI.DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: CCType, N1: CCNode.getValue(R: 0),
5595 N2: CCNode.getValue(R: 1));
5596}
5597
5598static SDValue PerformEXTRACTCombine(SDNode *N,
5599 TargetLowering::DAGCombinerInfo &DCI) {
5600 SDValue Vector = N->getOperand(Num: 0);
5601 if (Vector->getOpcode() == ISD::FREEZE)
5602 Vector = Vector->getOperand(Num: 0);
5603 SDLoc DL(N);
5604 EVT VectorVT = Vector.getValueType();
5605 if (Vector->getOpcode() == ISD::LOAD && VectorVT.isSimple() &&
5606 IsPTXVectorType(VT: VectorVT.getSimpleVT()))
5607 return SDValue(); // Native vector loads already combine nicely w/
5608 // extract_vector_elt.
5609 // Don't mess with singletons or v2*16, v4i8 and v8i8 types, we already
5610 // handle them OK.
5611 if (VectorVT.getVectorNumElements() == 1 || Isv2x16VT(VT: VectorVT) ||
5612 VectorVT == MVT::v4i8 || VectorVT == MVT::v8i8)
5613 return SDValue();
5614
5615 // Don't mess with undef values as sra may be simplified to 0, not undef.
5616 if (Vector->isUndef() || ISD::allOperandsUndef(N: Vector.getNode()))
5617 return SDValue();
5618
5619 uint64_t VectorBits = VectorVT.getSizeInBits();
5620 // We only handle the types we can extract in-register.
5621 if (!(VectorBits == 16 || VectorBits == 32 || VectorBits == 64))
5622 return SDValue();
5623
5624 ConstantSDNode *Index = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
5625 // Index == 0 is handled by generic DAG combiner.
5626 if (!Index || Index->getZExtValue() == 0)
5627 return SDValue();
5628
5629 MVT IVT = MVT::getIntegerVT(BitWidth: VectorBits);
5630 EVT EltVT = VectorVT.getVectorElementType();
5631 EVT EltIVT = EltVT.changeTypeToInteger();
5632 uint64_t EltBits = EltVT.getScalarSizeInBits();
5633
5634 SDValue Result = DCI.DAG.getNode(
5635 Opcode: ISD::TRUNCATE, DL, VT: EltIVT,
5636 Operand: DCI.DAG.getNode(
5637 Opcode: ISD::SRA, DL, VT: IVT, N1: DCI.DAG.getNode(Opcode: ISD::BITCAST, DL, VT: IVT, Operand: Vector),
5638 N2: DCI.DAG.getConstant(Val: Index->getZExtValue() * EltBits, DL, VT: IVT)));
5639
5640 // If element has non-integer type, bitcast it back to the expected type.
5641 if (EltVT != EltIVT)
5642 Result = DCI.DAG.getNode(Opcode: ISD::BITCAST, DL, VT: EltVT, Operand: Result);
5643 // Past legalizer, we may need to extent i8 -> i16 to match the register type.
5644 if (EltVT != N->getValueType(ResNo: 0))
5645 Result = DCI.DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT: N->getValueType(ResNo: 0), Operand: Result);
5646
5647 return Result;
5648}
5649
5650static SDValue PerformVSELECTCombine(SDNode *N,
5651 TargetLowering::DAGCombinerInfo &DCI) {
5652 SDValue VA = N->getOperand(Num: 1);
5653 EVT VectorVT = VA.getValueType();
5654 if (VectorVT != MVT::v4i8)
5655 return SDValue();
5656
5657 // We need to split vselect into individual per-element operations Because we
5658 // use BFE/BFI instruction for byte extraction/insertion, we do end up with
5659 // 32-bit values, so we may as well do comparison as i32 to avoid conversions
5660 // to/from i16 normally used for i8 values.
5661 SmallVector<SDValue, 4> E;
5662 SDLoc DL(N);
5663 SDValue VCond = N->getOperand(Num: 0);
5664 SDValue VB = N->getOperand(Num: 2);
5665 for (int I = 0; I < 4; ++I) {
5666 SDValue C = DCI.DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i1, N1: VCond,
5667 N2: DCI.DAG.getConstant(Val: I, DL, VT: MVT::i32));
5668 SDValue EA = DCI.DAG.getAnyExtOrTrunc(
5669 Op: DCI.DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i8, N1: VA,
5670 N2: DCI.DAG.getConstant(Val: I, DL, VT: MVT::i32)),
5671 DL, VT: MVT::i32);
5672 SDValue EB = DCI.DAG.getAnyExtOrTrunc(
5673 Op: DCI.DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: MVT::i8, N1: VB,
5674 N2: DCI.DAG.getConstant(Val: I, DL, VT: MVT::i32)),
5675 DL, VT: MVT::i32);
5676 E.push_back(Elt: DCI.DAG.getAnyExtOrTrunc(
5677 Op: DCI.DAG.getNode(Opcode: ISD::SELECT, DL, VT: MVT::i32, N1: C, N2: EA, N3: EB), DL, VT: MVT::i8));
5678 }
5679 return DCI.DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: MVT::v4i8, Ops: E);
5680}
5681
5682static SDValue
5683PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
5684 auto VT = N->getValueType(ResNo: 0);
5685 if (!DCI.isAfterLegalizeDAG() || !Isv2x16VT(VT))
5686 return SDValue();
5687
5688 auto Op0 = N->getOperand(Num: 0);
5689 auto Op1 = N->getOperand(Num: 1);
5690
5691 // Start out by assuming we want to take the lower 2 bytes of each i32
5692 // operand.
5693 uint64_t Op0Bytes = 0x10;
5694 uint64_t Op1Bytes = 0x54;
5695
5696 std::pair<SDValue *, uint64_t *> OpData[2] = {{&Op0, &Op0Bytes},
5697 {&Op1, &Op1Bytes}};
5698
5699 // Check that each operand is an i16, truncated from an i32 operand. We'll
5700 // select individual bytes from those original operands. Optionally, fold in a
5701 // shift right of that original operand.
5702 for (auto &[Op, OpBytes] : OpData) {
5703 // Eat up any bitcast
5704 if (Op->getOpcode() == ISD::BITCAST)
5705 *Op = Op->getOperand(i: 0);
5706
5707 if (!(Op->getValueType() == MVT::i16 && Op->getOpcode() == ISD::TRUNCATE &&
5708 Op->getOperand(i: 0).getValueType() == MVT::i32))
5709 return SDValue();
5710
5711 // If the truncate has multiple uses, this optimization can increase
5712 // register pressure
5713 if (!Op->hasOneUse())
5714 return SDValue();
5715
5716 *Op = Op->getOperand(i: 0);
5717
5718 // Optionally, fold in a shift-right of the original operand and let permute
5719 // pick the two higher bytes of the original value directly.
5720 if (Op->getOpcode() == ISD::SRL && isa<ConstantSDNode>(Val: Op->getOperand(i: 1))) {
5721 if (cast<ConstantSDNode>(Val: Op->getOperand(i: 1))->getZExtValue() == 16) {
5722 // Shift the PRMT byte selector to pick upper bytes from each respective
5723 // value, instead of the lower ones: 0x10 -> 0x32, 0x54 -> 0x76
5724 assert((*OpBytes == 0x10 || *OpBytes == 0x54) &&
5725 "PRMT selector values out of range");
5726 *OpBytes += 0x22;
5727 *Op = Op->getOperand(i: 0);
5728 }
5729 }
5730 }
5731
5732 SDLoc DL(N);
5733 auto &DAG = DCI.DAG;
5734
5735 auto PRMT = DAG.getNode(
5736 Opcode: NVPTXISD::PRMT, DL, VT: MVT::v4i8,
5737 Ops: {Op0, Op1, DAG.getConstant(Val: (Op1Bytes << 8) | Op0Bytes, DL, VT: MVT::i32),
5738 DAG.getConstant(Val: NVPTX::PTXPrmtMode::NONE, DL, VT: MVT::i32)});
5739 return DAG.getNode(Opcode: ISD::BITCAST, DL, VT, Operand: PRMT);
5740}
5741
5742static SDValue combineADDRSPACECAST(SDNode *N,
5743 TargetLowering::DAGCombinerInfo &DCI) {
5744 auto *ASCN1 = cast<AddrSpaceCastSDNode>(Val: N);
5745
5746 if (auto *ASCN2 = dyn_cast<AddrSpaceCastSDNode>(Val: ASCN1->getOperand(Num: 0))) {
5747 assert(ASCN2->getDestAddressSpace() == ASCN1->getSrcAddressSpace());
5748
5749 // Fold asc[B -> A](asc[A -> B](x)) -> x
5750 if (ASCN1->getDestAddressSpace() == ASCN2->getSrcAddressSpace())
5751 return ASCN2->getOperand(Num: 0);
5752 }
5753
5754 return SDValue();
5755}
5756
5757SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
5758 DAGCombinerInfo &DCI) const {
5759 CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
5760 switch (N->getOpcode()) {
5761 default: break;
5762 case ISD::ADD:
5763 return PerformADDCombine(N, DCI, OptLevel);
5764 case ISD::FADD:
5765 return PerformFADDCombine(N, DCI, OptLevel);
5766 case ISD::MUL:
5767 return PerformMULCombine(N, DCI, OptLevel);
5768 case ISD::SHL:
5769 return PerformSHLCombine(N, DCI, OptLevel);
5770 case ISD::AND:
5771 return PerformANDCombine(N, DCI);
5772 case ISD::UREM:
5773 case ISD::SREM:
5774 return PerformREMCombine(N, DCI, OptLevel);
5775 case ISD::SETCC:
5776 return PerformSETCCCombine(N, DCI, SmVersion: STI.getSmVersion());
5777 case ISD::LOAD:
5778 case NVPTXISD::LoadParamV2:
5779 case NVPTXISD::LoadV2:
5780 case NVPTXISD::LoadV4:
5781 return combineUnpackingMovIntoLoad(N, DCI);
5782 case NVPTXISD::StoreParam:
5783 case NVPTXISD::StoreParamV2:
5784 case NVPTXISD::StoreParamV4:
5785 return PerformStoreParamCombine(N, DCI);
5786 case ISD::STORE:
5787 case NVPTXISD::StoreV2:
5788 case NVPTXISD::StoreV4:
5789 return PerformStoreCombine(N, DCI);
5790 case ISD::EXTRACT_VECTOR_ELT:
5791 return PerformEXTRACTCombine(N, DCI);
5792 case ISD::VSELECT:
5793 return PerformVSELECTCombine(N, DCI);
5794 case ISD::BUILD_VECTOR:
5795 return PerformBUILD_VECTORCombine(N, DCI);
5796 case ISD::ADDRSPACECAST:
5797 return combineADDRSPACECAST(N, DCI);
5798 }
5799 return SDValue();
5800}
5801
5802static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG,
5803 SmallVectorImpl<SDValue> &Results) {
5804 // Handle bitcasting to v2i8 without hitting the default promotion
5805 // strategy which goes through stack memory.
5806 SDValue Op(Node, 0);
5807 EVT ToVT = Op->getValueType(ResNo: 0);
5808 if (ToVT != MVT::v2i8) {
5809 return;
5810 }
5811
5812 // Bitcast to i16 and unpack elements into a vector
5813 SDLoc DL(Node);
5814 SDValue AsInt = DAG.getBitcast(VT: MVT::i16, V: Op->getOperand(Num: 0));
5815 SDValue Vec0 = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MVT::i8, Operand: AsInt);
5816 SDValue Const8 = DAG.getConstant(Val: 8, DL, VT: MVT::i16);
5817 SDValue Vec1 =
5818 DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MVT::i8,
5819 Operand: DAG.getNode(Opcode: ISD::SRL, DL, VT: MVT::i16, Ops: {AsInt, Const8}));
5820 Results.push_back(
5821 Elt: DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: MVT::v2i8, Ops: {Vec0, Vec1}));
5822}
5823
5824/// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
5825static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
5826 SmallVectorImpl<SDValue> &Results,
5827 const NVPTXSubtarget &STI) {
5828 LoadSDNode *LD = cast<LoadSDNode>(Val: N);
5829 const EVT ResVT = LD->getValueType(ResNo: 0);
5830 const EVT MemVT = LD->getMemoryVT();
5831
5832 // If we're doing sign/zero extension as part of the load, avoid lowering to
5833 // a LoadV node. TODO: consider relaxing this restriction.
5834 if (ResVT != MemVT)
5835 return;
5836
5837 const auto NumEltsAndEltVT = getVectorLoweringShape(
5838 VectorEVT: ResVT, CanLowerTo256Bit: STI.has256BitVectorLoadStore(AS: LD->getAddressSpace()));
5839 if (!NumEltsAndEltVT)
5840 return;
5841 const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
5842
5843 Align Alignment = LD->getAlign();
5844 const auto &TD = DAG.getDataLayout();
5845 Align PrefAlign = TD.getPrefTypeAlign(Ty: MemVT.getTypeForEVT(Context&: *DAG.getContext()));
5846 if (Alignment < PrefAlign) {
5847 // This load is not sufficiently aligned, so bail out and let this vector
5848 // load be scalarized. Note that we may still be able to emit smaller
5849 // vector loads. For example, if we are loading a <4 x float> with an
5850 // alignment of 8, this check will fail but the legalizer will try again
5851 // with 2 x <2 x float>, which will succeed with an alignment of 8.
5852 return;
5853 }
5854
5855 // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
5856 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
5857 // loaded type to i16 and propagate the "real" type as the memory type.
5858 const MVT LoadEltVT = (EltVT.getSizeInBits() < 16) ? MVT::i16 : EltVT;
5859
5860 unsigned Opcode;
5861 switch (NumElts) {
5862 default:
5863 return;
5864 case 2:
5865 Opcode = NVPTXISD::LoadV2;
5866 break;
5867 case 4:
5868 Opcode = NVPTXISD::LoadV4;
5869 break;
5870 case 8:
5871 Opcode = NVPTXISD::LoadV8;
5872 break;
5873 }
5874 auto ListVTs = SmallVector<EVT, 9>(NumElts, LoadEltVT);
5875 ListVTs.push_back(Elt: MVT::Other);
5876 SDVTList LdResVTs = DAG.getVTList(VTs: ListVTs);
5877
5878 SDLoc DL(LD);
5879
5880 // Copy regular operands
5881 SmallVector<SDValue, 8> OtherOps(LD->ops());
5882
5883 // The select routine does not have access to the LoadSDNode instance, so
5884 // pass along the extension information
5885 OtherOps.push_back(Elt: DAG.getIntPtrConstant(Val: LD->getExtensionType(), DL));
5886
5887 SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, dl: DL, VTList: LdResVTs, Ops: OtherOps,
5888 MemVT: LD->getMemoryVT(),
5889 MMO: LD->getMemOperand());
5890
5891 SmallVector<SDValue> ScalarRes;
5892 if (EltVT.isVector()) {
5893 assert(EVT(EltVT.getVectorElementType()) == ResVT.getVectorElementType());
5894 assert(NumElts * EltVT.getVectorNumElements() ==
5895 ResVT.getVectorNumElements());
5896 // Generate EXTRACT_VECTOR_ELTs to split v2[i,f,bf]16/v4i8 subvectors back
5897 // into individual elements.
5898 for (const unsigned I : llvm::seq(Size: NumElts)) {
5899 SDValue SubVector = NewLD.getValue(R: I);
5900 DAG.ExtractVectorElements(Op: SubVector, Args&: ScalarRes);
5901 }
5902 } else {
5903 for (const unsigned I : llvm::seq(Size: NumElts)) {
5904 SDValue Res = NewLD.getValue(R: I);
5905 if (LoadEltVT != EltVT)
5906 Res = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: EltVT, Operand: Res);
5907 ScalarRes.push_back(Elt: Res);
5908 }
5909 }
5910
5911 SDValue LoadChain = NewLD.getValue(R: NumElts);
5912
5913 const MVT BuildVecVT =
5914 MVT::getVectorVT(VT: EltVT.getScalarType(), NumElements: ScalarRes.size());
5915 SDValue BuildVec = DAG.getBuildVector(VT: BuildVecVT, DL, Ops: ScalarRes);
5916 SDValue LoadValue = DAG.getBitcast(VT: ResVT, V: BuildVec);
5917
5918 Results.append(IL: {LoadValue, LoadChain});
5919}
5920
5921// Lower vector return type of tcgen05.ld intrinsics
5922static void ReplaceTcgen05Ld(SDNode *N, SelectionDAG &DAG,
5923 SmallVectorImpl<SDValue> &Results,
5924 bool hasOffset = false) {
5925 SDLoc DL(N);
5926 EVT ResVT = N->getValueType(ResNo: 0);
5927 if (!ResVT.isVector())
5928 return; // already legalized.
5929
5930 const unsigned NumElts = ResVT.getVectorNumElements();
5931
5932 // Create the return type of the instructions
5933 SmallVector<EVT, 5> ListVTs;
5934 for (unsigned i = 0; i < NumElts; ++i)
5935 ListVTs.push_back(Elt: MVT::i32);
5936
5937 ListVTs.push_back(Elt: N->getValueType(ResNo: 1)); // Chain
5938
5939 SDVTList ResVTs = DAG.getVTList(VTs: ListVTs);
5940
5941 SmallVector<SDValue, 8> Ops{N->getOperand(Num: 0), N->getOperand(Num: 1),
5942 N->getOperand(Num: 2)};
5943
5944 if (hasOffset) {
5945 Ops.push_back(Elt: N->getOperand(Num: 3)); // offset
5946 Ops.push_back(Elt: N->getOperand(Num: 4)); // Pack flag
5947 } else
5948 Ops.push_back(Elt: N->getOperand(Num: 3)); // Pack flag
5949
5950 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
5951 SDValue NewNode =
5952 DAG.getMemIntrinsicNode(Opcode: ISD::INTRINSIC_W_CHAIN, dl: DL, VTList: ResVTs, Ops,
5953 MemVT: MemSD->getMemoryVT(), MMO: MemSD->getMemOperand());
5954
5955 // split the vector result
5956 SmallVector<SDValue, 4> ScalarRes;
5957 for (unsigned i = 0; i < NumElts; ++i) {
5958 SDValue Res = NewNode.getValue(R: i);
5959 ScalarRes.push_back(Elt: Res);
5960 }
5961
5962 SDValue Chain = NewNode.getValue(R: NumElts);
5963 SDValue BuildVector = DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: ResVT, Ops: ScalarRes);
5964 Results.push_back(Elt: BuildVector); // Build Vector
5965 Results.push_back(Elt: Chain); // Chain
5966}
5967
5968static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
5969 SmallVectorImpl<SDValue> &Results) {
5970 SDValue Chain = N->getOperand(Num: 0);
5971 SDValue Intrin = N->getOperand(Num: 1);
5972 SDLoc DL(N);
5973
5974 // Get the intrinsic ID
5975 unsigned IntrinNo = Intrin.getNode()->getAsZExtVal();
5976 switch (IntrinNo) {
5977 default:
5978 return;
5979 case Intrinsic::nvvm_ldu_global_i:
5980 case Intrinsic::nvvm_ldu_global_f:
5981 case Intrinsic::nvvm_ldu_global_p: {
5982 EVT ResVT = N->getValueType(ResNo: 0);
5983
5984 if (ResVT.isVector()) {
5985 // Vector LDG/LDU
5986
5987 unsigned NumElts = ResVT.getVectorNumElements();
5988 EVT EltVT = ResVT.getVectorElementType();
5989
5990 // Since LDU/LDG are target nodes, we cannot rely on DAG type
5991 // legalization.
5992 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
5993 // loaded type to i16 and propagate the "real" type as the memory type.
5994 bool NeedTrunc = false;
5995 if (EltVT.getSizeInBits() < 16) {
5996 EltVT = MVT::i16;
5997 NeedTrunc = true;
5998 }
5999
6000 unsigned Opcode = 0;
6001 SDVTList LdResVTs;
6002
6003 switch (NumElts) {
6004 default:
6005 return;
6006 case 2:
6007 Opcode = NVPTXISD::LDUV2;
6008 LdResVTs = DAG.getVTList(VT1: EltVT, VT2: EltVT, VT3: MVT::Other);
6009 break;
6010 case 4: {
6011 Opcode = NVPTXISD::LDUV4;
6012 EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
6013 LdResVTs = DAG.getVTList(VTs: ListVTs);
6014 break;
6015 }
6016 }
6017
6018 SmallVector<SDValue, 8> OtherOps;
6019
6020 // Copy regular operands
6021
6022 OtherOps.push_back(Elt: Chain); // Chain
6023 // Skip operand 1 (intrinsic ID)
6024 // Others
6025 OtherOps.append(in_start: N->op_begin() + 2, in_end: N->op_end());
6026
6027 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
6028
6029 SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, dl: DL, VTList: LdResVTs, Ops: OtherOps,
6030 MemVT: MemSD->getMemoryVT(),
6031 MMO: MemSD->getMemOperand());
6032
6033 SmallVector<SDValue, 4> ScalarRes;
6034
6035 for (unsigned i = 0; i < NumElts; ++i) {
6036 SDValue Res = NewLD.getValue(R: i);
6037 if (NeedTrunc)
6038 Res =
6039 DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: ResVT.getVectorElementType(), Operand: Res);
6040 ScalarRes.push_back(Elt: Res);
6041 }
6042
6043 SDValue LoadChain = NewLD.getValue(R: NumElts);
6044
6045 SDValue BuildVec =
6046 DAG.getBuildVector(VT: ResVT, DL, Ops: ScalarRes);
6047
6048 Results.push_back(Elt: BuildVec);
6049 Results.push_back(Elt: LoadChain);
6050 } else {
6051 // i8 LDG/LDU
6052 assert(ResVT.isSimple() && ResVT.getSimpleVT().SimpleTy == MVT::i8 &&
6053 "Custom handling of non-i8 ldu/ldg?");
6054
6055 // Just copy all operands as-is
6056 SmallVector<SDValue, 4> Ops(N->ops());
6057
6058 // Force output to i16
6059 SDVTList LdResVTs = DAG.getVTList(VT1: MVT::i16, VT2: MVT::Other);
6060
6061 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(Val: N);
6062
6063 // We make sure the memory type is i8, which will be used during isel
6064 // to select the proper instruction.
6065 SDValue NewLD =
6066 DAG.getMemIntrinsicNode(Opcode: ISD::INTRINSIC_W_CHAIN, dl: DL, VTList: LdResVTs, Ops,
6067 MemVT: MVT::i8, MMO: MemSD->getMemOperand());
6068
6069 Results.push_back(Elt: DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MVT::i8,
6070 Operand: NewLD.getValue(R: 0)));
6071 Results.push_back(Elt: NewLD.getValue(R: 1));
6072 }
6073 return;
6074 }
6075
6076 case Intrinsic::nvvm_tcgen05_ld_16x64b_x2:
6077 case Intrinsic::nvvm_tcgen05_ld_16x64b_x4:
6078 case Intrinsic::nvvm_tcgen05_ld_16x64b_x8:
6079 case Intrinsic::nvvm_tcgen05_ld_16x64b_x16:
6080 case Intrinsic::nvvm_tcgen05_ld_16x64b_x32:
6081 case Intrinsic::nvvm_tcgen05_ld_16x64b_x64:
6082 case Intrinsic::nvvm_tcgen05_ld_16x64b_x128:
6083 case Intrinsic::nvvm_tcgen05_ld_32x32b_x2:
6084 case Intrinsic::nvvm_tcgen05_ld_32x32b_x4:
6085 case Intrinsic::nvvm_tcgen05_ld_32x32b_x8:
6086 case Intrinsic::nvvm_tcgen05_ld_32x32b_x16:
6087 case Intrinsic::nvvm_tcgen05_ld_32x32b_x32:
6088 case Intrinsic::nvvm_tcgen05_ld_32x32b_x64:
6089 case Intrinsic::nvvm_tcgen05_ld_32x32b_x128:
6090 case Intrinsic::nvvm_tcgen05_ld_16x128b_x1:
6091 case Intrinsic::nvvm_tcgen05_ld_16x128b_x2:
6092 case Intrinsic::nvvm_tcgen05_ld_16x128b_x4:
6093 case Intrinsic::nvvm_tcgen05_ld_16x128b_x8:
6094 case Intrinsic::nvvm_tcgen05_ld_16x128b_x16:
6095 case Intrinsic::nvvm_tcgen05_ld_16x128b_x32:
6096 case Intrinsic::nvvm_tcgen05_ld_16x128b_x64:
6097 case Intrinsic::nvvm_tcgen05_ld_16x256b_x1:
6098 case Intrinsic::nvvm_tcgen05_ld_16x256b_x2:
6099 case Intrinsic::nvvm_tcgen05_ld_16x256b_x4:
6100 case Intrinsic::nvvm_tcgen05_ld_16x256b_x8:
6101 case Intrinsic::nvvm_tcgen05_ld_16x256b_x16:
6102 case Intrinsic::nvvm_tcgen05_ld_16x256b_x32:
6103 return ReplaceTcgen05Ld(N, DAG, Results);
6104
6105 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x2:
6106 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x4:
6107 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x8:
6108 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x16:
6109 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x32:
6110 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x64:
6111 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x128:
6112 return ReplaceTcgen05Ld(N, DAG, Results, /* Offset */ hasOffset: true);
6113 }
6114}
6115
6116static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG,
6117 SmallVectorImpl<SDValue> &Results) {
6118 // Change the CopyFromReg to output 2 64-bit results instead of a 128-bit
6119 // result so that it can pass the legalization
6120 SDLoc DL(N);
6121 SDValue Chain = N->getOperand(Num: 0);
6122 SDValue Reg = N->getOperand(Num: 1);
6123 SDValue Glue = N->getOperand(Num: 2);
6124
6125 assert(Reg.getValueType() == MVT::i128 &&
6126 "Custom lowering for CopyFromReg with 128-bit reg only");
6127 SmallVector<EVT, 4> ResultsType = {MVT::i64, MVT::i64, N->getValueType(ResNo: 1),
6128 N->getValueType(ResNo: 2)};
6129 SmallVector<SDValue, 3> NewOps = {Chain, Reg, Glue};
6130
6131 SDValue NewValue = DAG.getNode(Opcode: ISD::CopyFromReg, DL, ResultTys: ResultsType, Ops: NewOps);
6132 SDValue Pair = DAG.getNode(Opcode: ISD::BUILD_PAIR, DL, VT: MVT::i128,
6133 Ops: {NewValue.getValue(R: 0), NewValue.getValue(R: 1)});
6134
6135 Results.push_back(Elt: Pair);
6136 Results.push_back(Elt: NewValue.getValue(R: 2));
6137 Results.push_back(Elt: NewValue.getValue(R: 3));
6138}
6139
6140void NVPTXTargetLowering::ReplaceNodeResults(
6141 SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
6142 switch (N->getOpcode()) {
6143 default:
6144 report_fatal_error(reason: "Unhandled custom legalization");
6145 case ISD::BITCAST:
6146 ReplaceBITCAST(Node: N, DAG, Results);
6147 return;
6148 case ISD::LOAD:
6149 ReplaceLoadVector(N, DAG, Results, STI);
6150 return;
6151 case ISD::INTRINSIC_W_CHAIN:
6152 ReplaceINTRINSIC_W_CHAIN(N, DAG, Results);
6153 return;
6154 case ISD::CopyFromReg:
6155 ReplaceCopyFromReg_128(N, DAG, Results);
6156 return;
6157 }
6158}
6159
6160NVPTXTargetLowering::AtomicExpansionKind
6161NVPTXTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
6162 Type *Ty = AI->getValOperand()->getType();
6163
6164 if (AI->isFloatingPointOperation()) {
6165 if (AI->getOperation() == AtomicRMWInst::BinOp::FAdd) {
6166 if (Ty->isHalfTy() && STI.getSmVersion() >= 70 &&
6167 STI.getPTXVersion() >= 63)
6168 return AtomicExpansionKind::None;
6169 if (Ty->isBFloatTy() && STI.getSmVersion() >= 90 &&
6170 STI.getPTXVersion() >= 78)
6171 return AtomicExpansionKind::None;
6172 if (Ty->isFloatTy())
6173 return AtomicExpansionKind::None;
6174 if (Ty->isDoubleTy() && STI.hasAtomAddF64())
6175 return AtomicExpansionKind::None;
6176 }
6177 return AtomicExpansionKind::CmpXChg;
6178 }
6179
6180 assert(Ty->isIntegerTy() && "Ty should be integer at this point");
6181 auto ITy = cast<llvm::IntegerType>(Val: Ty);
6182
6183 switch (AI->getOperation()) {
6184 default:
6185 return AtomicExpansionKind::CmpXChg;
6186 case AtomicRMWInst::BinOp::And:
6187 case AtomicRMWInst::BinOp::Or:
6188 case AtomicRMWInst::BinOp::Xor:
6189 case AtomicRMWInst::BinOp::Xchg:
6190 switch (ITy->getBitWidth()) {
6191 case 8:
6192 case 16:
6193 return AtomicExpansionKind::CmpXChg;
6194 case 32:
6195 return AtomicExpansionKind::None;
6196 case 64:
6197 if (STI.hasAtomBitwise64())
6198 return AtomicExpansionKind::None;
6199 return AtomicExpansionKind::CmpXChg;
6200 default:
6201 llvm_unreachable("unsupported width encountered");
6202 }
6203 case AtomicRMWInst::BinOp::Add:
6204 case AtomicRMWInst::BinOp::Sub:
6205 case AtomicRMWInst::BinOp::Max:
6206 case AtomicRMWInst::BinOp::Min:
6207 case AtomicRMWInst::BinOp::UMax:
6208 case AtomicRMWInst::BinOp::UMin:
6209 switch (ITy->getBitWidth()) {
6210 case 8:
6211 case 16:
6212 return AtomicExpansionKind::CmpXChg;
6213 case 32:
6214 return AtomicExpansionKind::None;
6215 case 64:
6216 if (STI.hasAtomMinMax64())
6217 return AtomicExpansionKind::None;
6218 return AtomicExpansionKind::CmpXChg;
6219 default:
6220 llvm_unreachable("unsupported width encountered");
6221 }
6222 case AtomicRMWInst::BinOp::UIncWrap:
6223 case AtomicRMWInst::BinOp::UDecWrap:
6224 switch (ITy->getBitWidth()) {
6225 case 32:
6226 return AtomicExpansionKind::None;
6227 case 8:
6228 case 16:
6229 case 64:
6230 return AtomicExpansionKind::CmpXChg;
6231 default:
6232 llvm_unreachable("unsupported width encountered");
6233 }
6234 }
6235
6236 return AtomicExpansionKind::CmpXChg;
6237}
6238
6239bool NVPTXTargetLowering::shouldInsertFencesForAtomic(
6240 const Instruction *I) const {
6241 auto *CI = dyn_cast<AtomicCmpXchgInst>(Val: I);
6242 // When CAS bitwidth is not supported on the hardware, the CAS is emulated
6243 // using a retry loop that uses a higher-bitwidth monotonic CAS. We enforce
6244 // the memory order using explicit fences around the retry loop.
6245 // The memory order of natively supported CAS operations can be enforced
6246 // by lowering to an atom.cas with the right memory synchronizing effect.
6247 // However, atom.cas only supports relaxed, acquire, release and acq_rel.
6248 // So we also use explicit fences for enforcing memory order for
6249 // seq_cast CAS with natively-supported bitwidths.
6250 return CI &&
6251 (cast<IntegerType>(Val: CI->getCompareOperand()->getType())->getBitWidth() <
6252 STI.getMinCmpXchgSizeInBits() ||
6253 CI->getMergedOrdering() == AtomicOrdering::SequentiallyConsistent);
6254}
6255
6256AtomicOrdering NVPTXTargetLowering::atomicOperationOrderAfterFenceSplit(
6257 const Instruction *I) const {
6258 auto *CI = dyn_cast<AtomicCmpXchgInst>(Val: I);
6259 bool BitwidthSupportedAndIsSeqCst =
6260 CI && CI->getMergedOrdering() == AtomicOrdering::SequentiallyConsistent &&
6261 cast<IntegerType>(Val: CI->getCompareOperand()->getType())->getBitWidth() >=
6262 STI.getMinCmpXchgSizeInBits();
6263 return BitwidthSupportedAndIsSeqCst ? AtomicOrdering::Acquire
6264 : AtomicOrdering::Monotonic;
6265}
6266
6267Instruction *NVPTXTargetLowering::emitLeadingFence(IRBuilderBase &Builder,
6268 Instruction *Inst,
6269 AtomicOrdering Ord) const {
6270 if (!isa<AtomicCmpXchgInst>(Val: Inst))
6271 return TargetLoweringBase::emitLeadingFence(Builder, Inst, Ord);
6272
6273 // Specialize for cmpxchg
6274 // Emit a fence.sc leading fence for cmpxchg seq_cst which are not emulated
6275 if (isReleaseOrStronger(AO: Ord))
6276 return Ord == AtomicOrdering::SequentiallyConsistent
6277 ? Builder.CreateFence(Ordering: AtomicOrdering::SequentiallyConsistent)
6278 : Builder.CreateFence(Ordering: AtomicOrdering::Release);
6279
6280 return nullptr;
6281}
6282
6283Instruction *NVPTXTargetLowering::emitTrailingFence(IRBuilderBase &Builder,
6284 Instruction *Inst,
6285 AtomicOrdering Ord) const {
6286 // Specialize for cmpxchg
6287 if (!isa<AtomicCmpXchgInst>(Val: Inst))
6288 return TargetLoweringBase::emitTrailingFence(Builder, Inst, Ord);
6289
6290 auto CASWidth =
6291 cast<IntegerType>(
6292 Val: dyn_cast<AtomicCmpXchgInst>(Val: Inst)->getCompareOperand()->getType())
6293 ->getBitWidth();
6294 // Do not emit a trailing fence for cmpxchg seq_cst which are not emulated
6295 if (isAcquireOrStronger(AO: Ord) &&
6296 (Ord != AtomicOrdering::SequentiallyConsistent ||
6297 CASWidth < STI.getMinCmpXchgSizeInBits()))
6298 return Builder.CreateFence(Ordering: AtomicOrdering::Acquire);
6299
6300 return nullptr;
6301}
6302
6303// Rather than default to SINT when both UINT and SINT are custom, we only
6304// change the opcode when UINT is not legal and SINT is. UINT is preferred when
6305// both are custom since unsigned CVT instructions can lead to slightly better
6306// SASS code with fewer instructions.
6307unsigned NVPTXTargetLowering::getPreferredFPToIntOpcode(unsigned Op, EVT FromVT,
6308 EVT ToVT) const {
6309 if (isOperationLegal(Op, VT: ToVT))
6310 return Op;
6311 switch (Op) {
6312 case ISD::FP_TO_UINT:
6313 if (isOperationLegal(Op: ISD::FP_TO_SINT, VT: ToVT))
6314 return ISD::FP_TO_SINT;
6315 break;
6316 case ISD::STRICT_FP_TO_UINT:
6317 if (isOperationLegal(Op: ISD::STRICT_FP_TO_SINT, VT: ToVT))
6318 return ISD::STRICT_FP_TO_SINT;
6319 break;
6320 case ISD::VP_FP_TO_UINT:
6321 if (isOperationLegal(Op: ISD::VP_FP_TO_SINT, VT: ToVT))
6322 return ISD::VP_FP_TO_SINT;
6323 break;
6324 default:
6325 break;
6326 }
6327 return Op;
6328}
6329
6330// Pin NVPTXTargetObjectFile's vtables to this file.
6331NVPTXTargetObjectFile::~NVPTXTargetObjectFile() = default;
6332
6333MCSection *NVPTXTargetObjectFile::SelectSectionForGlobal(
6334 const GlobalObject *GO, SectionKind Kind, const TargetMachine &TM) const {
6335 return getDataSection();
6336}
6337