1 | //===-- NVPTXISelLowering.h - NVPTX DAG Lowering Interface ------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the interfaces that NVPTX uses to lower LLVM code into a |
10 | // selection DAG. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #ifndef LLVM_LIB_TARGET_NVPTX_NVPTXISELLOWERING_H |
15 | #define LLVM_LIB_TARGET_NVPTX_NVPTXISELLOWERING_H |
16 | |
17 | #include "NVPTX.h" |
18 | #include "llvm/CodeGen/SelectionDAG.h" |
19 | #include "llvm/CodeGen/TargetLowering.h" |
20 | #include "llvm/Support/AtomicOrdering.h" |
21 | |
22 | namespace llvm { |
23 | namespace NVPTXISD { |
24 | enum NodeType : unsigned { |
25 | // Start the numbering from where ISD NodeType finishes. |
26 | FIRST_NUMBER = ISD::BUILTIN_OP_END, |
27 | RET_GLUE, |
28 | |
29 | /// These nodes represent a parameter declaration. In PTX this will look like: |
30 | /// .param .align 16 .b8 param0[1024]; |
31 | /// .param .b32 retval0; |
32 | /// |
33 | /// DeclareArrayParam(Chain, Externalsym, Align, Size, Glue) |
34 | /// DeclareScalarParam(Chain, Externalsym, Size, Glue) |
35 | DeclareScalarParam, |
36 | DeclareArrayParam, |
37 | |
38 | /// This node represents a PTX call instruction. It's operands are as follows: |
39 | /// |
40 | /// CALL(Chain, IsConvergent, IsIndirectCall/IsUniform, NumReturns, |
41 | /// NumParams, Callee, Proto, InGlue) |
42 | CALL, |
43 | |
44 | MoveParam, |
45 | CallPrototype, |
46 | ProxyReg, |
47 | FSHL_CLAMP, |
48 | FSHR_CLAMP, |
49 | MUL_WIDE_SIGNED, |
50 | MUL_WIDE_UNSIGNED, |
51 | SETP_F16X2, |
52 | SETP_BF16X2, |
53 | BFE, |
54 | BFI, |
55 | PRMT, |
56 | |
57 | /// This node is similar to ISD::BUILD_VECTOR except that the output may be |
58 | /// implicitly bitcast to a scalar. This allows for the representation of |
59 | /// packing move instructions for vector types which are not legal i.e. v2i32 |
60 | BUILD_VECTOR, |
61 | |
62 | /// This node is the inverse of NVPTX::BUILD_VECTOR. It takes a single value |
63 | /// which may be a scalar and unpacks it into multiple values by implicitly |
64 | /// converting it to a vector. |
65 | UNPACK_VECTOR, |
66 | |
67 | FCOPYSIGN, |
68 | DYNAMIC_STACKALLOC, |
69 | STACKRESTORE, |
70 | STACKSAVE, |
71 | BrxStart, |
72 | BrxItem, |
73 | BrxEnd, |
74 | CLUSTERLAUNCHCONTROL_QUERY_CANCEL_IS_CANCELED, |
75 | CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_X, |
76 | CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Y, |
77 | CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Z, |
78 | |
79 | FIRST_MEMORY_OPCODE, |
80 | LoadV2 = FIRST_MEMORY_OPCODE, |
81 | LoadV4, |
82 | LoadV8, |
83 | LDUV2, // LDU.v2 |
84 | LDUV4, // LDU.v4 |
85 | StoreV2, |
86 | StoreV4, |
87 | StoreV8, |
88 | LoadParam, |
89 | LoadParamV2, |
90 | LoadParamV4, |
91 | StoreParam, |
92 | StoreParamV2, |
93 | StoreParamV4, |
94 | LAST_MEMORY_OPCODE = StoreParamV4, |
95 | }; |
96 | } |
97 | |
98 | class NVPTXSubtarget; |
99 | |
100 | //===--------------------------------------------------------------------===// |
101 | // TargetLowering Implementation |
102 | //===--------------------------------------------------------------------===// |
103 | class NVPTXTargetLowering : public TargetLowering { |
104 | public: |
105 | explicit NVPTXTargetLowering(const NVPTXTargetMachine &TM, |
106 | const NVPTXSubtarget &STI); |
107 | SDValue LowerOperation(SDValue Op, SelectionDAG &DAG) const override; |
108 | |
109 | const char *getTargetNodeName(unsigned Opcode) const override; |
110 | |
111 | bool getTgtMemIntrinsic(IntrinsicInfo &Info, const CallInst &I, |
112 | MachineFunction &MF, |
113 | unsigned Intrinsic) const override; |
114 | |
115 | Align getFunctionArgumentAlignment(const Function *F, Type *Ty, unsigned Idx, |
116 | const DataLayout &DL) const; |
117 | |
118 | /// getFunctionParamOptimizedAlign - since function arguments are passed via |
119 | /// .param space, we may want to increase their alignment in a way that |
120 | /// ensures that we can effectively vectorize their loads & stores. We can |
121 | /// increase alignment only if the function has internal or has private |
122 | /// linkage as for other linkage types callers may already rely on default |
123 | /// alignment. To allow using 128-bit vectorized loads/stores, this function |
124 | /// ensures that alignment is 16 or greater. |
125 | Align getFunctionParamOptimizedAlign(const Function *F, Type *ArgTy, |
126 | const DataLayout &DL) const; |
127 | |
128 | /// Helper for computing alignment of a device function byval parameter. |
129 | Align getFunctionByValParamAlign(const Function *F, Type *ArgTy, |
130 | Align InitialAlign, |
131 | const DataLayout &DL) const; |
132 | |
133 | // Helper for getting a function parameter name. Name is composed from |
134 | // its index and the function name. Negative index corresponds to special |
135 | // parameter (unsized array) used for passing variable arguments. |
136 | std::string getParamName(const Function *F, int Idx) const; |
137 | |
138 | /// isLegalAddressingMode - Return true if the addressing mode represented |
139 | /// by AM is legal for this target, for a load/store of the specified type |
140 | /// Used to guide target specific optimizations, like loop strength |
141 | /// reduction (LoopStrengthReduce.cpp) and memory optimization for |
142 | /// address mode (CodeGenPrepare.cpp) |
143 | bool isLegalAddressingMode(const DataLayout &DL, const AddrMode &AM, Type *Ty, |
144 | unsigned AS, |
145 | Instruction *I = nullptr) const override; |
146 | |
147 | bool isTruncateFree(Type *SrcTy, Type *DstTy) const override { |
148 | // Truncating 64-bit to 32-bit is free in SASS. |
149 | if (!SrcTy->isIntegerTy() || !DstTy->isIntegerTy()) |
150 | return false; |
151 | return SrcTy->getPrimitiveSizeInBits() == 64 && |
152 | DstTy->getPrimitiveSizeInBits() == 32; |
153 | } |
154 | |
155 | EVT getSetCCResultType(const DataLayout &DL, LLVMContext &Ctx, |
156 | EVT VT) const override { |
157 | if (VT.isVector()) |
158 | return EVT::getVectorVT(Context&: Ctx, VT: MVT::i1, NumElements: VT.getVectorNumElements()); |
159 | return MVT::i1; |
160 | } |
161 | |
162 | ConstraintType getConstraintType(StringRef Constraint) const override; |
163 | std::pair<unsigned, const TargetRegisterClass *> |
164 | getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, |
165 | StringRef Constraint, MVT VT) const override; |
166 | |
167 | SDValue LowerFormalArguments(SDValue Chain, CallingConv::ID CallConv, |
168 | bool isVarArg, |
169 | const SmallVectorImpl<ISD::InputArg> &Ins, |
170 | const SDLoc &dl, SelectionDAG &DAG, |
171 | SmallVectorImpl<SDValue> &InVals) const override; |
172 | |
173 | SDValue LowerCall(CallLoweringInfo &CLI, |
174 | SmallVectorImpl<SDValue> &InVals) const override; |
175 | |
176 | SDValue LowerDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const; |
177 | SDValue LowerSTACKSAVE(SDValue Op, SelectionDAG &DAG) const; |
178 | SDValue LowerSTACKRESTORE(SDValue Op, SelectionDAG &DAG) const; |
179 | |
180 | std::string getPrototype(const DataLayout &DL, Type *, const ArgListTy &, |
181 | const SmallVectorImpl<ISD::OutputArg> &, |
182 | std::optional<unsigned> FirstVAArg, |
183 | const CallBase &CB, unsigned UniqueCallSite) const; |
184 | |
185 | SDValue LowerReturn(SDValue Chain, CallingConv::ID CallConv, bool isVarArg, |
186 | const SmallVectorImpl<ISD::OutputArg> &Outs, |
187 | const SmallVectorImpl<SDValue> &OutVals, const SDLoc &dl, |
188 | SelectionDAG &DAG) const override; |
189 | |
190 | void LowerAsmOperandForConstraint(SDValue Op, StringRef Constraint, |
191 | std::vector<SDValue> &Ops, |
192 | SelectionDAG &DAG) const override; |
193 | |
194 | const NVPTXTargetMachine *nvTM; |
195 | |
196 | // PTX always uses 32-bit shift amounts |
197 | MVT getScalarShiftAmountTy(const DataLayout &, EVT) const override { |
198 | return MVT::i32; |
199 | } |
200 | |
201 | TargetLoweringBase::LegalizeTypeAction |
202 | getPreferredVectorAction(MVT VT) const override; |
203 | |
204 | // Get the degree of precision we want from 32-bit floating point division |
205 | // operations. |
206 | NVPTX::DivPrecisionLevel getDivF32Level(const MachineFunction &MF, |
207 | const SDNode &N) const; |
208 | |
209 | // Get whether we should use a precise or approximate 32-bit floating point |
210 | // sqrt instruction. |
211 | bool usePrecSqrtF32(const MachineFunction &MF, |
212 | const SDNode *N = nullptr) const; |
213 | |
214 | // Get whether we should use instructions that flush floating-point denormals |
215 | // to sign-preserving zero. |
216 | bool useF32FTZ(const MachineFunction &MF) const; |
217 | |
218 | SDValue getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled, |
219 | int &, bool &UseOneConst, |
220 | bool Reciprocal) const override; |
221 | |
222 | unsigned combineRepeatedFPDivisors() const override { return 2; } |
223 | |
224 | bool allowFMA(MachineFunction &MF, CodeGenOptLevel OptLevel) const; |
225 | bool allowUnsafeFPMath(const MachineFunction &MF) const; |
226 | |
227 | bool isFMAFasterThanFMulAndFAdd(const MachineFunction &MF, |
228 | EVT) const override { |
229 | return true; |
230 | } |
231 | |
232 | // The default is the same as pointer type, but brx.idx only accepts i32 |
233 | MVT getJumpTableRegTy(const DataLayout &) const override { return MVT::i32; } |
234 | |
235 | unsigned getJumpTableEncoding() const override; |
236 | |
237 | bool enableAggressiveFMAFusion(EVT VT) const override { return true; } |
238 | |
239 | // The default is to transform llvm.ctlz(x, false) (where false indicates that |
240 | // x == 0 is not undefined behavior) into a branch that checks whether x is 0 |
241 | // and avoids calling ctlz in that case. We have a dedicated ctlz |
242 | // instruction, so we say that ctlz is cheap to speculate. |
243 | bool isCheapToSpeculateCtlz(Type *Ty) const override { return true; } |
244 | |
245 | AtomicExpansionKind shouldCastAtomicLoadInIR(LoadInst *LI) const override { |
246 | return AtomicExpansionKind::None; |
247 | } |
248 | |
249 | AtomicExpansionKind shouldCastAtomicStoreInIR(StoreInst *SI) const override { |
250 | return AtomicExpansionKind::None; |
251 | } |
252 | |
253 | AtomicExpansionKind |
254 | shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const override; |
255 | |
256 | bool aggressivelyPreferBuildVectorSources(EVT VecVT) const override { |
257 | // There's rarely any point of packing something into a vector type if we |
258 | // already have the source data. |
259 | return true; |
260 | } |
261 | |
262 | bool shouldInsertFencesForAtomic(const Instruction *) const override; |
263 | |
264 | AtomicOrdering |
265 | atomicOperationOrderAfterFenceSplit(const Instruction *I) const override; |
266 | |
267 | Instruction *emitLeadingFence(IRBuilderBase &Builder, Instruction *Inst, |
268 | AtomicOrdering Ord) const override; |
269 | Instruction *emitTrailingFence(IRBuilderBase &Builder, Instruction *Inst, |
270 | AtomicOrdering Ord) const override; |
271 | |
272 | unsigned getPreferredFPToIntOpcode(unsigned Op, EVT FromVT, |
273 | EVT ToVT) const override; |
274 | |
275 | private: |
276 | const NVPTXSubtarget &STI; // cache the subtarget here |
277 | mutable unsigned GlobalUniqueCallSite; |
278 | |
279 | SDValue getParamSymbol(SelectionDAG &DAG, int I, EVT T) const; |
280 | SDValue getCallParamSymbol(SelectionDAG &DAG, int I, EVT T) const; |
281 | SDValue LowerADDRSPACECAST(SDValue Op, SelectionDAG &DAG) const; |
282 | SDValue LowerBITCAST(SDValue Op, SelectionDAG &DAG) const; |
283 | |
284 | SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const; |
285 | SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const; |
286 | SDValue (SDValue Op, SelectionDAG &DAG) const; |
287 | SDValue LowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const; |
288 | SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const; |
289 | |
290 | SDValue LowerFCOPYSIGN(SDValue Op, SelectionDAG &DAG) const; |
291 | |
292 | SDValue LowerFROUND(SDValue Op, SelectionDAG &DAG) const; |
293 | SDValue LowerFROUND32(SDValue Op, SelectionDAG &DAG) const; |
294 | SDValue LowerFROUND64(SDValue Op, SelectionDAG &DAG) const; |
295 | |
296 | SDValue PromoteBinOpIfF32FTZ(SDValue Op, SelectionDAG &DAG) const; |
297 | |
298 | SDValue LowerINT_TO_FP(SDValue Op, SelectionDAG &DAG) const; |
299 | SDValue LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const; |
300 | |
301 | SDValue LowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const; |
302 | SDValue LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const; |
303 | |
304 | SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const; |
305 | SDValue LowerLOADi1(SDValue Op, SelectionDAG &DAG) const; |
306 | |
307 | SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const; |
308 | SDValue LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const; |
309 | SDValue LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const; |
310 | |
311 | SDValue LowerShiftRightParts(SDValue Op, SelectionDAG &DAG) const; |
312 | SDValue LowerShiftLeftParts(SDValue Op, SelectionDAG &DAG) const; |
313 | |
314 | SDValue LowerBR_JT(SDValue Op, SelectionDAG &DAG) const; |
315 | |
316 | SDValue LowerVAARG(SDValue Op, SelectionDAG &DAG) const; |
317 | SDValue LowerVASTART(SDValue Op, SelectionDAG &DAG) const; |
318 | |
319 | SDValue LowerCopyToReg_128(SDValue Op, SelectionDAG &DAG) const; |
320 | unsigned getNumRegisters(LLVMContext &Context, EVT VT, |
321 | std::optional<MVT> RegisterVT) const override; |
322 | bool |
323 | splitValueIntoRegisterParts(SelectionDAG &DAG, const SDLoc &DL, SDValue Val, |
324 | SDValue *Parts, unsigned NumParts, MVT PartVT, |
325 | std::optional<CallingConv::ID> CC) const override; |
326 | |
327 | void ReplaceNodeResults(SDNode *N, SmallVectorImpl<SDValue> &Results, |
328 | SelectionDAG &DAG) const override; |
329 | SDValue PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const override; |
330 | |
331 | Align getArgumentAlignment(const CallBase *CB, Type *Ty, unsigned Idx, |
332 | const DataLayout &DL) const; |
333 | }; |
334 | |
335 | } // namespace llvm |
336 | |
337 | #endif |
338 | |