1//===-- WebAssemblyTargetTransformInfo.cpp - WebAssembly-specific TTI -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file defines the WebAssembly-specific TargetTransformInfo
11/// implementation.
12///
13//===----------------------------------------------------------------------===//
14
15#include "WebAssemblyTargetTransformInfo.h"
16
17#include "llvm/CodeGen/CostTable.h"
18using namespace llvm;
19
20#define DEBUG_TYPE "wasmtti"
21
22TargetTransformInfo::PopcntSupportKind
23WebAssemblyTTIImpl::getPopcntSupport(unsigned TyWidth) const {
24 assert(isPowerOf2_32(TyWidth) && "Ty width must be power of 2");
25 return TargetTransformInfo::PSK_FastHardware;
26}
27
28unsigned WebAssemblyTTIImpl::getNumberOfRegisters(unsigned ClassID) const {
29 unsigned Result = BaseT::getNumberOfRegisters(ClassID);
30
31 // For SIMD, use at least 16 registers, as a rough guess.
32 bool Vector = (ClassID == 1);
33 if (Vector)
34 Result = std::max(a: Result, b: 16u);
35
36 return Result;
37}
38
39TypeSize WebAssemblyTTIImpl::getRegisterBitWidth(
40 TargetTransformInfo::RegisterKind K) const {
41 switch (K) {
42 case TargetTransformInfo::RGK_Scalar:
43 return TypeSize::getFixed(ExactSize: 64);
44 case TargetTransformInfo::RGK_FixedWidthVector:
45 return TypeSize::getFixed(ExactSize: getST()->hasSIMD128() ? 128 : 64);
46 case TargetTransformInfo::RGK_ScalableVector:
47 return TypeSize::getScalable(MinimumSize: 0);
48 }
49
50 llvm_unreachable("Unsupported register kind");
51}
52
53InstructionCost WebAssemblyTTIImpl::getArithmeticInstrCost(
54 unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
55 TTI::OperandValueInfo Op1Info, TTI::OperandValueInfo Op2Info,
56 ArrayRef<const Value *> Args, const Instruction *CxtI) const {
57
58 InstructionCost Cost =
59 BasicTTIImplBase<WebAssemblyTTIImpl>::getArithmeticInstrCost(
60 Opcode, Ty, CostKind, Opd1Info: Op1Info, Opd2Info: Op2Info);
61
62 if (auto *VTy = dyn_cast<VectorType>(Val: Ty)) {
63 switch (Opcode) {
64 case Instruction::LShr:
65 case Instruction::AShr:
66 case Instruction::Shl:
67 // SIMD128's shifts currently only accept a scalar shift count. For each
68 // element, we'll need to extract, op, insert. The following is a rough
69 // approximation.
70 if (!Op2Info.isUniform())
71 Cost =
72 cast<FixedVectorType>(Val: VTy)->getNumElements() *
73 (TargetTransformInfo::TCC_Basic +
74 getArithmeticInstrCost(Opcode, Ty: VTy->getElementType(), CostKind) +
75 TargetTransformInfo::TCC_Basic);
76 break;
77 }
78 }
79 return Cost;
80}
81
82InstructionCost WebAssemblyTTIImpl::getCastInstrCost(
83 unsigned Opcode, Type *Dst, Type *Src, TTI::CastContextHint CCH,
84 TTI::TargetCostKind CostKind, const Instruction *I) const {
85 int ISD = TLI->InstructionOpcodeToISD(Opcode);
86 auto SrcTy = TLI->getValueType(DL, Ty: Src);
87 auto DstTy = TLI->getValueType(DL, Ty: Dst);
88
89 if (!SrcTy.isSimple() || !DstTy.isSimple()) {
90 return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
91 }
92
93 if (!ST->hasSIMD128()) {
94 return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
95 }
96
97 auto DstVT = DstTy.getSimpleVT();
98 auto SrcVT = SrcTy.getSimpleVT();
99
100 if (I && I->hasOneUser()) {
101 auto *SingleUser = cast<Instruction>(Val: *I->user_begin());
102 int UserISD = TLI->InstructionOpcodeToISD(Opcode: SingleUser->getOpcode());
103
104 // extmul_low support
105 if (UserISD == ISD::MUL &&
106 (ISD == ISD::ZERO_EXTEND || ISD == ISD::SIGN_EXTEND)) {
107 // Free low extensions.
108 if ((SrcVT == MVT::v8i8 && DstVT == MVT::v8i16) ||
109 (SrcVT == MVT::v4i16 && DstVT == MVT::v4i32) ||
110 (SrcVT == MVT::v2i32 && DstVT == MVT::v2i64)) {
111 return 0;
112 }
113 // Will require an additional extlow operation for the intermediate
114 // i16/i32 value.
115 if ((SrcVT == MVT::v4i8 && DstVT == MVT::v4i32) ||
116 (SrcVT == MVT::v2i16 && DstVT == MVT::v2i64)) {
117 return 1;
118 }
119 }
120 }
121
122 // extend_low
123 static constexpr TypeConversionCostTblEntry ConversionTbl[] = {
124 {.ISD: ISD::SIGN_EXTEND, .Dst: MVT::v2i64, .Src: MVT::v2i32, .Cost: 1},
125 {.ISD: ISD::ZERO_EXTEND, .Dst: MVT::v2i64, .Src: MVT::v2i32, .Cost: 1},
126 {.ISD: ISD::SIGN_EXTEND, .Dst: MVT::v4i32, .Src: MVT::v4i16, .Cost: 1},
127 {.ISD: ISD::ZERO_EXTEND, .Dst: MVT::v4i32, .Src: MVT::v4i16, .Cost: 1},
128 {.ISD: ISD::SIGN_EXTEND, .Dst: MVT::v8i16, .Src: MVT::v8i8, .Cost: 1},
129 {.ISD: ISD::ZERO_EXTEND, .Dst: MVT::v8i16, .Src: MVT::v8i8, .Cost: 1},
130 {.ISD: ISD::SIGN_EXTEND, .Dst: MVT::v2i64, .Src: MVT::v2i16, .Cost: 2},
131 {.ISD: ISD::ZERO_EXTEND, .Dst: MVT::v2i64, .Src: MVT::v2i16, .Cost: 2},
132 {.ISD: ISD::SIGN_EXTEND, .Dst: MVT::v4i32, .Src: MVT::v4i8, .Cost: 2},
133 {.ISD: ISD::ZERO_EXTEND, .Dst: MVT::v4i32, .Src: MVT::v4i8, .Cost: 2},
134 };
135
136 if (const auto *Entry =
137 ConvertCostTableLookup(Table: ConversionTbl, ISD, Dst: DstVT, Src: SrcVT)) {
138 return Entry->Cost;
139 }
140
141 return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
142}
143
144InstructionCost WebAssemblyTTIImpl::getMemoryOpCost(
145 unsigned Opcode, Type *Ty, Align Alignment, unsigned AddressSpace,
146 TTI::TargetCostKind CostKind, TTI::OperandValueInfo OpInfo,
147 const Instruction *I) const {
148 if (!ST->hasSIMD128() || !isa<FixedVectorType>(Val: Ty)) {
149 return BaseT::getMemoryOpCost(Opcode, Src: Ty, Alignment, AddressSpace,
150 CostKind);
151 }
152
153 int ISD = TLI->InstructionOpcodeToISD(Opcode);
154 if (ISD != ISD::LOAD) {
155 return BaseT::getMemoryOpCost(Opcode, Src: Ty, Alignment, AddressSpace,
156 CostKind);
157 }
158
159 EVT VT = TLI->getValueType(DL, Ty, AllowUnknown: true);
160 // Type legalization can't handle structs
161 if (VT == MVT::Other)
162 return BaseT::getMemoryOpCost(Opcode, Src: Ty, Alignment, AddressSpace,
163 CostKind);
164
165 auto LT = getTypeLegalizationCost(Ty);
166 if (!LT.first.isValid())
167 return InstructionCost::getInvalid();
168
169 // 128-bit loads are a single instruction. 32-bit and 64-bit vector loads can
170 // be lowered to load32_zero and load64_zero respectively. Assume SIMD loads
171 // are twice as expensive as scalar.
172 unsigned width = VT.getSizeInBits();
173 switch (width) {
174 default:
175 break;
176 case 32:
177 case 64:
178 case 128:
179 return 2;
180 }
181
182 return BaseT::getMemoryOpCost(Opcode, Src: Ty, Alignment, AddressSpace, CostKind);
183}
184
185InstructionCost WebAssemblyTTIImpl::getVectorInstrCost(
186 unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
187 const Value *Op0, const Value *Op1) const {
188 InstructionCost Cost = BasicTTIImplBase::getVectorInstrCost(
189 Opcode, Val, CostKind, Index, Op0, Op1);
190
191 // SIMD128's insert/extract currently only take constant indices.
192 if (Index == -1u)
193 return Cost + 25 * TargetTransformInfo::TCC_Expensive;
194
195 return Cost;
196}
197
198InstructionCost WebAssemblyTTIImpl::getPartialReductionCost(
199 unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
200 ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
201 TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
202 TTI::TargetCostKind CostKind) const {
203 InstructionCost Invalid = InstructionCost::getInvalid();
204 if (!VF.isFixed() || !ST->hasSIMD128())
205 return Invalid;
206
207 if (CostKind != TTI::TCK_RecipThroughput)
208 return Invalid;
209
210 InstructionCost Cost(TTI::TCC_Basic);
211
212 // Possible options:
213 // - i16x8.extadd_pairwise_i8x16_sx
214 // - i32x4.extadd_pairwise_i16x8_sx
215 // - i32x4.dot_i16x8_s
216 // Only try to support dot, for now.
217
218 if (Opcode != Instruction::Add)
219 return Invalid;
220
221 if (!BinOp || *BinOp != Instruction::Mul)
222 return Invalid;
223
224 if (InputTypeA != InputTypeB)
225 return Invalid;
226
227 if (OpAExtend != OpBExtend)
228 return Invalid;
229
230 EVT InputEVT = EVT::getEVT(Ty: InputTypeA);
231 EVT AccumEVT = EVT::getEVT(Ty: AccumType);
232
233 // TODO: Add i64 accumulator.
234 if (AccumEVT != MVT::i32)
235 return Invalid;
236
237 // Signed inputs can lower to dot
238 if (InputEVT == MVT::i16 && VF.getFixedValue() == 8)
239 return OpAExtend == TTI::PR_SignExtend ? Cost : Cost * 2;
240
241 // Double the size of the lowered sequence.
242 if (InputEVT == MVT::i8 && VF.getFixedValue() == 16)
243 return OpAExtend == TTI::PR_SignExtend ? Cost * 2 : Cost * 4;
244
245 return Invalid;
246}
247
248TTI::ReductionShuffle WebAssemblyTTIImpl::getPreferredExpandedReductionShuffle(
249 const IntrinsicInst *II) const {
250
251 switch (II->getIntrinsicID()) {
252 default:
253 break;
254 case Intrinsic::vector_reduce_fadd:
255 return TTI::ReductionShuffle::Pairwise;
256 }
257 return TTI::ReductionShuffle::SplitHalf;
258}
259
260void WebAssemblyTTIImpl::getUnrollingPreferences(
261 Loop *L, ScalarEvolution &SE, TTI::UnrollingPreferences &UP,
262 OptimizationRemarkEmitter *ORE) const {
263 // Scan the loop: don't unroll loops with calls. This is a standard approach
264 // for most (all?) targets.
265 for (BasicBlock *BB : L->blocks())
266 for (Instruction &I : *BB)
267 if (isa<CallInst>(Val: I) || isa<InvokeInst>(Val: I))
268 if (const Function *F = cast<CallBase>(Val&: I).getCalledFunction())
269 if (isLoweredToCall(F))
270 return;
271
272 // The chosen threshold is within the range of 'LoopMicroOpBufferSize' of
273 // the various microarchitectures that use the BasicTTI implementation and
274 // has been selected through heuristics across multiple cores and runtimes.
275 UP.Partial = UP.Runtime = UP.UpperBound = true;
276 UP.PartialThreshold = 30;
277
278 // Avoid unrolling when optimizing for size.
279 UP.OptSizeThreshold = 0;
280 UP.PartialOptSizeThreshold = 0;
281
282 // Set number of instructions optimized when "back edge"
283 // becomes "fall through" to default value of 2.
284 UP.BEInsns = 2;
285}
286
287bool WebAssemblyTTIImpl::supportsTailCalls() const {
288 return getST()->hasTailCall();
289}
290
291bool WebAssemblyTTIImpl::isProfitableToSinkOperands(
292 Instruction *I, SmallVectorImpl<Use *> &Ops) const {
293 using namespace llvm::PatternMatch;
294
295 if (!I->getType()->isVectorTy() || !I->isShift())
296 return false;
297
298 Value *V = I->getOperand(i: 1);
299 // We dont need to sink constant splat.
300 if (isa<Constant>(Val: V))
301 return false;
302
303 if (match(V, P: m_Shuffle(v1: m_InsertElt(Val: m_Value(), Elt: m_Value(), Idx: m_ZeroInt()),
304 v2: m_Value(), mask: m_ZeroMask()))) {
305 // Sink insert
306 Ops.push_back(Elt: &cast<Instruction>(Val: V)->getOperandUse(i: 0));
307 // Sink shuffle
308 Ops.push_back(Elt: &I->getOperandUse(i: 1));
309 return true;
310 }
311
312 return false;
313}
314