1//===- RISCVTargetTransformInfo.h - RISC-V specific TTI ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8/// \file
9/// This file defines a TargetTransformInfoImplBase conforming object specific
10/// to the RISC-V target machine. It uses the target's detailed information to
11/// provide more precise answers to certain TTI queries, while letting the
12/// target independent and default TTI implementations handle the rest.
13///
14//===----------------------------------------------------------------------===//
15
16#ifndef LLVM_LIB_TARGET_RISCV_RISCVTARGETTRANSFORMINFO_H
17#define LLVM_LIB_TARGET_RISCV_RISCVTARGETTRANSFORMINFO_H
18
19#include "RISCVSubtarget.h"
20#include "RISCVTargetMachine.h"
21#include "llvm/Analysis/TargetTransformInfo.h"
22#include "llvm/CodeGen/BasicTTIImpl.h"
23#include "llvm/IR/Function.h"
24#include <optional>
25
26namespace llvm {
27
28class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
29 using BaseT = BasicTTIImplBase<RISCVTTIImpl>;
30 using TTI = TargetTransformInfo;
31
32 friend BaseT;
33
34 const RISCVSubtarget *ST;
35 const RISCVTargetLowering *TLI;
36
37 const RISCVSubtarget *getST() const { return ST; }
38 const RISCVTargetLowering *getTLI() const { return TLI; }
39
40 /// This function returns an estimate for VL to be used in VL based terms
41 /// of the cost model. For fixed length vectors, this is simply the
42 /// vector length. For scalable vectors, we return results consistent
43 /// with getVScaleForTuning under the assumption that clients are also
44 /// using that when comparing costs between scalar and vector representation.
45 /// This does unfortunately mean that we can both undershoot and overshot
46 /// the true cost significantly if getVScaleForTuning is wildly off for the
47 /// actual target hardware.
48 unsigned getEstimatedVLFor(VectorType *Ty) const;
49
50 /// This function calculates the costs for one or more RVV opcodes based
51 /// on the vtype and the cost kind.
52 /// \param Opcodes A list of opcodes of the RVV instruction to evaluate.
53 /// \param VT The MVT of vtype associated with the RVV instructions.
54 /// For widening/narrowing instructions where the result and source types
55 /// differ, it is important to check the spec to determine whether the vtype
56 /// refers to the result or source type.
57 /// \param CostKind The type of cost to compute.
58 InstructionCost getRISCVInstructionCost(ArrayRef<unsigned> OpCodes, MVT VT,
59 TTI::TargetCostKind CostKind) const;
60
61 // Return the cost of generating a PC relative address
62 InstructionCost
63 getStaticDataAddrGenerationCost(const TTI::TargetCostKind CostKind) const;
64
65 /// Return the cost of accessing a constant pool entry of the specified
66 /// type.
67 InstructionCost getConstantPoolLoadCost(Type *Ty,
68 TTI::TargetCostKind CostKind) const;
69
70 /// If this shuffle can be lowered as a masked slide pair (at worst),
71 /// return a cost for it.
72 InstructionCost getSlideCost(FixedVectorType *Tp, ArrayRef<int> Mask,
73 TTI::TargetCostKind CostKind) const;
74
75public:
76 explicit RISCVTTIImpl(const RISCVTargetMachine *TM, const Function &F)
77 : BaseT(TM, F.getDataLayout()), ST(TM->getSubtargetImpl(F)),
78 TLI(ST->getTargetLowering()) {}
79
80 /// Return the cost of materializing an immediate for a value operand of
81 /// a store instruction.
82 InstructionCost getStoreImmCost(Type *VecTy, TTI::OperandValueInfo OpInfo,
83 TTI::TargetCostKind CostKind) const;
84
85 InstructionCost getIntImmCost(const APInt &Imm, Type *Ty,
86 TTI::TargetCostKind CostKind) const override;
87 InstructionCost getIntImmCostInst(unsigned Opcode, unsigned Idx,
88 const APInt &Imm, Type *Ty,
89 TTI::TargetCostKind CostKind,
90 Instruction *Inst = nullptr) const override;
91 InstructionCost
92 getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx, const APInt &Imm,
93 Type *Ty, TTI::TargetCostKind CostKind) const override;
94
95 /// \name EVL Support for predicated vectorization.
96 /// Whether the target supports the %evl parameter of VP intrinsic efficiently
97 /// in hardware. (see LLVM Language Reference - "Vector Predication
98 /// Intrinsics",
99 /// https://llvm.org/docs/LangRef.html#vector-predication-intrinsics and
100 /// "IR-level VP intrinsics",
101 /// https://llvm.org/docs/Proposals/VectorPredication.html#ir-level-vp-intrinsics).
102 bool hasActiveVectorLength() const override;
103
104 TargetTransformInfo::PopcntSupportKind
105 getPopcntSupport(unsigned TyWidth) const override;
106
107 InstructionCost getPartialReductionCost(
108 unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
109 ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
110 TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
111 TTI::TargetCostKind CostKind,
112 std::optional<FastMathFlags> FMF) const override;
113
114 bool shouldExpandReduction(const IntrinsicInst *II) const override;
115 bool supportsScalableVectors() const override {
116 // VLEN=32 support is incomplete.
117 return ST->hasVInstructions() &&
118 (ST->getRealMinVLen() >= RISCV::RVVBitsPerBlock);
119 }
120 bool enableOrderedReductions() const override { return true; }
121 bool enableScalableVectorization() const override {
122 return ST->hasVInstructions();
123 }
124 bool preferPredicateOverEpilogue(TailFoldingInfo *TFI) const override {
125 return ST->hasVInstructions();
126 }
127 TailFoldingStyle getPreferredTailFoldingStyle() const override {
128 return ST->hasVInstructions() ? TailFoldingStyle::DataWithEVL
129 : TailFoldingStyle::None;
130 }
131 std::optional<unsigned> getMaxVScale() const override;
132 std::optional<unsigned> getVScaleForTuning() const override;
133
134 TypeSize
135 getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const override;
136
137 unsigned getRegUsageForType(Type *Ty) const override;
138
139 unsigned getMaximumVF(unsigned ElemWidth, unsigned Opcode) const override;
140
141 bool preferAlternateOpcodeVectorization() const override;
142
143 bool preferEpilogueVectorization(ElementCount Iters) const override {
144 // Epilogue vectorization is usually unprofitable - tail folding or
145 // a smaller VF would have been better. This a blunt hammer - we
146 // should re-examine this once vectorization is better tuned.
147 return false;
148 }
149
150 bool shouldConsiderVectorizationRegPressure() const override { return true; }
151
152 InstructionCost
153 getMemIntrinsicInstrCost(const MemIntrinsicCostAttributes &MICA,
154 TTI::TargetCostKind CostKind) const override;
155
156 InstructionCost getMaskedMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
157 TTI::TargetCostKind CostKind) const;
158
159 InstructionCost
160 getPointersChainCost(ArrayRef<const Value *> Ptrs, const Value *Base,
161 const TTI::PointersChainInfo &Info, Type *AccessTy,
162 TTI::TargetCostKind CostKind) const override;
163
164 void getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
165 TTI::UnrollingPreferences &UP,
166 OptimizationRemarkEmitter *ORE) const override;
167
168 void getPeelingPreferences(Loop *L, ScalarEvolution &SE,
169 TTI::PeelingPreferences &PP) const override;
170
171 bool getTgtMemIntrinsic(IntrinsicInst *Inst,
172 MemIntrinsicInfo &Info) const override;
173
174 unsigned getMinVectorRegisterBitWidth() const override {
175 return ST->useRVVForFixedLengthVectors() ? 16 : 0;
176 }
177
178 InstructionCost
179 getShuffleCost(TTI::ShuffleKind Kind, VectorType *DstTy, VectorType *SrcTy,
180 ArrayRef<int> Mask, TTI::TargetCostKind CostKind, int Index,
181 VectorType *SubTp, ArrayRef<const Value *> Args = {},
182 const Instruction *CxtI = nullptr) const override;
183
184 InstructionCost
185 getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
186 bool Insert, bool Extract,
187 TTI::TargetCostKind CostKind,
188 bool ForPoisonSrc = true, ArrayRef<Value *> VL = {},
189 TTI::VectorInstrContext VIC =
190 TTI::VectorInstrContext::None) const override;
191
192 InstructionCost
193 getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
194 TTI::TargetCostKind CostKind) const override;
195
196 InstructionCost
197 getAddressComputationCost(Type *PTy, ScalarEvolution *SE, const SCEV *Ptr,
198 TTI::TargetCostKind CostKind) const override;
199
200 InstructionCost getInterleavedMemoryOpCost(
201 unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
202 Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
203 bool UseMaskForCond = false, bool UseMaskForGaps = false) const override;
204
205 InstructionCost getGatherScatterOpCost(const MemIntrinsicCostAttributes &MICA,
206 TTI::TargetCostKind CostKind) const;
207
208 InstructionCost
209 getExpandCompressMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
210 TTI::TargetCostKind CostKind) const;
211
212 InstructionCost getStridedMemoryOpCost(const MemIntrinsicCostAttributes &MICA,
213 TTI::TargetCostKind CostKind) const;
214
215 InstructionCost
216 getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) const override;
217
218 InstructionCost
219 getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
220 TTI::CastContextHint CCH, TTI::TargetCostKind CostKind,
221 const Instruction *I = nullptr) const override;
222
223 InstructionCost
224 getMinMaxReductionCost(Intrinsic::ID IID, VectorType *Ty, FastMathFlags FMF,
225 TTI::TargetCostKind CostKind) const override;
226
227 InstructionCost
228 getArithmeticReductionCost(unsigned Opcode, VectorType *Ty,
229 std::optional<FastMathFlags> FMF,
230 TTI::TargetCostKind CostKind) const override;
231
232 InstructionCost
233 getExtendedReductionCost(unsigned Opcode, bool IsUnsigned, Type *ResTy,
234 VectorType *ValTy, std::optional<FastMathFlags> FMF,
235 TTI::TargetCostKind CostKind) const override;
236
237 InstructionCost getMemoryOpCost(
238 unsigned Opcode, Type *Src, Align Alignment, unsigned AddressSpace,
239 TTI::TargetCostKind CostKind,
240 TTI::OperandValueInfo OpdInfo = {.Kind: TTI::OK_AnyValue, .Properties: TTI::OP_None},
241 const Instruction *I = nullptr) const override;
242
243 InstructionCost getCmpSelInstrCost(
244 unsigned Opcode, Type *ValTy, Type *CondTy, CmpInst::Predicate VecPred,
245 TTI::TargetCostKind CostKind,
246 TTI::OperandValueInfo Op1Info = {.Kind: TTI::OK_AnyValue, .Properties: TTI::OP_None},
247 TTI::OperandValueInfo Op2Info = {.Kind: TTI::OK_AnyValue, .Properties: TTI::OP_None},
248 const Instruction *I = nullptr) const override;
249
250 InstructionCost getCFInstrCost(unsigned Opcode, TTI::TargetCostKind CostKind,
251 const Instruction *I = nullptr) const override;
252
253 using BaseT::getVectorInstrCost;
254 InstructionCost
255 getVectorInstrCost(unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind,
256 unsigned Index, const Value *Op0, const Value *Op1,
257 TTI::VectorInstrContext VIC =
258 TTI::VectorInstrContext::None) const override;
259
260 InstructionCost
261 getIndexedVectorInstrCostFromEnd(unsigned Opcode, Type *Val,
262 TTI::TargetCostKind CostKind,
263 unsigned Index) const override;
264
265 InstructionCost getArithmeticInstrCost(
266 unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
267 TTI::OperandValueInfo Op1Info = {.Kind: TTI::OK_AnyValue, .Properties: TTI::OP_None},
268 TTI::OperandValueInfo Op2Info = {.Kind: TTI::OK_AnyValue, .Properties: TTI::OP_None},
269 ArrayRef<const Value *> Args = {},
270 const Instruction *CxtI = nullptr) const override;
271
272 bool isElementTypeLegalForScalableVector(Type *Ty) const override {
273 return TLI->isLegalElementTypeForRVV(ScalarTy: TLI->getValueType(DL, Ty));
274 }
275
276 bool isLegalMaskedLoadStore(Type *DataType, Align Alignment) const {
277 if (!ST->hasVInstructions())
278 return false;
279
280 EVT DataTypeVT = TLI->getValueType(DL, Ty: DataType);
281
282 // Only support fixed vectors if we know the minimum vector size.
283 if (DataTypeVT.isFixedLengthVector() && !ST->useRVVForFixedLengthVectors())
284 return false;
285
286 EVT ElemType = DataTypeVT.getScalarType();
287 if (!ST->enableUnalignedVectorMem() && Alignment < ElemType.getStoreSize())
288 return false;
289
290 return TLI->isLegalElementTypeForRVV(ScalarTy: ElemType);
291 }
292
293 bool isLegalMaskedLoad(Type *DataType, Align Alignment,
294 unsigned /*AddressSpace*/,
295 TTI::MaskKind /*MaskKind*/) const override {
296 return isLegalMaskedLoadStore(DataType, Alignment);
297 }
298 bool isLegalMaskedStore(Type *DataType, Align Alignment,
299 unsigned /*AddressSpace*/,
300 TTI::MaskKind /*MaskKind*/) const override {
301 return isLegalMaskedLoadStore(DataType, Alignment);
302 }
303
304 bool isLegalMaskedGatherScatter(Type *DataType, Align Alignment) const {
305 if (!ST->hasVInstructions())
306 return false;
307
308 EVT DataTypeVT = TLI->getValueType(DL, Ty: DataType);
309
310 // Only support fixed vectors if we know the minimum vector size.
311 if (DataTypeVT.isFixedLengthVector() && !ST->useRVVForFixedLengthVectors())
312 return false;
313
314 // We also need to check if the vector of address is valid.
315 EVT PointerTypeVT = EVT(TLI->getPointerTy(DL));
316 if (DataTypeVT.isScalableVector() &&
317 !TLI->isLegalElementTypeForRVV(ScalarTy: PointerTypeVT))
318 return false;
319
320 EVT ElemType = DataTypeVT.getScalarType();
321 if (!ST->enableUnalignedVectorMem() && Alignment < ElemType.getStoreSize())
322 return false;
323
324 return TLI->isLegalElementTypeForRVV(ScalarTy: ElemType);
325 }
326
327 bool isLegalMaskedGather(Type *DataType, Align Alignment) const override {
328 return isLegalMaskedGatherScatter(DataType, Alignment);
329 }
330 bool isLegalMaskedScatter(Type *DataType, Align Alignment) const override {
331 return isLegalMaskedGatherScatter(DataType, Alignment);
332 }
333
334 bool forceScalarizeMaskedGather(VectorType *VTy,
335 Align Alignment) const override {
336 // Scalarize masked gather for RV64 if EEW=64 indices aren't supported.
337 return ST->is64Bit() && !ST->hasVInstructionsI64();
338 }
339
340 bool forceScalarizeMaskedScatter(VectorType *VTy,
341 Align Alignment) const override {
342 // Scalarize masked scatter for RV64 if EEW=64 indices aren't supported.
343 return ST->is64Bit() && !ST->hasVInstructionsI64();
344 }
345
346 bool isLegalStridedLoadStore(Type *DataType, Align Alignment) const override {
347 EVT DataTypeVT = TLI->getValueType(DL, Ty: DataType);
348 return TLI->isLegalStridedLoadStore(DataType: DataTypeVT, Alignment);
349 }
350
351 bool isLegalInterleavedAccessType(VectorType *VTy, unsigned Factor,
352 Align Alignment,
353 unsigned AddrSpace) const override {
354 return TLI->isLegalInterleavedAccessType(VTy, Factor, Alignment, AddrSpace,
355 DL);
356 }
357
358 bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment) const override;
359
360 bool isLegalMaskedCompressStore(Type *DataTy, Align Alignment) const override;
361
362 /// \returns How the target needs this vector-predicated operation to be
363 /// transformed.
364 TargetTransformInfo::VPLegalization
365 getVPLegalizationStrategy(const VPIntrinsic &PI) const override {
366 using VPLegalization = TargetTransformInfo::VPLegalization;
367 if (!ST->hasVInstructions() ||
368 (PI.getIntrinsicID() == Intrinsic::vp_reduce_mul &&
369 cast<VectorType>(Val: PI.getArgOperand(i: 1)->getType())
370 ->getElementType()
371 ->getIntegerBitWidth() != 1))
372 return VPLegalization(VPLegalization::Discard, VPLegalization::Convert);
373 return VPLegalization(VPLegalization::Legal, VPLegalization::Legal);
374 }
375
376 bool isLegalToVectorizeReduction(const RecurrenceDescriptor &RdxDesc,
377 ElementCount VF) const override {
378 if (!VF.isScalable())
379 return true;
380
381 Type *Ty = RdxDesc.getRecurrenceType();
382 if (!TLI->isLegalElementTypeForRVV(ScalarTy: TLI->getValueType(DL, Ty)))
383 return false;
384
385 switch (RdxDesc.getRecurrenceKind()) {
386 case RecurKind::Add:
387 case RecurKind::Sub:
388 case RecurKind::AddChainWithSubs:
389 case RecurKind::And:
390 case RecurKind::Or:
391 case RecurKind::Xor:
392 case RecurKind::SMin:
393 case RecurKind::SMax:
394 case RecurKind::UMin:
395 case RecurKind::UMax:
396 case RecurKind::FMin:
397 case RecurKind::FMax:
398 return true;
399 case RecurKind::AnyOf:
400 case RecurKind::FAdd:
401 case RecurKind::FMulAdd:
402 // We can't promote f16/bf16 fadd reductions and scalable vectors can't be
403 // expanded.
404 if (Ty->isBFloatTy() || (Ty->isHalfTy() && !ST->hasVInstructionsF16()))
405 return false;
406 return true;
407 default:
408 return false;
409 }
410 }
411
412 unsigned getMaxInterleaveFactor(ElementCount VF) const override {
413 // Don't interleave if the loop has been vectorized with scalable vectors.
414 if (VF.isScalable())
415 return 1;
416 // If the loop will not be vectorized, don't interleave the loop.
417 // Let regular unroll to unroll the loop.
418 return VF.isScalar() ? 1 : ST->getMaxInterleaveFactor();
419 }
420
421 bool enableInterleavedAccessVectorization() const override { return true; }
422
423 bool enableMaskedInterleavedAccessVectorization() const override {
424 return ST->hasVInstructions();
425 }
426
427 unsigned getMinTripCountTailFoldingThreshold() const override;
428
429 enum RISCVRegisterClass { GPRRC, FPRRC, VRRC };
430 unsigned getNumberOfRegisters(unsigned ClassID) const override {
431 switch (ClassID) {
432 case RISCVRegisterClass::GPRRC:
433 // 31 = 32 GPR - x0 (zero register)
434 // FIXME: Should we exclude fixed registers like SP, TP or GP?
435 return 31;
436 case RISCVRegisterClass::FPRRC:
437 if (ST->hasStdExtF())
438 return 32;
439 return 0;
440 case RISCVRegisterClass::VRRC:
441 // Although there are 32 vector registers, v0 is special in that it is the
442 // only register that can be used to hold a mask.
443 // FIXME: Should we conservatively return 31 as the number of usable
444 // vector registers?
445 return ST->hasVInstructions() ? 32 : 0;
446 }
447 llvm_unreachable("unknown register class");
448 }
449
450 TTI::AddressingModeKind
451 getPreferredAddressingMode(const Loop *L, ScalarEvolution *SE) const override;
452
453 unsigned getRegisterClassForType(bool Vector,
454 Type *Ty = nullptr) const override {
455 if (Vector)
456 return RISCVRegisterClass::VRRC;
457 if (!Ty)
458 return RISCVRegisterClass::GPRRC;
459
460 Type *ScalarTy = Ty->getScalarType();
461 if ((ScalarTy->isHalfTy() && ST->hasStdExtZfhmin()) ||
462 (ScalarTy->isFloatTy() && ST->hasStdExtF()) ||
463 (ScalarTy->isDoubleTy() && ST->hasStdExtD())) {
464 return RISCVRegisterClass::FPRRC;
465 }
466
467 return RISCVRegisterClass::GPRRC;
468 }
469
470 const char *getRegisterClassName(unsigned ClassID) const override {
471 switch (ClassID) {
472 case RISCVRegisterClass::GPRRC:
473 return "RISCV::GPRRC";
474 case RISCVRegisterClass::FPRRC:
475 return "RISCV::FPRRC";
476 case RISCVRegisterClass::VRRC:
477 return "RISCV::VRRC";
478 }
479 llvm_unreachable("unknown register class");
480 }
481
482 bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
483 const TargetTransformInfo::LSRCost &C2) const override;
484
485 bool shouldConsiderAddressTypePromotion(
486 const Instruction &I,
487 bool &AllowPromotionWithoutCommonHeader) const override;
488 std::optional<unsigned> getMinPageSize() const override { return 4096; }
489 /// Return true if the (vector) instruction I will be lowered to an
490 /// instruction with a scalar splat operand for the given Operand number.
491 bool canSplatOperand(Instruction *I, int Operand) const;
492 /// Return true if a vector instruction will lower to a target instruction
493 /// able to splat the given operand.
494 bool canSplatOperand(unsigned Opcode, int Operand) const;
495
496 bool isProfitableToSinkOperands(Instruction *I,
497 SmallVectorImpl<Use *> &Ops) const override;
498
499 TTI::MemCmpExpansionOptions
500 enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const override;
501
502 bool enableSelectOptimize() const override {
503 return ST->enableSelectOptimize();
504 }
505
506 bool shouldTreatInstructionLikeSelect(const Instruction *I) const override;
507
508 bool
509 shouldCopyAttributeWhenOutliningFrom(const Function *Caller,
510 const Attribute &Attr) const override;
511
512 std::optional<Instruction *>
513 instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const override;
514};
515
516} // end namespace llvm
517
518#endif // LLVM_LIB_TARGET_RISCV_RISCVTARGETTRANSFORMINFO_H
519