1//===-- VECustomDAG.h - VE Custom DAG Nodes ------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the interfaces that VE uses to lower LLVM code into a
10// selection DAG.
11//
12//===----------------------------------------------------------------------===//
13
14#include "VECustomDAG.h"
15#include "VESelectionDAGInfo.h"
16
17#ifndef DEBUG_TYPE
18#define DEBUG_TYPE "vecustomdag"
19#endif
20
21namespace llvm {
22
23bool isPackedVectorType(EVT SomeVT) {
24 if (!SomeVT.isVector())
25 return false;
26 return SomeVT.getVectorNumElements() > StandardVectorWidth;
27}
28
29MVT splitVectorType(MVT VT) {
30 if (!VT.isVector())
31 return VT;
32 return MVT::getVectorVT(VT: VT.getVectorElementType(), NumElements: StandardVectorWidth);
33}
34
35MVT getLegalVectorType(Packing P, MVT ElemVT) {
36 return MVT::getVectorVT(VT: ElemVT, NumElements: P == Packing::Normal ? StandardVectorWidth
37 : PackedVectorWidth);
38}
39
40Packing getTypePacking(EVT VT) {
41 assert(VT.isVector());
42 return isPackedVectorType(SomeVT: VT) ? Packing::Dense : Packing::Normal;
43}
44
45bool isMaskType(EVT SomeVT) {
46 if (!SomeVT.isVector())
47 return false;
48 return SomeVT.getVectorElementType() == MVT::i1;
49}
50
51bool isMaskArithmetic(SDValue Op) {
52 switch (Op.getOpcode()) {
53 default:
54 return false;
55 case ISD::AND:
56 case ISD::XOR:
57 case ISD::OR:
58 return isMaskType(SomeVT: Op.getValueType());
59 }
60}
61
62/// \returns the VVP_* SDNode opcode corresponsing to \p OC.
63std::optional<unsigned> getVVPOpcode(unsigned Opcode) {
64 switch (Opcode) {
65 case ISD::MLOAD:
66 return VEISD::VVP_LOAD;
67 case ISD::MSTORE:
68 return VEISD::VVP_STORE;
69#define HANDLE_VP_TO_VVP(VPOPC, VVPNAME) \
70 case ISD::VPOPC: \
71 return VEISD::VVPNAME;
72#define ADD_VVP_OP(VVPNAME, SDNAME) \
73 case VEISD::VVPNAME: \
74 case ISD::SDNAME: \
75 return VEISD::VVPNAME;
76#include "VVPNodes.def"
77 // TODO: Map those in VVPNodes.def too
78 case ISD::EXPERIMENTAL_VP_STRIDED_LOAD:
79 return VEISD::VVP_LOAD;
80 case ISD::EXPERIMENTAL_VP_STRIDED_STORE:
81 return VEISD::VVP_STORE;
82 }
83 return std::nullopt;
84}
85
86bool maySafelyIgnoreMask(SDValue Op) {
87 auto VVPOpc = getVVPOpcode(Opcode: Op->getOpcode());
88 auto Opc = VVPOpc.value_or(u: Op->getOpcode());
89
90 switch (Opc) {
91 case VEISD::VVP_SDIV:
92 case VEISD::VVP_UDIV:
93 case VEISD::VVP_FDIV:
94 case VEISD::VVP_SELECT:
95 return false;
96
97 default:
98 return true;
99 }
100}
101
102bool supportsPackedMode(unsigned Opcode, EVT IdiomVT) {
103 bool IsPackedOp = isPackedVectorType(SomeVT: IdiomVT);
104 bool IsMaskOp = isMaskType(SomeVT: IdiomVT);
105 switch (Opcode) {
106 default:
107 return false;
108
109 case VEISD::VEC_BROADCAST:
110 return true;
111#define REGISTER_PACKED(VVP_NAME) case VEISD::VVP_NAME:
112#include "VVPNodes.def"
113 return IsPackedOp && !IsMaskOp;
114 }
115}
116
117bool isPackingSupportOpcode(unsigned Opc) {
118 switch (Opc) {
119 case VEISD::VEC_PACK:
120 case VEISD::VEC_UNPACK_LO:
121 case VEISD::VEC_UNPACK_HI:
122 return true;
123 }
124 return false;
125}
126
127bool isVVPOrVEC(unsigned Opcode) {
128 switch (Opcode) {
129 case VEISD::VEC_BROADCAST:
130#define ADD_VVP_OP(VVPNAME, ...) case VEISD::VVPNAME:
131#include "VVPNodes.def"
132 return true;
133 }
134 return false;
135}
136
137bool isVVPUnaryOp(unsigned VVPOpcode) {
138 switch (VVPOpcode) {
139#define ADD_UNARY_VVP_OP(VVPNAME, ...) \
140 case VEISD::VVPNAME: \
141 return true;
142#include "VVPNodes.def"
143 }
144 return false;
145}
146
147bool isVVPBinaryOp(unsigned VVPOpcode) {
148 switch (VVPOpcode) {
149#define ADD_BINARY_VVP_OP(VVPNAME, ...) \
150 case VEISD::VVPNAME: \
151 return true;
152#include "VVPNodes.def"
153 }
154 return false;
155}
156
157bool isVVPReductionOp(unsigned Opcode) {
158 switch (Opcode) {
159#define ADD_REDUCE_VVP_OP(VVP_NAME, SDNAME) case VEISD::VVP_NAME:
160#include "VVPNodes.def"
161 return true;
162 }
163 return false;
164}
165
166// Return the AVL operand position for this VVP or VEC Op.
167std::optional<int> getAVLPos(unsigned Opc) {
168 // This is only available for VP SDNodes
169 auto PosOpt = ISD::getVPExplicitVectorLengthIdx(Opcode: Opc);
170 if (PosOpt)
171 return *PosOpt;
172
173 // VVP Opcodes.
174 if (isVVPBinaryOp(VVPOpcode: Opc))
175 return 3;
176
177 // VM Opcodes.
178 switch (Opc) {
179 case VEISD::VEC_BROADCAST:
180 return 1;
181 case VEISD::VVP_SELECT:
182 return 3;
183 case VEISD::VVP_LOAD:
184 return 4;
185 case VEISD::VVP_STORE:
186 return 5;
187 }
188
189 return std::nullopt;
190}
191
192std::optional<int> getMaskPos(unsigned Opc) {
193 // This is only available for VP SDNodes
194 auto PosOpt = ISD::getVPMaskIdx(Opcode: Opc);
195 if (PosOpt)
196 return *PosOpt;
197
198 // VVP Opcodes.
199 if (isVVPBinaryOp(VVPOpcode: Opc))
200 return 2;
201
202 // Other opcodes.
203 switch (Opc) {
204 case ISD::MSTORE:
205 return 4;
206 case ISD::MLOAD:
207 return 3;
208 case VEISD::VVP_SELECT:
209 return 2;
210 }
211
212 return std::nullopt;
213}
214
215bool isLegalAVL(SDValue AVL) { return AVL->getOpcode() == VEISD::LEGALAVL; }
216
217/// Node Properties {
218
219SDValue getNodeChain(SDValue Op) {
220 if (MemSDNode *MemN = dyn_cast<MemSDNode>(Val: Op.getNode()))
221 return MemN->getChain();
222
223 switch (Op->getOpcode()) {
224 case VEISD::VVP_LOAD:
225 case VEISD::VVP_STORE:
226 return Op->getOperand(Num: 0);
227 }
228 return SDValue();
229}
230
231SDValue getMemoryPtr(SDValue Op) {
232 if (auto *MemN = dyn_cast<MemSDNode>(Val: Op.getNode()))
233 return MemN->getBasePtr();
234
235 switch (Op->getOpcode()) {
236 case VEISD::VVP_LOAD:
237 return Op->getOperand(Num: 1);
238 case VEISD::VVP_STORE:
239 return Op->getOperand(Num: 2);
240 }
241 return SDValue();
242}
243
244std::optional<EVT> getIdiomaticVectorType(SDNode *Op) {
245 unsigned OC = Op->getOpcode();
246
247 // For memory ops -> the transfered data type
248 if (auto MemN = dyn_cast<MemSDNode>(Val: Op))
249 return MemN->getMemoryVT();
250
251 switch (OC) {
252 // Standard ISD.
253 case ISD::SELECT: // not aliased with VVP_SELECT
254 case ISD::CONCAT_VECTORS:
255 case ISD::EXTRACT_SUBVECTOR:
256 case ISD::VECTOR_SHUFFLE:
257 case ISD::BUILD_VECTOR:
258 case ISD::SCALAR_TO_VECTOR:
259 return Op->getValueType(ResNo: 0);
260 }
261
262 // Translate to VVP where possible.
263 unsigned OriginalOC = OC;
264 if (auto VVPOpc = getVVPOpcode(Opcode: OC))
265 OC = *VVPOpc;
266
267 if (isVVPReductionOp(Opcode: OC))
268 return Op->getOperand(Num: hasReductionStartParam(VVPOC: OriginalOC) ? 1 : 0)
269 .getValueType();
270
271 switch (OC) {
272 default:
273 case VEISD::VVP_SETCC:
274 return Op->getOperand(Num: 0).getValueType();
275
276 case VEISD::VVP_SELECT:
277#define ADD_BINARY_VVP_OP(VVP_NAME, ...) case VEISD::VVP_NAME:
278#include "VVPNodes.def"
279 return Op->getValueType(ResNo: 0);
280
281 case VEISD::VVP_LOAD:
282 return Op->getValueType(ResNo: 0);
283
284 case VEISD::VVP_STORE:
285 return Op->getOperand(Num: 1)->getValueType(ResNo: 0);
286
287 // VEC
288 case VEISD::VEC_BROADCAST:
289 return Op->getValueType(ResNo: 0);
290 }
291}
292
293SDValue getLoadStoreStride(SDValue Op, VECustomDAG &CDAG) {
294 switch (Op->getOpcode()) {
295 case VEISD::VVP_STORE:
296 return Op->getOperand(Num: 3);
297 case VEISD::VVP_LOAD:
298 return Op->getOperand(Num: 2);
299 }
300
301 if (auto *StoreN = dyn_cast<VPStridedStoreSDNode>(Val: Op.getNode()))
302 return StoreN->getStride();
303 if (auto *StoreN = dyn_cast<VPStridedLoadSDNode>(Val: Op.getNode()))
304 return StoreN->getStride();
305
306 if (isa<MemSDNode>(Val: Op.getNode())) {
307 // Regular MLOAD/MSTORE/LOAD/STORE
308 // No stride argument -> use the contiguous element size as stride.
309 uint64_t ElemStride = getIdiomaticVectorType(Op: Op.getNode())
310 ->getVectorElementType()
311 .getStoreSize();
312 return CDAG.getConstant(Val: ElemStride, VT: MVT::i64);
313 }
314 return SDValue();
315}
316
317SDValue getGatherScatterIndex(SDValue Op) {
318 if (auto *N = dyn_cast<MaskedGatherScatterSDNode>(Val: Op.getNode()))
319 return N->getIndex();
320 if (auto *N = dyn_cast<VPGatherScatterSDNode>(Val: Op.getNode()))
321 return N->getIndex();
322 return SDValue();
323}
324
325SDValue getGatherScatterScale(SDValue Op) {
326 if (auto *N = dyn_cast<MaskedGatherScatterSDNode>(Val: Op.getNode()))
327 return N->getScale();
328 if (auto *N = dyn_cast<VPGatherScatterSDNode>(Val: Op.getNode()))
329 return N->getScale();
330 return SDValue();
331}
332
333SDValue getStoredValue(SDValue Op) {
334 switch (Op->getOpcode()) {
335 case ISD::EXPERIMENTAL_VP_STRIDED_STORE:
336 case VEISD::VVP_STORE:
337 return Op->getOperand(Num: 1);
338 }
339 if (auto *StoreN = dyn_cast<StoreSDNode>(Val: Op.getNode()))
340 return StoreN->getValue();
341 if (auto *StoreN = dyn_cast<MaskedStoreSDNode>(Val: Op.getNode()))
342 return StoreN->getValue();
343 if (auto *StoreN = dyn_cast<VPStridedStoreSDNode>(Val: Op.getNode()))
344 return StoreN->getValue();
345 if (auto *StoreN = dyn_cast<VPStoreSDNode>(Val: Op.getNode()))
346 return StoreN->getValue();
347 if (auto *StoreN = dyn_cast<MaskedScatterSDNode>(Val: Op.getNode()))
348 return StoreN->getValue();
349 if (auto *StoreN = dyn_cast<VPScatterSDNode>(Val: Op.getNode()))
350 return StoreN->getValue();
351 return SDValue();
352}
353
354SDValue getNodePassthru(SDValue Op) {
355 if (auto *N = dyn_cast<MaskedLoadSDNode>(Val: Op.getNode()))
356 return N->getPassThru();
357 if (auto *N = dyn_cast<MaskedGatherSDNode>(Val: Op.getNode()))
358 return N->getPassThru();
359
360 return SDValue();
361}
362
363bool hasReductionStartParam(unsigned OPC) {
364 // TODO: Ordered reduction opcodes.
365 if (ISD::isVPReduction(Opcode: OPC))
366 return true;
367 return false;
368}
369
370unsigned getScalarReductionOpcode(unsigned VVPOC, bool IsMask) {
371 assert(!IsMask && "Mask reduction isel");
372
373 switch (VVPOC) {
374#define HANDLE_VVP_REDUCE_TO_SCALAR(VVP_RED_ISD, REDUCE_ISD) \
375 case VEISD::VVP_RED_ISD: \
376 return ISD::REDUCE_ISD;
377#include "VVPNodes.def"
378 default:
379 break;
380 }
381 llvm_unreachable("Cannot not scalarize this reduction Opcode!");
382}
383
384/// } Node Properties
385
386SDValue getNodeAVL(SDValue Op) {
387 auto PosOpt = getAVLPos(Opc: Op->getOpcode());
388 return PosOpt ? Op->getOperand(Num: *PosOpt) : SDValue();
389}
390
391SDValue getNodeMask(SDValue Op) {
392 auto PosOpt = getMaskPos(Opc: Op->getOpcode());
393 return PosOpt ? Op->getOperand(Num: *PosOpt) : SDValue();
394}
395
396std::pair<SDValue, bool> getAnnotatedNodeAVL(SDValue Op) {
397 SDValue AVL = getNodeAVL(Op);
398 if (!AVL)
399 return {SDValue(), true};
400 if (isLegalAVL(AVL))
401 return {AVL->getOperand(Num: 0), true};
402 return {AVL, false};
403}
404
405SDValue VECustomDAG::getConstant(uint64_t Val, EVT VT, bool IsTarget,
406 bool IsOpaque) const {
407 return DAG.getConstant(Val, DL, VT, isTarget: IsTarget, isOpaque: IsOpaque);
408}
409
410SDValue VECustomDAG::getConstantMask(Packing Packing, bool AllTrue) const {
411 auto MaskVT = getLegalVectorType(P: Packing, ElemVT: MVT::i1);
412
413 // VEISelDAGtoDAG will replace this pattern with the constant-true VM.
414 auto TrueVal = DAG.getAllOnesConstant(DL, VT: MVT::i32);
415 auto AVL = getConstant(Val: MaskVT.getVectorNumElements(), VT: MVT::i32);
416 auto Res = getNode(OC: VEISD::VEC_BROADCAST, ResVT: MaskVT, OpV: {TrueVal, AVL});
417 if (AllTrue)
418 return Res;
419
420 return DAG.getNOT(DL, Val: Res, VT: Res.getValueType());
421}
422
423SDValue VECustomDAG::getMaskBroadcast(EVT ResultVT, SDValue Scalar,
424 SDValue AVL) const {
425 // Constant mask splat.
426 if (auto BcConst = dyn_cast<ConstantSDNode>(Val&: Scalar))
427 return getConstantMask(Packing: getTypePacking(VT: ResultVT),
428 AllTrue: BcConst->getSExtValue() != 0);
429
430 // Expand the broadcast to a vector comparison.
431 auto ScalarBoolVT = Scalar.getSimpleValueType();
432 assert(ScalarBoolVT == MVT::i32);
433
434 // Cast to i32 ty.
435 SDValue CmpElem = DAG.getSExtOrTrunc(Op: Scalar, DL, VT: MVT::i32);
436 unsigned ElemCount = ResultVT.getVectorNumElements();
437 MVT CmpVecTy = MVT::getVectorVT(VT: ScalarBoolVT, NumElements: ElemCount);
438
439 // Broadcast to vector.
440 SDValue BCVec =
441 DAG.getNode(Opcode: VEISD::VEC_BROADCAST, DL, VT: CmpVecTy, Ops: {CmpElem, AVL});
442 SDValue ZeroVec =
443 getBroadcast(ResultVT: CmpVecTy, Scalar: {DAG.getConstant(Val: 0, DL, VT: ScalarBoolVT)}, AVL);
444
445 MVT BoolVecTy = MVT::getVectorVT(VT: MVT::i1, NumElements: ElemCount);
446
447 // Broadcast(Data) != Broadcast(0)
448 // TODO: Use a VVP operation for this.
449 return DAG.getSetCC(DL, VT: BoolVecTy, LHS: BCVec, RHS: ZeroVec, Cond: ISD::CondCode::SETNE);
450}
451
452SDValue VECustomDAG::getBroadcast(EVT ResultVT, SDValue Scalar,
453 SDValue AVL) const {
454 assert(ResultVT.isVector());
455 auto ScaVT = Scalar.getValueType();
456
457 if (isMaskType(SomeVT: ResultVT))
458 return getMaskBroadcast(ResultVT, Scalar, AVL);
459
460 if (isPackedVectorType(SomeVT: ResultVT)) {
461 // v512x packed mode broadcast
462 // Replicate the scalar reg (f32 or i32) onto the opposing half of the full
463 // scalar register. If it's an I64 type, assume that this has already
464 // happened.
465 if (ScaVT == MVT::f32) {
466 Scalar = getNode(OC: VEISD::REPL_F32, ResVT: MVT::i64, OpV: Scalar);
467 } else if (ScaVT == MVT::i32) {
468 Scalar = getNode(OC: VEISD::REPL_I32, ResVT: MVT::i64, OpV: Scalar);
469 }
470 }
471
472 return getNode(OC: VEISD::VEC_BROADCAST, ResVT: ResultVT, OpV: {Scalar, AVL});
473}
474
475SDValue VECustomDAG::annotateLegalAVL(SDValue AVL) const {
476 if (isLegalAVL(AVL))
477 return AVL;
478 return getNode(OC: VEISD::LEGALAVL, ResVT: AVL.getValueType(), OpV: AVL);
479}
480
481SDValue VECustomDAG::getUnpack(EVT DestVT, SDValue Vec, PackElem Part,
482 SDValue AVL) const {
483 assert(getAnnotatedNodeAVL(AVL).second && "Expected a pack-legalized AVL");
484
485 // TODO: Peek through VEC_PACK and VEC_BROADCAST(REPL_<sth> ..) operands.
486 unsigned OC =
487 (Part == PackElem::Lo) ? VEISD::VEC_UNPACK_LO : VEISD::VEC_UNPACK_HI;
488 return DAG.getNode(Opcode: OC, DL, VT: DestVT, N1: Vec, N2: AVL);
489}
490
491SDValue VECustomDAG::getPack(EVT DestVT, SDValue LoVec, SDValue HiVec,
492 SDValue AVL) const {
493 assert(getAnnotatedNodeAVL(AVL).second && "Expected a pack-legalized AVL");
494
495 // TODO: Peek through VEC_UNPACK_LO|HI operands.
496 return DAG.getNode(Opcode: VEISD::VEC_PACK, DL, VT: DestVT, N1: LoVec, N2: HiVec, N3: AVL);
497}
498
499VETargetMasks VECustomDAG::getTargetSplitMask(SDValue RawMask, SDValue RawAVL,
500 PackElem Part) const {
501 // Adjust AVL for this part
502 SDValue NewAVL;
503 SDValue OneV = getConstant(Val: 1, VT: MVT::i32);
504 if (Part == PackElem::Hi)
505 NewAVL = getNode(OC: ISD::ADD, ResVT: MVT::i32, OpV: {RawAVL, OneV});
506 else
507 NewAVL = RawAVL;
508 NewAVL = getNode(OC: ISD::SRL, ResVT: MVT::i32, OpV: {NewAVL, OneV});
509
510 NewAVL = annotateLegalAVL(AVL: NewAVL);
511
512 // Legalize Mask (unpack or all-true)
513 SDValue NewMask;
514 if (!RawMask)
515 NewMask = getConstantMask(Packing: Packing::Normal, AllTrue: true);
516 else
517 NewMask = getUnpack(DestVT: MVT::v256i1, Vec: RawMask, Part, AVL: NewAVL);
518
519 return VETargetMasks(NewMask, NewAVL);
520}
521
522SDValue VECustomDAG::getSplitPtrOffset(SDValue Ptr, SDValue ByteStride,
523 PackElem Part) const {
524 // High starts at base ptr but has more significant bits in the 64bit vector
525 // element.
526 if (Part == PackElem::Hi)
527 return Ptr;
528 return getNode(OC: ISD::ADD, ResVT: MVT::i64, OpV: {Ptr, ByteStride});
529}
530
531SDValue VECustomDAG::getSplitPtrStride(SDValue PackStride) const {
532 if (auto ConstBytes = dyn_cast<ConstantSDNode>(Val&: PackStride))
533 return getConstant(Val: 2 * ConstBytes->getSExtValue(), VT: MVT::i64);
534 return getNode(OC: ISD::SHL, ResVT: MVT::i64, OpV: {PackStride, getConstant(Val: 1, VT: MVT::i32)});
535}
536
537SDValue VECustomDAG::getGatherScatterAddress(SDValue BasePtr, SDValue Scale,
538 SDValue Index, SDValue Mask,
539 SDValue AVL) const {
540 EVT IndexVT = Index.getValueType();
541
542 // Apply scale.
543 SDValue ScaledIndex;
544 if (!Scale || isOneConstant(V: Scale))
545 ScaledIndex = Index;
546 else {
547 SDValue ScaleBroadcast = getBroadcast(ResultVT: IndexVT, Scalar: Scale, AVL);
548 ScaledIndex =
549 getNode(OC: VEISD::VVP_MUL, ResVT: IndexVT, OpV: {Index, ScaleBroadcast, Mask, AVL});
550 }
551
552 // Add basePtr.
553 if (isNullConstant(V: BasePtr))
554 return ScaledIndex;
555
556 // re-constitute pointer vector (basePtr + index * scale)
557 SDValue BaseBroadcast = getBroadcast(ResultVT: IndexVT, Scalar: BasePtr, AVL);
558 auto ResPtr =
559 getNode(OC: VEISD::VVP_ADD, ResVT: IndexVT, OpV: {BaseBroadcast, ScaledIndex, Mask, AVL});
560 return ResPtr;
561}
562
563SDValue VECustomDAG::getLegalReductionOpVVP(unsigned VVPOpcode, EVT ResVT,
564 SDValue StartV, SDValue VectorV,
565 SDValue Mask, SDValue AVL,
566 SDNodeFlags Flags) const {
567
568 // Optionally attach the start param with a scalar op (where it is
569 // unsupported).
570 bool scalarizeStartParam = StartV && !hasReductionStartParam(OPC: VVPOpcode);
571 bool IsMaskReduction = isMaskType(SomeVT: VectorV.getValueType());
572 assert(!IsMaskReduction && "TODO Implement");
573 auto AttachStartValue = [&](SDValue ReductionResV) {
574 if (!scalarizeStartParam)
575 return ReductionResV;
576 auto ScalarOC = getScalarReductionOpcode(VVPOC: VVPOpcode, IsMask: IsMaskReduction);
577 return getNode(OC: ScalarOC, ResVT, OpV: {StartV, ReductionResV});
578 };
579
580 // Fixup: Always Use sequential 'fmul' reduction.
581 if (!scalarizeStartParam && StartV) {
582 assert(hasReductionStartParam(VVPOpcode));
583 return AttachStartValue(
584 getNode(OC: VVPOpcode, ResVT, OpV: {StartV, VectorV, Mask, AVL}, Flags));
585 } else
586 return AttachStartValue(
587 getNode(OC: VVPOpcode, ResVT, OpV: {VectorV, Mask, AVL}, Flags));
588}
589
590} // namespace llvm
591