1//===-- NVPTXTargetTransformInfo.h - NVPTX specific TTI ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8/// \file
9/// This file a TargetTransformInfoImplBase conforming object specific to the
10/// NVPTX target machine. It uses the target's detailed information to
11/// provide more precise answers to certain TTI queries, while letting the
12/// target independent and default TTI implementations handle the rest.
13///
14//===----------------------------------------------------------------------===//
15
16#ifndef LLVM_LIB_TARGET_NVPTX_NVPTXTARGETTRANSFORMINFO_H
17#define LLVM_LIB_TARGET_NVPTX_NVPTXTARGETTRANSFORMINFO_H
18
19#include "MCTargetDesc/NVPTXBaseInfo.h"
20#include "NVPTXTargetMachine.h"
21#include "NVPTXUtilities.h"
22#include "llvm/Analysis/TargetTransformInfo.h"
23#include "llvm/CodeGen/BasicTTIImpl.h"
24#include "llvm/CodeGen/TargetLowering.h"
25#include <optional>
26
27namespace llvm {
28
29class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
30 typedef BasicTTIImplBase<NVPTXTTIImpl> BaseT;
31 typedef TargetTransformInfo TTI;
32 friend BaseT;
33
34 const NVPTXSubtarget *ST;
35 const NVPTXTargetLowering *TLI;
36
37 const NVPTXSubtarget *getST() const { return ST; };
38 const NVPTXTargetLowering *getTLI() const { return TLI; };
39
40 /// \returns true if the result of the value could potentially be
41 /// different across threads in a warp.
42 bool isSourceOfDivergence(const Value *V) const;
43
44public:
45 explicit NVPTXTTIImpl(const NVPTXTargetMachine *TM, const Function &F)
46 : BaseT(TM, F.getDataLayout()), ST(TM->getSubtargetImpl()),
47 TLI(ST->getTargetLowering()) {}
48
49 bool hasBranchDivergence(const Function *F = nullptr) const override {
50 return true;
51 }
52
53 unsigned getFlatAddressSpace() const override {
54 return AddressSpace::ADDRESS_SPACE_GENERIC;
55 }
56
57 bool
58 canHaveNonUndefGlobalInitializerInAddressSpace(unsigned AS) const override {
59 return AS != AddressSpace::ADDRESS_SPACE_SHARED &&
60 AS != AddressSpace::ADDRESS_SPACE_LOCAL && AS != ADDRESS_SPACE_PARAM;
61 }
62
63 std::optional<Instruction *>
64 instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const override;
65
66 // Loads and stores can be vectorized if the alignment is at least as big as
67 // the load/store we want to vectorize.
68 bool isLegalToVectorizeLoadChain(unsigned ChainSizeInBytes, Align Alignment,
69 unsigned AddrSpace) const override {
70 return Alignment >= ChainSizeInBytes;
71 }
72 bool isLegalToVectorizeStoreChain(unsigned ChainSizeInBytes, Align Alignment,
73 unsigned AddrSpace) const override {
74 return isLegalToVectorizeLoadChain(ChainSizeInBytes, Alignment, AddrSpace);
75 }
76
77 // NVPTX has infinite registers of all kinds, but the actual machine doesn't.
78 // We conservatively return 1 here which is just enough to enable the
79 // vectorizers but disables heuristics based on the number of registers.
80 // FIXME: Return a more reasonable number, while keeping an eye on
81 // LoopVectorizer's unrolling heuristics.
82 unsigned getNumberOfRegisters(unsigned ClassID) const override { return 1; }
83
84 // Only <2 x half> should be vectorized, so always return 32 for the vector
85 // register size.
86 TypeSize
87 getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const override {
88 return TypeSize::getFixed(ExactSize: 32);
89 }
90 unsigned getMinVectorRegisterBitWidth() const override { return 32; }
91
92 bool shouldExpandReduction(const IntrinsicInst *II) const override {
93 // Turn off ExpandReductions pass for NVPTX, which doesn't have advanced
94 // swizzling operations. Our backend/Selection DAG can expand these
95 // reductions with less movs.
96 return false;
97 }
98
99 // We don't want to prevent inlining because of target-cpu and -features
100 // attributes that were added to newer versions of LLVM/Clang: There are
101 // no incompatible functions in PTX, ptxas will throw errors in such cases.
102 bool areInlineCompatible(const Function *Caller,
103 const Function *Callee) const override {
104 return true;
105 }
106
107 // Increase the inlining cost threshold by a factor of 11, reflecting that
108 // calls are particularly expensive in NVPTX.
109 unsigned getInliningThresholdMultiplier() const override { return 11; }
110
111 InstructionCost
112 getInstructionCost(const User *U, ArrayRef<const Value *> Operands,
113 TTI::TargetCostKind CostKind) const override;
114
115 InstructionCost getArithmeticInstrCost(
116 unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
117 TTI::OperandValueInfo Op1Info = {.Kind: TTI::OK_AnyValue, .Properties: TTI::OP_None},
118 TTI::OperandValueInfo Op2Info = {.Kind: TTI::OK_AnyValue, .Properties: TTI::OP_None},
119 ArrayRef<const Value *> Args = {},
120 const Instruction *CxtI = nullptr) const override;
121
122 InstructionCost
123 getScalarizationOverhead(VectorType *InTy, const APInt &DemandedElts,
124 bool Insert, bool Extract,
125 TTI::TargetCostKind CostKind,
126 bool ForPoisonSrc = true, ArrayRef<Value *> VL = {},
127 TTI::VectorInstrContext VIC =
128 TTI::VectorInstrContext::None) const override {
129 if (!InTy->getElementCount().isFixed())
130 return InstructionCost::getInvalid();
131
132 auto VT = getTLI()->getValueType(DL, Ty: InTy);
133 auto NumElements = InTy->getElementCount().getFixedValue();
134 InstructionCost Cost = 0;
135 if (Insert && !VL.empty()) {
136 bool AllConstant = all_of(Range: seq(Size: NumElements), P: [&](int Idx) {
137 return !DemandedElts[Idx] || isa<Constant>(Val: VL[Idx]);
138 });
139 if (AllConstant) {
140 Cost += TTI::TCC_Free;
141 Insert = false;
142 }
143 }
144 if (Insert && NVPTX::isPackedVectorTy(VT) && VT.is32BitVector()) {
145 // Can be built in a single 32-bit mov (64-bit regs are emulated in SASS
146 // with 2x 32-bit regs)
147 Cost += 1;
148 Insert = false;
149 }
150 if (Insert && VT == MVT::v4i8) {
151 InstructionCost Cost = 3; // 3 x PRMT
152 for (auto Idx : seq(Size: NumElements))
153 if (DemandedElts[Idx])
154 Cost += 1; // zext operand to i32
155 Insert = false;
156 }
157 return Cost + BaseT::getScalarizationOverhead(InTy, DemandedElts, Insert,
158 Extract, CostKind,
159 ForPoisonSrc, VL);
160 }
161
162 void getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
163 TTI::UnrollingPreferences &UP,
164 OptimizationRemarkEmitter *ORE) const override;
165
166 void getPeelingPreferences(Loop *L, ScalarEvolution &SE,
167 TTI::PeelingPreferences &PP) const override;
168
169 bool hasVolatileVariant(Instruction *I, unsigned AddrSpace) const override {
170 // Volatile loads/stores are only supported for shared and global address
171 // spaces, or for generic AS that maps to them.
172 if (!(AddrSpace == llvm::ADDRESS_SPACE_GENERIC ||
173 AddrSpace == llvm::ADDRESS_SPACE_GLOBAL ||
174 AddrSpace == llvm::ADDRESS_SPACE_SHARED))
175 return false;
176
177 switch(I->getOpcode()){
178 default:
179 return false;
180 case Instruction::Load:
181 case Instruction::Store:
182 return true;
183 }
184 }
185
186 bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
187 Intrinsic::ID IID) const override;
188
189 bool isLegalMaskedStore(Type *DataType, Align Alignment, unsigned AddrSpace,
190 TTI::MaskKind MaskKind) const override;
191
192 bool isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned AddrSpace,
193 TTI::MaskKind MaskKind) const override;
194
195 unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const override;
196
197 Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV,
198 Value *NewV) const override;
199 unsigned getAssumedAddrSpace(const Value *V) const override;
200
201 void collectKernelLaunchBounds(
202 const Function &F,
203 SmallVectorImpl<std::pair<StringRef, int64_t>> &LB) const override;
204
205 bool shouldBuildRelLookupTables() const override {
206 // Self-referential globals are not supported.
207 return false;
208 }
209
210 InstructionUniformity getInstructionUniformity(const Value *V) const override;
211};
212
213} // end namespace llvm
214
215#endif
216