1 | //===-- VECustomDAG.h - VE Custom DAG Nodes ------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the interfaces that VE uses to lower LLVM code into a |
10 | // selection DAG. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "VECustomDAG.h" |
15 | |
16 | #ifndef DEBUG_TYPE |
17 | #define DEBUG_TYPE "vecustomdag" |
18 | #endif |
19 | |
20 | namespace llvm { |
21 | |
22 | bool isPackedVectorType(EVT SomeVT) { |
23 | if (!SomeVT.isVector()) |
24 | return false; |
25 | return SomeVT.getVectorNumElements() > StandardVectorWidth; |
26 | } |
27 | |
28 | MVT splitVectorType(MVT VT) { |
29 | if (!VT.isVector()) |
30 | return VT; |
31 | return MVT::getVectorVT(VT: VT.getVectorElementType(), NumElements: StandardVectorWidth); |
32 | } |
33 | |
34 | MVT getLegalVectorType(Packing P, MVT ElemVT) { |
35 | return MVT::getVectorVT(VT: ElemVT, NumElements: P == Packing::Normal ? StandardVectorWidth |
36 | : PackedVectorWidth); |
37 | } |
38 | |
39 | Packing getTypePacking(EVT VT) { |
40 | assert(VT.isVector()); |
41 | return isPackedVectorType(SomeVT: VT) ? Packing::Dense : Packing::Normal; |
42 | } |
43 | |
44 | bool isMaskType(EVT SomeVT) { |
45 | if (!SomeVT.isVector()) |
46 | return false; |
47 | return SomeVT.getVectorElementType() == MVT::i1; |
48 | } |
49 | |
50 | bool isMaskArithmetic(SDValue Op) { |
51 | switch (Op.getOpcode()) { |
52 | default: |
53 | return false; |
54 | case ISD::AND: |
55 | case ISD::XOR: |
56 | case ISD::OR: |
57 | return isMaskType(SomeVT: Op.getValueType()); |
58 | } |
59 | } |
60 | |
61 | /// \returns the VVP_* SDNode opcode corresponsing to \p OC. |
62 | std::optional<unsigned> getVVPOpcode(unsigned Opcode) { |
63 | switch (Opcode) { |
64 | case ISD::MLOAD: |
65 | return VEISD::VVP_LOAD; |
66 | case ISD::MSTORE: |
67 | return VEISD::VVP_STORE; |
68 | #define HANDLE_VP_TO_VVP(VPOPC, VVPNAME) \ |
69 | case ISD::VPOPC: \ |
70 | return VEISD::VVPNAME; |
71 | #define ADD_VVP_OP(VVPNAME, SDNAME) \ |
72 | case VEISD::VVPNAME: \ |
73 | case ISD::SDNAME: \ |
74 | return VEISD::VVPNAME; |
75 | #include "VVPNodes.def" |
76 | // TODO: Map those in VVPNodes.def too |
77 | case ISD::EXPERIMENTAL_VP_STRIDED_LOAD: |
78 | return VEISD::VVP_LOAD; |
79 | case ISD::EXPERIMENTAL_VP_STRIDED_STORE: |
80 | return VEISD::VVP_STORE; |
81 | } |
82 | return std::nullopt; |
83 | } |
84 | |
85 | bool maySafelyIgnoreMask(SDValue Op) { |
86 | auto VVPOpc = getVVPOpcode(Opcode: Op->getOpcode()); |
87 | auto Opc = VVPOpc.value_or(u: Op->getOpcode()); |
88 | |
89 | switch (Opc) { |
90 | case VEISD::VVP_SDIV: |
91 | case VEISD::VVP_UDIV: |
92 | case VEISD::VVP_FDIV: |
93 | case VEISD::VVP_SELECT: |
94 | return false; |
95 | |
96 | default: |
97 | return true; |
98 | } |
99 | } |
100 | |
101 | bool supportsPackedMode(unsigned Opcode, EVT IdiomVT) { |
102 | bool IsPackedOp = isPackedVectorType(SomeVT: IdiomVT); |
103 | bool IsMaskOp = isMaskType(SomeVT: IdiomVT); |
104 | switch (Opcode) { |
105 | default: |
106 | return false; |
107 | |
108 | case VEISD::VEC_BROADCAST: |
109 | return true; |
110 | #define REGISTER_PACKED(VVP_NAME) case VEISD::VVP_NAME: |
111 | #include "VVPNodes.def" |
112 | return IsPackedOp && !IsMaskOp; |
113 | } |
114 | } |
115 | |
116 | bool isPackingSupportOpcode(unsigned Opc) { |
117 | switch (Opc) { |
118 | case VEISD::VEC_PACK: |
119 | case VEISD::VEC_UNPACK_LO: |
120 | case VEISD::VEC_UNPACK_HI: |
121 | return true; |
122 | } |
123 | return false; |
124 | } |
125 | |
126 | bool isVVPOrVEC(unsigned Opcode) { |
127 | switch (Opcode) { |
128 | case VEISD::VEC_BROADCAST: |
129 | #define ADD_VVP_OP(VVPNAME, ...) case VEISD::VVPNAME: |
130 | #include "VVPNodes.def" |
131 | return true; |
132 | } |
133 | return false; |
134 | } |
135 | |
136 | bool isVVPUnaryOp(unsigned VVPOpcode) { |
137 | switch (VVPOpcode) { |
138 | #define ADD_UNARY_VVP_OP(VVPNAME, ...) \ |
139 | case VEISD::VVPNAME: \ |
140 | return true; |
141 | #include "VVPNodes.def" |
142 | } |
143 | return false; |
144 | } |
145 | |
146 | bool isVVPBinaryOp(unsigned VVPOpcode) { |
147 | switch (VVPOpcode) { |
148 | #define ADD_BINARY_VVP_OP(VVPNAME, ...) \ |
149 | case VEISD::VVPNAME: \ |
150 | return true; |
151 | #include "VVPNodes.def" |
152 | } |
153 | return false; |
154 | } |
155 | |
156 | bool isVVPReductionOp(unsigned Opcode) { |
157 | switch (Opcode) { |
158 | #define ADD_REDUCE_VVP_OP(VVP_NAME, SDNAME) case VEISD::VVP_NAME: |
159 | #include "VVPNodes.def" |
160 | return true; |
161 | } |
162 | return false; |
163 | } |
164 | |
165 | // Return the AVL operand position for this VVP or VEC Op. |
166 | std::optional<int> getAVLPos(unsigned Opc) { |
167 | // This is only available for VP SDNodes |
168 | auto PosOpt = ISD::getVPExplicitVectorLengthIdx(Opcode: Opc); |
169 | if (PosOpt) |
170 | return *PosOpt; |
171 | |
172 | // VVP Opcodes. |
173 | if (isVVPBinaryOp(VVPOpcode: Opc)) |
174 | return 3; |
175 | |
176 | // VM Opcodes. |
177 | switch (Opc) { |
178 | case VEISD::VEC_BROADCAST: |
179 | return 1; |
180 | case VEISD::VVP_SELECT: |
181 | return 3; |
182 | case VEISD::VVP_LOAD: |
183 | return 4; |
184 | case VEISD::VVP_STORE: |
185 | return 5; |
186 | } |
187 | |
188 | return std::nullopt; |
189 | } |
190 | |
191 | std::optional<int> getMaskPos(unsigned Opc) { |
192 | // This is only available for VP SDNodes |
193 | auto PosOpt = ISD::getVPMaskIdx(Opcode: Opc); |
194 | if (PosOpt) |
195 | return *PosOpt; |
196 | |
197 | // VVP Opcodes. |
198 | if (isVVPBinaryOp(VVPOpcode: Opc)) |
199 | return 2; |
200 | |
201 | // Other opcodes. |
202 | switch (Opc) { |
203 | case ISD::MSTORE: |
204 | return 4; |
205 | case ISD::MLOAD: |
206 | return 3; |
207 | case VEISD::VVP_SELECT: |
208 | return 2; |
209 | } |
210 | |
211 | return std::nullopt; |
212 | } |
213 | |
214 | bool isLegalAVL(SDValue AVL) { return AVL->getOpcode() == VEISD::LEGALAVL; } |
215 | |
216 | /// Node Properties { |
217 | |
218 | SDValue getNodeChain(SDValue Op) { |
219 | if (MemSDNode *MemN = dyn_cast<MemSDNode>(Val: Op.getNode())) |
220 | return MemN->getChain(); |
221 | |
222 | switch (Op->getOpcode()) { |
223 | case VEISD::VVP_LOAD: |
224 | case VEISD::VVP_STORE: |
225 | return Op->getOperand(Num: 0); |
226 | } |
227 | return SDValue(); |
228 | } |
229 | |
230 | SDValue getMemoryPtr(SDValue Op) { |
231 | if (auto *MemN = dyn_cast<MemSDNode>(Val: Op.getNode())) |
232 | return MemN->getBasePtr(); |
233 | |
234 | switch (Op->getOpcode()) { |
235 | case VEISD::VVP_LOAD: |
236 | return Op->getOperand(Num: 1); |
237 | case VEISD::VVP_STORE: |
238 | return Op->getOperand(Num: 2); |
239 | } |
240 | return SDValue(); |
241 | } |
242 | |
243 | std::optional<EVT> getIdiomaticVectorType(SDNode *Op) { |
244 | unsigned OC = Op->getOpcode(); |
245 | |
246 | // For memory ops -> the transfered data type |
247 | if (auto MemN = dyn_cast<MemSDNode>(Val: Op)) |
248 | return MemN->getMemoryVT(); |
249 | |
250 | switch (OC) { |
251 | // Standard ISD. |
252 | case ISD::SELECT: // not aliased with VVP_SELECT |
253 | case ISD::CONCAT_VECTORS: |
254 | case ISD::EXTRACT_SUBVECTOR: |
255 | case ISD::VECTOR_SHUFFLE: |
256 | case ISD::BUILD_VECTOR: |
257 | case ISD::SCALAR_TO_VECTOR: |
258 | return Op->getValueType(ResNo: 0); |
259 | } |
260 | |
261 | // Translate to VVP where possible. |
262 | unsigned OriginalOC = OC; |
263 | if (auto VVPOpc = getVVPOpcode(Opcode: OC)) |
264 | OC = *VVPOpc; |
265 | |
266 | if (isVVPReductionOp(Opcode: OC)) |
267 | return Op->getOperand(Num: hasReductionStartParam(VVPOC: OriginalOC) ? 1 : 0) |
268 | .getValueType(); |
269 | |
270 | switch (OC) { |
271 | default: |
272 | case VEISD::VVP_SETCC: |
273 | return Op->getOperand(Num: 0).getValueType(); |
274 | |
275 | case VEISD::VVP_SELECT: |
276 | #define ADD_BINARY_VVP_OP(VVP_NAME, ...) case VEISD::VVP_NAME: |
277 | #include "VVPNodes.def" |
278 | return Op->getValueType(ResNo: 0); |
279 | |
280 | case VEISD::VVP_LOAD: |
281 | return Op->getValueType(ResNo: 0); |
282 | |
283 | case VEISD::VVP_STORE: |
284 | return Op->getOperand(Num: 1)->getValueType(ResNo: 0); |
285 | |
286 | // VEC |
287 | case VEISD::VEC_BROADCAST: |
288 | return Op->getValueType(ResNo: 0); |
289 | } |
290 | } |
291 | |
292 | SDValue getLoadStoreStride(SDValue Op, VECustomDAG &CDAG) { |
293 | switch (Op->getOpcode()) { |
294 | case VEISD::VVP_STORE: |
295 | return Op->getOperand(Num: 3); |
296 | case VEISD::VVP_LOAD: |
297 | return Op->getOperand(Num: 2); |
298 | } |
299 | |
300 | if (auto *StoreN = dyn_cast<VPStridedStoreSDNode>(Val: Op.getNode())) |
301 | return StoreN->getStride(); |
302 | if (auto *StoreN = dyn_cast<VPStridedLoadSDNode>(Val: Op.getNode())) |
303 | return StoreN->getStride(); |
304 | |
305 | if (isa<MemSDNode>(Val: Op.getNode())) { |
306 | // Regular MLOAD/MSTORE/LOAD/STORE |
307 | // No stride argument -> use the contiguous element size as stride. |
308 | uint64_t ElemStride = getIdiomaticVectorType(Op: Op.getNode()) |
309 | ->getVectorElementType() |
310 | .getStoreSize(); |
311 | return CDAG.getConstant(Val: ElemStride, VT: MVT::i64); |
312 | } |
313 | return SDValue(); |
314 | } |
315 | |
316 | SDValue getGatherScatterIndex(SDValue Op) { |
317 | if (auto *N = dyn_cast<MaskedGatherScatterSDNode>(Val: Op.getNode())) |
318 | return N->getIndex(); |
319 | if (auto *N = dyn_cast<VPGatherScatterSDNode>(Val: Op.getNode())) |
320 | return N->getIndex(); |
321 | return SDValue(); |
322 | } |
323 | |
324 | SDValue getGatherScatterScale(SDValue Op) { |
325 | if (auto *N = dyn_cast<MaskedGatherScatterSDNode>(Val: Op.getNode())) |
326 | return N->getScale(); |
327 | if (auto *N = dyn_cast<VPGatherScatterSDNode>(Val: Op.getNode())) |
328 | return N->getScale(); |
329 | return SDValue(); |
330 | } |
331 | |
332 | SDValue getStoredValue(SDValue Op) { |
333 | switch (Op->getOpcode()) { |
334 | case ISD::EXPERIMENTAL_VP_STRIDED_STORE: |
335 | case VEISD::VVP_STORE: |
336 | return Op->getOperand(Num: 1); |
337 | } |
338 | if (auto *StoreN = dyn_cast<StoreSDNode>(Val: Op.getNode())) |
339 | return StoreN->getValue(); |
340 | if (auto *StoreN = dyn_cast<MaskedStoreSDNode>(Val: Op.getNode())) |
341 | return StoreN->getValue(); |
342 | if (auto *StoreN = dyn_cast<VPStridedStoreSDNode>(Val: Op.getNode())) |
343 | return StoreN->getValue(); |
344 | if (auto *StoreN = dyn_cast<VPStoreSDNode>(Val: Op.getNode())) |
345 | return StoreN->getValue(); |
346 | if (auto *StoreN = dyn_cast<MaskedScatterSDNode>(Val: Op.getNode())) |
347 | return StoreN->getValue(); |
348 | if (auto *StoreN = dyn_cast<VPScatterSDNode>(Val: Op.getNode())) |
349 | return StoreN->getValue(); |
350 | return SDValue(); |
351 | } |
352 | |
353 | SDValue getNodePassthru(SDValue Op) { |
354 | if (auto *N = dyn_cast<MaskedLoadSDNode>(Val: Op.getNode())) |
355 | return N->getPassThru(); |
356 | if (auto *N = dyn_cast<MaskedGatherSDNode>(Val: Op.getNode())) |
357 | return N->getPassThru(); |
358 | |
359 | return SDValue(); |
360 | } |
361 | |
362 | bool hasReductionStartParam(unsigned OPC) { |
363 | // TODO: Ordered reduction opcodes. |
364 | if (ISD::isVPReduction(Opcode: OPC)) |
365 | return true; |
366 | return false; |
367 | } |
368 | |
369 | unsigned getScalarReductionOpcode(unsigned VVPOC, bool IsMask) { |
370 | assert(!IsMask && "Mask reduction isel" ); |
371 | |
372 | switch (VVPOC) { |
373 | #define HANDLE_VVP_REDUCE_TO_SCALAR(VVP_RED_ISD, REDUCE_ISD) \ |
374 | case VEISD::VVP_RED_ISD: \ |
375 | return ISD::REDUCE_ISD; |
376 | #include "VVPNodes.def" |
377 | default: |
378 | break; |
379 | } |
380 | llvm_unreachable("Cannot not scalarize this reduction Opcode!" ); |
381 | } |
382 | |
383 | /// } Node Properties |
384 | |
385 | SDValue getNodeAVL(SDValue Op) { |
386 | auto PosOpt = getAVLPos(Opc: Op->getOpcode()); |
387 | return PosOpt ? Op->getOperand(Num: *PosOpt) : SDValue(); |
388 | } |
389 | |
390 | SDValue getNodeMask(SDValue Op) { |
391 | auto PosOpt = getMaskPos(Opc: Op->getOpcode()); |
392 | return PosOpt ? Op->getOperand(Num: *PosOpt) : SDValue(); |
393 | } |
394 | |
395 | std::pair<SDValue, bool> getAnnotatedNodeAVL(SDValue Op) { |
396 | SDValue AVL = getNodeAVL(Op); |
397 | if (!AVL) |
398 | return {SDValue(), true}; |
399 | if (isLegalAVL(AVL)) |
400 | return {AVL->getOperand(Num: 0), true}; |
401 | return {AVL, false}; |
402 | } |
403 | |
404 | SDValue VECustomDAG::getConstant(uint64_t Val, EVT VT, bool IsTarget, |
405 | bool IsOpaque) const { |
406 | return DAG.getConstant(Val, DL, VT, isTarget: IsTarget, isOpaque: IsOpaque); |
407 | } |
408 | |
409 | SDValue VECustomDAG::getConstantMask(Packing Packing, bool AllTrue) const { |
410 | auto MaskVT = getLegalVectorType(P: Packing, ElemVT: MVT::i1); |
411 | |
412 | // VEISelDAGtoDAG will replace this pattern with the constant-true VM. |
413 | auto TrueVal = DAG.getConstant(Val: -1, DL, VT: MVT::i32); |
414 | auto AVL = getConstant(Val: MaskVT.getVectorNumElements(), VT: MVT::i32); |
415 | auto Res = getNode(OC: VEISD::VEC_BROADCAST, ResVT: MaskVT, OpV: {TrueVal, AVL}); |
416 | if (AllTrue) |
417 | return Res; |
418 | |
419 | return DAG.getNOT(DL, Val: Res, VT: Res.getValueType()); |
420 | } |
421 | |
422 | SDValue VECustomDAG::getMaskBroadcast(EVT ResultVT, SDValue Scalar, |
423 | SDValue AVL) const { |
424 | // Constant mask splat. |
425 | if (auto BcConst = dyn_cast<ConstantSDNode>(Val&: Scalar)) |
426 | return getConstantMask(Packing: getTypePacking(VT: ResultVT), |
427 | AllTrue: BcConst->getSExtValue() != 0); |
428 | |
429 | // Expand the broadcast to a vector comparison. |
430 | auto ScalarBoolVT = Scalar.getSimpleValueType(); |
431 | assert(ScalarBoolVT == MVT::i32); |
432 | |
433 | // Cast to i32 ty. |
434 | SDValue CmpElem = DAG.getSExtOrTrunc(Op: Scalar, DL, VT: MVT::i32); |
435 | unsigned ElemCount = ResultVT.getVectorNumElements(); |
436 | MVT CmpVecTy = MVT::getVectorVT(VT: ScalarBoolVT, NumElements: ElemCount); |
437 | |
438 | // Broadcast to vector. |
439 | SDValue BCVec = |
440 | DAG.getNode(Opcode: VEISD::VEC_BROADCAST, DL, VT: CmpVecTy, Ops: {CmpElem, AVL}); |
441 | SDValue ZeroVec = |
442 | getBroadcast(ResultVT: CmpVecTy, Scalar: {DAG.getConstant(Val: 0, DL, VT: ScalarBoolVT)}, AVL); |
443 | |
444 | MVT BoolVecTy = MVT::getVectorVT(VT: MVT::i1, NumElements: ElemCount); |
445 | |
446 | // Broadcast(Data) != Broadcast(0) |
447 | // TODO: Use a VVP operation for this. |
448 | return DAG.getSetCC(DL, VT: BoolVecTy, LHS: BCVec, RHS: ZeroVec, Cond: ISD::CondCode::SETNE); |
449 | } |
450 | |
451 | SDValue VECustomDAG::getBroadcast(EVT ResultVT, SDValue Scalar, |
452 | SDValue AVL) const { |
453 | assert(ResultVT.isVector()); |
454 | auto ScaVT = Scalar.getValueType(); |
455 | |
456 | if (isMaskType(SomeVT: ResultVT)) |
457 | return getMaskBroadcast(ResultVT, Scalar, AVL); |
458 | |
459 | if (isPackedVectorType(SomeVT: ResultVT)) { |
460 | // v512x packed mode broadcast |
461 | // Replicate the scalar reg (f32 or i32) onto the opposing half of the full |
462 | // scalar register. If it's an I64 type, assume that this has already |
463 | // happened. |
464 | if (ScaVT == MVT::f32) { |
465 | Scalar = getNode(OC: VEISD::REPL_F32, ResVT: MVT::i64, OpV: Scalar); |
466 | } else if (ScaVT == MVT::i32) { |
467 | Scalar = getNode(OC: VEISD::REPL_I32, ResVT: MVT::i64, OpV: Scalar); |
468 | } |
469 | } |
470 | |
471 | return getNode(OC: VEISD::VEC_BROADCAST, ResVT: ResultVT, OpV: {Scalar, AVL}); |
472 | } |
473 | |
474 | SDValue VECustomDAG::annotateLegalAVL(SDValue AVL) const { |
475 | if (isLegalAVL(AVL)) |
476 | return AVL; |
477 | return getNode(OC: VEISD::LEGALAVL, ResVT: AVL.getValueType(), OpV: AVL); |
478 | } |
479 | |
480 | SDValue VECustomDAG::getUnpack(EVT DestVT, SDValue Vec, PackElem Part, |
481 | SDValue AVL) const { |
482 | assert(getAnnotatedNodeAVL(AVL).second && "Expected a pack-legalized AVL" ); |
483 | |
484 | // TODO: Peek through VEC_PACK and VEC_BROADCAST(REPL_<sth> ..) operands. |
485 | unsigned OC = |
486 | (Part == PackElem::Lo) ? VEISD::VEC_UNPACK_LO : VEISD::VEC_UNPACK_HI; |
487 | return DAG.getNode(Opcode: OC, DL, VT: DestVT, N1: Vec, N2: AVL); |
488 | } |
489 | |
490 | SDValue VECustomDAG::getPack(EVT DestVT, SDValue LoVec, SDValue HiVec, |
491 | SDValue AVL) const { |
492 | assert(getAnnotatedNodeAVL(AVL).second && "Expected a pack-legalized AVL" ); |
493 | |
494 | // TODO: Peek through VEC_UNPACK_LO|HI operands. |
495 | return DAG.getNode(Opcode: VEISD::VEC_PACK, DL, VT: DestVT, N1: LoVec, N2: HiVec, N3: AVL); |
496 | } |
497 | |
498 | VETargetMasks VECustomDAG::getTargetSplitMask(SDValue RawMask, SDValue RawAVL, |
499 | PackElem Part) const { |
500 | // Adjust AVL for this part |
501 | SDValue NewAVL; |
502 | SDValue OneV = getConstant(Val: 1, VT: MVT::i32); |
503 | if (Part == PackElem::Hi) |
504 | NewAVL = getNode(OC: ISD::ADD, ResVT: MVT::i32, OpV: {RawAVL, OneV}); |
505 | else |
506 | NewAVL = RawAVL; |
507 | NewAVL = getNode(OC: ISD::SRL, ResVT: MVT::i32, OpV: {NewAVL, OneV}); |
508 | |
509 | NewAVL = annotateLegalAVL(AVL: NewAVL); |
510 | |
511 | // Legalize Mask (unpack or all-true) |
512 | SDValue NewMask; |
513 | if (!RawMask) |
514 | NewMask = getConstantMask(Packing: Packing::Normal, AllTrue: true); |
515 | else |
516 | NewMask = getUnpack(DestVT: MVT::v256i1, Vec: RawMask, Part, AVL: NewAVL); |
517 | |
518 | return VETargetMasks(NewMask, NewAVL); |
519 | } |
520 | |
521 | SDValue VECustomDAG::getSplitPtrOffset(SDValue Ptr, SDValue ByteStride, |
522 | PackElem Part) const { |
523 | // High starts at base ptr but has more significant bits in the 64bit vector |
524 | // element. |
525 | if (Part == PackElem::Hi) |
526 | return Ptr; |
527 | return getNode(OC: ISD::ADD, ResVT: MVT::i64, OpV: {Ptr, ByteStride}); |
528 | } |
529 | |
530 | SDValue VECustomDAG::getSplitPtrStride(SDValue PackStride) const { |
531 | if (auto ConstBytes = dyn_cast<ConstantSDNode>(Val&: PackStride)) |
532 | return getConstant(Val: 2 * ConstBytes->getSExtValue(), VT: MVT::i64); |
533 | return getNode(OC: ISD::SHL, ResVT: MVT::i64, OpV: {PackStride, getConstant(Val: 1, VT: MVT::i32)}); |
534 | } |
535 | |
536 | SDValue VECustomDAG::getGatherScatterAddress(SDValue BasePtr, SDValue Scale, |
537 | SDValue Index, SDValue Mask, |
538 | SDValue AVL) const { |
539 | EVT IndexVT = Index.getValueType(); |
540 | |
541 | // Apply scale. |
542 | SDValue ScaledIndex; |
543 | if (!Scale || isOneConstant(V: Scale)) |
544 | ScaledIndex = Index; |
545 | else { |
546 | SDValue ScaleBroadcast = getBroadcast(ResultVT: IndexVT, Scalar: Scale, AVL); |
547 | ScaledIndex = |
548 | getNode(OC: VEISD::VVP_MUL, ResVT: IndexVT, OpV: {Index, ScaleBroadcast, Mask, AVL}); |
549 | } |
550 | |
551 | // Add basePtr. |
552 | if (isNullConstant(V: BasePtr)) |
553 | return ScaledIndex; |
554 | |
555 | // re-constitute pointer vector (basePtr + index * scale) |
556 | SDValue BaseBroadcast = getBroadcast(ResultVT: IndexVT, Scalar: BasePtr, AVL); |
557 | auto ResPtr = |
558 | getNode(OC: VEISD::VVP_ADD, ResVT: IndexVT, OpV: {BaseBroadcast, ScaledIndex, Mask, AVL}); |
559 | return ResPtr; |
560 | } |
561 | |
562 | SDValue VECustomDAG::getLegalReductionOpVVP(unsigned VVPOpcode, EVT ResVT, |
563 | SDValue StartV, SDValue VectorV, |
564 | SDValue Mask, SDValue AVL, |
565 | SDNodeFlags Flags) const { |
566 | |
567 | // Optionally attach the start param with a scalar op (where it is |
568 | // unsupported). |
569 | bool scalarizeStartParam = StartV && !hasReductionStartParam(OPC: VVPOpcode); |
570 | bool IsMaskReduction = isMaskType(SomeVT: VectorV.getValueType()); |
571 | assert(!IsMaskReduction && "TODO Implement" ); |
572 | auto AttachStartValue = [&](SDValue ReductionResV) { |
573 | if (!scalarizeStartParam) |
574 | return ReductionResV; |
575 | auto ScalarOC = getScalarReductionOpcode(VVPOC: VVPOpcode, IsMask: IsMaskReduction); |
576 | return getNode(OC: ScalarOC, ResVT, OpV: {StartV, ReductionResV}); |
577 | }; |
578 | |
579 | // Fixup: Always Use sequential 'fmul' reduction. |
580 | if (!scalarizeStartParam && StartV) { |
581 | assert(hasReductionStartParam(VVPOpcode)); |
582 | return AttachStartValue( |
583 | getNode(OC: VVPOpcode, ResVT, OpV: {StartV, VectorV, Mask, AVL}, Flags)); |
584 | } else |
585 | return AttachStartValue( |
586 | getNode(OC: VVPOpcode, ResVT, OpV: {VectorV, Mask, AVL}, Flags)); |
587 | } |
588 | |
589 | } // namespace llvm |
590 | |