1 | //===---------------- llvm/CodeGen/MatchContext.h --------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file declares the EmptyMatchContext class and VPMatchContext class. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef LLVM_LIB_CODEGEN_SELECTIONDAG_MATCHCONTEXT_H |
14 | #define LLVM_LIB_CODEGEN_SELECTIONDAG_MATCHCONTEXT_H |
15 | |
16 | #include "llvm/CodeGen/SelectionDAG.h" |
17 | #include "llvm/CodeGen/TargetLowering.h" |
18 | |
19 | namespace llvm { |
20 | |
21 | class EmptyMatchContext { |
22 | SelectionDAG &DAG; |
23 | const TargetLowering &TLI; |
24 | SDNode *Root; |
25 | |
26 | public: |
27 | EmptyMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *Root) |
28 | : DAG(DAG), TLI(TLI), Root(Root) {} |
29 | |
30 | unsigned getRootBaseOpcode() { return Root->getOpcode(); } |
31 | bool match(SDValue OpN, unsigned Opcode) const { |
32 | return Opcode == OpN->getOpcode(); |
33 | } |
34 | |
35 | // Same as SelectionDAG::getNode(). |
36 | template <typename... ArgT> SDValue getNode(ArgT &&...Args) { |
37 | return DAG.getNode(std::forward<ArgT>(Args)...); |
38 | } |
39 | |
40 | bool isOperationLegal(unsigned Op, EVT VT) const { |
41 | return TLI.isOperationLegal(Op, VT); |
42 | } |
43 | |
44 | bool isOperationLegalOrCustom(unsigned Op, EVT VT, |
45 | bool LegalOnly = false) const { |
46 | return TLI.isOperationLegalOrCustom(Op, VT, LegalOnly); |
47 | } |
48 | |
49 | unsigned getNumOperands(SDValue N) const { return N->getNumOperands(); } |
50 | }; |
51 | |
52 | class VPMatchContext { |
53 | SelectionDAG &DAG; |
54 | const TargetLowering &TLI; |
55 | SDValue RootMaskOp; |
56 | SDValue RootVectorLenOp; |
57 | SDNode *Root; |
58 | |
59 | public: |
60 | VPMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *_Root) |
61 | : DAG(DAG), TLI(TLI), RootMaskOp(), RootVectorLenOp() { |
62 | Root = _Root; |
63 | assert(Root->isVPOpcode()); |
64 | if (auto RootMaskPos = ISD::getVPMaskIdx(Opcode: Root->getOpcode())) |
65 | RootMaskOp = Root->getOperand(Num: *RootMaskPos); |
66 | else if (Root->getOpcode() == ISD::VP_SELECT) |
67 | RootMaskOp = DAG.getAllOnesConstant(DL: SDLoc(Root), |
68 | VT: Root->getOperand(Num: 0).getValueType()); |
69 | |
70 | if (auto RootVLenPos = ISD::getVPExplicitVectorLengthIdx(Opcode: Root->getOpcode())) |
71 | RootVectorLenOp = Root->getOperand(Num: *RootVLenPos); |
72 | } |
73 | |
74 | unsigned getRootBaseOpcode() { |
75 | std::optional<unsigned> Opcode = ISD::getBaseOpcodeForVP( |
76 | Opcode: Root->getOpcode(), hasFPExcept: !Root->getFlags().hasNoFPExcept()); |
77 | assert(Opcode.has_value()); |
78 | return *Opcode; |
79 | } |
80 | |
81 | /// whether \p OpVal is a node that is functionally compatible with the |
82 | /// NodeType \p Opc |
83 | bool match(SDValue OpVal, unsigned Opc) const { |
84 | if (!OpVal->isVPOpcode()) |
85 | return OpVal->getOpcode() == Opc; |
86 | |
87 | auto BaseOpc = ISD::getBaseOpcodeForVP(Opcode: OpVal->getOpcode(), |
88 | hasFPExcept: !OpVal->getFlags().hasNoFPExcept()); |
89 | if (BaseOpc != Opc) |
90 | return false; |
91 | |
92 | // Make sure the mask of OpVal is true mask or is same as Root's. |
93 | unsigned VPOpcode = OpVal->getOpcode(); |
94 | if (auto MaskPos = ISD::getVPMaskIdx(Opcode: VPOpcode)) { |
95 | SDValue MaskOp = OpVal.getOperand(i: *MaskPos); |
96 | if (RootMaskOp != MaskOp && |
97 | !ISD::isConstantSplatVectorAllOnes(N: MaskOp.getNode())) |
98 | return false; |
99 | } |
100 | |
101 | // Make sure the EVL of OpVal is same as Root's. |
102 | if (auto VLenPos = ISD::getVPExplicitVectorLengthIdx(Opcode: VPOpcode)) |
103 | if (RootVectorLenOp != OpVal.getOperand(i: *VLenPos)) |
104 | return false; |
105 | return true; |
106 | } |
107 | |
108 | // Specialize based on number of operands. |
109 | // TODO emit VP intrinsics where MaskOp/VectorLenOp != null |
110 | // SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT) { return |
111 | // DAG.getNode(Opcode, DL, VT); } |
112 | SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand) { |
113 | unsigned VPOpcode = *ISD::getVPForBaseOpcode(Opcode); |
114 | assert(ISD::getVPMaskIdx(VPOpcode) == 1 && |
115 | ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2); |
116 | return DAG.getNode(Opcode: VPOpcode, DL, VT, |
117 | Ops: {Operand, RootMaskOp, RootVectorLenOp}); |
118 | } |
119 | |
120 | SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, |
121 | SDValue N2) { |
122 | unsigned VPOpcode = *ISD::getVPForBaseOpcode(Opcode); |
123 | assert(ISD::getVPMaskIdx(VPOpcode) == 2 && |
124 | ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3); |
125 | return DAG.getNode(Opcode: VPOpcode, DL, VT, Ops: {N1, N2, RootMaskOp, RootVectorLenOp}); |
126 | } |
127 | |
128 | SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, |
129 | SDValue N2, SDValue N3) { |
130 | unsigned VPOpcode = *ISD::getVPForBaseOpcode(Opcode); |
131 | assert(ISD::getVPMaskIdx(VPOpcode) == 3 && |
132 | ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4); |
133 | return DAG.getNode(Opcode: VPOpcode, DL, VT, |
134 | Ops: {N1, N2, N3, RootMaskOp, RootVectorLenOp}); |
135 | } |
136 | |
137 | SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand, |
138 | SDNodeFlags Flags) { |
139 | unsigned VPOpcode = *ISD::getVPForBaseOpcode(Opcode); |
140 | assert(ISD::getVPMaskIdx(VPOpcode) == 1 && |
141 | ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2); |
142 | return DAG.getNode(Opcode: VPOpcode, DL, VT, Ops: {Operand, RootMaskOp, RootVectorLenOp}, |
143 | Flags); |
144 | } |
145 | |
146 | SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, |
147 | SDValue N2, SDNodeFlags Flags) { |
148 | unsigned VPOpcode = *ISD::getVPForBaseOpcode(Opcode); |
149 | assert(ISD::getVPMaskIdx(VPOpcode) == 2 && |
150 | ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3); |
151 | return DAG.getNode(Opcode: VPOpcode, DL, VT, Ops: {N1, N2, RootMaskOp, RootVectorLenOp}, |
152 | Flags); |
153 | } |
154 | |
155 | SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, |
156 | SDValue N2, SDValue N3, SDNodeFlags Flags) { |
157 | unsigned VPOpcode = *ISD::getVPForBaseOpcode(Opcode); |
158 | assert(ISD::getVPMaskIdx(VPOpcode) == 3 && |
159 | ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4); |
160 | return DAG.getNode(Opcode: VPOpcode, DL, VT, |
161 | Ops: {N1, N2, N3, RootMaskOp, RootVectorLenOp}, Flags); |
162 | } |
163 | |
164 | bool isOperationLegal(unsigned Op, EVT VT) const { |
165 | unsigned VPOp = *ISD::getVPForBaseOpcode(Opcode: Op); |
166 | return TLI.isOperationLegal(Op: VPOp, VT); |
167 | } |
168 | |
169 | bool isOperationLegalOrCustom(unsigned Op, EVT VT, |
170 | bool LegalOnly = false) const { |
171 | unsigned VPOp = *ISD::getVPForBaseOpcode(Opcode: Op); |
172 | return TLI.isOperationLegalOrCustom(Op: VPOp, VT, LegalOnly); |
173 | } |
174 | |
175 | unsigned getNumOperands(SDValue N) const { |
176 | return N->isVPOpcode() ? N->getNumOperands() - 2 : N->getNumOperands(); |
177 | } |
178 | }; |
179 | |
180 | } // namespace llvm |
181 | |
182 | #endif // LLVM_LIB_CODEGEN_SELECTIONDAG_MATCHCONTEXT_H |
183 | |