1//==--------------- llvm/CodeGen/SDPatternMatch.h ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8/// \file
9/// Contains matchers for matching SelectionDAG nodes and values.
10///
11//===----------------------------------------------------------------------===//
12
13#ifndef LLVM_CODEGEN_SDPATTERNMATCH_H
14#define LLVM_CODEGEN_SDPATTERNMATCH_H
15
16#include "llvm/ADT/APInt.h"
17#include "llvm/ADT/ArrayRef.h"
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/ADT/SmallBitVector.h"
20#include "llvm/CodeGen/SelectionDAG.h"
21#include "llvm/CodeGen/SelectionDAGNodes.h"
22#include "llvm/CodeGen/TargetLowering.h"
23#include "llvm/Support/KnownBits.h"
24
25#include <type_traits>
26
27namespace llvm {
28namespace SDPatternMatch {
29
30/// MatchContext can repurpose existing patterns to behave differently under
31/// a certain context. For instance, `m_Opc(ISD::ADD)` matches plain ADD nodes
32/// in normal circumstances, but matches VP_ADD nodes under a custom
33/// VPMatchContext. This design is meant to facilitate code / pattern reusing.
34class BasicMatchContext {
35 const SelectionDAG *DAG;
36 const TargetLowering *TLI;
37
38public:
39 explicit BasicMatchContext(const SelectionDAG *DAG)
40 : DAG(DAG), TLI(DAG ? &DAG->getTargetLoweringInfo() : nullptr) {}
41
42 explicit BasicMatchContext(const TargetLowering *TLI)
43 : DAG(nullptr), TLI(TLI) {}
44
45 // A valid MatchContext has to implement the following functions.
46
47 const SelectionDAG *getDAG() const { return DAG; }
48
49 const TargetLowering *getTLI() const { return TLI; }
50
51 /// Return true if N effectively has opcode Opcode.
52 bool match(SDValue N, unsigned Opcode) const {
53 return N->getOpcode() == Opcode;
54 }
55
56 unsigned getNumOperands(SDValue N) const { return N->getNumOperands(); }
57};
58
59template <typename Pattern, typename MatchContext>
60[[nodiscard]] bool sd_context_match(SDValue N, const MatchContext &Ctx,
61 Pattern &&P) {
62 return P.match(Ctx, N);
63}
64
65template <typename Pattern, typename MatchContext>
66[[nodiscard]] bool sd_context_match(SDNode *N, const MatchContext &Ctx,
67 Pattern &&P) {
68 return sd_context_match(SDValue(N, 0), Ctx, P);
69}
70
71template <typename Pattern>
72[[nodiscard]] bool sd_match(SDNode *N, const SelectionDAG *DAG, Pattern &&P) {
73 return sd_context_match(N, BasicMatchContext(DAG), P);
74}
75
76template <typename Pattern>
77[[nodiscard]] bool sd_match(SDValue N, const SelectionDAG *DAG, Pattern &&P) {
78 return sd_context_match(N, BasicMatchContext(DAG), P);
79}
80
81template <typename Pattern>
82[[nodiscard]] bool sd_match(SDNode *N, Pattern &&P) {
83 return sd_match(N, nullptr, P);
84}
85
86template <typename Pattern>
87[[nodiscard]] bool sd_match(SDValue N, Pattern &&P) {
88 return sd_match(N, nullptr, P);
89}
90
91// === Utilities ===
92struct Value_match {
93 SDValue MatchVal;
94
95 Value_match() = default;
96
97 explicit Value_match(SDValue Match) : MatchVal(Match) {}
98
99 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
100 if (MatchVal)
101 return MatchVal == N;
102 return N.getNode();
103 }
104};
105
106/// Match any valid SDValue.
107inline Value_match m_Value() { return Value_match(); }
108
109inline Value_match m_Specific(SDValue N) {
110 assert(N);
111 return Value_match(N);
112}
113
114template <unsigned ResNo, typename Pattern> struct Result_match {
115 Pattern P;
116
117 explicit Result_match(const Pattern &P) : P(P) {}
118
119 template <typename MatchContext>
120 bool match(const MatchContext &Ctx, SDValue N) {
121 return N.getResNo() == ResNo && P.match(Ctx, N);
122 }
123};
124
125/// Match only if the SDValue is a certain result at ResNo.
126template <unsigned ResNo, typename Pattern>
127inline Result_match<ResNo, Pattern> m_Result(const Pattern &P) {
128 return Result_match<ResNo, Pattern>(P);
129}
130
131struct DeferredValue_match {
132 SDValue &MatchVal;
133
134 explicit DeferredValue_match(SDValue &Match) : MatchVal(Match) {}
135
136 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
137 return N == MatchVal;
138 }
139};
140
141/// Similar to m_Specific, but the specific value to match is determined by
142/// another sub-pattern in the same sd_match() expression. For instance,
143/// We cannot match `(add V, V)` with `m_Add(m_Value(X), m_Specific(X))` since
144/// `X` is not initialized at the time it got copied into `m_Specific`. Instead,
145/// we should use `m_Add(m_Value(X), m_Deferred(X))`.
146inline DeferredValue_match m_Deferred(SDValue &V) {
147 return DeferredValue_match(V);
148}
149
150struct Opcode_match {
151 unsigned Opcode;
152
153 explicit Opcode_match(unsigned Opc) : Opcode(Opc) {}
154
155 template <typename MatchContext>
156 bool match(const MatchContext &Ctx, SDValue N) {
157 return Ctx.match(N, Opcode);
158 }
159};
160
161// === Patterns combinators ===
162template <typename... Preds> struct And {
163 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
164 return true;
165 }
166};
167
168template <typename Pred, typename... Preds>
169struct And<Pred, Preds...> : And<Preds...> {
170 Pred P;
171 And(const Pred &p, const Preds &...preds) : And<Preds...>(preds...), P(p) {}
172
173 template <typename MatchContext>
174 bool match(const MatchContext &Ctx, SDValue N) {
175 return P.match(Ctx, N) && And<Preds...>::match(Ctx, N);
176 }
177};
178
179template <typename... Preds> struct Or {
180 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
181 return false;
182 }
183};
184
185template <typename Pred, typename... Preds>
186struct Or<Pred, Preds...> : Or<Preds...> {
187 Pred P;
188 Or(const Pred &p, const Preds &...preds) : Or<Preds...>(preds...), P(p) {}
189
190 template <typename MatchContext>
191 bool match(const MatchContext &Ctx, SDValue N) {
192 return P.match(Ctx, N) || Or<Preds...>::match(Ctx, N);
193 }
194};
195
196template <typename Pred> struct Not {
197 Pred P;
198
199 explicit Not(const Pred &P) : P(P) {}
200
201 template <typename MatchContext>
202 bool match(const MatchContext &Ctx, SDValue N) {
203 return !P.match(Ctx, N);
204 }
205};
206// Explicit deduction guide.
207template <typename Pred> Not(const Pred &P) -> Not<Pred>;
208
209/// Match if the inner pattern does NOT match.
210template <typename Pred> inline Not<Pred> m_Unless(const Pred &P) {
211 return Not{P};
212}
213
214template <typename... Preds> And<Preds...> m_AllOf(const Preds &...preds) {
215 return And<Preds...>(preds...);
216}
217
218template <typename... Preds> Or<Preds...> m_AnyOf(const Preds &...preds) {
219 return Or<Preds...>(preds...);
220}
221
222template <typename... Preds> auto m_NoneOf(const Preds &...preds) {
223 return m_Unless(m_AnyOf(preds...));
224}
225
226inline Opcode_match m_Opc(unsigned Opcode) { return Opcode_match(Opcode); }
227
228inline auto m_Undef() {
229 return m_AnyOf(preds: Opcode_match(ISD::UNDEF), preds: Opcode_match(ISD::POISON));
230}
231
232inline Opcode_match m_Poison() { return Opcode_match(ISD::POISON); }
233
234template <unsigned NumUses, typename Pattern> struct NUses_match {
235 Pattern P;
236
237 explicit NUses_match(const Pattern &P) : P(P) {}
238
239 template <typename MatchContext>
240 bool match(const MatchContext &Ctx, SDValue N) {
241 // SDNode::hasNUsesOfValue is pretty expensive when the SDNode produces
242 // multiple results, hence we check the subsequent pattern here before
243 // checking the number of value users.
244 return P.match(Ctx, N) && N->hasNUsesOfValue(NUses: NumUses, Value: N.getResNo());
245 }
246};
247
248template <typename Pattern>
249inline NUses_match<1, Pattern> m_OneUse(const Pattern &P) {
250 return NUses_match<1, Pattern>(P);
251}
252template <unsigned N, typename Pattern>
253inline NUses_match<N, Pattern> m_NUses(const Pattern &P) {
254 return NUses_match<N, Pattern>(P);
255}
256
257inline NUses_match<1, Value_match> m_OneUse() {
258 return NUses_match<1, Value_match>(m_Value());
259}
260template <unsigned N> inline NUses_match<N, Value_match> m_NUses() {
261 return NUses_match<N, Value_match>(m_Value());
262}
263
264template <typename PredPattern> struct Value_bind {
265 SDValue &BindVal;
266 PredPattern Pred;
267
268 Value_bind(SDValue &N, const PredPattern &P) : BindVal(N), Pred(P) {}
269
270 template <typename MatchContext>
271 bool match(const MatchContext &Ctx, SDValue N) {
272 if (!Pred.match(Ctx, N))
273 return false;
274
275 BindVal = N;
276 return true;
277 }
278};
279
280inline auto m_Value(SDValue &N) {
281 return Value_bind<Value_match>(N, m_Value());
282}
283/// Conditionally bind an SDValue based on the predicate.
284template <typename PredPattern>
285inline auto m_Value(SDValue &N, const PredPattern &P) {
286 return Value_bind<PredPattern>(N, P);
287}
288
289template <typename Pattern, typename PredFuncT> struct TLI_pred_match {
290 Pattern P;
291 PredFuncT PredFunc;
292
293 TLI_pred_match(const PredFuncT &Pred, const Pattern &P)
294 : P(P), PredFunc(Pred) {}
295
296 template <typename MatchContext>
297 bool match(const MatchContext &Ctx, SDValue N) {
298 assert(Ctx.getTLI() && "TargetLowering is required for this pattern.");
299 return PredFunc(*Ctx.getTLI(), N) && P.match(Ctx, N);
300 }
301};
302
303// Explicit deduction guide.
304template <typename PredFuncT, typename Pattern>
305TLI_pred_match(const PredFuncT &Pred, const Pattern &P)
306 -> TLI_pred_match<Pattern, PredFuncT>;
307
308/// Match legal SDNodes based on the information provided by TargetLowering.
309template <typename Pattern> inline auto m_LegalOp(const Pattern &P) {
310 return TLI_pred_match{[](const TargetLowering &TLI, SDValue N) {
311 return TLI.isOperationLegal(Op: N->getOpcode(),
312 VT: N.getValueType());
313 },
314 P};
315}
316
317/// Switch to a different MatchContext for subsequent patterns.
318template <typename NewMatchContext, typename Pattern> struct SwitchContext {
319 const NewMatchContext &Ctx;
320 Pattern P;
321
322 template <typename OrigMatchContext>
323 bool match(const OrigMatchContext &, SDValue N) {
324 return P.match(Ctx, N);
325 }
326};
327
328template <typename MatchContext, typename Pattern>
329inline SwitchContext<MatchContext, Pattern> m_Context(const MatchContext &Ctx,
330 Pattern &&P) {
331 return SwitchContext<MatchContext, Pattern>{Ctx, std::move(P)};
332}
333
334// === Value type ===
335
336template <typename Pattern> struct ValueType_bind {
337 EVT &BindVT;
338 Pattern P;
339
340 explicit ValueType_bind(EVT &Bind, const Pattern &P) : BindVT(Bind), P(P) {}
341
342 template <typename MatchContext>
343 bool match(const MatchContext &Ctx, SDValue N) {
344 BindVT = N.getValueType();
345 return P.match(Ctx, N);
346 }
347};
348
349template <typename Pattern>
350ValueType_bind(const Pattern &P) -> ValueType_bind<Pattern>;
351
352/// Retreive the ValueType of the current SDValue.
353inline auto m_VT(EVT &VT) { return ValueType_bind(VT, m_Value()); }
354
355template <typename Pattern> inline auto m_VT(EVT &VT, const Pattern &P) {
356 return ValueType_bind(VT, P);
357}
358
359template <typename Pattern, typename PredFuncT> struct ValueType_match {
360 PredFuncT PredFunc;
361 Pattern P;
362
363 ValueType_match(const PredFuncT &Pred, const Pattern &P)
364 : PredFunc(Pred), P(P) {}
365
366 template <typename MatchContext>
367 bool match(const MatchContext &Ctx, SDValue N) {
368 return PredFunc(N.getValueType()) && P.match(Ctx, N);
369 }
370};
371
372// Explicit deduction guide.
373template <typename PredFuncT, typename Pattern>
374ValueType_match(const PredFuncT &Pred, const Pattern &P)
375 -> ValueType_match<Pattern, PredFuncT>;
376
377/// Match a specific ValueType.
378template <typename Pattern>
379inline auto m_SpecificVT(EVT RefVT, const Pattern &P) {
380 return ValueType_match{[=](EVT VT) { return VT == RefVT; }, P};
381}
382inline auto m_SpecificVT(EVT RefVT) {
383 return ValueType_match{[=](EVT VT) { return VT == RefVT; }, m_Value()};
384}
385
386inline auto m_Glue() { return m_SpecificVT(RefVT: MVT::Glue); }
387inline auto m_OtherVT() { return m_SpecificVT(RefVT: MVT::Other); }
388
389/// Match a scalar ValueType.
390template <typename Pattern>
391inline auto m_SpecificScalarVT(EVT RefVT, const Pattern &P) {
392 return ValueType_match{[=](EVT VT) { return VT.getScalarType() == RefVT; },
393 P};
394}
395inline auto m_SpecificScalarVT(EVT RefVT) {
396 return ValueType_match{[=](EVT VT) { return VT.getScalarType() == RefVT; },
397 m_Value()};
398}
399
400/// Match a vector ValueType.
401template <typename Pattern>
402inline auto m_SpecificVectorElementVT(EVT RefVT, const Pattern &P) {
403 return ValueType_match{[=](EVT VT) {
404 return VT.isVector() &&
405 VT.getVectorElementType() == RefVT;
406 },
407 P};
408}
409inline auto m_SpecificVectorElementVT(EVT RefVT) {
410 return ValueType_match{[=](EVT VT) {
411 return VT.isVector() &&
412 VT.getVectorElementType() == RefVT;
413 },
414 m_Value()};
415}
416
417/// Match any integer ValueTypes.
418template <typename Pattern> inline auto m_IntegerVT(const Pattern &P) {
419 return ValueType_match{[](EVT VT) { return VT.isInteger(); }, P};
420}
421inline auto m_IntegerVT() {
422 return ValueType_match{[](EVT VT) { return VT.isInteger(); }, m_Value()};
423}
424
425/// Match any floating point ValueTypes.
426template <typename Pattern> inline auto m_FloatingPointVT(const Pattern &P) {
427 return ValueType_match{[](EVT VT) { return VT.isFloatingPoint(); }, P};
428}
429inline auto m_FloatingPointVT() {
430 return ValueType_match{[](EVT VT) { return VT.isFloatingPoint(); },
431 m_Value()};
432}
433
434/// Match any vector ValueTypes.
435template <typename Pattern> inline auto m_VectorVT(const Pattern &P) {
436 return ValueType_match{[](EVT VT) { return VT.isVector(); }, P};
437}
438inline auto m_VectorVT() {
439 return ValueType_match{[](EVT VT) { return VT.isVector(); }, m_Value()};
440}
441
442/// Match fixed-length vector ValueTypes.
443template <typename Pattern> inline auto m_FixedVectorVT(const Pattern &P) {
444 return ValueType_match{[](EVT VT) { return VT.isFixedLengthVector(); }, P};
445}
446inline auto m_FixedVectorVT() {
447 return ValueType_match{[](EVT VT) { return VT.isFixedLengthVector(); },
448 m_Value()};
449}
450
451/// Match scalable vector ValueTypes.
452template <typename Pattern> inline auto m_ScalableVectorVT(const Pattern &P) {
453 return ValueType_match{[](EVT VT) { return VT.isScalableVector(); }, P};
454}
455inline auto m_ScalableVectorVT() {
456 return ValueType_match{[](EVT VT) { return VT.isScalableVector(); },
457 m_Value()};
458}
459
460/// Match legal ValueTypes based on the information provided by TargetLowering.
461template <typename Pattern> inline auto m_LegalType(const Pattern &P) {
462 return TLI_pred_match{[](const TargetLowering &TLI, SDValue N) {
463 return TLI.isTypeLegal(VT: N.getValueType());
464 },
465 P};
466}
467
468// === Generic node matching ===
469template <unsigned OpIdx, typename... OpndPreds> struct Operands_match {
470 template <typename MatchContext>
471 bool match(const MatchContext &Ctx, SDValue N) {
472 // Returns false if there are more operands than predicates;
473 // Ignores the last two operands if both the Context and the Node are VP
474 return Ctx.getNumOperands(N) == OpIdx;
475 }
476};
477
478template <unsigned OpIdx, typename OpndPred, typename... OpndPreds>
479struct Operands_match<OpIdx, OpndPred, OpndPreds...>
480 : Operands_match<OpIdx + 1, OpndPreds...> {
481 OpndPred P;
482
483 Operands_match(const OpndPred &p, const OpndPreds &...preds)
484 : Operands_match<OpIdx + 1, OpndPreds...>(preds...), P(p) {}
485
486 template <typename MatchContext>
487 bool match(const MatchContext &Ctx, SDValue N) {
488 if (OpIdx < N->getNumOperands())
489 return P.match(Ctx, N->getOperand(Num: OpIdx)) &&
490 Operands_match<OpIdx + 1, OpndPreds...>::match(Ctx, N);
491
492 // This is the case where there are more predicates than operands.
493 return false;
494 }
495};
496
497template <typename... OpndPreds>
498auto m_Node(unsigned Opcode, const OpndPreds &...preds) {
499 return m_AllOf(m_Opc(Opcode), Operands_match<0, OpndPreds...>(preds...));
500}
501
502/// Provide number of operands that are not chain or glue, as well as the first
503/// index of such operand.
504template <bool ExcludeChain> struct EffectiveOperands {
505 unsigned Size = 0;
506 unsigned FirstIndex = 0;
507
508 template <typename MatchContext>
509 explicit EffectiveOperands(SDValue N, const MatchContext &Ctx) {
510 const unsigned TotalNumOps = Ctx.getNumOperands(N);
511 FirstIndex = TotalNumOps;
512 for (unsigned I = 0; I < TotalNumOps; ++I) {
513 // Count the number of non-chain and non-glue nodes (we ignore chain
514 // and glue by default) and retreive the operand index offset.
515 EVT VT = N->getOperand(Num: I).getValueType();
516 if (VT != MVT::Glue && VT != MVT::Other) {
517 ++Size;
518 if (FirstIndex == TotalNumOps)
519 FirstIndex = I;
520 }
521 }
522 }
523};
524
525template <> struct EffectiveOperands<false> {
526 unsigned Size = 0;
527 unsigned FirstIndex = 0;
528
529 template <typename MatchContext>
530 explicit EffectiveOperands(SDValue N, const MatchContext &Ctx)
531 : Size(Ctx.getNumOperands(N)) {}
532};
533
534// === Ternary operations ===
535template <typename T0_P, typename T1_P, typename T2_P, bool Commutable = false,
536 bool ExcludeChain = false>
537struct TernaryOpc_match {
538 unsigned Opcode;
539 T0_P Op0;
540 T1_P Op1;
541 T2_P Op2;
542
543 TernaryOpc_match(unsigned Opc, const T0_P &Op0, const T1_P &Op1,
544 const T2_P &Op2)
545 : Opcode(Opc), Op0(Op0), Op1(Op1), Op2(Op2) {}
546
547 template <typename MatchContext>
548 bool match(const MatchContext &Ctx, SDValue N) {
549 if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
550 EffectiveOperands<ExcludeChain> EO(N, Ctx);
551 assert(EO.Size == 3);
552 return ((Op0.match(Ctx, N->getOperand(Num: EO.FirstIndex)) &&
553 Op1.match(Ctx, N->getOperand(Num: EO.FirstIndex + 1))) ||
554 (Commutable && Op0.match(Ctx, N->getOperand(Num: EO.FirstIndex + 1)) &&
555 Op1.match(Ctx, N->getOperand(Num: EO.FirstIndex)))) &&
556 Op2.match(Ctx, N->getOperand(Num: EO.FirstIndex + 2));
557 }
558
559 return false;
560 }
561};
562
563template <typename T0_P, typename T1_P, typename T2_P>
564inline TernaryOpc_match<T0_P, T1_P, T2_P>
565m_SetCC(const T0_P &LHS, const T1_P &RHS, const T2_P &CC) {
566 return TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::SETCC, LHS, RHS, CC);
567}
568
569template <typename T0_P, typename T1_P, typename T2_P>
570inline TernaryOpc_match<T0_P, T1_P, T2_P, true, false>
571m_c_SetCC(const T0_P &LHS, const T1_P &RHS, const T2_P &CC) {
572 return TernaryOpc_match<T0_P, T1_P, T2_P, true, false>(ISD::SETCC, LHS, RHS,
573 CC);
574}
575
576template <typename T0_P, typename T1_P, typename T2_P>
577inline TernaryOpc_match<T0_P, T1_P, T2_P>
578m_Select(const T0_P &Cond, const T1_P &T, const T2_P &F) {
579 return TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::SELECT, Cond, T, F);
580}
581
582template <typename T0_P, typename T1_P, typename T2_P>
583inline TernaryOpc_match<T0_P, T1_P, T2_P>
584m_VSelect(const T0_P &Cond, const T1_P &T, const T2_P &F) {
585 return TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::VSELECT, Cond, T, F);
586}
587
588template <typename T0_P, typename T1_P, typename T2_P>
589inline auto m_SelectLike(const T0_P &Cond, const T1_P &T, const T2_P &F) {
590 return m_AnyOf(m_Select(Cond, T, F), m_VSelect(Cond, T, F));
591}
592
593template <typename T0_P, typename T1_P, typename T2_P>
594inline Result_match<0, TernaryOpc_match<T0_P, T1_P, T2_P>>
595m_Load(const T0_P &Ch, const T1_P &Ptr, const T2_P &Offset) {
596 return m_Result<0>(
597 TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::LOAD, Ch, Ptr, Offset));
598}
599
600template <typename T0_P, typename T1_P, typename T2_P>
601inline TernaryOpc_match<T0_P, T1_P, T2_P>
602m_InsertElt(const T0_P &Vec, const T1_P &Val, const T2_P &Idx) {
603 return TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::INSERT_VECTOR_ELT, Vec, Val,
604 Idx);
605}
606
607template <typename LHS, typename RHS, typename IDX>
608inline TernaryOpc_match<LHS, RHS, IDX>
609m_InsertSubvector(const LHS &Base, const RHS &Sub, const IDX &Idx) {
610 return TernaryOpc_match<LHS, RHS, IDX>(ISD::INSERT_SUBVECTOR, Base, Sub, Idx);
611}
612
613template <typename T0_P, typename T1_P, typename T2_P>
614inline TernaryOpc_match<T0_P, T1_P, T2_P>
615m_TernaryOp(unsigned Opc, const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
616 return TernaryOpc_match<T0_P, T1_P, T2_P>(Opc, Op0, Op1, Op2);
617}
618
619template <typename T0_P, typename T1_P, typename T2_P>
620inline TernaryOpc_match<T0_P, T1_P, T2_P, true>
621m_c_TernaryOp(unsigned Opc, const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
622 return TernaryOpc_match<T0_P, T1_P, T2_P, true>(Opc, Op0, Op1, Op2);
623}
624
625template <typename LTy, typename RTy, typename TTy, typename FTy, typename CCTy>
626inline auto m_SelectCC(const LTy &L, const RTy &R, const TTy &T, const FTy &F,
627 const CCTy &CC) {
628 return m_Node(ISD::SELECT_CC, L, R, T, F, CC);
629}
630
631template <typename LTy, typename RTy, typename TTy, typename FTy, typename CCTy>
632inline auto m_SelectCCLike(const LTy &L, const RTy &R, const TTy &T,
633 const FTy &F, const CCTy &CC) {
634 return m_AnyOf(m_Select(m_SetCC(L, R, CC), T, F), m_SelectCC(L, R, T, F, CC));
635}
636
637// === Binary operations ===
638template <typename LHS_P, typename RHS_P, bool Commutable = false,
639 bool ExcludeChain = false>
640struct BinaryOpc_match {
641 unsigned Opcode;
642 LHS_P LHS;
643 RHS_P RHS;
644 SDNodeFlags Flags;
645 BinaryOpc_match(unsigned Opc, const LHS_P &L, const RHS_P &R,
646 SDNodeFlags Flgs = SDNodeFlags())
647 : Opcode(Opc), LHS(L), RHS(R), Flags(Flgs) {}
648
649 template <typename MatchContext>
650 bool match(const MatchContext &Ctx, SDValue N) {
651 if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
652 EffectiveOperands<ExcludeChain> EO(N, Ctx);
653 assert(EO.Size == 2);
654 if (!((LHS.match(Ctx, N->getOperand(Num: EO.FirstIndex)) &&
655 RHS.match(Ctx, N->getOperand(Num: EO.FirstIndex + 1))) ||
656 (Commutable && LHS.match(Ctx, N->getOperand(Num: EO.FirstIndex + 1)) &&
657 RHS.match(Ctx, N->getOperand(Num: EO.FirstIndex)))))
658 return false;
659
660 return (Flags & N->getFlags()) == Flags;
661 }
662
663 return false;
664 }
665};
666
667/// Matching while capturing mask
668template <typename T0, typename T1, typename T2> struct SDShuffle_match {
669 T0 Op1;
670 T1 Op2;
671 T2 Mask;
672
673 SDShuffle_match(const T0 &Op1, const T1 &Op2, const T2 &Mask)
674 : Op1(Op1), Op2(Op2), Mask(Mask) {}
675
676 template <typename MatchContext>
677 bool match(const MatchContext &Ctx, SDValue N) {
678 if (auto *I = dyn_cast<ShuffleVectorSDNode>(Val&: N)) {
679 return Op1.match(Ctx, I->getOperand(Num: 0)) &&
680 Op2.match(Ctx, I->getOperand(Num: 1)) && Mask.match(I->getMask());
681 }
682 return false;
683 }
684};
685struct m_Mask {
686 ArrayRef<int> &MaskRef;
687 m_Mask(ArrayRef<int> &MaskRef) : MaskRef(MaskRef) {}
688 bool match(ArrayRef<int> Mask) {
689 MaskRef = Mask;
690 return true;
691 }
692};
693
694struct m_SpecificMask {
695 ArrayRef<int> MaskRef;
696 m_SpecificMask(ArrayRef<int> MaskRef) : MaskRef(MaskRef) {}
697 bool match(ArrayRef<int> Mask) { return MaskRef == Mask; }
698};
699
700template <typename LHS_P, typename RHS_P, typename Pred_t,
701 bool Commutable = false, bool ExcludeChain = false>
702struct MaxMin_match {
703 using PredType = Pred_t;
704 LHS_P LHS;
705 RHS_P RHS;
706
707 MaxMin_match(const LHS_P &L, const RHS_P &R) : LHS(L), RHS(R) {}
708
709 template <typename MatchContext>
710 bool match(const MatchContext &Ctx, SDValue N) {
711 auto MatchMinMax = [&](SDValue L, SDValue R, SDValue TrueValue,
712 SDValue FalseValue, ISD::CondCode CC) {
713 if ((TrueValue != L || FalseValue != R) &&
714 (TrueValue != R || FalseValue != L))
715 return false;
716
717 ISD::CondCode Cond =
718 TrueValue == L ? CC : getSetCCInverse(Operation: CC, Type: L.getValueType());
719 if (!Pred_t::match(Cond))
720 return false;
721
722 return (LHS.match(Ctx, L) && RHS.match(Ctx, R)) ||
723 (Commutable && LHS.match(Ctx, R) && RHS.match(Ctx, L));
724 };
725
726 if (sd_context_match(N, Ctx, m_Opc(Opcode: ISD::SELECT)) ||
727 sd_context_match(N, Ctx, m_Opc(Opcode: ISD::VSELECT))) {
728 EffectiveOperands<ExcludeChain> EO_SELECT(N, Ctx);
729 assert(EO_SELECT.Size == 3);
730 SDValue Cond = N->getOperand(Num: EO_SELECT.FirstIndex);
731 SDValue TrueValue = N->getOperand(Num: EO_SELECT.FirstIndex + 1);
732 SDValue FalseValue = N->getOperand(Num: EO_SELECT.FirstIndex + 2);
733
734 if (sd_context_match(Cond, Ctx, m_Opc(Opcode: ISD::SETCC))) {
735 EffectiveOperands<ExcludeChain> EO_SETCC(Cond, Ctx);
736 assert(EO_SETCC.Size == 3);
737 SDValue L = Cond->getOperand(Num: EO_SETCC.FirstIndex);
738 SDValue R = Cond->getOperand(Num: EO_SETCC.FirstIndex + 1);
739 auto *CondNode =
740 cast<CondCodeSDNode>(Cond->getOperand(Num: EO_SETCC.FirstIndex + 2));
741 return MatchMinMax(L, R, TrueValue, FalseValue, CondNode->get());
742 }
743 }
744
745 if (sd_context_match(N, Ctx, m_Opc(Opcode: ISD::SELECT_CC))) {
746 EffectiveOperands<ExcludeChain> EO_SELECT(N, Ctx);
747 assert(EO_SELECT.Size == 5);
748 SDValue L = N->getOperand(Num: EO_SELECT.FirstIndex);
749 SDValue R = N->getOperand(Num: EO_SELECT.FirstIndex + 1);
750 SDValue TrueValue = N->getOperand(Num: EO_SELECT.FirstIndex + 2);
751 SDValue FalseValue = N->getOperand(Num: EO_SELECT.FirstIndex + 3);
752 auto *CondNode =
753 cast<CondCodeSDNode>(N->getOperand(Num: EO_SELECT.FirstIndex + 4));
754 return MatchMinMax(L, R, TrueValue, FalseValue, CondNode->get());
755 }
756
757 return false;
758 }
759};
760
761// Helper class for identifying signed max predicates.
762struct smax_pred_ty {
763 static bool match(ISD::CondCode Cond) {
764 return Cond == ISD::CondCode::SETGT || Cond == ISD::CondCode::SETGE;
765 }
766};
767
768// Helper class for identifying unsigned max predicates.
769struct umax_pred_ty {
770 static bool match(ISD::CondCode Cond) {
771 return Cond == ISD::CondCode::SETUGT || Cond == ISD::CondCode::SETUGE;
772 }
773};
774
775// Helper class for identifying signed min predicates.
776struct smin_pred_ty {
777 static bool match(ISD::CondCode Cond) {
778 return Cond == ISD::CondCode::SETLT || Cond == ISD::CondCode::SETLE;
779 }
780};
781
782// Helper class for identifying unsigned min predicates.
783struct umin_pred_ty {
784 static bool match(ISD::CondCode Cond) {
785 return Cond == ISD::CondCode::SETULT || Cond == ISD::CondCode::SETULE;
786 }
787};
788
789template <typename LHS, typename RHS>
790inline BinaryOpc_match<LHS, RHS> m_BinOp(unsigned Opc, const LHS &L,
791 const RHS &R,
792 SDNodeFlags Flgs = SDNodeFlags()) {
793 return BinaryOpc_match<LHS, RHS>(Opc, L, R, Flgs);
794}
795template <typename LHS, typename RHS>
796inline BinaryOpc_match<LHS, RHS, true>
797m_c_BinOp(unsigned Opc, const LHS &L, const RHS &R,
798 SDNodeFlags Flgs = SDNodeFlags()) {
799 return BinaryOpc_match<LHS, RHS, true>(Opc, L, R, Flgs);
800}
801
802template <typename LHS, typename RHS>
803inline BinaryOpc_match<LHS, RHS, false, true>
804m_ChainedBinOp(unsigned Opc, const LHS &L, const RHS &R) {
805 return BinaryOpc_match<LHS, RHS, false, true>(Opc, L, R);
806}
807template <typename LHS, typename RHS>
808inline BinaryOpc_match<LHS, RHS, true, true>
809m_c_ChainedBinOp(unsigned Opc, const LHS &L, const RHS &R) {
810 return BinaryOpc_match<LHS, RHS, true, true>(Opc, L, R);
811}
812
813// Common binary operations
814template <typename LHS, typename RHS>
815inline BinaryOpc_match<LHS, RHS, true> m_Add(const LHS &L, const RHS &R) {
816 return BinaryOpc_match<LHS, RHS, true>(ISD::ADD, L, R);
817}
818
819template <typename LHS, typename RHS>
820inline BinaryOpc_match<LHS, RHS> m_Sub(const LHS &L, const RHS &R) {
821 return BinaryOpc_match<LHS, RHS>(ISD::SUB, L, R);
822}
823
824template <typename LHS, typename RHS>
825inline BinaryOpc_match<LHS, RHS, true> m_Mul(const LHS &L, const RHS &R) {
826 return BinaryOpc_match<LHS, RHS, true>(ISD::MUL, L, R);
827}
828
829template <typename LHS, typename RHS>
830inline BinaryOpc_match<LHS, RHS, true> m_And(const LHS &L, const RHS &R) {
831 return BinaryOpc_match<LHS, RHS, true>(ISD::AND, L, R);
832}
833
834template <typename LHS, typename RHS>
835inline BinaryOpc_match<LHS, RHS, true> m_Or(const LHS &L, const RHS &R) {
836 return BinaryOpc_match<LHS, RHS, true>(ISD::OR, L, R);
837}
838
839template <typename LHS, typename RHS>
840inline BinaryOpc_match<LHS, RHS, true> m_DisjointOr(const LHS &L,
841 const RHS &R) {
842 return BinaryOpc_match<LHS, RHS, true>(ISD::OR, L, R, SDNodeFlags::Disjoint);
843}
844
845template <typename LHS, typename RHS>
846inline auto m_AddLike(const LHS &L, const RHS &R) {
847 return m_AnyOf(m_Add(L, R), m_DisjointOr(L, R));
848}
849
850template <typename LHS, typename RHS>
851inline BinaryOpc_match<LHS, RHS, true> m_Xor(const LHS &L, const RHS &R) {
852 return BinaryOpc_match<LHS, RHS, true>(ISD::XOR, L, R);
853}
854
855template <typename LHS, typename RHS>
856inline auto m_BitwiseLogic(const LHS &L, const RHS &R) {
857 return m_AnyOf(m_And(L, R), m_Or(L, R), m_Xor(L, R));
858}
859
860template <unsigned Opc, typename Pred, typename LHS, typename RHS>
861inline auto m_MaxMinLike(const LHS &L, const RHS &R) {
862 return m_AnyOf(BinaryOpc_match<LHS, RHS, true>(Opc, L, R),
863 MaxMin_match<LHS, RHS, Pred, true>(L, R));
864}
865
866template <typename LHS, typename RHS>
867inline BinaryOpc_match<LHS, RHS, true> m_SMin(const LHS &L, const RHS &R) {
868 return BinaryOpc_match<LHS, RHS, true>(ISD::SMIN, L, R);
869}
870
871template <typename LHS, typename RHS>
872inline auto m_SMinLike(const LHS &L, const RHS &R) {
873 return m_AnyOf(
874 m_MaxMinLike<ISD::SMIN, smin_pred_ty>(L, R),
875 m_MaxMinLike<ISD::UMIN, umin_pred_ty>(m_NonNegative(L), m_NonNegative(R)),
876 m_MaxMinLike<ISD::UMIN, umin_pred_ty>(m_Negative(L), m_Negative(R)));
877}
878
879template <typename LHS, typename RHS>
880inline BinaryOpc_match<LHS, RHS, true> m_SMax(const LHS &L, const RHS &R) {
881 return BinaryOpc_match<LHS, RHS, true>(ISD::SMAX, L, R);
882}
883
884template <typename LHS, typename RHS>
885inline auto m_SMaxLike(const LHS &L, const RHS &R) {
886 return m_AnyOf(
887 m_MaxMinLike<ISD::SMAX, smax_pred_ty>(L, R),
888 m_MaxMinLike<ISD::UMAX, umax_pred_ty>(m_NonNegative(L), m_NonNegative(R)),
889 m_MaxMinLike<ISD::UMAX, umax_pred_ty>(m_Negative(L), m_Negative(R)));
890}
891
892template <typename LHS, typename RHS>
893inline BinaryOpc_match<LHS, RHS, true> m_UMin(const LHS &L, const RHS &R) {
894 return BinaryOpc_match<LHS, RHS, true>(ISD::UMIN, L, R);
895}
896
897template <typename LHS, typename RHS>
898inline auto m_UMinLike(const LHS &L, const RHS &R) {
899 return m_AnyOf(
900 m_MaxMinLike<ISD::UMIN, umin_pred_ty>(L, R),
901 m_MaxMinLike<ISD::SMIN, smin_pred_ty>(m_NonNegative(L), m_NonNegative(R)),
902 m_MaxMinLike<ISD::SMIN, smin_pred_ty>(m_Negative(L), m_Negative(R)));
903}
904
905template <typename LHS, typename RHS>
906inline BinaryOpc_match<LHS, RHS, true> m_UMax(const LHS &L, const RHS &R) {
907 return BinaryOpc_match<LHS, RHS, true>(ISD::UMAX, L, R);
908}
909
910template <typename LHS, typename RHS>
911inline auto m_UMaxLike(const LHS &L, const RHS &R) {
912 return m_AnyOf(
913 m_MaxMinLike<ISD::UMAX, umax_pred_ty>(L, R),
914 m_MaxMinLike<ISD::SMAX, smax_pred_ty>(m_NonNegative(L), m_NonNegative(R)),
915 m_MaxMinLike<ISD::SMAX, smax_pred_ty>(m_Negative(L), m_Negative(R)));
916}
917
918template <typename LHS, typename RHS>
919inline BinaryOpc_match<LHS, RHS> m_UDiv(const LHS &L, const RHS &R) {
920 return BinaryOpc_match<LHS, RHS>(ISD::UDIV, L, R);
921}
922template <typename LHS, typename RHS>
923inline BinaryOpc_match<LHS, RHS> m_SDiv(const LHS &L, const RHS &R) {
924 return BinaryOpc_match<LHS, RHS>(ISD::SDIV, L, R);
925}
926
927template <typename LHS, typename RHS>
928inline BinaryOpc_match<LHS, RHS> m_URem(const LHS &L, const RHS &R) {
929 return BinaryOpc_match<LHS, RHS>(ISD::UREM, L, R);
930}
931template <typename LHS, typename RHS>
932inline BinaryOpc_match<LHS, RHS> m_SRem(const LHS &L, const RHS &R) {
933 return BinaryOpc_match<LHS, RHS>(ISD::SREM, L, R);
934}
935
936template <typename LHS, typename RHS>
937inline BinaryOpc_match<LHS, RHS> m_Shl(const LHS &L, const RHS &R) {
938 return BinaryOpc_match<LHS, RHS>(ISD::SHL, L, R);
939}
940
941template <typename LHS, typename RHS>
942inline BinaryOpc_match<LHS, RHS> m_Sra(const LHS &L, const RHS &R) {
943 return BinaryOpc_match<LHS, RHS>(ISD::SRA, L, R);
944}
945template <typename LHS, typename RHS>
946inline BinaryOpc_match<LHS, RHS> m_Srl(const LHS &L, const RHS &R) {
947 return BinaryOpc_match<LHS, RHS>(ISD::SRL, L, R);
948}
949template <typename LHS, typename RHS>
950inline auto m_ExactSr(const LHS &L, const RHS &R) {
951 return m_AnyOf(BinaryOpc_match<LHS, RHS>(ISD::SRA, L, R, SDNodeFlags::Exact),
952 BinaryOpc_match<LHS, RHS>(ISD::SRL, L, R, SDNodeFlags::Exact));
953}
954
955template <typename LHS, typename RHS>
956inline BinaryOpc_match<LHS, RHS> m_Rotl(const LHS &L, const RHS &R) {
957 return BinaryOpc_match<LHS, RHS>(ISD::ROTL, L, R);
958}
959
960template <typename LHS, typename RHS>
961inline BinaryOpc_match<LHS, RHS> m_Rotr(const LHS &L, const RHS &R) {
962 return BinaryOpc_match<LHS, RHS>(ISD::ROTR, L, R);
963}
964
965template <typename LHS, typename RHS>
966inline BinaryOpc_match<LHS, RHS, true> m_Clmul(const LHS &L, const RHS &R) {
967 return BinaryOpc_match<LHS, RHS, true>(ISD::CLMUL, L, R);
968}
969
970template <typename LHS, typename RHS>
971inline BinaryOpc_match<LHS, RHS, true> m_FAdd(const LHS &L, const RHS &R) {
972 return BinaryOpc_match<LHS, RHS, true>(ISD::FADD, L, R);
973}
974
975template <typename LHS, typename RHS>
976inline BinaryOpc_match<LHS, RHS> m_FSub(const LHS &L, const RHS &R) {
977 return BinaryOpc_match<LHS, RHS>(ISD::FSUB, L, R);
978}
979
980template <typename LHS, typename RHS>
981inline BinaryOpc_match<LHS, RHS, true> m_FMul(const LHS &L, const RHS &R) {
982 return BinaryOpc_match<LHS, RHS, true>(ISD::FMUL, L, R);
983}
984
985template <typename LHS, typename RHS>
986inline BinaryOpc_match<LHS, RHS> m_FDiv(const LHS &L, const RHS &R) {
987 return BinaryOpc_match<LHS, RHS>(ISD::FDIV, L, R);
988}
989
990template <typename LHS, typename RHS>
991inline BinaryOpc_match<LHS, RHS> m_FRem(const LHS &L, const RHS &R) {
992 return BinaryOpc_match<LHS, RHS>(ISD::FREM, L, R);
993}
994
995template <typename V1_t, typename V2_t>
996inline BinaryOpc_match<V1_t, V2_t> m_Shuffle(const V1_t &v1, const V2_t &v2) {
997 return BinaryOpc_match<V1_t, V2_t>(ISD::VECTOR_SHUFFLE, v1, v2);
998}
999
1000template <typename V1_t, typename V2_t, typename Mask_t>
1001inline SDShuffle_match<V1_t, V2_t, Mask_t>
1002m_Shuffle(const V1_t &v1, const V2_t &v2, const Mask_t &mask) {
1003 return SDShuffle_match<V1_t, V2_t, Mask_t>(v1, v2, mask);
1004}
1005
1006template <typename LHS, typename RHS>
1007inline BinaryOpc_match<LHS, RHS> m_ExtractElt(const LHS &Vec, const RHS &Idx) {
1008 return BinaryOpc_match<LHS, RHS>(ISD::EXTRACT_VECTOR_ELT, Vec, Idx);
1009}
1010
1011template <typename LHS, typename RHS>
1012inline BinaryOpc_match<LHS, RHS> m_ExtractSubvector(const LHS &Vec,
1013 const RHS &Idx) {
1014 return BinaryOpc_match<LHS, RHS>(ISD::EXTRACT_SUBVECTOR, Vec, Idx);
1015}
1016
1017// === Unary operations ===
1018template <typename Opnd_P, bool ExcludeChain = false> struct UnaryOpc_match {
1019 unsigned Opcode;
1020 Opnd_P Opnd;
1021 SDNodeFlags Flags;
1022 UnaryOpc_match(unsigned Opc, const Opnd_P &Op,
1023 SDNodeFlags Flgs = SDNodeFlags())
1024 : Opcode(Opc), Opnd(Op), Flags(Flgs) {}
1025
1026 template <typename MatchContext>
1027 bool match(const MatchContext &Ctx, SDValue N) {
1028 if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
1029 EffectiveOperands<ExcludeChain> EO(N, Ctx);
1030 assert(EO.Size == 1);
1031 if (!Opnd.match(Ctx, N->getOperand(Num: EO.FirstIndex)))
1032 return false;
1033
1034 return (Flags & N->getFlags()) == Flags;
1035 }
1036
1037 return false;
1038 }
1039};
1040
1041template <typename Opnd>
1042inline UnaryOpc_match<Opnd> m_UnaryOp(unsigned Opc, const Opnd &Op) {
1043 return UnaryOpc_match<Opnd>(Opc, Op);
1044}
1045template <typename Opnd>
1046inline UnaryOpc_match<Opnd, true> m_ChainedUnaryOp(unsigned Opc,
1047 const Opnd &Op) {
1048 return UnaryOpc_match<Opnd, true>(Opc, Op);
1049}
1050
1051template <typename Opnd> inline UnaryOpc_match<Opnd> m_BitCast(const Opnd &Op) {
1052 return UnaryOpc_match<Opnd>(ISD::BITCAST, Op);
1053}
1054
1055template <typename Opnd>
1056inline UnaryOpc_match<Opnd> m_BSwap(const Opnd &Op) {
1057 return UnaryOpc_match<Opnd>(ISD::BSWAP, Op);
1058}
1059
1060template <typename Opnd>
1061inline UnaryOpc_match<Opnd> m_BitReverse(const Opnd &Op) {
1062 return UnaryOpc_match<Opnd>(ISD::BITREVERSE, Op);
1063}
1064
1065template <typename Opnd> inline UnaryOpc_match<Opnd> m_ZExt(const Opnd &Op) {
1066 return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op);
1067}
1068
1069template <typename Opnd>
1070inline UnaryOpc_match<Opnd> m_NNegZExt(const Opnd &Op) {
1071 return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op, SDNodeFlags::NonNeg);
1072}
1073
1074template <typename Opnd> inline auto m_SExt(const Opnd &Op) {
1075 return UnaryOpc_match<Opnd>(ISD::SIGN_EXTEND, Op);
1076}
1077
1078template <typename Opnd> inline UnaryOpc_match<Opnd> m_AnyExt(const Opnd &Op) {
1079 return UnaryOpc_match<Opnd>(ISD::ANY_EXTEND, Op);
1080}
1081
1082template <typename Opnd> inline UnaryOpc_match<Opnd> m_Trunc(const Opnd &Op) {
1083 return UnaryOpc_match<Opnd>(ISD::TRUNCATE, Op);
1084}
1085
1086template <typename Opnd> inline UnaryOpc_match<Opnd> m_Abs(const Opnd &Op) {
1087 return UnaryOpc_match<Opnd>(ISD::ABS, Op);
1088}
1089
1090template <typename Opnd> inline UnaryOpc_match<Opnd> m_FAbs(const Opnd &Op) {
1091 return UnaryOpc_match<Opnd>(ISD::FABS, Op);
1092}
1093
1094/// Match a zext or identity
1095/// Allows to peek through optional extensions
1096template <typename Opnd> inline auto m_ZExtOrSelf(const Opnd &Op) {
1097 return m_AnyOf(m_ZExt(Op), Op);
1098}
1099
1100/// Match a sext or identity
1101/// Allows to peek through optional extensions
1102template <typename Opnd> inline auto m_SExtOrSelf(const Opnd &Op) {
1103 return m_AnyOf(m_SExt(Op), Op);
1104}
1105
1106template <typename Opnd> inline auto m_SExtLike(const Opnd &Op) {
1107 return m_AnyOf(m_SExt(Op), m_NNegZExt(Op));
1108}
1109
1110/// Match a aext or identity
1111/// Allows to peek through optional extensions
1112template <typename Opnd>
1113inline Or<UnaryOpc_match<Opnd>, Opnd> m_AExtOrSelf(const Opnd &Op) {
1114 return Or<UnaryOpc_match<Opnd>, Opnd>(m_AnyExt(Op), Op);
1115}
1116
1117/// Match a trunc or identity
1118/// Allows to peek through optional truncations
1119template <typename Opnd>
1120inline Or<UnaryOpc_match<Opnd>, Opnd> m_TruncOrSelf(const Opnd &Op) {
1121 return Or<UnaryOpc_match<Opnd>, Opnd>(m_Trunc(Op), Op);
1122}
1123
1124template <typename Opnd> inline UnaryOpc_match<Opnd> m_VScale(const Opnd &Op) {
1125 return UnaryOpc_match<Opnd>(ISD::VSCALE, Op);
1126}
1127
1128template <typename Opnd> inline UnaryOpc_match<Opnd> m_FPToUI(const Opnd &Op) {
1129 return UnaryOpc_match<Opnd>(ISD::FP_TO_UINT, Op);
1130}
1131
1132template <typename Opnd> inline UnaryOpc_match<Opnd> m_FPToSI(const Opnd &Op) {
1133 return UnaryOpc_match<Opnd>(ISD::FP_TO_SINT, Op);
1134}
1135
1136template <typename Opnd> inline UnaryOpc_match<Opnd> m_Ctpop(const Opnd &Op) {
1137 return UnaryOpc_match<Opnd>(ISD::CTPOP, Op);
1138}
1139
1140template <typename Opnd> inline UnaryOpc_match<Opnd> m_Ctlz(const Opnd &Op) {
1141 return UnaryOpc_match<Opnd>(ISD::CTLZ, Op);
1142}
1143
1144template <typename Opnd> inline UnaryOpc_match<Opnd> m_Cttz(const Opnd &Op) {
1145 return UnaryOpc_match<Opnd>(ISD::CTTZ, Op);
1146}
1147
1148template <typename Opnd> inline UnaryOpc_match<Opnd> m_FNeg(const Opnd &Op) {
1149 return UnaryOpc_match<Opnd>(ISD::FNEG, Op);
1150}
1151
1152// === Constants ===
1153struct ConstantInt_match {
1154 APInt *BindVal;
1155
1156 explicit ConstantInt_match(APInt *V) : BindVal(V) {}
1157
1158 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
1159 // The logics here are similar to that in
1160 // SelectionDAG::isConstantIntBuildVectorOrConstantInt, but the latter also
1161 // treats GlobalAddressSDNode as a constant, which is difficult to turn into
1162 // APInt.
1163 if (auto *C = dyn_cast_or_null<ConstantSDNode>(Val: N.getNode())) {
1164 if (BindVal)
1165 *BindVal = C->getAPIntValue();
1166 return true;
1167 }
1168
1169 APInt Discard;
1170 return ISD::isConstantSplatVector(N: N.getNode(),
1171 SplatValue&: BindVal ? *BindVal : Discard);
1172 }
1173};
1174
1175template <typename T> struct Constant64_match {
1176 static_assert(sizeof(T) == 8, "T must be 64 bits wide");
1177
1178 T &BindVal;
1179
1180 explicit Constant64_match(T &V) : BindVal(V) {}
1181
1182 template <typename MatchContext>
1183 bool match(const MatchContext &Ctx, SDValue N) {
1184 APInt V;
1185 if (!ConstantInt_match(&V).match(Ctx, N))
1186 return false;
1187
1188 if constexpr (std::is_signed_v<T>) {
1189 if (std::optional<int64_t> TrySExt = V.trySExtValue()) {
1190 BindVal = *TrySExt;
1191 return true;
1192 }
1193 }
1194
1195 if constexpr (std::is_unsigned_v<T>) {
1196 if (std::optional<uint64_t> TryZExt = V.tryZExtValue()) {
1197 BindVal = *TryZExt;
1198 return true;
1199 }
1200 }
1201
1202 return false;
1203 }
1204};
1205
1206/// Match any integer constants or splat of an integer constant.
1207inline ConstantInt_match m_ConstInt() { return ConstantInt_match(nullptr); }
1208/// Match any integer constants or splat of an integer constant; return the
1209/// specific constant or constant splat value.
1210inline ConstantInt_match m_ConstInt(APInt &V) { return ConstantInt_match(&V); }
1211/// Match any integer constants or splat of an integer constant that can fit in
1212/// 64 bits; return the specific constant or constant splat value, zero-extended
1213/// to 64 bits.
1214inline Constant64_match<uint64_t> m_ConstInt(uint64_t &V) {
1215 return Constant64_match<uint64_t>(V);
1216}
1217/// Match any integer constants or splat of an integer constant that can fit in
1218/// 64 bits; return the specific constant or constant splat value, sign-extended
1219/// to 64 bits.
1220inline Constant64_match<int64_t> m_ConstInt(int64_t &V) {
1221 return Constant64_match<int64_t>(V);
1222}
1223
1224struct SpecificInt_match {
1225 APInt IntVal;
1226
1227 explicit SpecificInt_match(APInt APV) : IntVal(std::move(APV)) {}
1228
1229 template <typename MatchContext>
1230 bool match(const MatchContext &Ctx, SDValue N) {
1231 APInt ConstInt;
1232 if (sd_context_match(N, Ctx, m_ConstInt(V&: ConstInt)))
1233 return APInt::isSameValue(I1: IntVal, I2: ConstInt);
1234 return false;
1235 }
1236};
1237
1238/// Match a specific integer constant or constant splat value.
1239inline SpecificInt_match m_SpecificInt(APInt V) {
1240 return SpecificInt_match(std::move(V));
1241}
1242inline SpecificInt_match m_SpecificInt(uint64_t V) {
1243 return SpecificInt_match(APInt(64, V));
1244}
1245
1246struct SpecificFP_match {
1247 APFloat Val;
1248
1249 explicit SpecificFP_match(APFloat V) : Val(V) {}
1250
1251 template <typename MatchContext>
1252 bool match(const MatchContext &Ctx, SDValue V) {
1253 if (const auto *CFP = dyn_cast<ConstantFPSDNode>(Val: V.getNode()))
1254 return CFP->isExactlyValue(V: Val);
1255 if (ConstantFPSDNode *C = isConstOrConstSplatFP(N: V, /*AllowUndefs=*/AllowUndefs: true))
1256 return C->getValueAPF().compare(RHS: Val) == APFloat::cmpEqual;
1257 return false;
1258 }
1259};
1260
1261/// Match a specific float constant.
1262inline SpecificFP_match m_SpecificFP(APFloat V) { return SpecificFP_match(V); }
1263
1264inline SpecificFP_match m_SpecificFP(double V) {
1265 return SpecificFP_match(APFloat(V));
1266}
1267
1268struct Negative_match {
1269 template <typename MatchContext>
1270 bool match(const MatchContext &Ctx, SDValue N) {
1271 const SelectionDAG *DAG = Ctx.getDAG();
1272 return DAG && DAG->computeKnownBits(Op: N).isNegative();
1273 }
1274};
1275
1276struct NonNegative_match {
1277 template <typename MatchContext>
1278 bool match(const MatchContext &Ctx, SDValue N) {
1279 const SelectionDAG *DAG = Ctx.getDAG();
1280 return DAG && DAG->computeKnownBits(Op: N).isNonNegative();
1281 }
1282};
1283
1284struct StrictlyPositive_match {
1285 template <typename MatchContext>
1286 bool match(const MatchContext &Ctx, SDValue N) {
1287 const SelectionDAG *DAG = Ctx.getDAG();
1288 return DAG && DAG->computeKnownBits(Op: N).isStrictlyPositive();
1289 }
1290};
1291
1292struct NonPositive_match {
1293 template <typename MatchContext>
1294 bool match(const MatchContext &Ctx, SDValue N) {
1295 const SelectionDAG *DAG = Ctx.getDAG();
1296 return DAG && DAG->computeKnownBits(Op: N).isNonPositive();
1297 }
1298};
1299
1300struct NonZero_match {
1301 template <typename MatchContext>
1302 bool match(const MatchContext &Ctx, SDValue N) {
1303 const SelectionDAG *DAG = Ctx.getDAG();
1304 return DAG && DAG->computeKnownBits(Op: N).isNonZero();
1305 }
1306};
1307
1308struct Zero_match {
1309 bool AllowUndefs;
1310
1311 explicit Zero_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
1312
1313 template <typename MatchContext>
1314 bool match(const MatchContext &, SDValue N) const {
1315 return isZeroOrZeroSplat(N, AllowUndefs);
1316 }
1317};
1318
1319struct Ones_match {
1320 bool AllowUndefs;
1321
1322 Ones_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
1323
1324 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
1325 return isOnesOrOnesSplat(N, AllowUndefs);
1326 }
1327};
1328
1329struct AllOnes_match {
1330 bool AllowUndefs;
1331
1332 AllOnes_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
1333
1334 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
1335 return isAllOnesOrAllOnesSplat(V: N, AllowUndefs);
1336 }
1337};
1338
1339inline Negative_match m_Negative() { return Negative_match(); }
1340template <typename Pattern> inline auto m_Negative(const Pattern &P) {
1341 return m_AllOf(m_Negative(), P);
1342}
1343inline NonNegative_match m_NonNegative() { return NonNegative_match(); }
1344template <typename Pattern> inline auto m_NonNegative(const Pattern &P) {
1345 return m_AllOf(m_NonNegative(), P);
1346}
1347inline StrictlyPositive_match m_StrictlyPositive() {
1348 return StrictlyPositive_match();
1349}
1350template <typename Pattern> inline auto m_StrictlyPositive(const Pattern &P) {
1351 return m_AllOf(m_StrictlyPositive(), P);
1352}
1353inline NonPositive_match m_NonPositive() { return NonPositive_match(); }
1354template <typename Pattern> inline auto m_NonPositive(const Pattern &P) {
1355 return m_AllOf(m_NonPositive(), P);
1356}
1357inline NonZero_match m_NonZero() { return NonZero_match(); }
1358template <typename Pattern> inline auto m_NonZero(const Pattern &P) {
1359 return m_AllOf(m_NonZero(), P);
1360}
1361inline Ones_match m_One(bool AllowUndefs = false) {
1362 return Ones_match(AllowUndefs);
1363}
1364inline Zero_match m_Zero(bool AllowUndefs = false) {
1365 return Zero_match(AllowUndefs);
1366}
1367inline AllOnes_match m_AllOnes(bool AllowUndefs = false) {
1368 return AllOnes_match(AllowUndefs);
1369}
1370
1371/// Match true boolean value based on the information provided by
1372/// TargetLowering.
1373inline auto m_True() {
1374 return TLI_pred_match{
1375 [](const TargetLowering &TLI, SDValue N) {
1376 APInt ConstVal;
1377 if (sd_match(N, P: m_ConstInt(V&: ConstVal)))
1378 switch (TLI.getBooleanContents(Type: N.getValueType())) {
1379 case TargetLowering::ZeroOrOneBooleanContent:
1380 return ConstVal.isOne();
1381 case TargetLowering::ZeroOrNegativeOneBooleanContent:
1382 return ConstVal.isAllOnes();
1383 case TargetLowering::UndefinedBooleanContent:
1384 return (ConstVal & 0x01) == 1;
1385 }
1386
1387 return false;
1388 },
1389 m_Value()};
1390}
1391/// Match false boolean value based on the information provided by
1392/// TargetLowering.
1393inline auto m_False() {
1394 return TLI_pred_match{
1395 [](const TargetLowering &TLI, SDValue N) {
1396 APInt ConstVal;
1397 if (sd_match(N, P: m_ConstInt(V&: ConstVal)))
1398 switch (TLI.getBooleanContents(Type: N.getValueType())) {
1399 case TargetLowering::ZeroOrOneBooleanContent:
1400 case TargetLowering::ZeroOrNegativeOneBooleanContent:
1401 return ConstVal.isZero();
1402 case TargetLowering::UndefinedBooleanContent:
1403 return (ConstVal & 0x01) == 0;
1404 }
1405
1406 return false;
1407 },
1408 m_Value()};
1409}
1410
1411struct CondCode_match {
1412 std::optional<ISD::CondCode> CCToMatch;
1413 ISD::CondCode *BindCC = nullptr;
1414
1415 explicit CondCode_match(ISD::CondCode CC) : CCToMatch(CC) {}
1416
1417 explicit CondCode_match(ISD::CondCode *CC) : BindCC(CC) {}
1418
1419 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
1420 if (auto *CC = dyn_cast<CondCodeSDNode>(Val: N.getNode())) {
1421 if (CCToMatch && *CCToMatch != CC->get())
1422 return false;
1423
1424 if (BindCC)
1425 *BindCC = CC->get();
1426 return true;
1427 }
1428
1429 return false;
1430 }
1431};
1432
1433/// Match any conditional code SDNode.
1434inline CondCode_match m_CondCode() { return CondCode_match(nullptr); }
1435/// Match any conditional code SDNode and return its ISD::CondCode value.
1436inline CondCode_match m_CondCode(ISD::CondCode &CC) {
1437 return CondCode_match(&CC);
1438}
1439/// Match a conditional code SDNode with a specific ISD::CondCode.
1440inline CondCode_match m_SpecificCondCode(ISD::CondCode CC) {
1441 return CondCode_match(CC);
1442}
1443
1444/// Match a negate as a sub(0, v)
1445template <typename ValTy>
1446inline BinaryOpc_match<Zero_match, ValTy, false> m_Neg(const ValTy &V) {
1447 return m_Sub(m_Zero(), V);
1448}
1449
1450/// Match a Not as a xor(v, -1) or xor(-1, v)
1451template <typename ValTy>
1452inline BinaryOpc_match<ValTy, AllOnes_match, true> m_Not(const ValTy &V) {
1453 return m_Xor(V, m_AllOnes());
1454}
1455
1456template <unsigned IntrinsicId, typename... OpndPreds>
1457inline auto m_IntrinsicWOChain(const OpndPreds &...Opnds) {
1458 return m_Node(ISD::INTRINSIC_WO_CHAIN, m_SpecificInt(V: IntrinsicId), Opnds...);
1459}
1460
1461struct SpecificNeg_match {
1462 SDValue V;
1463
1464 explicit SpecificNeg_match(SDValue V) : V(V) {}
1465
1466 template <typename MatchContext>
1467 bool match(const MatchContext &Ctx, SDValue N) {
1468 if (sd_context_match(N, Ctx, m_Neg(V: m_Specific(N: V))))
1469 return true;
1470
1471 return ISD::matchBinaryPredicate(
1472 LHS: V, RHS: N, Match: [](ConstantSDNode *LHS, ConstantSDNode *RHS) {
1473 return LHS->getAPIntValue() == -RHS->getAPIntValue();
1474 });
1475 }
1476};
1477
1478/// Match a negation of a specific value V, either as sub(0, V) or as
1479/// constant(s) that are the negation of V's constant(s).
1480inline SpecificNeg_match m_SpecificNeg(SDValue V) {
1481 return SpecificNeg_match(V);
1482}
1483
1484template <typename... PatternTs> struct ReassociatableOpc_match {
1485 unsigned Opcode;
1486 std::tuple<PatternTs...> Patterns;
1487 constexpr static size_t NumPatterns =
1488 std::tuple_size_v<std::tuple<PatternTs...>>;
1489
1490 SDNodeFlags Flags;
1491
1492 ReassociatableOpc_match(unsigned Opcode, const PatternTs &...Patterns)
1493 : Opcode(Opcode), Patterns(Patterns...) {}
1494
1495 ReassociatableOpc_match(unsigned Opcode, SDNodeFlags Flags,
1496 const PatternTs &...Patterns)
1497 : Opcode(Opcode), Patterns(Patterns...), Flags(Flags) {}
1498
1499 template <typename MatchContext>
1500 bool match(const MatchContext &Ctx, SDValue N) {
1501 std::array<SDValue, NumPatterns> Leaves;
1502 size_t LeavesIdx = 0;
1503 if (!(collectLeaves(V: N, Leaves, LeafIdx&: LeavesIdx) && (LeavesIdx == NumPatterns)))
1504 return false;
1505
1506 Bitset<NumPatterns> Used;
1507 return std::apply(
1508 [&](auto &...P) -> bool {
1509 return reassociatableMatchHelper(Ctx, Leaves, Used, P...);
1510 },
1511 Patterns);
1512 }
1513
1514 bool collectLeaves(SDValue V, std::array<SDValue, NumPatterns> &Leaves,
1515 std::size_t &LeafIdx) {
1516 if (V->getOpcode() == Opcode && (Flags & V->getFlags()) == Flags) {
1517 for (size_t I = 0, N = V->getNumOperands(); I < N; I++)
1518 if ((LeafIdx == NumPatterns) ||
1519 !collectLeaves(V: V->getOperand(Num: I), Leaves, LeafIdx))
1520 return false;
1521 } else {
1522 Leaves[LeafIdx] = V;
1523 LeafIdx++;
1524 }
1525 return true;
1526 }
1527
1528 // Searchs for a matching leaf for every sub-pattern.
1529 template <typename MatchContext, typename PatternHd, typename... PatternTl>
1530 [[nodiscard]] inline bool
1531 reassociatableMatchHelper(const MatchContext &Ctx, ArrayRef<SDValue> Leaves,
1532 Bitset<NumPatterns> &Used, PatternHd &HeadPattern,
1533 PatternTl &...TailPatterns) {
1534 for (size_t Match = 0, N = Used.size(); Match < N; Match++) {
1535 if (Used[Match] || !(sd_context_match(Leaves[Match], Ctx, HeadPattern)))
1536 continue;
1537 Used.set(Match);
1538 if (reassociatableMatchHelper(Ctx, Leaves, Used, TailPatterns...))
1539 return true;
1540 Used.reset(Match);
1541 }
1542 return false;
1543 }
1544
1545 template <typename MatchContext>
1546 [[nodiscard]] inline bool
1547 reassociatableMatchHelper(const MatchContext &Ctx, ArrayRef<SDValue> Leaves,
1548 Bitset<NumPatterns> &Used) {
1549 return true;
1550 }
1551};
1552
1553template <typename... PatternTs>
1554inline ReassociatableOpc_match<PatternTs...>
1555m_ReassociatableAdd(const PatternTs &...Patterns) {
1556 return ReassociatableOpc_match<PatternTs...>(ISD::ADD, Patterns...);
1557}
1558
1559template <typename... PatternTs>
1560inline ReassociatableOpc_match<PatternTs...>
1561m_ReassociatableOr(const PatternTs &...Patterns) {
1562 return ReassociatableOpc_match<PatternTs...>(ISD::OR, Patterns...);
1563}
1564
1565template <typename... PatternTs>
1566inline ReassociatableOpc_match<PatternTs...>
1567m_ReassociatableAnd(const PatternTs &...Patterns) {
1568 return ReassociatableOpc_match<PatternTs...>(ISD::AND, Patterns...);
1569}
1570
1571template <typename... PatternTs>
1572inline ReassociatableOpc_match<PatternTs...>
1573m_ReassociatableMul(const PatternTs &...Patterns) {
1574 return ReassociatableOpc_match<PatternTs...>(ISD::MUL, Patterns...);
1575}
1576
1577template <typename... PatternTs>
1578inline ReassociatableOpc_match<PatternTs...>
1579m_ReassociatableNSWAdd(const PatternTs &...Patterns) {
1580 return ReassociatableOpc_match<PatternTs...>(
1581 ISD::ADD, SDNodeFlags::NoSignedWrap, Patterns...);
1582}
1583
1584template <typename... PatternTs>
1585inline ReassociatableOpc_match<PatternTs...>
1586m_ReassociatableNUWAdd(const PatternTs &...Patterns) {
1587 return ReassociatableOpc_match<PatternTs...>(
1588 ISD::ADD, SDNodeFlags::NoUnsignedWrap, Patterns...);
1589}
1590
1591} // namespace SDPatternMatch
1592} // namespace llvm
1593#endif
1594