1//===---------------- llvm/CodeGen/MatchContext.h --------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file declares the EmptyMatchContext class and VPMatchContext class.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef LLVM_LIB_CODEGEN_SELECTIONDAG_MATCHCONTEXT_H
14#define LLVM_LIB_CODEGEN_SELECTIONDAG_MATCHCONTEXT_H
15
16#include "llvm/CodeGen/SelectionDAG.h"
17#include "llvm/CodeGen/TargetLowering.h"
18
19using namespace llvm;
20
21namespace {
22class EmptyMatchContext {
23 SelectionDAG &DAG;
24 const TargetLowering &TLI;
25 SDNode *Root;
26
27public:
28 EmptyMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *Root)
29 : DAG(DAG), TLI(TLI), Root(Root) {}
30
31 unsigned getRootBaseOpcode() { return Root->getOpcode(); }
32 bool match(SDValue OpN, unsigned Opcode) const {
33 return Opcode == OpN->getOpcode();
34 }
35
36 // Same as SelectionDAG::getNode().
37 template <typename... ArgT> SDValue getNode(ArgT &&...Args) {
38 return DAG.getNode(std::forward<ArgT>(Args)...);
39 }
40
41 bool isOperationLegal(unsigned Op, EVT VT) const {
42 return TLI.isOperationLegal(Op, VT);
43 }
44
45 bool isOperationLegalOrCustom(unsigned Op, EVT VT,
46 bool LegalOnly = false) const {
47 return TLI.isOperationLegalOrCustom(Op, VT, LegalOnly);
48 }
49};
50
51class VPMatchContext {
52 SelectionDAG &DAG;
53 const TargetLowering &TLI;
54 SDValue RootMaskOp;
55 SDValue RootVectorLenOp;
56 SDNode *Root;
57
58public:
59 VPMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *_Root)
60 : DAG(DAG), TLI(TLI), RootMaskOp(), RootVectorLenOp() {
61 Root = _Root;
62 assert(Root->isVPOpcode());
63 if (auto RootMaskPos = ISD::getVPMaskIdx(Opcode: Root->getOpcode()))
64 RootMaskOp = Root->getOperand(Num: *RootMaskPos);
65 else if (Root->getOpcode() == ISD::VP_SELECT)
66 RootMaskOp = DAG.getAllOnesConstant(DL: SDLoc(Root),
67 VT: Root->getOperand(Num: 0).getValueType());
68
69 if (auto RootVLenPos = ISD::getVPExplicitVectorLengthIdx(Opcode: Root->getOpcode()))
70 RootVectorLenOp = Root->getOperand(Num: *RootVLenPos);
71 }
72
73 unsigned getRootBaseOpcode() {
74 std::optional<unsigned> Opcode = ISD::getBaseOpcodeForVP(
75 Opcode: Root->getOpcode(), hasFPExcept: !Root->getFlags().hasNoFPExcept());
76 assert(Opcode.has_value());
77 return *Opcode;
78 }
79
80 /// whether \p OpVal is a node that is functionally compatible with the
81 /// NodeType \p Opc
82 bool match(SDValue OpVal, unsigned Opc) const {
83 if (!OpVal->isVPOpcode())
84 return OpVal->getOpcode() == Opc;
85
86 auto BaseOpc = ISD::getBaseOpcodeForVP(Opcode: OpVal->getOpcode(),
87 hasFPExcept: !OpVal->getFlags().hasNoFPExcept());
88 if (BaseOpc != Opc)
89 return false;
90
91 // Make sure the mask of OpVal is true mask or is same as Root's.
92 unsigned VPOpcode = OpVal->getOpcode();
93 if (auto MaskPos = ISD::getVPMaskIdx(Opcode: VPOpcode)) {
94 SDValue MaskOp = OpVal.getOperand(i: *MaskPos);
95 if (RootMaskOp != MaskOp &&
96 !ISD::isConstantSplatVectorAllOnes(N: MaskOp.getNode()))
97 return false;
98 }
99
100 // Make sure the EVL of OpVal is same as Root's.
101 if (auto VLenPos = ISD::getVPExplicitVectorLengthIdx(Opcode: VPOpcode))
102 if (RootVectorLenOp != OpVal.getOperand(i: *VLenPos))
103 return false;
104 return true;
105 }
106
107 // Specialize based on number of operands.
108 // TODO emit VP intrinsics where MaskOp/VectorLenOp != null
109 // SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT) { return
110 // DAG.getNode(Opcode, DL, VT); }
111 SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand) {
112 unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
113 assert(ISD::getVPMaskIdx(VPOpcode) == 1 &&
114 ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2);
115 return DAG.getNode(Opcode: VPOpcode, DL, VT,
116 Ops: {Operand, RootMaskOp, RootVectorLenOp});
117 }
118
119 SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
120 SDValue N2) {
121 unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
122 assert(ISD::getVPMaskIdx(VPOpcode) == 2 &&
123 ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3);
124 return DAG.getNode(Opcode: VPOpcode, DL, VT, Ops: {N1, N2, RootMaskOp, RootVectorLenOp});
125 }
126
127 SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
128 SDValue N2, SDValue N3) {
129 unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
130 assert(ISD::getVPMaskIdx(VPOpcode) == 3 &&
131 ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4);
132 return DAG.getNode(Opcode: VPOpcode, DL, VT,
133 Ops: {N1, N2, N3, RootMaskOp, RootVectorLenOp});
134 }
135
136 SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand,
137 SDNodeFlags Flags) {
138 unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
139 assert(ISD::getVPMaskIdx(VPOpcode) == 1 &&
140 ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2);
141 return DAG.getNode(Opcode: VPOpcode, DL, VT, Ops: {Operand, RootMaskOp, RootVectorLenOp},
142 Flags);
143 }
144
145 SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
146 SDValue N2, SDNodeFlags Flags) {
147 unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
148 assert(ISD::getVPMaskIdx(VPOpcode) == 2 &&
149 ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3);
150 return DAG.getNode(Opcode: VPOpcode, DL, VT, Ops: {N1, N2, RootMaskOp, RootVectorLenOp},
151 Flags);
152 }
153
154 SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
155 SDValue N2, SDValue N3, SDNodeFlags Flags) {
156 unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
157 assert(ISD::getVPMaskIdx(VPOpcode) == 3 &&
158 ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4);
159 return DAG.getNode(Opcode: VPOpcode, DL, VT,
160 Ops: {N1, N2, N3, RootMaskOp, RootVectorLenOp}, Flags);
161 }
162
163 bool isOperationLegal(unsigned Op, EVT VT) const {
164 unsigned VPOp = ISD::getVPForBaseOpcode(Opcode: Op);
165 return TLI.isOperationLegal(Op: VPOp, VT);
166 }
167
168 bool isOperationLegalOrCustom(unsigned Op, EVT VT,
169 bool LegalOnly = false) const {
170 unsigned VPOp = ISD::getVPForBaseOpcode(Opcode: Op);
171 return TLI.isOperationLegalOrCustom(Op: VPOp, VT, LegalOnly);
172 }
173};
174} // end anonymous namespace
175#endif
176