1//===-- VVPISelLowering.cpp - VE DAG Lowering Implementation --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the lowering and legalization of vector instructions to
10// VVP_*layer SDNodes.
11//
12//===----------------------------------------------------------------------===//
13
14#include "VECustomDAG.h"
15#include "VEISelLowering.h"
16#include "VESelectionDAGInfo.h"
17
18using namespace llvm;
19
20#define DEBUG_TYPE "ve-lower"
21
22SDValue VETargetLowering::splitMaskArithmetic(SDValue Op,
23 SelectionDAG &DAG) const {
24 VECustomDAG CDAG(DAG, Op);
25 SDValue AVL =
26 CDAG.getConstant(Val: Op.getValueType().getVectorNumElements(), VT: MVT::i32);
27 SDValue A = Op->getOperand(Num: 0);
28 SDValue B = Op->getOperand(Num: 1);
29 SDValue LoA = CDAG.getUnpack(DestVT: MVT::v256i1, Vec: A, Part: PackElem::Lo, AVL);
30 SDValue HiA = CDAG.getUnpack(DestVT: MVT::v256i1, Vec: A, Part: PackElem::Hi, AVL);
31 SDValue LoB = CDAG.getUnpack(DestVT: MVT::v256i1, Vec: B, Part: PackElem::Lo, AVL);
32 SDValue HiB = CDAG.getUnpack(DestVT: MVT::v256i1, Vec: B, Part: PackElem::Hi, AVL);
33 unsigned Opc = Op.getOpcode();
34 auto LoRes = CDAG.getNode(OC: Opc, ResVT: MVT::v256i1, OpV: {LoA, LoB});
35 auto HiRes = CDAG.getNode(OC: Opc, ResVT: MVT::v256i1, OpV: {HiA, HiB});
36 return CDAG.getPack(DestVT: MVT::v512i1, LoVec: LoRes, HiVec: HiRes, AVL);
37}
38
39SDValue VETargetLowering::lowerToVVP(SDValue Op, SelectionDAG &DAG) const {
40 // Can we represent this as a VVP node.
41 const unsigned Opcode = Op->getOpcode();
42 auto VVPOpcodeOpt = getVVPOpcode(Opcode);
43 if (!VVPOpcodeOpt)
44 return SDValue();
45 unsigned VVPOpcode = *VVPOpcodeOpt;
46 const bool FromVP = ISD::isVPOpcode(Opcode);
47
48 // The representative and legalized vector type of this operation.
49 VECustomDAG CDAG(DAG, Op);
50 // Dispatch to complex lowering functions.
51 switch (VVPOpcode) {
52 case VEISD::VVP_LOAD:
53 case VEISD::VVP_STORE:
54 return lowerVVP_LOAD_STORE(Op, CDAG);
55 case VEISD::VVP_GATHER:
56 case VEISD::VVP_SCATTER:
57 return lowerVVP_GATHER_SCATTER(Op, CDAG);
58 }
59
60 EVT OpVecVT = *getIdiomaticVectorType(Op: Op.getNode());
61 EVT LegalVecVT = getTypeToTransformTo(Context&: *DAG.getContext(), VT: OpVecVT);
62 auto Packing = getTypePacking(LegalVecVT.getSimpleVT());
63
64 SDValue AVL;
65 SDValue Mask;
66
67 if (FromVP) {
68 // All upstream VP SDNodes always have a mask and avl.
69 auto MaskIdx = ISD::getVPMaskIdx(Opcode);
70 auto AVLIdx = ISD::getVPExplicitVectorLengthIdx(Opcode);
71 if (MaskIdx)
72 Mask = Op->getOperand(Num: *MaskIdx);
73 if (AVLIdx)
74 AVL = Op->getOperand(Num: *AVLIdx);
75 }
76
77 // Materialize default mask and avl.
78 if (!AVL)
79 AVL = CDAG.getConstant(Val: OpVecVT.getVectorNumElements(), VT: MVT::i32);
80 if (!Mask)
81 Mask = CDAG.getConstantMask(Packing, AllTrue: true);
82
83 assert(LegalVecVT.isSimple());
84 if (isVVPUnaryOp(Opcode: VVPOpcode))
85 return CDAG.getNode(OC: VVPOpcode, ResVT: LegalVecVT, OpV: {Op->getOperand(Num: 0), Mask, AVL});
86 if (isVVPBinaryOp(Opcode: VVPOpcode))
87 return CDAG.getNode(OC: VVPOpcode, ResVT: LegalVecVT,
88 OpV: {Op->getOperand(Num: 0), Op->getOperand(Num: 1), Mask, AVL});
89 if (isVVPReductionOp(Opcode: VVPOpcode)) {
90 auto SrcHasStart = hasReductionStartParam(VVPOC: Op->getOpcode());
91 SDValue StartV = SrcHasStart ? Op->getOperand(Num: 0) : SDValue();
92 SDValue VectorV = Op->getOperand(Num: SrcHasStart ? 1 : 0);
93 return CDAG.getLegalReductionOpVVP(VVPOpcode, ResVT: Op.getValueType(), StartV,
94 VectorV, Mask, AVL, Flags: Op->getFlags());
95 }
96
97 switch (VVPOpcode) {
98 default:
99 llvm_unreachable("lowerToVVP called for unexpected SDNode.");
100 case VEISD::VVP_FFMA: {
101 // VE has a swizzled operand order in FMA (compared to LLVM IR and
102 // SDNodes).
103 auto X = Op->getOperand(Num: 2);
104 auto Y = Op->getOperand(Num: 0);
105 auto Z = Op->getOperand(Num: 1);
106 return CDAG.getNode(OC: VVPOpcode, ResVT: LegalVecVT, OpV: {X, Y, Z, Mask, AVL});
107 }
108 case VEISD::VVP_SELECT: {
109 auto Mask = Op->getOperand(Num: 0);
110 auto OnTrue = Op->getOperand(Num: 1);
111 auto OnFalse = Op->getOperand(Num: 2);
112 return CDAG.getNode(OC: VVPOpcode, ResVT: LegalVecVT, OpV: {OnTrue, OnFalse, Mask, AVL});
113 }
114 case VEISD::VVP_SETCC: {
115 EVT LegalResVT = getTypeToTransformTo(Context&: *DAG.getContext(), VT: Op.getValueType());
116 auto LHS = Op->getOperand(Num: 0);
117 auto RHS = Op->getOperand(Num: 1);
118 auto Pred = Op->getOperand(Num: 2);
119 return CDAG.getNode(OC: VVPOpcode, ResVT: LegalResVT, OpV: {LHS, RHS, Pred, Mask, AVL});
120 }
121 }
122}
123
124SDValue VETargetLowering::lowerVVP_LOAD_STORE(SDValue Op,
125 VECustomDAG &CDAG) const {
126 auto VVPOpc = *getVVPOpcode(Opcode: Op->getOpcode());
127 const bool IsLoad = (VVPOpc == VEISD::VVP_LOAD);
128
129 // Shares.
130 SDValue BasePtr = getMemoryPtr(Op);
131 SDValue Mask = getNodeMask(Op);
132 SDValue Chain = getNodeChain(Op);
133 SDValue AVL = getNodeAVL(Op);
134 // Store specific.
135 SDValue Data = getStoredValue(Op);
136 // Load specific.
137 SDValue PassThru = getNodePassthru(Op);
138
139 SDValue StrideV = getLoadStoreStride(Op, CDAG);
140
141 auto DataVT = *getIdiomaticVectorType(Op: Op.getNode());
142 auto Packing = getTypePacking(DataVT);
143
144 // TODO: Infer lower AVL from mask.
145 if (!AVL)
146 AVL = CDAG.getConstant(Val: DataVT.getVectorNumElements(), VT: MVT::i32);
147
148 // Default to the all-true mask.
149 if (!Mask)
150 Mask = CDAG.getConstantMask(Packing, AllTrue: true);
151
152 if (IsLoad) {
153 MVT LegalDataVT = getLegalVectorType(
154 P: Packing, ElemVT: DataVT.getVectorElementType().getSimpleVT());
155
156 auto NewLoadV = CDAG.getNode(OC: VEISD::VVP_LOAD, ResVT: {LegalDataVT, MVT::Other},
157 OpV: {Chain, BasePtr, StrideV, Mask, AVL});
158
159 if (!PassThru || PassThru->isUndef())
160 return NewLoadV;
161
162 // Convert passthru to an explicit select node.
163 SDValue DataV = CDAG.getNode(OC: VEISD::VVP_SELECT, ResVT: DataVT,
164 OpV: {NewLoadV, PassThru, Mask, AVL});
165 SDValue NewLoadChainV = SDValue(NewLoadV.getNode(), 1);
166
167 // Merge them back into one node.
168 return CDAG.getMergeValues(Values: {DataV, NewLoadChainV});
169 }
170
171 // VVP_STORE
172 assert(VVPOpc == VEISD::VVP_STORE);
173 if (getTypeAction(Context&: *CDAG.getDAG()->getContext(), VT: Data.getValueType()) !=
174 TargetLowering::TypeLegal)
175 // Doesn't lower store instruction if an operand is not lowered yet.
176 // If it isn't, return SDValue(). In this way, LLVM will try to lower
177 // store instruction again after lowering all operands.
178 return SDValue();
179 return CDAG.getNode(OC: VEISD::VVP_STORE, VTL: Op.getNode()->getVTList(),
180 OpV: {Chain, Data, BasePtr, StrideV, Mask, AVL});
181}
182
183SDValue VETargetLowering::splitPackedLoadStore(SDValue Op,
184 VECustomDAG &CDAG) const {
185 auto VVPOC = *getVVPOpcode(Opcode: Op.getOpcode());
186 assert((VVPOC == VEISD::VVP_LOAD) || (VVPOC == VEISD::VVP_STORE));
187
188 MVT DataVT = getIdiomaticVectorType(Op: Op.getNode())->getSimpleVT();
189 assert(getTypePacking(DataVT) == Packing::Dense &&
190 "Can only split packed load/store");
191 MVT SplitDataVT = splitVectorType(VT: DataVT);
192
193 assert(!getNodePassthru(Op) &&
194 "Should have been folded in lowering to VVP layer");
195
196 // Analyze the operation
197 SDValue PackedMask = getNodeMask(Op);
198 SDValue PackedAVL = getAnnotatedNodeAVL(Op).first;
199 SDValue PackPtr = getMemoryPtr(Op);
200 SDValue PackData = getStoredValue(Op);
201 SDValue PackStride = getLoadStoreStride(Op, CDAG);
202
203 unsigned ChainResIdx = PackData ? 0 : 1;
204
205 SDValue PartOps[2];
206
207 SDValue UpperPartAVL; // we will use this for packing things back together
208 for (PackElem Part : {PackElem::Hi, PackElem::Lo}) {
209 // VP ops already have an explicit mask and AVL. When expanding from non-VP
210 // attach those additional inputs here.
211 auto SplitTM = CDAG.getTargetSplitMask(RawMask: PackedMask, RawAVL: PackedAVL, Part);
212
213 // Keep track of the (higher) lvl.
214 if (Part == PackElem::Hi)
215 UpperPartAVL = SplitTM.AVL;
216
217 // Attach non-predicating value operands
218 SmallVector<SDValue, 4> OpVec;
219
220 // Chain
221 OpVec.push_back(Elt: getNodeChain(Op));
222
223 // Data
224 if (PackData) {
225 SDValue PartData =
226 CDAG.getUnpack(DestVT: SplitDataVT, Vec: PackData, Part, AVL: SplitTM.AVL);
227 OpVec.push_back(Elt: PartData);
228 }
229
230 // Ptr & Stride
231 // Push (ptr + ElemBytes * <Part>, 2 * ElemBytes)
232 // Stride info
233 // EVT DataVT = LegalizeVectorType(getMemoryDataVT(Op), Op, DAG, Mode);
234 OpVec.push_back(Elt: CDAG.getSplitPtrOffset(Ptr: PackPtr, ByteStride: PackStride, Part));
235 OpVec.push_back(Elt: CDAG.getSplitPtrStride(PackStride));
236
237 // Add predicating args and generate part node
238 OpVec.push_back(Elt: SplitTM.Mask);
239 OpVec.push_back(Elt: SplitTM.AVL);
240
241 if (PackData) {
242 // Store
243 PartOps[(int)Part] = CDAG.getNode(OC: VVPOC, ResVT: MVT::Other, OpV: OpVec);
244 } else {
245 // Load
246 PartOps[(int)Part] =
247 CDAG.getNode(OC: VVPOC, ResVT: {SplitDataVT, MVT::Other}, OpV: OpVec);
248 }
249 }
250
251 // Merge the chains
252 SDValue LowChain = SDValue(PartOps[(int)PackElem::Lo].getNode(), ChainResIdx);
253 SDValue HiChain = SDValue(PartOps[(int)PackElem::Hi].getNode(), ChainResIdx);
254 SDValue FusedChains =
255 CDAG.getNode(OC: ISD::TokenFactor, ResVT: MVT::Other, OpV: {LowChain, HiChain});
256
257 // Chain only [store]
258 if (PackData)
259 return FusedChains;
260
261 // Re-pack into full packed vector result
262 MVT PackedVT =
263 getLegalVectorType(P: Packing::Dense, ElemVT: DataVT.getVectorElementType());
264 SDValue PackedVals = CDAG.getPack(DestVT: PackedVT, LoVec: PartOps[(int)PackElem::Lo],
265 HiVec: PartOps[(int)PackElem::Hi], AVL: UpperPartAVL);
266
267 return CDAG.getMergeValues(Values: {PackedVals, FusedChains});
268}
269
270SDValue VETargetLowering::lowerVVP_GATHER_SCATTER(SDValue Op,
271 VECustomDAG &CDAG) const {
272 EVT DataVT = *getIdiomaticVectorType(Op: Op.getNode());
273 auto Packing = getTypePacking(DataVT);
274 MVT LegalDataVT =
275 getLegalVectorType(P: Packing, ElemVT: DataVT.getVectorElementType().getSimpleVT());
276
277 SDValue AVL = getAnnotatedNodeAVL(Op).first;
278 SDValue Index = getGatherScatterIndex(Op);
279 SDValue BasePtr = getMemoryPtr(Op);
280 SDValue Mask = getNodeMask(Op);
281 SDValue Chain = getNodeChain(Op);
282 SDValue Scale = getGatherScatterScale(Op);
283 SDValue PassThru = getNodePassthru(Op);
284 SDValue StoredValue = getStoredValue(Op);
285 if (PassThru && PassThru->isUndef())
286 PassThru = SDValue();
287
288 bool IsScatter = (bool)StoredValue;
289
290 // TODO: Infer lower AVL from mask.
291 if (!AVL)
292 AVL = CDAG.getConstant(Val: DataVT.getVectorNumElements(), VT: MVT::i32);
293
294 // Default to the all-true mask.
295 if (!Mask)
296 Mask = CDAG.getConstantMask(Packing, AllTrue: true);
297
298 SDValue AddressVec =
299 CDAG.getGatherScatterAddress(BasePtr, Scale, Index, Mask, AVL);
300 if (IsScatter)
301 return CDAG.getNode(OC: VEISD::VVP_SCATTER, ResVT: MVT::Other,
302 OpV: {Chain, StoredValue, AddressVec, Mask, AVL});
303
304 // Gather.
305 SDValue NewLoadV = CDAG.getNode(OC: VEISD::VVP_GATHER, ResVT: {LegalDataVT, MVT::Other},
306 OpV: {Chain, AddressVec, Mask, AVL});
307
308 if (!PassThru)
309 return NewLoadV;
310
311 // TODO: Use vvp_select
312 SDValue DataV = CDAG.getNode(OC: VEISD::VVP_SELECT, ResVT: LegalDataVT,
313 OpV: {NewLoadV, PassThru, Mask, AVL});
314 SDValue NewLoadChainV = SDValue(NewLoadV.getNode(), 1);
315 return CDAG.getMergeValues(Values: {DataV, NewLoadChainV});
316}
317
318SDValue VETargetLowering::legalizeInternalLoadStoreOp(SDValue Op,
319 VECustomDAG &CDAG) const {
320 LLVM_DEBUG(dbgs() << "::legalizeInternalLoadStoreOp\n";);
321 MVT DataVT = getIdiomaticVectorType(Op: Op.getNode())->getSimpleVT();
322
323 // TODO: Recognize packable load,store.
324 if (isPackedVectorType(SomeVT: DataVT))
325 return splitPackedLoadStore(Op, CDAG);
326
327 return legalizePackedAVL(Op, CDAG);
328}
329
330SDValue VETargetLowering::legalizeInternalVectorOp(SDValue Op,
331 SelectionDAG &DAG) const {
332 LLVM_DEBUG(dbgs() << "::legalizeInternalVectorOp\n";);
333 VECustomDAG CDAG(DAG, Op);
334
335 // Dispatch to specialized legalization functions.
336 switch (Op->getOpcode()) {
337 case VEISD::VVP_LOAD:
338 case VEISD::VVP_STORE:
339 return legalizeInternalLoadStoreOp(Op, CDAG);
340 }
341
342 EVT IdiomVT = Op.getValueType();
343 if (isPackedVectorType(SomeVT: IdiomVT) &&
344 !supportsPackedMode(Opcode: Op.getOpcode(), IdiomVT))
345 return splitVectorOp(Op, CDAG);
346
347 // TODO: Implement odd/even splitting.
348 return legalizePackedAVL(Op, CDAG);
349}
350
351SDValue VETargetLowering::splitVectorOp(SDValue Op, VECustomDAG &CDAG) const {
352 MVT ResVT = splitVectorType(VT: Op.getValue(R: 0).getSimpleValueType());
353
354 auto AVLPos = getAVLPos(Op->getOpcode());
355 auto MaskPos = getMaskPos(Op->getOpcode());
356
357 SDValue PackedMask = getNodeMask(Op);
358 auto AVLPair = getAnnotatedNodeAVL(Op);
359 SDValue PackedAVL = AVLPair.first;
360 assert(!AVLPair.second && "Expecting non pack-legalized oepration");
361
362 // request the parts
363 SDValue PartOps[2];
364
365 SDValue UpperPartAVL; // we will use this for packing things back together
366 for (PackElem Part : {PackElem::Hi, PackElem::Lo}) {
367 // VP ops already have an explicit mask and AVL. When expanding from non-VP
368 // attach those additional inputs here.
369 auto SplitTM = CDAG.getTargetSplitMask(RawMask: PackedMask, RawAVL: PackedAVL, Part);
370
371 if (Part == PackElem::Hi)
372 UpperPartAVL = SplitTM.AVL;
373
374 // Attach non-predicating value operands
375 SmallVector<SDValue, 4> OpVec;
376 for (unsigned i = 0; i < Op.getNumOperands(); ++i) {
377 if (AVLPos && ((int)i) == *AVLPos)
378 continue;
379 if (MaskPos && ((int)i) == *MaskPos)
380 continue;
381
382 // Value operand
383 auto PackedOperand = Op.getOperand(i);
384 auto UnpackedOpVT = splitVectorType(VT: PackedOperand.getSimpleValueType());
385 SDValue PartV =
386 CDAG.getUnpack(DestVT: UnpackedOpVT, Vec: PackedOperand, Part, AVL: SplitTM.AVL);
387 OpVec.push_back(Elt: PartV);
388 }
389
390 // Add predicating args and generate part node.
391 OpVec.push_back(Elt: SplitTM.Mask);
392 OpVec.push_back(Elt: SplitTM.AVL);
393 // Emit legal VVP nodes.
394 PartOps[(int)Part] =
395 CDAG.getNode(OC: Op.getOpcode(), ResVT, OpV: OpVec, Flags: Op->getFlags());
396 }
397
398 // Re-package vectors.
399 return CDAG.getPack(DestVT: Op.getValueType(), LoVec: PartOps[(int)PackElem::Lo],
400 HiVec: PartOps[(int)PackElem::Hi], AVL: UpperPartAVL);
401}
402
403SDValue VETargetLowering::legalizePackedAVL(SDValue Op,
404 VECustomDAG &CDAG) const {
405 LLVM_DEBUG(dbgs() << "::legalizePackedAVL\n";);
406 // Only required for VEC and VVP ops.
407 if (!isVVPOrVEC(Op->getOpcode()))
408 return Op;
409
410 // Operation already has a legal AVL.
411 auto AVL = getNodeAVL(Op);
412 if (isLegalAVL(AVL))
413 return Op;
414
415 // Half and round up EVL for 32bit element types.
416 SDValue LegalAVL = AVL;
417 MVT IdiomVT = getIdiomaticVectorType(Op: Op.getNode())->getSimpleVT();
418 if (isPackedVectorType(SomeVT: IdiomVT)) {
419 assert(maySafelyIgnoreMask(Op) &&
420 "TODO Shift predication from EVL into Mask");
421
422 if (auto *ConstAVL = dyn_cast<ConstantSDNode>(Val&: AVL)) {
423 LegalAVL = CDAG.getConstant(Val: (ConstAVL->getZExtValue() + 1) / 2, VT: MVT::i32);
424 } else {
425 auto ConstOne = CDAG.getConstant(Val: 1, VT: MVT::i32);
426 auto PlusOne = CDAG.getNode(OC: ISD::ADD, ResVT: MVT::i32, OpV: {AVL, ConstOne});
427 LegalAVL = CDAG.getNode(OC: ISD::SRL, ResVT: MVT::i32, OpV: {PlusOne, ConstOne});
428 }
429 }
430
431 SDValue AnnotatedLegalAVL = CDAG.annotateLegalAVL(AVL: LegalAVL);
432
433 // Copy the operand list.
434 int NumOp = Op->getNumOperands();
435 auto AVLPos = getAVLPos(Op->getOpcode());
436 std::vector<SDValue> FixedOperands;
437 for (int i = 0; i < NumOp; ++i) {
438 if (AVLPos && (i == *AVLPos)) {
439 FixedOperands.push_back(x: AnnotatedLegalAVL);
440 continue;
441 }
442 FixedOperands.push_back(x: Op->getOperand(Num: i));
443 }
444
445 // Clone the operation with fixed operands.
446 auto Flags = Op->getFlags();
447 SDValue NewN =
448 CDAG.getNode(OC: Op->getOpcode(), VTL: Op->getVTList(), OpV: FixedOperands, Flags);
449 return NewN;
450}
451