1//===-- AArch64SelectionDAGInfo.cpp - AArch64 SelectionDAG Info -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the AArch64SelectionDAGInfo class.
10//
11//===----------------------------------------------------------------------===//
12
13#include "AArch64SelectionDAGInfo.h"
14#include "AArch64MachineFunctionInfo.h"
15
16#define GET_SDNODE_DESC
17#include "AArch64GenSDNodeInfo.inc"
18#undef GET_SDNODE_DESC
19
20using namespace llvm;
21
22#define DEBUG_TYPE "aarch64-selectiondag-info"
23
24static cl::opt<bool>
25 LowerToSMERoutines("aarch64-lower-to-sme-routines", cl::Hidden,
26 cl::desc("Enable AArch64 SME memory operations "
27 "to lower to librt functions"),
28 cl::init(Val: true));
29
30static cl::opt<bool> UseMOPS("aarch64-use-mops", cl::Hidden,
31 cl::desc("Enable AArch64 MOPS instructions "
32 "for memcpy/memset/memmove"),
33 cl::init(Val: true));
34
35AArch64SelectionDAGInfo::AArch64SelectionDAGInfo()
36 : SelectionDAGGenTargetInfo(AArch64GenSDNodeInfo) {}
37
38void AArch64SelectionDAGInfo::verifyTargetNode(const SelectionDAG &DAG,
39 const SDNode *N) const {
40 switch (N->getOpcode()) {
41 case AArch64ISD::WrapperLarge:
42 // operand #0 must have type i32, but has type i64
43 return;
44 }
45
46 SelectionDAGGenTargetInfo::verifyTargetNode(DAG, N);
47
48#ifndef NDEBUG
49 // Some additional checks not yet implemented by verifyTargetNode.
50 switch (N->getOpcode()) {
51 case AArch64ISD::CTTZ_ELTS:
52 assert(N->getOperand(0).getValueType() == N->getOperand(1).getValueType() &&
53 "Expected the general-predicate and mask to have matching types");
54 break;
55 case AArch64ISD::SADDWT:
56 case AArch64ISD::SADDWB:
57 case AArch64ISD::UADDWT:
58 case AArch64ISD::UADDWB: {
59 EVT VT = N->getValueType(0);
60 EVT Op0VT = N->getOperand(0).getValueType();
61 EVT Op1VT = N->getOperand(1).getValueType();
62 assert(VT.isVector() && Op0VT.isVector() && Op1VT.isVector() &&
63 VT.isInteger() && Op0VT.isInteger() && Op1VT.isInteger() &&
64 "Expected integer vectors!");
65 assert(VT == Op0VT &&
66 "Expected result and first input to have the same type!");
67 assert(Op0VT.getSizeInBits() == Op1VT.getSizeInBits() &&
68 "Expected vectors of equal size!");
69 assert(Op0VT.getVectorElementCount() * 2 == Op1VT.getVectorElementCount() &&
70 "Expected result vector and first input vector to have half the "
71 "lanes of the second input vector!");
72 break;
73 }
74 case AArch64ISD::SUNPKLO:
75 case AArch64ISD::SUNPKHI:
76 case AArch64ISD::UUNPKLO:
77 case AArch64ISD::UUNPKHI: {
78 EVT VT = N->getValueType(0);
79 EVT OpVT = N->getOperand(0).getValueType();
80 assert(OpVT.isVector() && VT.isVector() && OpVT.isInteger() &&
81 VT.isInteger() && "Expected integer vectors!");
82 assert(OpVT.getSizeInBits() == VT.getSizeInBits() &&
83 "Expected vectors of equal size!");
84 assert(OpVT.getVectorElementCount() == VT.getVectorElementCount() * 2 &&
85 "Expected result vector with half the lanes of its input!");
86 break;
87 }
88 case AArch64ISD::TRN1:
89 case AArch64ISD::TRN2:
90 case AArch64ISD::UZP1:
91 case AArch64ISD::UZP2:
92 case AArch64ISD::ZIP1:
93 case AArch64ISD::ZIP2: {
94 EVT VT = N->getValueType(0);
95 EVT Op0VT = N->getOperand(0).getValueType();
96 EVT Op1VT = N->getOperand(1).getValueType();
97 assert(VT.isVector() && Op0VT.isVector() && Op1VT.isVector() &&
98 "Expected vectors!");
99 assert(VT == Op0VT && VT == Op1VT && "Expected matching vectors!");
100 break;
101 }
102 case AArch64ISD::RSHRNB_I: {
103 EVT VT = N->getValueType(0);
104 EVT Op0VT = N->getOperand(0).getValueType();
105 assert(VT.isVector() && VT.isInteger() &&
106 "Expected integer vector result type!");
107 assert(Op0VT.isVector() && Op0VT.isInteger() &&
108 "Expected first operand to be an integer vector!");
109 assert(VT.getSizeInBits() == Op0VT.getSizeInBits() &&
110 "Expected vectors of equal size!");
111 assert(VT.getVectorElementCount() == Op0VT.getVectorElementCount() * 2 &&
112 "Expected input vector with half the lanes of its result!");
113 assert(isa<ConstantSDNode>(N->getOperand(1)) &&
114 "Expected second operand to be a constant!");
115 break;
116 }
117 }
118#endif
119}
120
121SDValue AArch64SelectionDAGInfo::EmitMOPS(unsigned Opcode, SelectionDAG &DAG,
122 const SDLoc &DL, SDValue Chain,
123 SDValue Dst, SDValue SrcOrValue,
124 SDValue Size, Align Alignment,
125 bool isVolatile,
126 MachinePointerInfo DstPtrInfo,
127 MachinePointerInfo SrcPtrInfo) const {
128
129 // Get the constant size of the copy/set.
130 uint64_t ConstSize = 0;
131 if (auto *C = dyn_cast<ConstantSDNode>(Val&: Size))
132 ConstSize = C->getZExtValue();
133
134 const bool IsSet = Opcode == AArch64::MOPSMemorySetPseudo ||
135 Opcode == AArch64::MOPSMemorySetTaggingPseudo;
136
137 MachineFunction &MF = DAG.getMachineFunction();
138
139 auto Vol =
140 isVolatile ? MachineMemOperand::MOVolatile : MachineMemOperand::MONone;
141 auto DstFlags = MachineMemOperand::MOStore | Vol;
142 auto *DstOp =
143 MF.getMachineMemOperand(PtrInfo: DstPtrInfo, F: DstFlags, Size: ConstSize, BaseAlignment: Alignment);
144
145 if (IsSet) {
146 // Extend value to i64, if required.
147 if (SrcOrValue.getValueType() != MVT::i64)
148 SrcOrValue = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT: MVT::i64, Operand: SrcOrValue);
149 SDValue Ops[] = {Dst, Size, SrcOrValue, Chain};
150 const EVT ResultTys[] = {MVT::i64, MVT::i64, MVT::Other};
151 MachineSDNode *Node = DAG.getMachineNode(Opcode, dl: DL, ResultTys, Ops);
152 DAG.setNodeMemRefs(N: Node, NewMemRefs: {DstOp});
153 return SDValue(Node, 2);
154 } else {
155 SDValue Ops[] = {Dst, SrcOrValue, Size, Chain};
156 const EVT ResultTys[] = {MVT::i64, MVT::i64, MVT::i64, MVT::Other};
157 MachineSDNode *Node = DAG.getMachineNode(Opcode, dl: DL, ResultTys, Ops);
158
159 auto SrcFlags = MachineMemOperand::MOLoad | Vol;
160 auto *SrcOp =
161 MF.getMachineMemOperand(PtrInfo: SrcPtrInfo, F: SrcFlags, Size: ConstSize, BaseAlignment: Alignment);
162 DAG.setNodeMemRefs(N: Node, NewMemRefs: {DstOp, SrcOp});
163 return SDValue(Node, 3);
164 }
165}
166
167SDValue AArch64SelectionDAGInfo::EmitStreamingCompatibleMemLibCall(
168 SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, SDValue Op0, SDValue Op1,
169 SDValue Size, RTLIB::Libcall LC) const {
170 const AArch64Subtarget &STI =
171 DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
172 const AArch64TargetLowering *TLI = STI.getTargetLowering();
173 TargetLowering::ArgListTy Args;
174 Args.emplace_back(args&: Op0, args: PointerType::getUnqual(C&: *DAG.getContext()));
175
176 bool UsesResult = false;
177 RTLIB::Libcall NewLC;
178 switch (LC) {
179 case RTLIB::MEMCPY: {
180 NewLC = RTLIB::SC_MEMCPY;
181 Args.emplace_back(args&: Op1, args: PointerType::getUnqual(C&: *DAG.getContext()));
182 break;
183 }
184 case RTLIB::MEMMOVE: {
185 NewLC = RTLIB::SC_MEMMOVE;
186 Args.emplace_back(args&: Op1, args: PointerType::getUnqual(C&: *DAG.getContext()));
187 break;
188 }
189 case RTLIB::MEMSET: {
190 NewLC = RTLIB::SC_MEMSET;
191 Args.emplace_back(args: DAG.getZExtOrTrunc(Op: Op1, DL, VT: MVT::i32),
192 args: Type::getInt32Ty(C&: *DAG.getContext()));
193 break;
194 }
195 case RTLIB::MEMCHR: {
196 UsesResult = true;
197 NewLC = RTLIB::SC_MEMCHR;
198 Args.emplace_back(args: DAG.getZExtOrTrunc(Op: Op1, DL, VT: MVT::i32),
199 args: Type::getInt32Ty(C&: *DAG.getContext()));
200 break;
201 }
202 default:
203 return SDValue();
204 }
205
206 RTLIB::LibcallImpl NewLCImpl = DAG.getLibcalls().getLibcallImpl(Call: NewLC);
207 if (NewLCImpl == RTLIB::Unsupported)
208 return SDValue();
209
210 EVT PointerVT = TLI->getPointerTy(DL: DAG.getDataLayout());
211 SDValue Symbol = DAG.getExternalSymbol(LCImpl: NewLCImpl, VT: PointerVT);
212 Args.emplace_back(args&: Size, args: DAG.getDataLayout().getIntPtrType(C&: *DAG.getContext()));
213
214 TargetLowering::CallLoweringInfo CLI(DAG);
215 PointerType *RetTy = PointerType::getUnqual(C&: *DAG.getContext());
216 CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
217 CC: DAG.getLibcalls().getLibcallImplCallingConv(Call: NewLCImpl), ResultType: RetTy, Target: Symbol,
218 ArgsList: std::move(Args));
219
220 auto [Result, ChainOut] = TLI->LowerCallTo(CLI);
221 return UsesResult ? DAG.getMergeValues(Ops: {Result, ChainOut}, dl: DL) : ChainOut;
222}
223
224SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemcpy(
225 SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, SDValue Dst, SDValue Src,
226 SDValue Size, Align Alignment, bool isVolatile, bool AlwaysInline,
227 MachinePointerInfo DstPtrInfo, MachinePointerInfo SrcPtrInfo) const {
228 const AArch64Subtarget &STI =
229 DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
230
231 if (UseMOPS && STI.hasMOPS())
232 return EmitMOPS(Opcode: AArch64::MOPSMemoryCopyPseudo, DAG, DL, Chain, Dst, SrcOrValue: Src,
233 Size, Alignment, isVolatile, DstPtrInfo, SrcPtrInfo);
234
235 auto *AFI = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
236 SMEAttrs Attrs = AFI->getSMEFnAttrs();
237 if (LowerToSMERoutines && !Attrs.hasNonStreamingInterfaceAndBody())
238 return EmitStreamingCompatibleMemLibCall(DAG, DL, Chain, Op0: Dst, Op1: Src, Size,
239 LC: RTLIB::MEMCPY);
240 return SDValue();
241}
242
243SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemset(
244 SelectionDAG &DAG, const SDLoc &dl, SDValue Chain, SDValue Dst, SDValue Src,
245 SDValue Size, Align Alignment, bool isVolatile, bool AlwaysInline,
246 MachinePointerInfo DstPtrInfo) const {
247 const AArch64Subtarget &STI =
248 DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
249
250 if (UseMOPS && STI.hasMOPS())
251 return EmitMOPS(Opcode: AArch64::MOPSMemorySetPseudo, DAG, DL: dl, Chain, Dst, SrcOrValue: Src,
252 Size, Alignment, isVolatile, DstPtrInfo,
253 SrcPtrInfo: MachinePointerInfo{});
254
255 auto *AFI = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
256 SMEAttrs Attrs = AFI->getSMEFnAttrs();
257 if (LowerToSMERoutines && !Attrs.hasNonStreamingInterfaceAndBody())
258 return EmitStreamingCompatibleMemLibCall(DAG, DL: dl, Chain, Op0: Dst, Op1: Src, Size,
259 LC: RTLIB::MEMSET);
260 return SDValue();
261}
262
263SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemmove(
264 SelectionDAG &DAG, const SDLoc &dl, SDValue Chain, SDValue Dst, SDValue Src,
265 SDValue Size, Align Alignment, bool isVolatile,
266 MachinePointerInfo DstPtrInfo, MachinePointerInfo SrcPtrInfo) const {
267 const AArch64Subtarget &STI =
268 DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
269
270 if (UseMOPS && STI.hasMOPS())
271 return EmitMOPS(Opcode: AArch64::MOPSMemoryMovePseudo, DAG, DL: dl, Chain, Dst, SrcOrValue: Src,
272 Size, Alignment, isVolatile, DstPtrInfo, SrcPtrInfo);
273
274 auto *AFI = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
275 SMEAttrs Attrs = AFI->getSMEFnAttrs();
276 if (LowerToSMERoutines && !Attrs.hasNonStreamingInterfaceAndBody())
277 return EmitStreamingCompatibleMemLibCall(DAG, DL: dl, Chain, Op0: Dst, Op1: Src, Size,
278 LC: RTLIB::MEMMOVE);
279 return SDValue();
280}
281
282std::pair<SDValue, SDValue> AArch64SelectionDAGInfo::EmitTargetCodeForMemchr(
283 SelectionDAG &DAG, const SDLoc &dl, SDValue Chain, SDValue Src,
284 SDValue Char, SDValue Length, MachinePointerInfo SrcPtrInfo) const {
285 auto *AFI = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
286 SMEAttrs Attrs = AFI->getSMEFnAttrs();
287 if (LowerToSMERoutines && !Attrs.hasNonStreamingInterfaceAndBody()) {
288 SDValue Result = EmitStreamingCompatibleMemLibCall(
289 DAG, DL: dl, Chain, Op0: Src, Op1: Char, Size: Length, LC: RTLIB::MEMCHR);
290 return std::make_pair(x: Result.getValue(R: 0), y: Result.getValue(R: 1));
291 }
292 return std::make_pair(x: SDValue(), y: SDValue());
293}
294
295static const int kSetTagLoopThreshold = 176;
296
297static SDValue EmitUnrolledSetTag(SelectionDAG &DAG, const SDLoc &dl,
298 SDValue Chain, SDValue Ptr, uint64_t ObjSize,
299 const MachineMemOperand *BaseMemOperand,
300 bool ZeroData) {
301 MachineFunction &MF = DAG.getMachineFunction();
302 unsigned ObjSizeScaled = ObjSize / 16;
303
304 SDValue TagSrc = Ptr;
305 if (Ptr.getOpcode() == ISD::FrameIndex) {
306 int FI = cast<FrameIndexSDNode>(Val&: Ptr)->getIndex();
307 Ptr = DAG.getTargetFrameIndex(FI, VT: MVT::i64);
308 // A frame index operand may end up as [SP + offset] => it is fine to use SP
309 // register as the tag source.
310 TagSrc = DAG.getRegister(Reg: AArch64::SP, VT: MVT::i64);
311 }
312
313 const unsigned OpCode1 = ZeroData ? AArch64ISD::STZG : AArch64ISD::STG;
314 const unsigned OpCode2 = ZeroData ? AArch64ISD::STZ2G : AArch64ISD::ST2G;
315
316 SmallVector<SDValue, 8> OutChains;
317 unsigned OffsetScaled = 0;
318 while (OffsetScaled < ObjSizeScaled) {
319 if (ObjSizeScaled - OffsetScaled >= 2) {
320 SDValue AddrNode = DAG.getMemBasePlusOffset(
321 Base: Ptr, Offset: TypeSize::getFixed(ExactSize: OffsetScaled * 16), DL: dl);
322 SDValue St = DAG.getMemIntrinsicNode(
323 Opcode: OpCode2, dl, VTList: DAG.getVTList(VT: MVT::Other),
324 Ops: {Chain, TagSrc, AddrNode},
325 MemVT: MVT::v4i64,
326 MMO: MF.getMachineMemOperand(MMO: BaseMemOperand, Offset: OffsetScaled * 16, Size: 16 * 2));
327 OffsetScaled += 2;
328 OutChains.push_back(Elt: St);
329 continue;
330 }
331
332 if (ObjSizeScaled - OffsetScaled > 0) {
333 SDValue AddrNode = DAG.getMemBasePlusOffset(
334 Base: Ptr, Offset: TypeSize::getFixed(ExactSize: OffsetScaled * 16), DL: dl);
335 SDValue St = DAG.getMemIntrinsicNode(
336 Opcode: OpCode1, dl, VTList: DAG.getVTList(VT: MVT::Other),
337 Ops: {Chain, TagSrc, AddrNode},
338 MemVT: MVT::v2i64,
339 MMO: MF.getMachineMemOperand(MMO: BaseMemOperand, Offset: OffsetScaled * 16, Size: 16));
340 OffsetScaled += 1;
341 OutChains.push_back(Elt: St);
342 }
343 }
344
345 SDValue Res = DAG.getNode(Opcode: ISD::TokenFactor, DL: dl, VT: MVT::Other, Ops: OutChains);
346 return Res;
347}
348
349SDValue AArch64SelectionDAGInfo::EmitTargetCodeForSetTag(
350 SelectionDAG &DAG, const SDLoc &dl, SDValue Chain, SDValue Addr,
351 SDValue Size, MachinePointerInfo DstPtrInfo, bool ZeroData) const {
352 uint64_t ObjSize = Size->getAsZExtVal();
353 assert(ObjSize % 16 == 0);
354
355 MachineFunction &MF = DAG.getMachineFunction();
356 MachineMemOperand *BaseMemOperand = MF.getMachineMemOperand(
357 PtrInfo: DstPtrInfo, F: MachineMemOperand::MOStore, Size: ObjSize, BaseAlignment: Align(16));
358
359 bool UseSetTagRangeLoop =
360 kSetTagLoopThreshold >= 0 && (int)ObjSize >= kSetTagLoopThreshold;
361 if (!UseSetTagRangeLoop)
362 return EmitUnrolledSetTag(DAG, dl, Chain, Ptr: Addr, ObjSize, BaseMemOperand,
363 ZeroData);
364
365 const EVT ResTys[] = {MVT::i64, MVT::i64, MVT::Other};
366
367 unsigned Opcode;
368 if (Addr.getOpcode() == ISD::FrameIndex) {
369 int FI = cast<FrameIndexSDNode>(Val&: Addr)->getIndex();
370 Addr = DAG.getTargetFrameIndex(FI, VT: MVT::i64);
371 Opcode = ZeroData ? AArch64::STZGloop : AArch64::STGloop;
372 } else {
373 Opcode = ZeroData ? AArch64::STZGloop_wback : AArch64::STGloop_wback;
374 }
375 SDValue Ops[] = {DAG.getTargetConstant(Val: ObjSize, DL: dl, VT: MVT::i64), Addr, Chain};
376 SDNode *St = DAG.getMachineNode(Opcode, dl, ResultTys: ResTys, Ops);
377
378 DAG.setNodeMemRefs(N: cast<MachineSDNode>(Val: St), NewMemRefs: {BaseMemOperand});
379 return SDValue(St, 2);
380}
381