1//===---------------- llvm/CodeGen/MatchContext.h --------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file declares the EmptyMatchContext class and VPMatchContext class.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef LLVM_LIB_CODEGEN_SELECTIONDAG_MATCHCONTEXT_H
14#define LLVM_LIB_CODEGEN_SELECTIONDAG_MATCHCONTEXT_H
15
16#include "llvm/CodeGen/SelectionDAG.h"
17#include "llvm/CodeGen/TargetLowering.h"
18
19namespace llvm {
20
21class EmptyMatchContext {
22 SelectionDAG &DAG;
23 const TargetLowering &TLI;
24 SDNode *Root;
25
26public:
27 EmptyMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *Root)
28 : DAG(DAG), TLI(TLI), Root(Root) {}
29
30 unsigned getRootBaseOpcode() { return Root->getOpcode(); }
31 bool match(SDValue OpN, unsigned Opcode) const {
32 return Opcode == OpN->getOpcode();
33 }
34
35 // Same as SelectionDAG::getNode().
36 template <typename... ArgT> SDValue getNode(ArgT &&...Args) {
37 return DAG.getNode(std::forward<ArgT>(Args)...);
38 }
39
40 bool isOperationLegal(unsigned Op, EVT VT) const {
41 return TLI.isOperationLegal(Op, VT);
42 }
43
44 bool isOperationLegalOrCustom(unsigned Op, EVT VT,
45 bool LegalOnly = false) const {
46 return TLI.isOperationLegalOrCustom(Op, VT, LegalOnly);
47 }
48
49 unsigned getNumOperands(SDValue N) const { return N->getNumOperands(); }
50};
51
52class VPMatchContext {
53 SelectionDAG &DAG;
54 const TargetLowering &TLI;
55 SDValue RootMaskOp;
56 SDValue RootVectorLenOp;
57 SDNode *Root;
58
59public:
60 VPMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *_Root)
61 : DAG(DAG), TLI(TLI), RootMaskOp(), RootVectorLenOp() {
62 Root = _Root;
63 assert(Root->isVPOpcode());
64 if (auto RootMaskPos = ISD::getVPMaskIdx(Opcode: Root->getOpcode()))
65 RootMaskOp = Root->getOperand(Num: *RootMaskPos);
66 else if (Root->getOpcode() == ISD::VP_SELECT)
67 RootMaskOp = DAG.getAllOnesConstant(DL: SDLoc(Root),
68 VT: Root->getOperand(Num: 0).getValueType());
69
70 if (auto RootVLenPos = ISD::getVPExplicitVectorLengthIdx(Opcode: Root->getOpcode()))
71 RootVectorLenOp = Root->getOperand(Num: *RootVLenPos);
72 }
73
74 unsigned getRootBaseOpcode() {
75 std::optional<unsigned> Opcode = ISD::getBaseOpcodeForVP(
76 Opcode: Root->getOpcode(), hasFPExcept: !Root->getFlags().hasNoFPExcept());
77 assert(Opcode.has_value());
78 return *Opcode;
79 }
80
81 /// whether \p OpVal is a node that is functionally compatible with the
82 /// NodeType \p Opc
83 bool match(SDValue OpVal, unsigned Opc) const {
84 if (!OpVal->isVPOpcode())
85 return OpVal->getOpcode() == Opc;
86
87 auto BaseOpc = ISD::getBaseOpcodeForVP(Opcode: OpVal->getOpcode(),
88 hasFPExcept: !OpVal->getFlags().hasNoFPExcept());
89 if (BaseOpc != Opc)
90 return false;
91
92 // Make sure the mask of OpVal is true mask or is same as Root's.
93 unsigned VPOpcode = OpVal->getOpcode();
94 if (auto MaskPos = ISD::getVPMaskIdx(Opcode: VPOpcode)) {
95 SDValue MaskOp = OpVal.getOperand(i: *MaskPos);
96 if (RootMaskOp != MaskOp &&
97 !ISD::isConstantSplatVectorAllOnes(N: MaskOp.getNode()))
98 return false;
99 }
100
101 // Make sure the EVL of OpVal is same as Root's.
102 if (auto VLenPos = ISD::getVPExplicitVectorLengthIdx(Opcode: VPOpcode))
103 if (RootVectorLenOp != OpVal.getOperand(i: *VLenPos))
104 return false;
105 return true;
106 }
107
108 // Specialize based on number of operands.
109 // TODO emit VP intrinsics where MaskOp/VectorLenOp != null
110 // SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT) { return
111 // DAG.getNode(Opcode, DL, VT); }
112 SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand) {
113 unsigned VPOpcode = *ISD::getVPForBaseOpcode(Opcode);
114 assert(ISD::getVPMaskIdx(VPOpcode) == 1 &&
115 ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2);
116 return DAG.getNode(Opcode: VPOpcode, DL, VT,
117 Ops: {Operand, RootMaskOp, RootVectorLenOp});
118 }
119
120 SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
121 SDValue N2) {
122 unsigned VPOpcode = *ISD::getVPForBaseOpcode(Opcode);
123 assert(ISD::getVPMaskIdx(VPOpcode) == 2 &&
124 ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3);
125 return DAG.getNode(Opcode: VPOpcode, DL, VT, Ops: {N1, N2, RootMaskOp, RootVectorLenOp});
126 }
127
128 SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
129 SDValue N2, SDValue N3) {
130 unsigned VPOpcode = *ISD::getVPForBaseOpcode(Opcode);
131 assert(ISD::getVPMaskIdx(VPOpcode) == 3 &&
132 ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4);
133 return DAG.getNode(Opcode: VPOpcode, DL, VT,
134 Ops: {N1, N2, N3, RootMaskOp, RootVectorLenOp});
135 }
136
137 SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand,
138 SDNodeFlags Flags) {
139 unsigned VPOpcode = *ISD::getVPForBaseOpcode(Opcode);
140 assert(ISD::getVPMaskIdx(VPOpcode) == 1 &&
141 ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2);
142 return DAG.getNode(Opcode: VPOpcode, DL, VT, Ops: {Operand, RootMaskOp, RootVectorLenOp},
143 Flags);
144 }
145
146 SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
147 SDValue N2, SDNodeFlags Flags) {
148 unsigned VPOpcode = *ISD::getVPForBaseOpcode(Opcode);
149 assert(ISD::getVPMaskIdx(VPOpcode) == 2 &&
150 ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3);
151 return DAG.getNode(Opcode: VPOpcode, DL, VT, Ops: {N1, N2, RootMaskOp, RootVectorLenOp},
152 Flags);
153 }
154
155 SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
156 SDValue N2, SDValue N3, SDNodeFlags Flags) {
157 unsigned VPOpcode = *ISD::getVPForBaseOpcode(Opcode);
158 assert(ISD::getVPMaskIdx(VPOpcode) == 3 &&
159 ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4);
160 return DAG.getNode(Opcode: VPOpcode, DL, VT,
161 Ops: {N1, N2, N3, RootMaskOp, RootVectorLenOp}, Flags);
162 }
163
164 bool isOperationLegal(unsigned Op, EVT VT) const {
165 unsigned VPOp = *ISD::getVPForBaseOpcode(Opcode: Op);
166 return TLI.isOperationLegal(Op: VPOp, VT);
167 }
168
169 bool isOperationLegalOrCustom(unsigned Op, EVT VT,
170 bool LegalOnly = false) const {
171 unsigned VPOp = *ISD::getVPForBaseOpcode(Opcode: Op);
172 return TLI.isOperationLegalOrCustom(Op: VPOp, VT, LegalOnly);
173 }
174
175 unsigned getNumOperands(SDValue N) const {
176 return N->isVPOpcode() ? N->getNumOperands() - 2 : N->getNumOperands();
177 }
178};
179
180} // namespace llvm
181
182#endif // LLVM_LIB_CODEGEN_SELECTIONDAG_MATCHCONTEXT_H
183