1//===-- NVPTXISelLowering.h - NVPTX DAG Lowering Interface ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the interfaces that NVPTX uses to lower LLVM code into a
10// selection DAG.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef LLVM_LIB_TARGET_NVPTX_NVPTXISELLOWERING_H
15#define LLVM_LIB_TARGET_NVPTX_NVPTXISELLOWERING_H
16
17#include "NVPTX.h"
18#include "llvm/CodeGen/SelectionDAG.h"
19#include "llvm/CodeGen/TargetLowering.h"
20#include "llvm/Support/AtomicOrdering.h"
21
22namespace llvm {
23namespace NVPTXISD {
24enum NodeType : unsigned {
25 // Start the numbering from where ISD NodeType finishes.
26 FIRST_NUMBER = ISD::BUILTIN_OP_END,
27 RET_GLUE,
28
29 /// These nodes represent a parameter declaration. In PTX this will look like:
30 /// .param .align 16 .b8 param0[1024];
31 /// .param .b32 retval0;
32 ///
33 /// DeclareArrayParam(Chain, Externalsym, Align, Size, Glue)
34 /// DeclareScalarParam(Chain, Externalsym, Size, Glue)
35 DeclareScalarParam,
36 DeclareArrayParam,
37
38 /// This node represents a PTX call instruction. It's operands are as follows:
39 ///
40 /// CALL(Chain, IsConvergent, IsIndirectCall/IsUniform, NumReturns,
41 /// NumParams, Callee, Proto, InGlue)
42 CALL,
43
44 MoveParam,
45 CallPrototype,
46 ProxyReg,
47 FSHL_CLAMP,
48 FSHR_CLAMP,
49 MUL_WIDE_SIGNED,
50 MUL_WIDE_UNSIGNED,
51 SETP_F16X2,
52 SETP_BF16X2,
53 BFE,
54 BFI,
55 PRMT,
56
57 /// This node is similar to ISD::BUILD_VECTOR except that the output may be
58 /// implicitly bitcast to a scalar. This allows for the representation of
59 /// packing move instructions for vector types which are not legal i.e. v2i32
60 BUILD_VECTOR,
61
62 /// This node is the inverse of NVPTX::BUILD_VECTOR. It takes a single value
63 /// which may be a scalar and unpacks it into multiple values by implicitly
64 /// converting it to a vector.
65 UNPACK_VECTOR,
66
67 FCOPYSIGN,
68 DYNAMIC_STACKALLOC,
69 STACKRESTORE,
70 STACKSAVE,
71 BrxStart,
72 BrxItem,
73 BrxEnd,
74 CLUSTERLAUNCHCONTROL_QUERY_CANCEL_IS_CANCELED,
75 CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_X,
76 CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Y,
77 CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Z,
78
79 FIRST_MEMORY_OPCODE,
80 LoadV2 = FIRST_MEMORY_OPCODE,
81 LoadV4,
82 LoadV8,
83 LDUV2, // LDU.v2
84 LDUV4, // LDU.v4
85 StoreV2,
86 StoreV4,
87 StoreV8,
88 LoadParam,
89 LoadParamV2,
90 LoadParamV4,
91 StoreParam,
92 StoreParamV2,
93 StoreParamV4,
94 LAST_MEMORY_OPCODE = StoreParamV4,
95};
96}
97
98class NVPTXSubtarget;
99
100//===--------------------------------------------------------------------===//
101// TargetLowering Implementation
102//===--------------------------------------------------------------------===//
103class NVPTXTargetLowering : public TargetLowering {
104public:
105 explicit NVPTXTargetLowering(const NVPTXTargetMachine &TM,
106 const NVPTXSubtarget &STI);
107 SDValue LowerOperation(SDValue Op, SelectionDAG &DAG) const override;
108
109 const char *getTargetNodeName(unsigned Opcode) const override;
110
111 bool getTgtMemIntrinsic(IntrinsicInfo &Info, const CallInst &I,
112 MachineFunction &MF,
113 unsigned Intrinsic) const override;
114
115 Align getFunctionArgumentAlignment(const Function *F, Type *Ty, unsigned Idx,
116 const DataLayout &DL) const;
117
118 /// getFunctionParamOptimizedAlign - since function arguments are passed via
119 /// .param space, we may want to increase their alignment in a way that
120 /// ensures that we can effectively vectorize their loads & stores. We can
121 /// increase alignment only if the function has internal or has private
122 /// linkage as for other linkage types callers may already rely on default
123 /// alignment. To allow using 128-bit vectorized loads/stores, this function
124 /// ensures that alignment is 16 or greater.
125 Align getFunctionParamOptimizedAlign(const Function *F, Type *ArgTy,
126 const DataLayout &DL) const;
127
128 /// Helper for computing alignment of a device function byval parameter.
129 Align getFunctionByValParamAlign(const Function *F, Type *ArgTy,
130 Align InitialAlign,
131 const DataLayout &DL) const;
132
133 // Helper for getting a function parameter name. Name is composed from
134 // its index and the function name. Negative index corresponds to special
135 // parameter (unsized array) used for passing variable arguments.
136 std::string getParamName(const Function *F, int Idx) const;
137
138 /// isLegalAddressingMode - Return true if the addressing mode represented
139 /// by AM is legal for this target, for a load/store of the specified type
140 /// Used to guide target specific optimizations, like loop strength
141 /// reduction (LoopStrengthReduce.cpp) and memory optimization for
142 /// address mode (CodeGenPrepare.cpp)
143 bool isLegalAddressingMode(const DataLayout &DL, const AddrMode &AM, Type *Ty,
144 unsigned AS,
145 Instruction *I = nullptr) const override;
146
147 bool isTruncateFree(Type *SrcTy, Type *DstTy) const override {
148 // Truncating 64-bit to 32-bit is free in SASS.
149 if (!SrcTy->isIntegerTy() || !DstTy->isIntegerTy())
150 return false;
151 return SrcTy->getPrimitiveSizeInBits() == 64 &&
152 DstTy->getPrimitiveSizeInBits() == 32;
153 }
154
155 EVT getSetCCResultType(const DataLayout &DL, LLVMContext &Ctx,
156 EVT VT) const override {
157 if (VT.isVector())
158 return EVT::getVectorVT(Context&: Ctx, VT: MVT::i1, NumElements: VT.getVectorNumElements());
159 return MVT::i1;
160 }
161
162 ConstraintType getConstraintType(StringRef Constraint) const override;
163 std::pair<unsigned, const TargetRegisterClass *>
164 getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
165 StringRef Constraint, MVT VT) const override;
166
167 SDValue LowerFormalArguments(SDValue Chain, CallingConv::ID CallConv,
168 bool isVarArg,
169 const SmallVectorImpl<ISD::InputArg> &Ins,
170 const SDLoc &dl, SelectionDAG &DAG,
171 SmallVectorImpl<SDValue> &InVals) const override;
172
173 SDValue LowerCall(CallLoweringInfo &CLI,
174 SmallVectorImpl<SDValue> &InVals) const override;
175
176 SDValue LowerDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
177 SDValue LowerSTACKSAVE(SDValue Op, SelectionDAG &DAG) const;
178 SDValue LowerSTACKRESTORE(SDValue Op, SelectionDAG &DAG) const;
179
180 std::string getPrototype(const DataLayout &DL, Type *, const ArgListTy &,
181 const SmallVectorImpl<ISD::OutputArg> &,
182 std::optional<unsigned> FirstVAArg,
183 const CallBase &CB, unsigned UniqueCallSite) const;
184
185 SDValue LowerReturn(SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
186 const SmallVectorImpl<ISD::OutputArg> &Outs,
187 const SmallVectorImpl<SDValue> &OutVals, const SDLoc &dl,
188 SelectionDAG &DAG) const override;
189
190 void LowerAsmOperandForConstraint(SDValue Op, StringRef Constraint,
191 std::vector<SDValue> &Ops,
192 SelectionDAG &DAG) const override;
193
194 const NVPTXTargetMachine *nvTM;
195
196 // PTX always uses 32-bit shift amounts
197 MVT getScalarShiftAmountTy(const DataLayout &, EVT) const override {
198 return MVT::i32;
199 }
200
201 TargetLoweringBase::LegalizeTypeAction
202 getPreferredVectorAction(MVT VT) const override;
203
204 // Get the degree of precision we want from 32-bit floating point division
205 // operations.
206 NVPTX::DivPrecisionLevel getDivF32Level(const MachineFunction &MF,
207 const SDNode &N) const;
208
209 // Get whether we should use a precise or approximate 32-bit floating point
210 // sqrt instruction.
211 bool usePrecSqrtF32(const MachineFunction &MF,
212 const SDNode *N = nullptr) const;
213
214 // Get whether we should use instructions that flush floating-point denormals
215 // to sign-preserving zero.
216 bool useF32FTZ(const MachineFunction &MF) const;
217
218 SDValue getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
219 int &ExtraSteps, bool &UseOneConst,
220 bool Reciprocal) const override;
221
222 unsigned combineRepeatedFPDivisors() const override { return 2; }
223
224 bool allowFMA(MachineFunction &MF, CodeGenOptLevel OptLevel) const;
225 bool allowUnsafeFPMath(const MachineFunction &MF) const;
226
227 bool isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
228 EVT) const override {
229 return true;
230 }
231
232 // The default is the same as pointer type, but brx.idx only accepts i32
233 MVT getJumpTableRegTy(const DataLayout &) const override { return MVT::i32; }
234
235 unsigned getJumpTableEncoding() const override;
236
237 bool enableAggressiveFMAFusion(EVT VT) const override { return true; }
238
239 // The default is to transform llvm.ctlz(x, false) (where false indicates that
240 // x == 0 is not undefined behavior) into a branch that checks whether x is 0
241 // and avoids calling ctlz in that case. We have a dedicated ctlz
242 // instruction, so we say that ctlz is cheap to speculate.
243 bool isCheapToSpeculateCtlz(Type *Ty) const override { return true; }
244
245 AtomicExpansionKind shouldCastAtomicLoadInIR(LoadInst *LI) const override {
246 return AtomicExpansionKind::None;
247 }
248
249 AtomicExpansionKind shouldCastAtomicStoreInIR(StoreInst *SI) const override {
250 return AtomicExpansionKind::None;
251 }
252
253 AtomicExpansionKind
254 shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const override;
255
256 bool aggressivelyPreferBuildVectorSources(EVT VecVT) const override {
257 // There's rarely any point of packing something into a vector type if we
258 // already have the source data.
259 return true;
260 }
261
262 bool shouldInsertFencesForAtomic(const Instruction *) const override;
263
264 AtomicOrdering
265 atomicOperationOrderAfterFenceSplit(const Instruction *I) const override;
266
267 Instruction *emitLeadingFence(IRBuilderBase &Builder, Instruction *Inst,
268 AtomicOrdering Ord) const override;
269 Instruction *emitTrailingFence(IRBuilderBase &Builder, Instruction *Inst,
270 AtomicOrdering Ord) const override;
271
272 unsigned getPreferredFPToIntOpcode(unsigned Op, EVT FromVT,
273 EVT ToVT) const override;
274
275private:
276 const NVPTXSubtarget &STI; // cache the subtarget here
277 mutable unsigned GlobalUniqueCallSite;
278
279 SDValue getParamSymbol(SelectionDAG &DAG, int I, EVT T) const;
280 SDValue getCallParamSymbol(SelectionDAG &DAG, int I, EVT T) const;
281 SDValue LowerADDRSPACECAST(SDValue Op, SelectionDAG &DAG) const;
282 SDValue LowerBITCAST(SDValue Op, SelectionDAG &DAG) const;
283
284 SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
285 SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
286 SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
287 SDValue LowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
288 SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;
289
290 SDValue LowerFCOPYSIGN(SDValue Op, SelectionDAG &DAG) const;
291
292 SDValue LowerFROUND(SDValue Op, SelectionDAG &DAG) const;
293 SDValue LowerFROUND32(SDValue Op, SelectionDAG &DAG) const;
294 SDValue LowerFROUND64(SDValue Op, SelectionDAG &DAG) const;
295
296 SDValue PromoteBinOpIfF32FTZ(SDValue Op, SelectionDAG &DAG) const;
297
298 SDValue LowerINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
299 SDValue LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const;
300
301 SDValue LowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const;
302 SDValue LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const;
303
304 SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const;
305 SDValue LowerLOADi1(SDValue Op, SelectionDAG &DAG) const;
306
307 SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const;
308 SDValue LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const;
309 SDValue LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const;
310
311 SDValue LowerShiftRightParts(SDValue Op, SelectionDAG &DAG) const;
312 SDValue LowerShiftLeftParts(SDValue Op, SelectionDAG &DAG) const;
313
314 SDValue LowerBR_JT(SDValue Op, SelectionDAG &DAG) const;
315
316 SDValue LowerVAARG(SDValue Op, SelectionDAG &DAG) const;
317 SDValue LowerVASTART(SDValue Op, SelectionDAG &DAG) const;
318
319 SDValue LowerCopyToReg_128(SDValue Op, SelectionDAG &DAG) const;
320 unsigned getNumRegisters(LLVMContext &Context, EVT VT,
321 std::optional<MVT> RegisterVT) const override;
322 bool
323 splitValueIntoRegisterParts(SelectionDAG &DAG, const SDLoc &DL, SDValue Val,
324 SDValue *Parts, unsigned NumParts, MVT PartVT,
325 std::optional<CallingConv::ID> CC) const override;
326
327 void ReplaceNodeResults(SDNode *N, SmallVectorImpl<SDValue> &Results,
328 SelectionDAG &DAG) const override;
329 SDValue PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const override;
330
331 Align getArgumentAlignment(const CallBase *CB, Type *Ty, unsigned Idx,
332 const DataLayout &DL) const;
333};
334
335} // namespace llvm
336
337#endif
338