1//===-- VVPISelLowering.cpp - VE DAG Lowering Implementation --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the lowering and legalization of vector instructions to
10// VVP_*layer SDNodes.
11//
12//===----------------------------------------------------------------------===//
13
14#include "VECustomDAG.h"
15#include "VEISelLowering.h"
16
17using namespace llvm;
18
19#define DEBUG_TYPE "ve-lower"
20
21SDValue VETargetLowering::splitMaskArithmetic(SDValue Op,
22 SelectionDAG &DAG) const {
23 VECustomDAG CDAG(DAG, Op);
24 SDValue AVL =
25 CDAG.getConstant(Val: Op.getValueType().getVectorNumElements(), VT: MVT::i32);
26 SDValue A = Op->getOperand(Num: 0);
27 SDValue B = Op->getOperand(Num: 1);
28 SDValue LoA = CDAG.getUnpack(DestVT: MVT::v256i1, Vec: A, Part: PackElem::Lo, AVL);
29 SDValue HiA = CDAG.getUnpack(DestVT: MVT::v256i1, Vec: A, Part: PackElem::Hi, AVL);
30 SDValue LoB = CDAG.getUnpack(DestVT: MVT::v256i1, Vec: B, Part: PackElem::Lo, AVL);
31 SDValue HiB = CDAG.getUnpack(DestVT: MVT::v256i1, Vec: B, Part: PackElem::Hi, AVL);
32 unsigned Opc = Op.getOpcode();
33 auto LoRes = CDAG.getNode(OC: Opc, ResVT: MVT::v256i1, OpV: {LoA, LoB});
34 auto HiRes = CDAG.getNode(OC: Opc, ResVT: MVT::v256i1, OpV: {HiA, HiB});
35 return CDAG.getPack(DestVT: MVT::v512i1, LoVec: LoRes, HiVec: HiRes, AVL);
36}
37
38SDValue VETargetLowering::lowerToVVP(SDValue Op, SelectionDAG &DAG) const {
39 // Can we represent this as a VVP node.
40 const unsigned Opcode = Op->getOpcode();
41 auto VVPOpcodeOpt = getVVPOpcode(Opcode);
42 if (!VVPOpcodeOpt)
43 return SDValue();
44 unsigned VVPOpcode = *VVPOpcodeOpt;
45 const bool FromVP = ISD::isVPOpcode(Opcode);
46
47 // The representative and legalized vector type of this operation.
48 VECustomDAG CDAG(DAG, Op);
49 // Dispatch to complex lowering functions.
50 switch (VVPOpcode) {
51 case VEISD::VVP_LOAD:
52 case VEISD::VVP_STORE:
53 return lowerVVP_LOAD_STORE(Op, CDAG);
54 case VEISD::VVP_GATHER:
55 case VEISD::VVP_SCATTER:
56 return lowerVVP_GATHER_SCATTER(Op, CDAG);
57 }
58
59 EVT OpVecVT = *getIdiomaticVectorType(Op: Op.getNode());
60 EVT LegalVecVT = getTypeToTransformTo(Context&: *DAG.getContext(), VT: OpVecVT);
61 auto Packing = getTypePacking(LegalVecVT.getSimpleVT());
62
63 SDValue AVL;
64 SDValue Mask;
65
66 if (FromVP) {
67 // All upstream VP SDNodes always have a mask and avl.
68 auto MaskIdx = ISD::getVPMaskIdx(Opcode);
69 auto AVLIdx = ISD::getVPExplicitVectorLengthIdx(Opcode);
70 if (MaskIdx)
71 Mask = Op->getOperand(Num: *MaskIdx);
72 if (AVLIdx)
73 AVL = Op->getOperand(Num: *AVLIdx);
74 }
75
76 // Materialize default mask and avl.
77 if (!AVL)
78 AVL = CDAG.getConstant(Val: OpVecVT.getVectorNumElements(), VT: MVT::i32);
79 if (!Mask)
80 Mask = CDAG.getConstantMask(Packing, AllTrue: true);
81
82 assert(LegalVecVT.isSimple());
83 if (isVVPUnaryOp(Opcode: VVPOpcode))
84 return CDAG.getNode(OC: VVPOpcode, ResVT: LegalVecVT, OpV: {Op->getOperand(Num: 0), Mask, AVL});
85 if (isVVPBinaryOp(Opcode: VVPOpcode))
86 return CDAG.getNode(OC: VVPOpcode, ResVT: LegalVecVT,
87 OpV: {Op->getOperand(Num: 0), Op->getOperand(Num: 1), Mask, AVL});
88 if (isVVPReductionOp(Opcode: VVPOpcode)) {
89 auto SrcHasStart = hasReductionStartParam(VVPOC: Op->getOpcode());
90 SDValue StartV = SrcHasStart ? Op->getOperand(Num: 0) : SDValue();
91 SDValue VectorV = Op->getOperand(Num: SrcHasStart ? 1 : 0);
92 return CDAG.getLegalReductionOpVVP(VVPOpcode, ResVT: Op.getValueType(), StartV,
93 VectorV, Mask, AVL, Flags: Op->getFlags());
94 }
95
96 switch (VVPOpcode) {
97 default:
98 llvm_unreachable("lowerToVVP called for unexpected SDNode.");
99 case VEISD::VVP_FFMA: {
100 // VE has a swizzled operand order in FMA (compared to LLVM IR and
101 // SDNodes).
102 auto X = Op->getOperand(Num: 2);
103 auto Y = Op->getOperand(Num: 0);
104 auto Z = Op->getOperand(Num: 1);
105 return CDAG.getNode(OC: VVPOpcode, ResVT: LegalVecVT, OpV: {X, Y, Z, Mask, AVL});
106 }
107 case VEISD::VVP_SELECT: {
108 auto Mask = Op->getOperand(Num: 0);
109 auto OnTrue = Op->getOperand(Num: 1);
110 auto OnFalse = Op->getOperand(Num: 2);
111 return CDAG.getNode(OC: VVPOpcode, ResVT: LegalVecVT, OpV: {OnTrue, OnFalse, Mask, AVL});
112 }
113 case VEISD::VVP_SETCC: {
114 EVT LegalResVT = getTypeToTransformTo(Context&: *DAG.getContext(), VT: Op.getValueType());
115 auto LHS = Op->getOperand(Num: 0);
116 auto RHS = Op->getOperand(Num: 1);
117 auto Pred = Op->getOperand(Num: 2);
118 return CDAG.getNode(OC: VVPOpcode, ResVT: LegalResVT, OpV: {LHS, RHS, Pred, Mask, AVL});
119 }
120 }
121}
122
123SDValue VETargetLowering::lowerVVP_LOAD_STORE(SDValue Op,
124 VECustomDAG &CDAG) const {
125 auto VVPOpc = *getVVPOpcode(Opcode: Op->getOpcode());
126 const bool IsLoad = (VVPOpc == VEISD::VVP_LOAD);
127
128 // Shares.
129 SDValue BasePtr = getMemoryPtr(Op);
130 SDValue Mask = getNodeMask(Op);
131 SDValue Chain = getNodeChain(Op);
132 SDValue AVL = getNodeAVL(Op);
133 // Store specific.
134 SDValue Data = getStoredValue(Op);
135 // Load specific.
136 SDValue PassThru = getNodePassthru(Op);
137
138 SDValue StrideV = getLoadStoreStride(Op, CDAG);
139
140 auto DataVT = *getIdiomaticVectorType(Op: Op.getNode());
141 auto Packing = getTypePacking(DataVT);
142
143 // TODO: Infer lower AVL from mask.
144 if (!AVL)
145 AVL = CDAG.getConstant(Val: DataVT.getVectorNumElements(), VT: MVT::i32);
146
147 // Default to the all-true mask.
148 if (!Mask)
149 Mask = CDAG.getConstantMask(Packing, AllTrue: true);
150
151 if (IsLoad) {
152 MVT LegalDataVT = getLegalVectorType(
153 P: Packing, ElemVT: DataVT.getVectorElementType().getSimpleVT());
154
155 auto NewLoadV = CDAG.getNode(OC: VEISD::VVP_LOAD, ResVT: {LegalDataVT, MVT::Other},
156 OpV: {Chain, BasePtr, StrideV, Mask, AVL});
157
158 if (!PassThru || PassThru->isUndef())
159 return NewLoadV;
160
161 // Convert passthru to an explicit select node.
162 SDValue DataV = CDAG.getNode(OC: VEISD::VVP_SELECT, ResVT: DataVT,
163 OpV: {NewLoadV, PassThru, Mask, AVL});
164 SDValue NewLoadChainV = SDValue(NewLoadV.getNode(), 1);
165
166 // Merge them back into one node.
167 return CDAG.getMergeValues(Values: {DataV, NewLoadChainV});
168 }
169
170 // VVP_STORE
171 assert(VVPOpc == VEISD::VVP_STORE);
172 if (getTypeAction(Context&: *CDAG.getDAG()->getContext(), VT: Data.getValueType()) !=
173 TargetLowering::TypeLegal)
174 // Doesn't lower store instruction if an operand is not lowered yet.
175 // If it isn't, return SDValue(). In this way, LLVM will try to lower
176 // store instruction again after lowering all operands.
177 return SDValue();
178 return CDAG.getNode(OC: VEISD::VVP_STORE, VTL: Op.getNode()->getVTList(),
179 OpV: {Chain, Data, BasePtr, StrideV, Mask, AVL});
180}
181
182SDValue VETargetLowering::splitPackedLoadStore(SDValue Op,
183 VECustomDAG &CDAG) const {
184 auto VVPOC = *getVVPOpcode(Opcode: Op.getOpcode());
185 assert((VVPOC == VEISD::VVP_LOAD) || (VVPOC == VEISD::VVP_STORE));
186
187 MVT DataVT = getIdiomaticVectorType(Op: Op.getNode())->getSimpleVT();
188 assert(getTypePacking(DataVT) == Packing::Dense &&
189 "Can only split packed load/store");
190 MVT SplitDataVT = splitVectorType(VT: DataVT);
191
192 assert(!getNodePassthru(Op) &&
193 "Should have been folded in lowering to VVP layer");
194
195 // Analyze the operation
196 SDValue PackedMask = getNodeMask(Op);
197 SDValue PackedAVL = getAnnotatedNodeAVL(Op).first;
198 SDValue PackPtr = getMemoryPtr(Op);
199 SDValue PackData = getStoredValue(Op);
200 SDValue PackStride = getLoadStoreStride(Op, CDAG);
201
202 unsigned ChainResIdx = PackData ? 0 : 1;
203
204 SDValue PartOps[2];
205
206 SDValue UpperPartAVL; // we will use this for packing things back together
207 for (PackElem Part : {PackElem::Hi, PackElem::Lo}) {
208 // VP ops already have an explicit mask and AVL. When expanding from non-VP
209 // attach those additional inputs here.
210 auto SplitTM = CDAG.getTargetSplitMask(RawMask: PackedMask, RawAVL: PackedAVL, Part);
211
212 // Keep track of the (higher) lvl.
213 if (Part == PackElem::Hi)
214 UpperPartAVL = SplitTM.AVL;
215
216 // Attach non-predicating value operands
217 SmallVector<SDValue, 4> OpVec;
218
219 // Chain
220 OpVec.push_back(Elt: getNodeChain(Op));
221
222 // Data
223 if (PackData) {
224 SDValue PartData =
225 CDAG.getUnpack(DestVT: SplitDataVT, Vec: PackData, Part, AVL: SplitTM.AVL);
226 OpVec.push_back(Elt: PartData);
227 }
228
229 // Ptr & Stride
230 // Push (ptr + ElemBytes * <Part>, 2 * ElemBytes)
231 // Stride info
232 // EVT DataVT = LegalizeVectorType(getMemoryDataVT(Op), Op, DAG, Mode);
233 OpVec.push_back(Elt: CDAG.getSplitPtrOffset(Ptr: PackPtr, ByteStride: PackStride, Part));
234 OpVec.push_back(Elt: CDAG.getSplitPtrStride(PackStride));
235
236 // Add predicating args and generate part node
237 OpVec.push_back(Elt: SplitTM.Mask);
238 OpVec.push_back(Elt: SplitTM.AVL);
239
240 if (PackData) {
241 // Store
242 PartOps[(int)Part] = CDAG.getNode(OC: VVPOC, ResVT: MVT::Other, OpV: OpVec);
243 } else {
244 // Load
245 PartOps[(int)Part] =
246 CDAG.getNode(OC: VVPOC, ResVT: {SplitDataVT, MVT::Other}, OpV: OpVec);
247 }
248 }
249
250 // Merge the chains
251 SDValue LowChain = SDValue(PartOps[(int)PackElem::Lo].getNode(), ChainResIdx);
252 SDValue HiChain = SDValue(PartOps[(int)PackElem::Hi].getNode(), ChainResIdx);
253 SDValue FusedChains =
254 CDAG.getNode(OC: ISD::TokenFactor, ResVT: MVT::Other, OpV: {LowChain, HiChain});
255
256 // Chain only [store]
257 if (PackData)
258 return FusedChains;
259
260 // Re-pack into full packed vector result
261 MVT PackedVT =
262 getLegalVectorType(P: Packing::Dense, ElemVT: DataVT.getVectorElementType());
263 SDValue PackedVals = CDAG.getPack(DestVT: PackedVT, LoVec: PartOps[(int)PackElem::Lo],
264 HiVec: PartOps[(int)PackElem::Hi], AVL: UpperPartAVL);
265
266 return CDAG.getMergeValues(Values: {PackedVals, FusedChains});
267}
268
269SDValue VETargetLowering::lowerVVP_GATHER_SCATTER(SDValue Op,
270 VECustomDAG &CDAG) const {
271 EVT DataVT = *getIdiomaticVectorType(Op: Op.getNode());
272 auto Packing = getTypePacking(DataVT);
273 MVT LegalDataVT =
274 getLegalVectorType(P: Packing, ElemVT: DataVT.getVectorElementType().getSimpleVT());
275
276 SDValue AVL = getAnnotatedNodeAVL(Op).first;
277 SDValue Index = getGatherScatterIndex(Op);
278 SDValue BasePtr = getMemoryPtr(Op);
279 SDValue Mask = getNodeMask(Op);
280 SDValue Chain = getNodeChain(Op);
281 SDValue Scale = getGatherScatterScale(Op);
282 SDValue PassThru = getNodePassthru(Op);
283 SDValue StoredValue = getStoredValue(Op);
284 if (PassThru && PassThru->isUndef())
285 PassThru = SDValue();
286
287 bool IsScatter = (bool)StoredValue;
288
289 // TODO: Infer lower AVL from mask.
290 if (!AVL)
291 AVL = CDAG.getConstant(Val: DataVT.getVectorNumElements(), VT: MVT::i32);
292
293 // Default to the all-true mask.
294 if (!Mask)
295 Mask = CDAG.getConstantMask(Packing, AllTrue: true);
296
297 SDValue AddressVec =
298 CDAG.getGatherScatterAddress(BasePtr, Scale, Index, Mask, AVL);
299 if (IsScatter)
300 return CDAG.getNode(OC: VEISD::VVP_SCATTER, ResVT: MVT::Other,
301 OpV: {Chain, StoredValue, AddressVec, Mask, AVL});
302
303 // Gather.
304 SDValue NewLoadV = CDAG.getNode(OC: VEISD::VVP_GATHER, ResVT: {LegalDataVT, MVT::Other},
305 OpV: {Chain, AddressVec, Mask, AVL});
306
307 if (!PassThru)
308 return NewLoadV;
309
310 // TODO: Use vvp_select
311 SDValue DataV = CDAG.getNode(OC: VEISD::VVP_SELECT, ResVT: LegalDataVT,
312 OpV: {NewLoadV, PassThru, Mask, AVL});
313 SDValue NewLoadChainV = SDValue(NewLoadV.getNode(), 1);
314 return CDAG.getMergeValues(Values: {DataV, NewLoadChainV});
315}
316
317SDValue VETargetLowering::legalizeInternalLoadStoreOp(SDValue Op,
318 VECustomDAG &CDAG) const {
319 LLVM_DEBUG(dbgs() << "::legalizeInternalLoadStoreOp\n";);
320 MVT DataVT = getIdiomaticVectorType(Op: Op.getNode())->getSimpleVT();
321
322 // TODO: Recognize packable load,store.
323 if (isPackedVectorType(SomeVT: DataVT))
324 return splitPackedLoadStore(Op, CDAG);
325
326 return legalizePackedAVL(Op, CDAG);
327}
328
329SDValue VETargetLowering::legalizeInternalVectorOp(SDValue Op,
330 SelectionDAG &DAG) const {
331 LLVM_DEBUG(dbgs() << "::legalizeInternalVectorOp\n";);
332 VECustomDAG CDAG(DAG, Op);
333
334 // Dispatch to specialized legalization functions.
335 switch (Op->getOpcode()) {
336 case VEISD::VVP_LOAD:
337 case VEISD::VVP_STORE:
338 return legalizeInternalLoadStoreOp(Op, CDAG);
339 }
340
341 EVT IdiomVT = Op.getValueType();
342 if (isPackedVectorType(SomeVT: IdiomVT) &&
343 !supportsPackedMode(Opcode: Op.getOpcode(), IdiomVT))
344 return splitVectorOp(Op, CDAG);
345
346 // TODO: Implement odd/even splitting.
347 return legalizePackedAVL(Op, CDAG);
348}
349
350SDValue VETargetLowering::splitVectorOp(SDValue Op, VECustomDAG &CDAG) const {
351 MVT ResVT = splitVectorType(VT: Op.getValue(R: 0).getSimpleValueType());
352
353 auto AVLPos = getAVLPos(Op->getOpcode());
354 auto MaskPos = getMaskPos(Op->getOpcode());
355
356 SDValue PackedMask = getNodeMask(Op);
357 auto AVLPair = getAnnotatedNodeAVL(Op);
358 SDValue PackedAVL = AVLPair.first;
359 assert(!AVLPair.second && "Expecting non pack-legalized oepration");
360
361 // request the parts
362 SDValue PartOps[2];
363
364 SDValue UpperPartAVL; // we will use this for packing things back together
365 for (PackElem Part : {PackElem::Hi, PackElem::Lo}) {
366 // VP ops already have an explicit mask and AVL. When expanding from non-VP
367 // attach those additional inputs here.
368 auto SplitTM = CDAG.getTargetSplitMask(RawMask: PackedMask, RawAVL: PackedAVL, Part);
369
370 if (Part == PackElem::Hi)
371 UpperPartAVL = SplitTM.AVL;
372
373 // Attach non-predicating value operands
374 SmallVector<SDValue, 4> OpVec;
375 for (unsigned i = 0; i < Op.getNumOperands(); ++i) {
376 if (AVLPos && ((int)i) == *AVLPos)
377 continue;
378 if (MaskPos && ((int)i) == *MaskPos)
379 continue;
380
381 // Value operand
382 auto PackedOperand = Op.getOperand(i);
383 auto UnpackedOpVT = splitVectorType(VT: PackedOperand.getSimpleValueType());
384 SDValue PartV =
385 CDAG.getUnpack(DestVT: UnpackedOpVT, Vec: PackedOperand, Part, AVL: SplitTM.AVL);
386 OpVec.push_back(Elt: PartV);
387 }
388
389 // Add predicating args and generate part node.
390 OpVec.push_back(Elt: SplitTM.Mask);
391 OpVec.push_back(Elt: SplitTM.AVL);
392 // Emit legal VVP nodes.
393 PartOps[(int)Part] =
394 CDAG.getNode(OC: Op.getOpcode(), ResVT, OpV: OpVec, Flags: Op->getFlags());
395 }
396
397 // Re-package vectors.
398 return CDAG.getPack(DestVT: Op.getValueType(), LoVec: PartOps[(int)PackElem::Lo],
399 HiVec: PartOps[(int)PackElem::Hi], AVL: UpperPartAVL);
400}
401
402SDValue VETargetLowering::legalizePackedAVL(SDValue Op,
403 VECustomDAG &CDAG) const {
404 LLVM_DEBUG(dbgs() << "::legalizePackedAVL\n";);
405 // Only required for VEC and VVP ops.
406 if (!isVVPOrVEC(Op->getOpcode()))
407 return Op;
408
409 // Operation already has a legal AVL.
410 auto AVL = getNodeAVL(Op);
411 if (isLegalAVL(AVL))
412 return Op;
413
414 // Half and round up EVL for 32bit element types.
415 SDValue LegalAVL = AVL;
416 MVT IdiomVT = getIdiomaticVectorType(Op: Op.getNode())->getSimpleVT();
417 if (isPackedVectorType(SomeVT: IdiomVT)) {
418 assert(maySafelyIgnoreMask(Op) &&
419 "TODO Shift predication from EVL into Mask");
420
421 if (auto *ConstAVL = dyn_cast<ConstantSDNode>(Val&: AVL)) {
422 LegalAVL = CDAG.getConstant(Val: (ConstAVL->getZExtValue() + 1) / 2, VT: MVT::i32);
423 } else {
424 auto ConstOne = CDAG.getConstant(Val: 1, VT: MVT::i32);
425 auto PlusOne = CDAG.getNode(OC: ISD::ADD, ResVT: MVT::i32, OpV: {AVL, ConstOne});
426 LegalAVL = CDAG.getNode(OC: ISD::SRL, ResVT: MVT::i32, OpV: {PlusOne, ConstOne});
427 }
428 }
429
430 SDValue AnnotatedLegalAVL = CDAG.annotateLegalAVL(AVL: LegalAVL);
431
432 // Copy the operand list.
433 int NumOp = Op->getNumOperands();
434 auto AVLPos = getAVLPos(Op->getOpcode());
435 std::vector<SDValue> FixedOperands;
436 for (int i = 0; i < NumOp; ++i) {
437 if (AVLPos && (i == *AVLPos)) {
438 FixedOperands.push_back(x: AnnotatedLegalAVL);
439 continue;
440 }
441 FixedOperands.push_back(x: Op->getOperand(Num: i));
442 }
443
444 // Clone the operation with fixed operands.
445 auto Flags = Op->getFlags();
446 SDValue NewN =
447 CDAG.getNode(OC: Op->getOpcode(), VTL: Op->getVTList(), OpV: FixedOperands, Flags);
448 return NewN;
449}
450