1 | //===---------------- llvm/CodeGen/MatchContext.h --------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file declares the EmptyMatchContext class and VPMatchContext class. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef LLVM_LIB_CODEGEN_SELECTIONDAG_MATCHCONTEXT_H |
14 | #define LLVM_LIB_CODEGEN_SELECTIONDAG_MATCHCONTEXT_H |
15 | |
16 | #include "llvm/CodeGen/SelectionDAG.h" |
17 | #include "llvm/CodeGen/TargetLowering.h" |
18 | |
19 | using namespace llvm; |
20 | |
21 | namespace { |
22 | class EmptyMatchContext { |
23 | SelectionDAG &DAG; |
24 | const TargetLowering &TLI; |
25 | SDNode *Root; |
26 | |
27 | public: |
28 | EmptyMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *Root) |
29 | : DAG(DAG), TLI(TLI), Root(Root) {} |
30 | |
31 | unsigned getRootBaseOpcode() { return Root->getOpcode(); } |
32 | bool match(SDValue OpN, unsigned Opcode) const { |
33 | return Opcode == OpN->getOpcode(); |
34 | } |
35 | |
36 | // Same as SelectionDAG::getNode(). |
37 | template <typename... ArgT> SDValue getNode(ArgT &&...Args) { |
38 | return DAG.getNode(std::forward<ArgT>(Args)...); |
39 | } |
40 | |
41 | bool isOperationLegal(unsigned Op, EVT VT) const { |
42 | return TLI.isOperationLegal(Op, VT); |
43 | } |
44 | |
45 | bool isOperationLegalOrCustom(unsigned Op, EVT VT, |
46 | bool LegalOnly = false) const { |
47 | return TLI.isOperationLegalOrCustom(Op, VT, LegalOnly); |
48 | } |
49 | }; |
50 | |
51 | class VPMatchContext { |
52 | SelectionDAG &DAG; |
53 | const TargetLowering &TLI; |
54 | SDValue RootMaskOp; |
55 | SDValue RootVectorLenOp; |
56 | SDNode *Root; |
57 | |
58 | public: |
59 | VPMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *_Root) |
60 | : DAG(DAG), TLI(TLI), RootMaskOp(), RootVectorLenOp() { |
61 | Root = _Root; |
62 | assert(Root->isVPOpcode()); |
63 | if (auto RootMaskPos = ISD::getVPMaskIdx(Opcode: Root->getOpcode())) |
64 | RootMaskOp = Root->getOperand(Num: *RootMaskPos); |
65 | else if (Root->getOpcode() == ISD::VP_SELECT) |
66 | RootMaskOp = DAG.getAllOnesConstant(DL: SDLoc(Root), |
67 | VT: Root->getOperand(Num: 0).getValueType()); |
68 | |
69 | if (auto RootVLenPos = ISD::getVPExplicitVectorLengthIdx(Opcode: Root->getOpcode())) |
70 | RootVectorLenOp = Root->getOperand(Num: *RootVLenPos); |
71 | } |
72 | |
73 | unsigned getRootBaseOpcode() { |
74 | std::optional<unsigned> Opcode = ISD::getBaseOpcodeForVP( |
75 | Opcode: Root->getOpcode(), hasFPExcept: !Root->getFlags().hasNoFPExcept()); |
76 | assert(Opcode.has_value()); |
77 | return *Opcode; |
78 | } |
79 | |
80 | /// whether \p OpVal is a node that is functionally compatible with the |
81 | /// NodeType \p Opc |
82 | bool match(SDValue OpVal, unsigned Opc) const { |
83 | if (!OpVal->isVPOpcode()) |
84 | return OpVal->getOpcode() == Opc; |
85 | |
86 | auto BaseOpc = ISD::getBaseOpcodeForVP(Opcode: OpVal->getOpcode(), |
87 | hasFPExcept: !OpVal->getFlags().hasNoFPExcept()); |
88 | if (BaseOpc != Opc) |
89 | return false; |
90 | |
91 | // Make sure the mask of OpVal is true mask or is same as Root's. |
92 | unsigned VPOpcode = OpVal->getOpcode(); |
93 | if (auto MaskPos = ISD::getVPMaskIdx(Opcode: VPOpcode)) { |
94 | SDValue MaskOp = OpVal.getOperand(i: *MaskPos); |
95 | if (RootMaskOp != MaskOp && |
96 | !ISD::isConstantSplatVectorAllOnes(N: MaskOp.getNode())) |
97 | return false; |
98 | } |
99 | |
100 | // Make sure the EVL of OpVal is same as Root's. |
101 | if (auto VLenPos = ISD::getVPExplicitVectorLengthIdx(Opcode: VPOpcode)) |
102 | if (RootVectorLenOp != OpVal.getOperand(i: *VLenPos)) |
103 | return false; |
104 | return true; |
105 | } |
106 | |
107 | // Specialize based on number of operands. |
108 | // TODO emit VP intrinsics where MaskOp/VectorLenOp != null |
109 | // SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT) { return |
110 | // DAG.getNode(Opcode, DL, VT); } |
111 | SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand) { |
112 | unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode); |
113 | assert(ISD::getVPMaskIdx(VPOpcode) == 1 && |
114 | ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2); |
115 | return DAG.getNode(Opcode: VPOpcode, DL, VT, |
116 | Ops: {Operand, RootMaskOp, RootVectorLenOp}); |
117 | } |
118 | |
119 | SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, |
120 | SDValue N2) { |
121 | unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode); |
122 | assert(ISD::getVPMaskIdx(VPOpcode) == 2 && |
123 | ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3); |
124 | return DAG.getNode(Opcode: VPOpcode, DL, VT, Ops: {N1, N2, RootMaskOp, RootVectorLenOp}); |
125 | } |
126 | |
127 | SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, |
128 | SDValue N2, SDValue N3) { |
129 | unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode); |
130 | assert(ISD::getVPMaskIdx(VPOpcode) == 3 && |
131 | ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4); |
132 | return DAG.getNode(Opcode: VPOpcode, DL, VT, |
133 | Ops: {N1, N2, N3, RootMaskOp, RootVectorLenOp}); |
134 | } |
135 | |
136 | SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand, |
137 | SDNodeFlags Flags) { |
138 | unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode); |
139 | assert(ISD::getVPMaskIdx(VPOpcode) == 1 && |
140 | ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2); |
141 | return DAG.getNode(Opcode: VPOpcode, DL, VT, Ops: {Operand, RootMaskOp, RootVectorLenOp}, |
142 | Flags); |
143 | } |
144 | |
145 | SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, |
146 | SDValue N2, SDNodeFlags Flags) { |
147 | unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode); |
148 | assert(ISD::getVPMaskIdx(VPOpcode) == 2 && |
149 | ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3); |
150 | return DAG.getNode(Opcode: VPOpcode, DL, VT, Ops: {N1, N2, RootMaskOp, RootVectorLenOp}, |
151 | Flags); |
152 | } |
153 | |
154 | SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, |
155 | SDValue N2, SDValue N3, SDNodeFlags Flags) { |
156 | unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode); |
157 | assert(ISD::getVPMaskIdx(VPOpcode) == 3 && |
158 | ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4); |
159 | return DAG.getNode(Opcode: VPOpcode, DL, VT, |
160 | Ops: {N1, N2, N3, RootMaskOp, RootVectorLenOp}, Flags); |
161 | } |
162 | |
163 | bool isOperationLegal(unsigned Op, EVT VT) const { |
164 | unsigned VPOp = ISD::getVPForBaseOpcode(Opcode: Op); |
165 | return TLI.isOperationLegal(Op: VPOp, VT); |
166 | } |
167 | |
168 | bool isOperationLegalOrCustom(unsigned Op, EVT VT, |
169 | bool LegalOnly = false) const { |
170 | unsigned VPOp = ISD::getVPForBaseOpcode(Opcode: Op); |
171 | return TLI.isOperationLegalOrCustom(Op: VPOp, VT, LegalOnly); |
172 | } |
173 | }; |
174 | } // end anonymous namespace |
175 | #endif |
176 | |