1//===-- VECustomDAG.h - VE Custom DAG Nodes ------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the interfaces that VE uses to lower LLVM code into a
10// selection DAG.
11//
12//===----------------------------------------------------------------------===//
13
14#include "VECustomDAG.h"
15
16#ifndef DEBUG_TYPE
17#define DEBUG_TYPE "vecustomdag"
18#endif
19
20namespace llvm {
21
22bool isPackedVectorType(EVT SomeVT) {
23 if (!SomeVT.isVector())
24 return false;
25 return SomeVT.getVectorNumElements() > StandardVectorWidth;
26}
27
28MVT splitVectorType(MVT VT) {
29 if (!VT.isVector())
30 return VT;
31 return MVT::getVectorVT(VT: VT.getVectorElementType(), NumElements: StandardVectorWidth);
32}
33
34MVT getLegalVectorType(Packing P, MVT ElemVT) {
35 return MVT::getVectorVT(VT: ElemVT, NumElements: P == Packing::Normal ? StandardVectorWidth
36 : PackedVectorWidth);
37}
38
39Packing getTypePacking(EVT VT) {
40 assert(VT.isVector());
41 return isPackedVectorType(SomeVT: VT) ? Packing::Dense : Packing::Normal;
42}
43
44bool isMaskType(EVT SomeVT) {
45 if (!SomeVT.isVector())
46 return false;
47 return SomeVT.getVectorElementType() == MVT::i1;
48}
49
50bool isMaskArithmetic(SDValue Op) {
51 switch (Op.getOpcode()) {
52 default:
53 return false;
54 case ISD::AND:
55 case ISD::XOR:
56 case ISD::OR:
57 return isMaskType(SomeVT: Op.getValueType());
58 }
59}
60
61/// \returns the VVP_* SDNode opcode corresponsing to \p OC.
62std::optional<unsigned> getVVPOpcode(unsigned Opcode) {
63 switch (Opcode) {
64 case ISD::MLOAD:
65 return VEISD::VVP_LOAD;
66 case ISD::MSTORE:
67 return VEISD::VVP_STORE;
68#define HANDLE_VP_TO_VVP(VPOPC, VVPNAME) \
69 case ISD::VPOPC: \
70 return VEISD::VVPNAME;
71#define ADD_VVP_OP(VVPNAME, SDNAME) \
72 case VEISD::VVPNAME: \
73 case ISD::SDNAME: \
74 return VEISD::VVPNAME;
75#include "VVPNodes.def"
76 // TODO: Map those in VVPNodes.def too
77 case ISD::EXPERIMENTAL_VP_STRIDED_LOAD:
78 return VEISD::VVP_LOAD;
79 case ISD::EXPERIMENTAL_VP_STRIDED_STORE:
80 return VEISD::VVP_STORE;
81 }
82 return std::nullopt;
83}
84
85bool maySafelyIgnoreMask(SDValue Op) {
86 auto VVPOpc = getVVPOpcode(Opcode: Op->getOpcode());
87 auto Opc = VVPOpc.value_or(u: Op->getOpcode());
88
89 switch (Opc) {
90 case VEISD::VVP_SDIV:
91 case VEISD::VVP_UDIV:
92 case VEISD::VVP_FDIV:
93 case VEISD::VVP_SELECT:
94 return false;
95
96 default:
97 return true;
98 }
99}
100
101bool supportsPackedMode(unsigned Opcode, EVT IdiomVT) {
102 bool IsPackedOp = isPackedVectorType(SomeVT: IdiomVT);
103 bool IsMaskOp = isMaskType(SomeVT: IdiomVT);
104 switch (Opcode) {
105 default:
106 return false;
107
108 case VEISD::VEC_BROADCAST:
109 return true;
110#define REGISTER_PACKED(VVP_NAME) case VEISD::VVP_NAME:
111#include "VVPNodes.def"
112 return IsPackedOp && !IsMaskOp;
113 }
114}
115
116bool isPackingSupportOpcode(unsigned Opc) {
117 switch (Opc) {
118 case VEISD::VEC_PACK:
119 case VEISD::VEC_UNPACK_LO:
120 case VEISD::VEC_UNPACK_HI:
121 return true;
122 }
123 return false;
124}
125
126bool isVVPOrVEC(unsigned Opcode) {
127 switch (Opcode) {
128 case VEISD::VEC_BROADCAST:
129#define ADD_VVP_OP(VVPNAME, ...) case VEISD::VVPNAME:
130#include "VVPNodes.def"
131 return true;
132 }
133 return false;
134}
135
136bool isVVPUnaryOp(unsigned VVPOpcode) {
137 switch (VVPOpcode) {
138#define ADD_UNARY_VVP_OP(VVPNAME, ...) \
139 case VEISD::VVPNAME: \
140 return true;
141#include "VVPNodes.def"
142 }
143 return false;
144}
145
146bool isVVPBinaryOp(unsigned VVPOpcode) {
147 switch (VVPOpcode) {
148#define ADD_BINARY_VVP_OP(VVPNAME, ...) \
149 case VEISD::VVPNAME: \
150 return true;
151#include "VVPNodes.def"
152 }
153 return false;
154}
155
156bool isVVPReductionOp(unsigned Opcode) {
157 switch (Opcode) {
158#define ADD_REDUCE_VVP_OP(VVP_NAME, SDNAME) case VEISD::VVP_NAME:
159#include "VVPNodes.def"
160 return true;
161 }
162 return false;
163}
164
165// Return the AVL operand position for this VVP or VEC Op.
166std::optional<int> getAVLPos(unsigned Opc) {
167 // This is only available for VP SDNodes
168 auto PosOpt = ISD::getVPExplicitVectorLengthIdx(Opcode: Opc);
169 if (PosOpt)
170 return *PosOpt;
171
172 // VVP Opcodes.
173 if (isVVPBinaryOp(VVPOpcode: Opc))
174 return 3;
175
176 // VM Opcodes.
177 switch (Opc) {
178 case VEISD::VEC_BROADCAST:
179 return 1;
180 case VEISD::VVP_SELECT:
181 return 3;
182 case VEISD::VVP_LOAD:
183 return 4;
184 case VEISD::VVP_STORE:
185 return 5;
186 }
187
188 return std::nullopt;
189}
190
191std::optional<int> getMaskPos(unsigned Opc) {
192 // This is only available for VP SDNodes
193 auto PosOpt = ISD::getVPMaskIdx(Opcode: Opc);
194 if (PosOpt)
195 return *PosOpt;
196
197 // VVP Opcodes.
198 if (isVVPBinaryOp(VVPOpcode: Opc))
199 return 2;
200
201 // Other opcodes.
202 switch (Opc) {
203 case ISD::MSTORE:
204 return 4;
205 case ISD::MLOAD:
206 return 3;
207 case VEISD::VVP_SELECT:
208 return 2;
209 }
210
211 return std::nullopt;
212}
213
214bool isLegalAVL(SDValue AVL) { return AVL->getOpcode() == VEISD::LEGALAVL; }
215
216/// Node Properties {
217
218SDValue getNodeChain(SDValue Op) {
219 if (MemSDNode *MemN = dyn_cast<MemSDNode>(Val: Op.getNode()))
220 return MemN->getChain();
221
222 switch (Op->getOpcode()) {
223 case VEISD::VVP_LOAD:
224 case VEISD::VVP_STORE:
225 return Op->getOperand(Num: 0);
226 }
227 return SDValue();
228}
229
230SDValue getMemoryPtr(SDValue Op) {
231 if (auto *MemN = dyn_cast<MemSDNode>(Val: Op.getNode()))
232 return MemN->getBasePtr();
233
234 switch (Op->getOpcode()) {
235 case VEISD::VVP_LOAD:
236 return Op->getOperand(Num: 1);
237 case VEISD::VVP_STORE:
238 return Op->getOperand(Num: 2);
239 }
240 return SDValue();
241}
242
243std::optional<EVT> getIdiomaticVectorType(SDNode *Op) {
244 unsigned OC = Op->getOpcode();
245
246 // For memory ops -> the transfered data type
247 if (auto MemN = dyn_cast<MemSDNode>(Val: Op))
248 return MemN->getMemoryVT();
249
250 switch (OC) {
251 // Standard ISD.
252 case ISD::SELECT: // not aliased with VVP_SELECT
253 case ISD::CONCAT_VECTORS:
254 case ISD::EXTRACT_SUBVECTOR:
255 case ISD::VECTOR_SHUFFLE:
256 case ISD::BUILD_VECTOR:
257 case ISD::SCALAR_TO_VECTOR:
258 return Op->getValueType(ResNo: 0);
259 }
260
261 // Translate to VVP where possible.
262 unsigned OriginalOC = OC;
263 if (auto VVPOpc = getVVPOpcode(Opcode: OC))
264 OC = *VVPOpc;
265
266 if (isVVPReductionOp(Opcode: OC))
267 return Op->getOperand(Num: hasReductionStartParam(VVPOC: OriginalOC) ? 1 : 0)
268 .getValueType();
269
270 switch (OC) {
271 default:
272 case VEISD::VVP_SETCC:
273 return Op->getOperand(Num: 0).getValueType();
274
275 case VEISD::VVP_SELECT:
276#define ADD_BINARY_VVP_OP(VVP_NAME, ...) case VEISD::VVP_NAME:
277#include "VVPNodes.def"
278 return Op->getValueType(ResNo: 0);
279
280 case VEISD::VVP_LOAD:
281 return Op->getValueType(ResNo: 0);
282
283 case VEISD::VVP_STORE:
284 return Op->getOperand(Num: 1)->getValueType(ResNo: 0);
285
286 // VEC
287 case VEISD::VEC_BROADCAST:
288 return Op->getValueType(ResNo: 0);
289 }
290}
291
292SDValue getLoadStoreStride(SDValue Op, VECustomDAG &CDAG) {
293 switch (Op->getOpcode()) {
294 case VEISD::VVP_STORE:
295 return Op->getOperand(Num: 3);
296 case VEISD::VVP_LOAD:
297 return Op->getOperand(Num: 2);
298 }
299
300 if (auto *StoreN = dyn_cast<VPStridedStoreSDNode>(Val: Op.getNode()))
301 return StoreN->getStride();
302 if (auto *StoreN = dyn_cast<VPStridedLoadSDNode>(Val: Op.getNode()))
303 return StoreN->getStride();
304
305 if (isa<MemSDNode>(Val: Op.getNode())) {
306 // Regular MLOAD/MSTORE/LOAD/STORE
307 // No stride argument -> use the contiguous element size as stride.
308 uint64_t ElemStride = getIdiomaticVectorType(Op: Op.getNode())
309 ->getVectorElementType()
310 .getStoreSize();
311 return CDAG.getConstant(Val: ElemStride, VT: MVT::i64);
312 }
313 return SDValue();
314}
315
316SDValue getGatherScatterIndex(SDValue Op) {
317 if (auto *N = dyn_cast<MaskedGatherScatterSDNode>(Val: Op.getNode()))
318 return N->getIndex();
319 if (auto *N = dyn_cast<VPGatherScatterSDNode>(Val: Op.getNode()))
320 return N->getIndex();
321 return SDValue();
322}
323
324SDValue getGatherScatterScale(SDValue Op) {
325 if (auto *N = dyn_cast<MaskedGatherScatterSDNode>(Val: Op.getNode()))
326 return N->getScale();
327 if (auto *N = dyn_cast<VPGatherScatterSDNode>(Val: Op.getNode()))
328 return N->getScale();
329 return SDValue();
330}
331
332SDValue getStoredValue(SDValue Op) {
333 switch (Op->getOpcode()) {
334 case ISD::EXPERIMENTAL_VP_STRIDED_STORE:
335 case VEISD::VVP_STORE:
336 return Op->getOperand(Num: 1);
337 }
338 if (auto *StoreN = dyn_cast<StoreSDNode>(Val: Op.getNode()))
339 return StoreN->getValue();
340 if (auto *StoreN = dyn_cast<MaskedStoreSDNode>(Val: Op.getNode()))
341 return StoreN->getValue();
342 if (auto *StoreN = dyn_cast<VPStridedStoreSDNode>(Val: Op.getNode()))
343 return StoreN->getValue();
344 if (auto *StoreN = dyn_cast<VPStoreSDNode>(Val: Op.getNode()))
345 return StoreN->getValue();
346 if (auto *StoreN = dyn_cast<MaskedScatterSDNode>(Val: Op.getNode()))
347 return StoreN->getValue();
348 if (auto *StoreN = dyn_cast<VPScatterSDNode>(Val: Op.getNode()))
349 return StoreN->getValue();
350 return SDValue();
351}
352
353SDValue getNodePassthru(SDValue Op) {
354 if (auto *N = dyn_cast<MaskedLoadSDNode>(Val: Op.getNode()))
355 return N->getPassThru();
356 if (auto *N = dyn_cast<MaskedGatherSDNode>(Val: Op.getNode()))
357 return N->getPassThru();
358
359 return SDValue();
360}
361
362bool hasReductionStartParam(unsigned OPC) {
363 // TODO: Ordered reduction opcodes.
364 if (ISD::isVPReduction(Opcode: OPC))
365 return true;
366 return false;
367}
368
369unsigned getScalarReductionOpcode(unsigned VVPOC, bool IsMask) {
370 assert(!IsMask && "Mask reduction isel");
371
372 switch (VVPOC) {
373#define HANDLE_VVP_REDUCE_TO_SCALAR(VVP_RED_ISD, REDUCE_ISD) \
374 case VEISD::VVP_RED_ISD: \
375 return ISD::REDUCE_ISD;
376#include "VVPNodes.def"
377 default:
378 break;
379 }
380 llvm_unreachable("Cannot not scalarize this reduction Opcode!");
381}
382
383/// } Node Properties
384
385SDValue getNodeAVL(SDValue Op) {
386 auto PosOpt = getAVLPos(Opc: Op->getOpcode());
387 return PosOpt ? Op->getOperand(Num: *PosOpt) : SDValue();
388}
389
390SDValue getNodeMask(SDValue Op) {
391 auto PosOpt = getMaskPos(Opc: Op->getOpcode());
392 return PosOpt ? Op->getOperand(Num: *PosOpt) : SDValue();
393}
394
395std::pair<SDValue, bool> getAnnotatedNodeAVL(SDValue Op) {
396 SDValue AVL = getNodeAVL(Op);
397 if (!AVL)
398 return {SDValue(), true};
399 if (isLegalAVL(AVL))
400 return {AVL->getOperand(Num: 0), true};
401 return {AVL, false};
402}
403
404SDValue VECustomDAG::getConstant(uint64_t Val, EVT VT, bool IsTarget,
405 bool IsOpaque) const {
406 return DAG.getConstant(Val, DL, VT, isTarget: IsTarget, isOpaque: IsOpaque);
407}
408
409SDValue VECustomDAG::getConstantMask(Packing Packing, bool AllTrue) const {
410 auto MaskVT = getLegalVectorType(P: Packing, ElemVT: MVT::i1);
411
412 // VEISelDAGtoDAG will replace this pattern with the constant-true VM.
413 auto TrueVal = DAG.getConstant(Val: -1, DL, VT: MVT::i32);
414 auto AVL = getConstant(Val: MaskVT.getVectorNumElements(), VT: MVT::i32);
415 auto Res = getNode(OC: VEISD::VEC_BROADCAST, ResVT: MaskVT, OpV: {TrueVal, AVL});
416 if (AllTrue)
417 return Res;
418
419 return DAG.getNOT(DL, Val: Res, VT: Res.getValueType());
420}
421
422SDValue VECustomDAG::getMaskBroadcast(EVT ResultVT, SDValue Scalar,
423 SDValue AVL) const {
424 // Constant mask splat.
425 if (auto BcConst = dyn_cast<ConstantSDNode>(Val&: Scalar))
426 return getConstantMask(Packing: getTypePacking(VT: ResultVT),
427 AllTrue: BcConst->getSExtValue() != 0);
428
429 // Expand the broadcast to a vector comparison.
430 auto ScalarBoolVT = Scalar.getSimpleValueType();
431 assert(ScalarBoolVT == MVT::i32);
432
433 // Cast to i32 ty.
434 SDValue CmpElem = DAG.getSExtOrTrunc(Op: Scalar, DL, VT: MVT::i32);
435 unsigned ElemCount = ResultVT.getVectorNumElements();
436 MVT CmpVecTy = MVT::getVectorVT(VT: ScalarBoolVT, NumElements: ElemCount);
437
438 // Broadcast to vector.
439 SDValue BCVec =
440 DAG.getNode(Opcode: VEISD::VEC_BROADCAST, DL, VT: CmpVecTy, Ops: {CmpElem, AVL});
441 SDValue ZeroVec =
442 getBroadcast(ResultVT: CmpVecTy, Scalar: {DAG.getConstant(Val: 0, DL, VT: ScalarBoolVT)}, AVL);
443
444 MVT BoolVecTy = MVT::getVectorVT(VT: MVT::i1, NumElements: ElemCount);
445
446 // Broadcast(Data) != Broadcast(0)
447 // TODO: Use a VVP operation for this.
448 return DAG.getSetCC(DL, VT: BoolVecTy, LHS: BCVec, RHS: ZeroVec, Cond: ISD::CondCode::SETNE);
449}
450
451SDValue VECustomDAG::getBroadcast(EVT ResultVT, SDValue Scalar,
452 SDValue AVL) const {
453 assert(ResultVT.isVector());
454 auto ScaVT = Scalar.getValueType();
455
456 if (isMaskType(SomeVT: ResultVT))
457 return getMaskBroadcast(ResultVT, Scalar, AVL);
458
459 if (isPackedVectorType(SomeVT: ResultVT)) {
460 // v512x packed mode broadcast
461 // Replicate the scalar reg (f32 or i32) onto the opposing half of the full
462 // scalar register. If it's an I64 type, assume that this has already
463 // happened.
464 if (ScaVT == MVT::f32) {
465 Scalar = getNode(OC: VEISD::REPL_F32, ResVT: MVT::i64, OpV: Scalar);
466 } else if (ScaVT == MVT::i32) {
467 Scalar = getNode(OC: VEISD::REPL_I32, ResVT: MVT::i64, OpV: Scalar);
468 }
469 }
470
471 return getNode(OC: VEISD::VEC_BROADCAST, ResVT: ResultVT, OpV: {Scalar, AVL});
472}
473
474SDValue VECustomDAG::annotateLegalAVL(SDValue AVL) const {
475 if (isLegalAVL(AVL))
476 return AVL;
477 return getNode(OC: VEISD::LEGALAVL, ResVT: AVL.getValueType(), OpV: AVL);
478}
479
480SDValue VECustomDAG::getUnpack(EVT DestVT, SDValue Vec, PackElem Part,
481 SDValue AVL) const {
482 assert(getAnnotatedNodeAVL(AVL).second && "Expected a pack-legalized AVL");
483
484 // TODO: Peek through VEC_PACK and VEC_BROADCAST(REPL_<sth> ..) operands.
485 unsigned OC =
486 (Part == PackElem::Lo) ? VEISD::VEC_UNPACK_LO : VEISD::VEC_UNPACK_HI;
487 return DAG.getNode(Opcode: OC, DL, VT: DestVT, N1: Vec, N2: AVL);
488}
489
490SDValue VECustomDAG::getPack(EVT DestVT, SDValue LoVec, SDValue HiVec,
491 SDValue AVL) const {
492 assert(getAnnotatedNodeAVL(AVL).second && "Expected a pack-legalized AVL");
493
494 // TODO: Peek through VEC_UNPACK_LO|HI operands.
495 return DAG.getNode(Opcode: VEISD::VEC_PACK, DL, VT: DestVT, N1: LoVec, N2: HiVec, N3: AVL);
496}
497
498VETargetMasks VECustomDAG::getTargetSplitMask(SDValue RawMask, SDValue RawAVL,
499 PackElem Part) const {
500 // Adjust AVL for this part
501 SDValue NewAVL;
502 SDValue OneV = getConstant(Val: 1, VT: MVT::i32);
503 if (Part == PackElem::Hi)
504 NewAVL = getNode(OC: ISD::ADD, ResVT: MVT::i32, OpV: {RawAVL, OneV});
505 else
506 NewAVL = RawAVL;
507 NewAVL = getNode(OC: ISD::SRL, ResVT: MVT::i32, OpV: {NewAVL, OneV});
508
509 NewAVL = annotateLegalAVL(AVL: NewAVL);
510
511 // Legalize Mask (unpack or all-true)
512 SDValue NewMask;
513 if (!RawMask)
514 NewMask = getConstantMask(Packing: Packing::Normal, AllTrue: true);
515 else
516 NewMask = getUnpack(DestVT: MVT::v256i1, Vec: RawMask, Part, AVL: NewAVL);
517
518 return VETargetMasks(NewMask, NewAVL);
519}
520
521SDValue VECustomDAG::getSplitPtrOffset(SDValue Ptr, SDValue ByteStride,
522 PackElem Part) const {
523 // High starts at base ptr but has more significant bits in the 64bit vector
524 // element.
525 if (Part == PackElem::Hi)
526 return Ptr;
527 return getNode(OC: ISD::ADD, ResVT: MVT::i64, OpV: {Ptr, ByteStride});
528}
529
530SDValue VECustomDAG::getSplitPtrStride(SDValue PackStride) const {
531 if (auto ConstBytes = dyn_cast<ConstantSDNode>(Val&: PackStride))
532 return getConstant(Val: 2 * ConstBytes->getSExtValue(), VT: MVT::i64);
533 return getNode(OC: ISD::SHL, ResVT: MVT::i64, OpV: {PackStride, getConstant(Val: 1, VT: MVT::i32)});
534}
535
536SDValue VECustomDAG::getGatherScatterAddress(SDValue BasePtr, SDValue Scale,
537 SDValue Index, SDValue Mask,
538 SDValue AVL) const {
539 EVT IndexVT = Index.getValueType();
540
541 // Apply scale.
542 SDValue ScaledIndex;
543 if (!Scale || isOneConstant(V: Scale))
544 ScaledIndex = Index;
545 else {
546 SDValue ScaleBroadcast = getBroadcast(ResultVT: IndexVT, Scalar: Scale, AVL);
547 ScaledIndex =
548 getNode(OC: VEISD::VVP_MUL, ResVT: IndexVT, OpV: {Index, ScaleBroadcast, Mask, AVL});
549 }
550
551 // Add basePtr.
552 if (isNullConstant(V: BasePtr))
553 return ScaledIndex;
554
555 // re-constitute pointer vector (basePtr + index * scale)
556 SDValue BaseBroadcast = getBroadcast(ResultVT: IndexVT, Scalar: BasePtr, AVL);
557 auto ResPtr =
558 getNode(OC: VEISD::VVP_ADD, ResVT: IndexVT, OpV: {BaseBroadcast, ScaledIndex, Mask, AVL});
559 return ResPtr;
560}
561
562SDValue VECustomDAG::getLegalReductionOpVVP(unsigned VVPOpcode, EVT ResVT,
563 SDValue StartV, SDValue VectorV,
564 SDValue Mask, SDValue AVL,
565 SDNodeFlags Flags) const {
566
567 // Optionally attach the start param with a scalar op (where it is
568 // unsupported).
569 bool scalarizeStartParam = StartV && !hasReductionStartParam(OPC: VVPOpcode);
570 bool IsMaskReduction = isMaskType(SomeVT: VectorV.getValueType());
571 assert(!IsMaskReduction && "TODO Implement");
572 auto AttachStartValue = [&](SDValue ReductionResV) {
573 if (!scalarizeStartParam)
574 return ReductionResV;
575 auto ScalarOC = getScalarReductionOpcode(VVPOC: VVPOpcode, IsMask: IsMaskReduction);
576 return getNode(OC: ScalarOC, ResVT, OpV: {StartV, ReductionResV});
577 };
578
579 // Fixup: Always Use sequential 'fmul' reduction.
580 if (!scalarizeStartParam && StartV) {
581 assert(hasReductionStartParam(VVPOpcode));
582 return AttachStartValue(
583 getNode(OC: VVPOpcode, ResVT, OpV: {StartV, VectorV, Mask, AVL}, Flags));
584 } else
585 return AttachStartValue(
586 getNode(OC: VVPOpcode, ResVT, OpV: {VectorV, Mask, AVL}, Flags));
587}
588
589} // namespace llvm
590