1//===-- NVPTXISelLowering.h - NVPTX DAG Lowering Interface ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the interfaces that NVPTX uses to lower LLVM code into a
10// selection DAG.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef LLVM_LIB_TARGET_NVPTX_NVPTXISELLOWERING_H
15#define LLVM_LIB_TARGET_NVPTX_NVPTXISELLOWERING_H
16
17#include "NVPTX.h"
18#include "llvm/CodeGen/SelectionDAG.h"
19#include "llvm/CodeGen/TargetLowering.h"
20#include "llvm/Support/AtomicOrdering.h"
21
22namespace llvm {
23
24class NVPTXSubtarget;
25
26//===--------------------------------------------------------------------===//
27// TargetLowering Implementation
28//===--------------------------------------------------------------------===//
29class NVPTXTargetLowering : public TargetLowering {
30public:
31 explicit NVPTXTargetLowering(const NVPTXTargetMachine &TM,
32 const NVPTXSubtarget &STI);
33 SDValue LowerOperation(SDValue Op, SelectionDAG &DAG) const override;
34
35 void getTgtMemIntrinsic(SmallVectorImpl<IntrinsicInfo> &Infos,
36 const CallBase &I, MachineFunction &MF,
37 unsigned Intrinsic) const override;
38
39 Align getFunctionArgumentAlignment(const Function *F, Type *Ty, unsigned Idx,
40 const DataLayout &DL) const;
41
42 /// getFunctionParamOptimizedAlign - since function arguments are passed via
43 /// .param space, we may want to increase their alignment in a way that
44 /// ensures that we can effectively vectorize their loads & stores. We can
45 /// increase alignment only if the function has internal or has private
46 /// linkage as for other linkage types callers may already rely on default
47 /// alignment. To allow using 128-bit vectorized loads/stores, this function
48 /// ensures that alignment is 16 or greater.
49 Align getFunctionParamOptimizedAlign(const Function *F, Type *ArgTy,
50 const DataLayout &DL) const;
51
52 /// Helper for computing alignment of a device function byval parameter.
53 Align getFunctionByValParamAlign(const Function *F, Type *ArgTy,
54 Align InitialAlign,
55 const DataLayout &DL) const;
56
57 // Helper for getting a function parameter name. Name is composed from
58 // its index and the function name. Negative index corresponds to special
59 // parameter (unsized array) used for passing variable arguments.
60 std::string getParamName(const Function *F, int Idx) const;
61
62 /// isLegalAddressingMode - Return true if the addressing mode represented
63 /// by AM is legal for this target, for a load/store of the specified type
64 /// Used to guide target specific optimizations, like loop strength
65 /// reduction (LoopStrengthReduce.cpp) and memory optimization for
66 /// address mode (CodeGenPrepare.cpp)
67 bool isLegalAddressingMode(const DataLayout &DL, const AddrMode &AM, Type *Ty,
68 unsigned AS,
69 Instruction *I = nullptr) const override;
70
71 bool isTruncateFree(Type *SrcTy, Type *DstTy) const override {
72 // Truncating 64-bit to 32-bit is free in SASS.
73 if (!SrcTy->isIntegerTy() || !DstTy->isIntegerTy())
74 return false;
75 return SrcTy->getPrimitiveSizeInBits() == 64 &&
76 DstTy->getPrimitiveSizeInBits() == 32;
77 }
78
79 EVT getSetCCResultType(const DataLayout &DL, LLVMContext &Ctx,
80 EVT VT) const override {
81 if (VT.isVector())
82 return EVT::getVectorVT(Context&: Ctx, VT: MVT::i1, NumElements: VT.getVectorNumElements());
83 return MVT::i1;
84 }
85
86 ConstraintType getConstraintType(StringRef Constraint) const override;
87 std::pair<unsigned, const TargetRegisterClass *>
88 getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
89 StringRef Constraint, MVT VT) const override;
90
91 SDValue LowerFormalArguments(SDValue Chain, CallingConv::ID CallConv,
92 bool isVarArg,
93 const SmallVectorImpl<ISD::InputArg> &Ins,
94 const SDLoc &dl, SelectionDAG &DAG,
95 SmallVectorImpl<SDValue> &InVals) const override;
96
97 SDValue LowerCall(CallLoweringInfo &CLI,
98 SmallVectorImpl<SDValue> &InVals) const override;
99
100 SDValue LowerDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
101 SDValue LowerSTACKSAVE(SDValue Op, SelectionDAG &DAG) const;
102 SDValue LowerSTACKRESTORE(SDValue Op, SelectionDAG &DAG) const;
103
104 std::string getPrototype(const DataLayout &DL, Type *, const ArgListTy &,
105 const SmallVectorImpl<ISD::OutputArg> &,
106 std::optional<unsigned> FirstVAArg,
107 const CallBase &CB, unsigned UniqueCallSite) const;
108
109 SDValue LowerReturn(SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
110 const SmallVectorImpl<ISD::OutputArg> &Outs,
111 const SmallVectorImpl<SDValue> &OutVals, const SDLoc &dl,
112 SelectionDAG &DAG) const override;
113
114 void LowerAsmOperandForConstraint(SDValue Op, StringRef Constraint,
115 std::vector<SDValue> &Ops,
116 SelectionDAG &DAG) const override;
117
118 const NVPTXTargetMachine *nvTM;
119
120 // PTX always uses 32-bit shift amounts
121 MVT getScalarShiftAmountTy(const DataLayout &, EVT) const override {
122 return MVT::i32;
123 }
124
125 TargetLoweringBase::LegalizeTypeAction
126 getPreferredVectorAction(MVT VT) const override;
127
128 // Get the degree of precision we want from 32-bit floating point division
129 // operations.
130 NVPTX::DivPrecisionLevel getDivF32Level(const MachineFunction &MF,
131 const SDNode &N) const;
132
133 // Get whether we should use a precise or approximate 32-bit floating point
134 // sqrt instruction.
135 bool usePrecSqrtF32(const SDNode *N = nullptr) const;
136
137 // Get whether we should use instructions that flush floating-point denormals
138 // to sign-preserving zero.
139 bool useF32FTZ(const MachineFunction &MF) const;
140
141 SDValue getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
142 int &ExtraSteps, bool &UseOneConst,
143 bool Reciprocal) const override;
144
145 unsigned combineRepeatedFPDivisors() const override { return 2; }
146
147 bool allowFMA(MachineFunction &MF, CodeGenOptLevel OptLevel) const;
148
149 bool isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
150 EVT) const override {
151 return true;
152 }
153
154 // The default is the same as pointer type, but brx.idx only accepts i32
155 MVT getJumpTableRegTy(const DataLayout &) const override { return MVT::i32; }
156
157 unsigned getJumpTableEncoding() const override;
158
159 bool enableAggressiveFMAFusion(EVT VT) const override { return true; }
160
161 // The default is to transform llvm.ctlz(x, false) (where false indicates that
162 // x == 0 is not undefined behavior) into a branch that checks whether x is 0
163 // and avoids calling ctlz in that case. We have a dedicated ctlz
164 // instruction, so we say that ctlz is cheap to speculate.
165 bool isCheapToSpeculateCtlz(Type *Ty) const override { return true; }
166
167 AtomicExpansionKind shouldCastAtomicLoadInIR(LoadInst *LI) const override {
168 return AtomicExpansionKind::None;
169 }
170
171 AtomicExpansionKind shouldCastAtomicStoreInIR(StoreInst *SI) const override {
172 return AtomicExpansionKind::None;
173 }
174
175 AtomicExpansionKind
176 shouldExpandAtomicRMWInIR(const AtomicRMWInst *AI) const override;
177
178 bool aggressivelyPreferBuildVectorSources(EVT VecVT) const override {
179 // There's rarely any point of packing something into a vector type if we
180 // already have the source data.
181 return true;
182 }
183
184 bool shouldInsertFencesForAtomic(const Instruction *) const override;
185
186 AtomicOrdering
187 atomicOperationOrderAfterFenceSplit(const Instruction *I) const override;
188
189 Instruction *emitLeadingFence(IRBuilderBase &Builder, Instruction *Inst,
190 AtomicOrdering Ord) const override;
191 Instruction *emitTrailingFence(IRBuilderBase &Builder, Instruction *Inst,
192 AtomicOrdering Ord) const override;
193
194 unsigned getPreferredFPToIntOpcode(unsigned Op, EVT FromVT,
195 EVT ToVT) const override;
196
197 void computeKnownBitsForTargetNode(const SDValue Op, KnownBits &Known,
198 const APInt &DemandedElts,
199 const SelectionDAG &DAG,
200 unsigned Depth = 0) const override;
201 bool SimplifyDemandedBitsForTargetNode(SDValue Op, const APInt &DemandedBits,
202 const APInt &DemandedElts,
203 KnownBits &Known,
204 TargetLoweringOpt &TLO,
205 unsigned Depth = 0) const override;
206
207private:
208 const NVPTXSubtarget &STI; // cache the subtarget here
209 mutable unsigned GlobalUniqueCallSite;
210
211 SDValue getParamSymbol(SelectionDAG &DAG, int I, EVT T) const;
212 SDValue getCallParamSymbol(SelectionDAG &DAG, int I, EVT T) const;
213 SDValue LowerADDRSPACECAST(SDValue Op, SelectionDAG &DAG) const;
214 SDValue LowerBITCAST(SDValue Op, SelectionDAG &DAG) const;
215
216 SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
217 SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
218 SDValue LowerVECREDUCE(SDValue Op, SelectionDAG &DAG) const;
219 SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
220 SDValue LowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
221 SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;
222
223 SDValue LowerFCOPYSIGN(SDValue Op, SelectionDAG &DAG) const;
224
225 SDValue LowerFROUND(SDValue Op, SelectionDAG &DAG) const;
226 SDValue LowerFROUND32(SDValue Op, SelectionDAG &DAG) const;
227 SDValue LowerFROUND64(SDValue Op, SelectionDAG &DAG) const;
228
229 SDValue PromoteBinOpIfF32FTZ(SDValue Op, SelectionDAG &DAG) const;
230
231 SDValue LowerINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
232 SDValue LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const;
233
234 SDValue LowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const;
235 SDValue LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const;
236
237 SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const;
238 SDValue LowerMLOAD(SDValue Op, SelectionDAG &DAG) const;
239 SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const;
240 SDValue LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const;
241
242 SDValue LowerShiftRightParts(SDValue Op, SelectionDAG &DAG) const;
243 SDValue LowerShiftLeftParts(SDValue Op, SelectionDAG &DAG) const;
244
245 SDValue LowerVAARG(SDValue Op, SelectionDAG &DAG) const;
246 SDValue LowerVASTART(SDValue Op, SelectionDAG &DAG) const;
247
248 SDValue LowerCopyToReg_128(SDValue Op, SelectionDAG &DAG) const;
249 unsigned getNumRegisters(LLVMContext &Context, EVT VT,
250 std::optional<MVT> RegisterVT) const override;
251 bool
252 splitValueIntoRegisterParts(SelectionDAG &DAG, const SDLoc &DL, SDValue Val,
253 SDValue *Parts, unsigned NumParts, MVT PartVT,
254 std::optional<CallingConv::ID> CC) const override;
255
256 void ReplaceNodeResults(SDNode *N, SmallVectorImpl<SDValue> &Results,
257 SelectionDAG &DAG) const override;
258 SDValue PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const override;
259
260 Align getArgumentAlignment(const CallBase *CB, Type *Ty, unsigned Idx,
261 const DataLayout &DL) const;
262};
263
264} // namespace llvm
265
266#endif
267