1//==--------------- llvm/CodeGen/SDPatternMatch.h ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8/// \file
9/// Contains matchers for matching SelectionDAG nodes and values.
10///
11//===----------------------------------------------------------------------===//
12
13#ifndef LLVM_CODEGEN_SDPATTERNMATCH_H
14#define LLVM_CODEGEN_SDPATTERNMATCH_H
15
16#include "llvm/ADT/APInt.h"
17#include "llvm/ADT/ArrayRef.h"
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/ADT/SmallBitVector.h"
20#include "llvm/CodeGen/SelectionDAG.h"
21#include "llvm/CodeGen/SelectionDAGNodes.h"
22#include "llvm/CodeGen/TargetLowering.h"
23#include "llvm/Support/KnownBits.h"
24
25#include <type_traits>
26
27namespace llvm {
28namespace SDPatternMatch {
29
30/// MatchContext can repurpose existing patterns to behave differently under
31/// a certain context. For instance, `m_Opc(ISD::ADD)` matches plain ADD nodes
32/// in normal circumstances, but matches VP_ADD nodes under a custom
33/// VPMatchContext. This design is meant to facilitate code / pattern reusing.
34class BasicMatchContext {
35 const SelectionDAG *DAG;
36 const TargetLowering *TLI;
37
38public:
39 explicit BasicMatchContext(const SelectionDAG *DAG)
40 : DAG(DAG), TLI(DAG ? &DAG->getTargetLoweringInfo() : nullptr) {}
41
42 explicit BasicMatchContext(const TargetLowering *TLI)
43 : DAG(nullptr), TLI(TLI) {}
44
45 // A valid MatchContext has to implement the following functions.
46
47 const SelectionDAG *getDAG() const { return DAG; }
48
49 const TargetLowering *getTLI() const { return TLI; }
50
51 /// Return true if N effectively has opcode Opcode.
52 bool match(SDValue N, unsigned Opcode) const {
53 return N->getOpcode() == Opcode;
54 }
55
56 unsigned getNumOperands(SDValue N) const { return N->getNumOperands(); }
57};
58
59template <typename Pattern, typename MatchContext>
60[[nodiscard]] bool sd_context_match(SDValue N, const MatchContext &Ctx,
61 Pattern &&P) {
62 return P.match(Ctx, N);
63}
64
65template <typename Pattern, typename MatchContext>
66[[nodiscard]] bool sd_context_match(SDNode *N, const MatchContext &Ctx,
67 Pattern &&P) {
68 return sd_context_match(SDValue(N, 0), Ctx, P);
69}
70
71template <typename Pattern>
72[[nodiscard]] bool sd_match(SDNode *N, const SelectionDAG *DAG, Pattern &&P) {
73 return sd_context_match(N, BasicMatchContext(DAG), P);
74}
75
76template <typename Pattern>
77[[nodiscard]] bool sd_match(SDValue N, const SelectionDAG *DAG, Pattern &&P) {
78 return sd_context_match(N, BasicMatchContext(DAG), P);
79}
80
81template <typename Pattern>
82[[nodiscard]] bool sd_match(SDNode *N, Pattern &&P) {
83 return sd_match(N, nullptr, P);
84}
85
86template <typename Pattern>
87[[nodiscard]] bool sd_match(SDValue N, Pattern &&P) {
88 return sd_match(N, nullptr, P);
89}
90
91// === Utilities ===
92struct Value_match {
93 SDValue MatchVal;
94
95 Value_match() = default;
96
97 explicit Value_match(SDValue Match) : MatchVal(Match) {}
98
99 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
100 if (MatchVal)
101 return MatchVal == N;
102 return N.getNode();
103 }
104};
105
106/// Match any valid SDValue.
107inline Value_match m_Value() { return Value_match(); }
108
109inline Value_match m_Specific(SDValue N) {
110 assert(N);
111 return Value_match(N);
112}
113
114template <unsigned ResNo, typename Pattern> struct Result_match {
115 Pattern P;
116
117 explicit Result_match(const Pattern &P) : P(P) {}
118
119 template <typename MatchContext>
120 bool match(const MatchContext &Ctx, SDValue N) {
121 return N.getResNo() == ResNo && P.match(Ctx, N);
122 }
123};
124
125/// Match only if the SDValue is a certain result at ResNo.
126template <unsigned ResNo, typename Pattern>
127inline Result_match<ResNo, Pattern> m_Result(const Pattern &P) {
128 return Result_match<ResNo, Pattern>(P);
129}
130
131struct DeferredValue_match {
132 SDValue &MatchVal;
133
134 explicit DeferredValue_match(SDValue &Match) : MatchVal(Match) {}
135
136 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
137 return N == MatchVal;
138 }
139};
140
141/// Similar to m_Specific, but the specific value to match is determined by
142/// another sub-pattern in the same sd_match() expression. For instance,
143/// We cannot match `(add V, V)` with `m_Add(m_Value(X), m_Specific(X))` since
144/// `X` is not initialized at the time it got copied into `m_Specific`. Instead,
145/// we should use `m_Add(m_Value(X), m_Deferred(X))`.
146inline DeferredValue_match m_Deferred(SDValue &V) {
147 return DeferredValue_match(V);
148}
149
150struct Opcode_match {
151 unsigned Opcode;
152
153 explicit Opcode_match(unsigned Opc) : Opcode(Opc) {}
154
155 template <typename MatchContext>
156 bool match(const MatchContext &Ctx, SDValue N) {
157 return Ctx.match(N, Opcode);
158 }
159};
160
161// === Patterns combinators ===
162template <typename... Preds> struct And {
163 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
164 return true;
165 }
166};
167
168template <typename Pred, typename... Preds>
169struct And<Pred, Preds...> : And<Preds...> {
170 Pred P;
171 And(const Pred &p, const Preds &...preds) : And<Preds...>(preds...), P(p) {}
172
173 template <typename MatchContext>
174 bool match(const MatchContext &Ctx, SDValue N) {
175 return P.match(Ctx, N) && And<Preds...>::match(Ctx, N);
176 }
177};
178
179template <typename... Preds> struct Or {
180 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
181 return false;
182 }
183};
184
185template <typename Pred, typename... Preds>
186struct Or<Pred, Preds...> : Or<Preds...> {
187 Pred P;
188 Or(const Pred &p, const Preds &...preds) : Or<Preds...>(preds...), P(p) {}
189
190 template <typename MatchContext>
191 bool match(const MatchContext &Ctx, SDValue N) {
192 return P.match(Ctx, N) || Or<Preds...>::match(Ctx, N);
193 }
194};
195
196template <typename Pred> struct Not {
197 Pred P;
198
199 explicit Not(const Pred &P) : P(P) {}
200
201 template <typename MatchContext>
202 bool match(const MatchContext &Ctx, SDValue N) {
203 return !P.match(Ctx, N);
204 }
205};
206// Explicit deduction guide.
207template <typename Pred> Not(const Pred &P) -> Not<Pred>;
208
209/// Match if the inner pattern does NOT match.
210template <typename Pred> inline Not<Pred> m_Unless(const Pred &P) {
211 return Not{P};
212}
213
214template <typename... Preds> And<Preds...> m_AllOf(const Preds &...preds) {
215 return And<Preds...>(preds...);
216}
217
218template <typename... Preds> Or<Preds...> m_AnyOf(const Preds &...preds) {
219 return Or<Preds...>(preds...);
220}
221
222template <typename... Preds> auto m_NoneOf(const Preds &...preds) {
223 return m_Unless(m_AnyOf(preds...));
224}
225
226inline Opcode_match m_Opc(unsigned Opcode) { return Opcode_match(Opcode); }
227
228inline auto m_Undef() {
229 return m_AnyOf(preds: Opcode_match(ISD::UNDEF), preds: Opcode_match(ISD::POISON));
230}
231
232inline Opcode_match m_Poison() { return Opcode_match(ISD::POISON); }
233
234template <unsigned NumUses, typename Pattern> struct NUses_match {
235 Pattern P;
236
237 explicit NUses_match(const Pattern &P) : P(P) {}
238
239 template <typename MatchContext>
240 bool match(const MatchContext &Ctx, SDValue N) {
241 // SDNode::hasNUsesOfValue is pretty expensive when the SDNode produces
242 // multiple results, hence we check the subsequent pattern here before
243 // checking the number of value users.
244 return P.match(Ctx, N) && N->hasNUsesOfValue(NUses: NumUses, Value: N.getResNo());
245 }
246};
247
248template <typename Pattern>
249inline NUses_match<1, Pattern> m_OneUse(const Pattern &P) {
250 return NUses_match<1, Pattern>(P);
251}
252template <unsigned N, typename Pattern>
253inline NUses_match<N, Pattern> m_NUses(const Pattern &P) {
254 return NUses_match<N, Pattern>(P);
255}
256
257inline NUses_match<1, Value_match> m_OneUse() {
258 return NUses_match<1, Value_match>(m_Value());
259}
260template <unsigned N> inline NUses_match<N, Value_match> m_NUses() {
261 return NUses_match<N, Value_match>(m_Value());
262}
263
264template <typename PredPattern> struct Value_bind {
265 SDValue &BindVal;
266 PredPattern Pred;
267
268 Value_bind(SDValue &N, const PredPattern &P) : BindVal(N), Pred(P) {}
269
270 template <typename MatchContext>
271 bool match(const MatchContext &Ctx, SDValue N) {
272 if (!Pred.match(Ctx, N))
273 return false;
274
275 BindVal = N;
276 return true;
277 }
278};
279
280inline auto m_Value(SDValue &N) {
281 return Value_bind<Value_match>(N, m_Value());
282}
283/// Conditionally bind an SDValue based on the predicate.
284template <typename PredPattern>
285inline auto m_Value(SDValue &N, const PredPattern &P) {
286 return Value_bind<PredPattern>(N, P);
287}
288
289template <typename Pattern, typename PredFuncT> struct TLI_pred_match {
290 Pattern P;
291 PredFuncT PredFunc;
292
293 TLI_pred_match(const PredFuncT &Pred, const Pattern &P)
294 : P(P), PredFunc(Pred) {}
295
296 template <typename MatchContext>
297 bool match(const MatchContext &Ctx, SDValue N) {
298 assert(Ctx.getTLI() && "TargetLowering is required for this pattern.");
299 return PredFunc(*Ctx.getTLI(), N) && P.match(Ctx, N);
300 }
301};
302
303// Explicit deduction guide.
304template <typename PredFuncT, typename Pattern>
305TLI_pred_match(const PredFuncT &Pred, const Pattern &P)
306 -> TLI_pred_match<Pattern, PredFuncT>;
307
308/// Match legal SDNodes based on the information provided by TargetLowering.
309template <typename Pattern> inline auto m_LegalOp(const Pattern &P) {
310 return TLI_pred_match{[](const TargetLowering &TLI, SDValue N) {
311 return TLI.isOperationLegal(Op: N->getOpcode(),
312 VT: N.getValueType());
313 },
314 P};
315}
316
317/// Switch to a different MatchContext for subsequent patterns.
318template <typename NewMatchContext, typename Pattern> struct SwitchContext {
319 const NewMatchContext &Ctx;
320 Pattern P;
321
322 template <typename OrigMatchContext>
323 bool match(const OrigMatchContext &, SDValue N) {
324 return P.match(Ctx, N);
325 }
326};
327
328template <typename MatchContext, typename Pattern>
329inline SwitchContext<MatchContext, Pattern> m_Context(const MatchContext &Ctx,
330 Pattern &&P) {
331 return SwitchContext<MatchContext, Pattern>{Ctx, std::move(P)};
332}
333
334// === Value type ===
335struct ValueType_bind {
336 EVT &BindVT;
337
338 explicit ValueType_bind(EVT &Bind) : BindVT(Bind) {}
339
340 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
341 BindVT = N.getValueType();
342 return true;
343 }
344};
345
346/// Retreive the ValueType of the current SDValue.
347inline ValueType_bind m_VT(EVT &VT) { return ValueType_bind(VT); }
348
349template <typename Pattern, typename PredFuncT> struct ValueType_match {
350 PredFuncT PredFunc;
351 Pattern P;
352
353 ValueType_match(const PredFuncT &Pred, const Pattern &P)
354 : PredFunc(Pred), P(P) {}
355
356 template <typename MatchContext>
357 bool match(const MatchContext &Ctx, SDValue N) {
358 return PredFunc(N.getValueType()) && P.match(Ctx, N);
359 }
360};
361
362// Explicit deduction guide.
363template <typename PredFuncT, typename Pattern>
364ValueType_match(const PredFuncT &Pred, const Pattern &P)
365 -> ValueType_match<Pattern, PredFuncT>;
366
367/// Match a specific ValueType.
368template <typename Pattern>
369inline auto m_SpecificVT(EVT RefVT, const Pattern &P) {
370 return ValueType_match{[=](EVT VT) { return VT == RefVT; }, P};
371}
372inline auto m_SpecificVT(EVT RefVT) {
373 return ValueType_match{[=](EVT VT) { return VT == RefVT; }, m_Value()};
374}
375
376inline auto m_Glue() { return m_SpecificVT(RefVT: MVT::Glue); }
377inline auto m_OtherVT() { return m_SpecificVT(RefVT: MVT::Other); }
378
379/// Match a scalar ValueType.
380template <typename Pattern>
381inline auto m_SpecificScalarVT(EVT RefVT, const Pattern &P) {
382 return ValueType_match{[=](EVT VT) { return VT.getScalarType() == RefVT; },
383 P};
384}
385inline auto m_SpecificScalarVT(EVT RefVT) {
386 return ValueType_match{[=](EVT VT) { return VT.getScalarType() == RefVT; },
387 m_Value()};
388}
389
390/// Match a vector ValueType.
391template <typename Pattern>
392inline auto m_SpecificVectorElementVT(EVT RefVT, const Pattern &P) {
393 return ValueType_match{[=](EVT VT) {
394 return VT.isVector() &&
395 VT.getVectorElementType() == RefVT;
396 },
397 P};
398}
399inline auto m_SpecificVectorElementVT(EVT RefVT) {
400 return ValueType_match{[=](EVT VT) {
401 return VT.isVector() &&
402 VT.getVectorElementType() == RefVT;
403 },
404 m_Value()};
405}
406
407/// Match any integer ValueTypes.
408template <typename Pattern> inline auto m_IntegerVT(const Pattern &P) {
409 return ValueType_match{[](EVT VT) { return VT.isInteger(); }, P};
410}
411inline auto m_IntegerVT() {
412 return ValueType_match{[](EVT VT) { return VT.isInteger(); }, m_Value()};
413}
414
415/// Match any floating point ValueTypes.
416template <typename Pattern> inline auto m_FloatingPointVT(const Pattern &P) {
417 return ValueType_match{[](EVT VT) { return VT.isFloatingPoint(); }, P};
418}
419inline auto m_FloatingPointVT() {
420 return ValueType_match{[](EVT VT) { return VT.isFloatingPoint(); },
421 m_Value()};
422}
423
424/// Match any vector ValueTypes.
425template <typename Pattern> inline auto m_VectorVT(const Pattern &P) {
426 return ValueType_match{[](EVT VT) { return VT.isVector(); }, P};
427}
428inline auto m_VectorVT() {
429 return ValueType_match{[](EVT VT) { return VT.isVector(); }, m_Value()};
430}
431
432/// Match fixed-length vector ValueTypes.
433template <typename Pattern> inline auto m_FixedVectorVT(const Pattern &P) {
434 return ValueType_match{[](EVT VT) { return VT.isFixedLengthVector(); }, P};
435}
436inline auto m_FixedVectorVT() {
437 return ValueType_match{[](EVT VT) { return VT.isFixedLengthVector(); },
438 m_Value()};
439}
440
441/// Match scalable vector ValueTypes.
442template <typename Pattern> inline auto m_ScalableVectorVT(const Pattern &P) {
443 return ValueType_match{[](EVT VT) { return VT.isScalableVector(); }, P};
444}
445inline auto m_ScalableVectorVT() {
446 return ValueType_match{[](EVT VT) { return VT.isScalableVector(); },
447 m_Value()};
448}
449
450/// Match legal ValueTypes based on the information provided by TargetLowering.
451template <typename Pattern> inline auto m_LegalType(const Pattern &P) {
452 return TLI_pred_match{[](const TargetLowering &TLI, SDValue N) {
453 return TLI.isTypeLegal(VT: N.getValueType());
454 },
455 P};
456}
457
458// === Generic node matching ===
459template <unsigned OpIdx, typename... OpndPreds> struct Operands_match {
460 template <typename MatchContext>
461 bool match(const MatchContext &Ctx, SDValue N) {
462 // Returns false if there are more operands than predicates;
463 // Ignores the last two operands if both the Context and the Node are VP
464 return Ctx.getNumOperands(N) == OpIdx;
465 }
466};
467
468template <unsigned OpIdx, typename OpndPred, typename... OpndPreds>
469struct Operands_match<OpIdx, OpndPred, OpndPreds...>
470 : Operands_match<OpIdx + 1, OpndPreds...> {
471 OpndPred P;
472
473 Operands_match(const OpndPred &p, const OpndPreds &...preds)
474 : Operands_match<OpIdx + 1, OpndPreds...>(preds...), P(p) {}
475
476 template <typename MatchContext>
477 bool match(const MatchContext &Ctx, SDValue N) {
478 if (OpIdx < N->getNumOperands())
479 return P.match(Ctx, N->getOperand(Num: OpIdx)) &&
480 Operands_match<OpIdx + 1, OpndPreds...>::match(Ctx, N);
481
482 // This is the case where there are more predicates than operands.
483 return false;
484 }
485};
486
487template <typename... OpndPreds>
488auto m_Node(unsigned Opcode, const OpndPreds &...preds) {
489 return m_AllOf(m_Opc(Opcode), Operands_match<0, OpndPreds...>(preds...));
490}
491
492/// Provide number of operands that are not chain or glue, as well as the first
493/// index of such operand.
494template <bool ExcludeChain> struct EffectiveOperands {
495 unsigned Size = 0;
496 unsigned FirstIndex = 0;
497
498 template <typename MatchContext>
499 explicit EffectiveOperands(SDValue N, const MatchContext &Ctx) {
500 const unsigned TotalNumOps = Ctx.getNumOperands(N);
501 FirstIndex = TotalNumOps;
502 for (unsigned I = 0; I < TotalNumOps; ++I) {
503 // Count the number of non-chain and non-glue nodes (we ignore chain
504 // and glue by default) and retreive the operand index offset.
505 EVT VT = N->getOperand(Num: I).getValueType();
506 if (VT != MVT::Glue && VT != MVT::Other) {
507 ++Size;
508 if (FirstIndex == TotalNumOps)
509 FirstIndex = I;
510 }
511 }
512 }
513};
514
515template <> struct EffectiveOperands<false> {
516 unsigned Size = 0;
517 unsigned FirstIndex = 0;
518
519 template <typename MatchContext>
520 explicit EffectiveOperands(SDValue N, const MatchContext &Ctx)
521 : Size(Ctx.getNumOperands(N)) {}
522};
523
524// === Ternary operations ===
525template <typename T0_P, typename T1_P, typename T2_P, bool Commutable = false,
526 bool ExcludeChain = false>
527struct TernaryOpc_match {
528 unsigned Opcode;
529 T0_P Op0;
530 T1_P Op1;
531 T2_P Op2;
532
533 TernaryOpc_match(unsigned Opc, const T0_P &Op0, const T1_P &Op1,
534 const T2_P &Op2)
535 : Opcode(Opc), Op0(Op0), Op1(Op1), Op2(Op2) {}
536
537 template <typename MatchContext>
538 bool match(const MatchContext &Ctx, SDValue N) {
539 if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
540 EffectiveOperands<ExcludeChain> EO(N, Ctx);
541 assert(EO.Size == 3);
542 return ((Op0.match(Ctx, N->getOperand(Num: EO.FirstIndex)) &&
543 Op1.match(Ctx, N->getOperand(Num: EO.FirstIndex + 1))) ||
544 (Commutable && Op0.match(Ctx, N->getOperand(Num: EO.FirstIndex + 1)) &&
545 Op1.match(Ctx, N->getOperand(Num: EO.FirstIndex)))) &&
546 Op2.match(Ctx, N->getOperand(Num: EO.FirstIndex + 2));
547 }
548
549 return false;
550 }
551};
552
553template <typename T0_P, typename T1_P, typename T2_P>
554inline TernaryOpc_match<T0_P, T1_P, T2_P>
555m_SetCC(const T0_P &LHS, const T1_P &RHS, const T2_P &CC) {
556 return TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::SETCC, LHS, RHS, CC);
557}
558
559template <typename T0_P, typename T1_P, typename T2_P>
560inline TernaryOpc_match<T0_P, T1_P, T2_P, true, false>
561m_c_SetCC(const T0_P &LHS, const T1_P &RHS, const T2_P &CC) {
562 return TernaryOpc_match<T0_P, T1_P, T2_P, true, false>(ISD::SETCC, LHS, RHS,
563 CC);
564}
565
566template <typename T0_P, typename T1_P, typename T2_P>
567inline TernaryOpc_match<T0_P, T1_P, T2_P>
568m_Select(const T0_P &Cond, const T1_P &T, const T2_P &F) {
569 return TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::SELECT, Cond, T, F);
570}
571
572template <typename T0_P, typename T1_P, typename T2_P>
573inline TernaryOpc_match<T0_P, T1_P, T2_P>
574m_VSelect(const T0_P &Cond, const T1_P &T, const T2_P &F) {
575 return TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::VSELECT, Cond, T, F);
576}
577
578template <typename T0_P, typename T1_P, typename T2_P>
579inline auto m_SelectLike(const T0_P &Cond, const T1_P &T, const T2_P &F) {
580 return m_AnyOf(m_Select(Cond, T, F), m_VSelect(Cond, T, F));
581}
582
583template <typename T0_P, typename T1_P, typename T2_P>
584inline Result_match<0, TernaryOpc_match<T0_P, T1_P, T2_P>>
585m_Load(const T0_P &Ch, const T1_P &Ptr, const T2_P &Offset) {
586 return m_Result<0>(
587 TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::LOAD, Ch, Ptr, Offset));
588}
589
590template <typename T0_P, typename T1_P, typename T2_P>
591inline TernaryOpc_match<T0_P, T1_P, T2_P>
592m_InsertElt(const T0_P &Vec, const T1_P &Val, const T2_P &Idx) {
593 return TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::INSERT_VECTOR_ELT, Vec, Val,
594 Idx);
595}
596
597template <typename LHS, typename RHS, typename IDX>
598inline TernaryOpc_match<LHS, RHS, IDX>
599m_InsertSubvector(const LHS &Base, const RHS &Sub, const IDX &Idx) {
600 return TernaryOpc_match<LHS, RHS, IDX>(ISD::INSERT_SUBVECTOR, Base, Sub, Idx);
601}
602
603template <typename T0_P, typename T1_P, typename T2_P>
604inline TernaryOpc_match<T0_P, T1_P, T2_P>
605m_TernaryOp(unsigned Opc, const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
606 return TernaryOpc_match<T0_P, T1_P, T2_P>(Opc, Op0, Op1, Op2);
607}
608
609template <typename T0_P, typename T1_P, typename T2_P>
610inline TernaryOpc_match<T0_P, T1_P, T2_P, true>
611m_c_TernaryOp(unsigned Opc, const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
612 return TernaryOpc_match<T0_P, T1_P, T2_P, true>(Opc, Op0, Op1, Op2);
613}
614
615template <typename LTy, typename RTy, typename TTy, typename FTy, typename CCTy>
616inline auto m_SelectCC(const LTy &L, const RTy &R, const TTy &T, const FTy &F,
617 const CCTy &CC) {
618 return m_Node(ISD::SELECT_CC, L, R, T, F, CC);
619}
620
621template <typename LTy, typename RTy, typename TTy, typename FTy, typename CCTy>
622inline auto m_SelectCCLike(const LTy &L, const RTy &R, const TTy &T,
623 const FTy &F, const CCTy &CC) {
624 return m_AnyOf(m_Select(m_SetCC(L, R, CC), T, F), m_SelectCC(L, R, T, F, CC));
625}
626
627// === Binary operations ===
628template <typename LHS_P, typename RHS_P, bool Commutable = false,
629 bool ExcludeChain = false>
630struct BinaryOpc_match {
631 unsigned Opcode;
632 LHS_P LHS;
633 RHS_P RHS;
634 SDNodeFlags Flags;
635 BinaryOpc_match(unsigned Opc, const LHS_P &L, const RHS_P &R,
636 SDNodeFlags Flgs = SDNodeFlags())
637 : Opcode(Opc), LHS(L), RHS(R), Flags(Flgs) {}
638
639 template <typename MatchContext>
640 bool match(const MatchContext &Ctx, SDValue N) {
641 if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
642 EffectiveOperands<ExcludeChain> EO(N, Ctx);
643 assert(EO.Size == 2);
644 if (!((LHS.match(Ctx, N->getOperand(Num: EO.FirstIndex)) &&
645 RHS.match(Ctx, N->getOperand(Num: EO.FirstIndex + 1))) ||
646 (Commutable && LHS.match(Ctx, N->getOperand(Num: EO.FirstIndex + 1)) &&
647 RHS.match(Ctx, N->getOperand(Num: EO.FirstIndex)))))
648 return false;
649
650 return (Flags & N->getFlags()) == Flags;
651 }
652
653 return false;
654 }
655};
656
657/// Matching while capturing mask
658template <typename T0, typename T1, typename T2> struct SDShuffle_match {
659 T0 Op1;
660 T1 Op2;
661 T2 Mask;
662
663 SDShuffle_match(const T0 &Op1, const T1 &Op2, const T2 &Mask)
664 : Op1(Op1), Op2(Op2), Mask(Mask) {}
665
666 template <typename MatchContext>
667 bool match(const MatchContext &Ctx, SDValue N) {
668 if (auto *I = dyn_cast<ShuffleVectorSDNode>(Val&: N)) {
669 return Op1.match(Ctx, I->getOperand(Num: 0)) &&
670 Op2.match(Ctx, I->getOperand(Num: 1)) && Mask.match(I->getMask());
671 }
672 return false;
673 }
674};
675struct m_Mask {
676 ArrayRef<int> &MaskRef;
677 m_Mask(ArrayRef<int> &MaskRef) : MaskRef(MaskRef) {}
678 bool match(ArrayRef<int> Mask) {
679 MaskRef = Mask;
680 return true;
681 }
682};
683
684struct m_SpecificMask {
685 ArrayRef<int> MaskRef;
686 m_SpecificMask(ArrayRef<int> MaskRef) : MaskRef(MaskRef) {}
687 bool match(ArrayRef<int> Mask) { return MaskRef == Mask; }
688};
689
690template <typename LHS_P, typename RHS_P, typename Pred_t,
691 bool Commutable = false, bool ExcludeChain = false>
692struct MaxMin_match {
693 using PredType = Pred_t;
694 LHS_P LHS;
695 RHS_P RHS;
696
697 MaxMin_match(const LHS_P &L, const RHS_P &R) : LHS(L), RHS(R) {}
698
699 template <typename MatchContext>
700 bool match(const MatchContext &Ctx, SDValue N) {
701 auto MatchMinMax = [&](SDValue L, SDValue R, SDValue TrueValue,
702 SDValue FalseValue, ISD::CondCode CC) {
703 if ((TrueValue != L || FalseValue != R) &&
704 (TrueValue != R || FalseValue != L))
705 return false;
706
707 ISD::CondCode Cond =
708 TrueValue == L ? CC : getSetCCInverse(Operation: CC, Type: L.getValueType());
709 if (!Pred_t::match(Cond))
710 return false;
711
712 return (LHS.match(Ctx, L) && RHS.match(Ctx, R)) ||
713 (Commutable && LHS.match(Ctx, R) && RHS.match(Ctx, L));
714 };
715
716 if (sd_context_match(N, Ctx, m_Opc(Opcode: ISD::SELECT)) ||
717 sd_context_match(N, Ctx, m_Opc(Opcode: ISD::VSELECT))) {
718 EffectiveOperands<ExcludeChain> EO_SELECT(N, Ctx);
719 assert(EO_SELECT.Size == 3);
720 SDValue Cond = N->getOperand(Num: EO_SELECT.FirstIndex);
721 SDValue TrueValue = N->getOperand(Num: EO_SELECT.FirstIndex + 1);
722 SDValue FalseValue = N->getOperand(Num: EO_SELECT.FirstIndex + 2);
723
724 if (sd_context_match(Cond, Ctx, m_Opc(Opcode: ISD::SETCC))) {
725 EffectiveOperands<ExcludeChain> EO_SETCC(Cond, Ctx);
726 assert(EO_SETCC.Size == 3);
727 SDValue L = Cond->getOperand(Num: EO_SETCC.FirstIndex);
728 SDValue R = Cond->getOperand(Num: EO_SETCC.FirstIndex + 1);
729 auto *CondNode =
730 cast<CondCodeSDNode>(Cond->getOperand(Num: EO_SETCC.FirstIndex + 2));
731 return MatchMinMax(L, R, TrueValue, FalseValue, CondNode->get());
732 }
733 }
734
735 if (sd_context_match(N, Ctx, m_Opc(Opcode: ISD::SELECT_CC))) {
736 EffectiveOperands<ExcludeChain> EO_SELECT(N, Ctx);
737 assert(EO_SELECT.Size == 5);
738 SDValue L = N->getOperand(Num: EO_SELECT.FirstIndex);
739 SDValue R = N->getOperand(Num: EO_SELECT.FirstIndex + 1);
740 SDValue TrueValue = N->getOperand(Num: EO_SELECT.FirstIndex + 2);
741 SDValue FalseValue = N->getOperand(Num: EO_SELECT.FirstIndex + 3);
742 auto *CondNode =
743 cast<CondCodeSDNode>(N->getOperand(Num: EO_SELECT.FirstIndex + 4));
744 return MatchMinMax(L, R, TrueValue, FalseValue, CondNode->get());
745 }
746
747 return false;
748 }
749};
750
751// Helper class for identifying signed max predicates.
752struct smax_pred_ty {
753 static bool match(ISD::CondCode Cond) {
754 return Cond == ISD::CondCode::SETGT || Cond == ISD::CondCode::SETGE;
755 }
756};
757
758// Helper class for identifying unsigned max predicates.
759struct umax_pred_ty {
760 static bool match(ISD::CondCode Cond) {
761 return Cond == ISD::CondCode::SETUGT || Cond == ISD::CondCode::SETUGE;
762 }
763};
764
765// Helper class for identifying signed min predicates.
766struct smin_pred_ty {
767 static bool match(ISD::CondCode Cond) {
768 return Cond == ISD::CondCode::SETLT || Cond == ISD::CondCode::SETLE;
769 }
770};
771
772// Helper class for identifying unsigned min predicates.
773struct umin_pred_ty {
774 static bool match(ISD::CondCode Cond) {
775 return Cond == ISD::CondCode::SETULT || Cond == ISD::CondCode::SETULE;
776 }
777};
778
779template <typename LHS, typename RHS>
780inline BinaryOpc_match<LHS, RHS> m_BinOp(unsigned Opc, const LHS &L,
781 const RHS &R,
782 SDNodeFlags Flgs = SDNodeFlags()) {
783 return BinaryOpc_match<LHS, RHS>(Opc, L, R, Flgs);
784}
785template <typename LHS, typename RHS>
786inline BinaryOpc_match<LHS, RHS, true>
787m_c_BinOp(unsigned Opc, const LHS &L, const RHS &R,
788 SDNodeFlags Flgs = SDNodeFlags()) {
789 return BinaryOpc_match<LHS, RHS, true>(Opc, L, R, Flgs);
790}
791
792template <typename LHS, typename RHS>
793inline BinaryOpc_match<LHS, RHS, false, true>
794m_ChainedBinOp(unsigned Opc, const LHS &L, const RHS &R) {
795 return BinaryOpc_match<LHS, RHS, false, true>(Opc, L, R);
796}
797template <typename LHS, typename RHS>
798inline BinaryOpc_match<LHS, RHS, true, true>
799m_c_ChainedBinOp(unsigned Opc, const LHS &L, const RHS &R) {
800 return BinaryOpc_match<LHS, RHS, true, true>(Opc, L, R);
801}
802
803// Common binary operations
804template <typename LHS, typename RHS>
805inline BinaryOpc_match<LHS, RHS, true> m_Add(const LHS &L, const RHS &R) {
806 return BinaryOpc_match<LHS, RHS, true>(ISD::ADD, L, R);
807}
808
809template <typename LHS, typename RHS>
810inline BinaryOpc_match<LHS, RHS> m_Sub(const LHS &L, const RHS &R) {
811 return BinaryOpc_match<LHS, RHS>(ISD::SUB, L, R);
812}
813
814template <typename LHS, typename RHS>
815inline BinaryOpc_match<LHS, RHS, true> m_Mul(const LHS &L, const RHS &R) {
816 return BinaryOpc_match<LHS, RHS, true>(ISD::MUL, L, R);
817}
818
819template <typename LHS, typename RHS>
820inline BinaryOpc_match<LHS, RHS, true> m_And(const LHS &L, const RHS &R) {
821 return BinaryOpc_match<LHS, RHS, true>(ISD::AND, L, R);
822}
823
824template <typename LHS, typename RHS>
825inline BinaryOpc_match<LHS, RHS, true> m_Or(const LHS &L, const RHS &R) {
826 return BinaryOpc_match<LHS, RHS, true>(ISD::OR, L, R);
827}
828
829template <typename LHS, typename RHS>
830inline BinaryOpc_match<LHS, RHS, true> m_DisjointOr(const LHS &L,
831 const RHS &R) {
832 return BinaryOpc_match<LHS, RHS, true>(ISD::OR, L, R, SDNodeFlags::Disjoint);
833}
834
835template <typename LHS, typename RHS>
836inline auto m_AddLike(const LHS &L, const RHS &R) {
837 return m_AnyOf(m_Add(L, R), m_DisjointOr(L, R));
838}
839
840template <typename LHS, typename RHS>
841inline BinaryOpc_match<LHS, RHS, true> m_Xor(const LHS &L, const RHS &R) {
842 return BinaryOpc_match<LHS, RHS, true>(ISD::XOR, L, R);
843}
844
845template <typename LHS, typename RHS>
846inline auto m_BitwiseLogic(const LHS &L, const RHS &R) {
847 return m_AnyOf(m_And(L, R), m_Or(L, R), m_Xor(L, R));
848}
849
850template <unsigned Opc, typename Pred, typename LHS, typename RHS>
851inline auto m_MaxMinLike(const LHS &L, const RHS &R) {
852 return m_AnyOf(BinaryOpc_match<LHS, RHS, true>(Opc, L, R),
853 MaxMin_match<LHS, RHS, Pred, true>(L, R));
854}
855
856template <typename LHS, typename RHS>
857inline BinaryOpc_match<LHS, RHS, true> m_SMin(const LHS &L, const RHS &R) {
858 return BinaryOpc_match<LHS, RHS, true>(ISD::SMIN, L, R);
859}
860
861template <typename LHS, typename RHS>
862inline auto m_SMinLike(const LHS &L, const RHS &R) {
863 return m_AnyOf(
864 m_MaxMinLike<ISD::SMIN, smin_pred_ty>(L, R),
865 m_MaxMinLike<ISD::UMIN, umin_pred_ty>(m_NonNegative(L), m_NonNegative(R)),
866 m_MaxMinLike<ISD::UMIN, umin_pred_ty>(m_Negative(L), m_Negative(R)));
867}
868
869template <typename LHS, typename RHS>
870inline BinaryOpc_match<LHS, RHS, true> m_SMax(const LHS &L, const RHS &R) {
871 return BinaryOpc_match<LHS, RHS, true>(ISD::SMAX, L, R);
872}
873
874template <typename LHS, typename RHS>
875inline auto m_SMaxLike(const LHS &L, const RHS &R) {
876 return m_AnyOf(
877 m_MaxMinLike<ISD::SMAX, smax_pred_ty>(L, R),
878 m_MaxMinLike<ISD::UMAX, umax_pred_ty>(m_NonNegative(L), m_NonNegative(R)),
879 m_MaxMinLike<ISD::UMAX, umax_pred_ty>(m_Negative(L), m_Negative(R)));
880}
881
882template <typename LHS, typename RHS>
883inline BinaryOpc_match<LHS, RHS, true> m_UMin(const LHS &L, const RHS &R) {
884 return BinaryOpc_match<LHS, RHS, true>(ISD::UMIN, L, R);
885}
886
887template <typename LHS, typename RHS>
888inline auto m_UMinLike(const LHS &L, const RHS &R) {
889 return m_AnyOf(
890 m_MaxMinLike<ISD::UMIN, umin_pred_ty>(L, R),
891 m_MaxMinLike<ISD::SMIN, smin_pred_ty>(m_NonNegative(L), m_NonNegative(R)),
892 m_MaxMinLike<ISD::SMIN, smin_pred_ty>(m_Negative(L), m_Negative(R)));
893}
894
895template <typename LHS, typename RHS>
896inline BinaryOpc_match<LHS, RHS, true> m_UMax(const LHS &L, const RHS &R) {
897 return BinaryOpc_match<LHS, RHS, true>(ISD::UMAX, L, R);
898}
899
900template <typename LHS, typename RHS>
901inline auto m_UMaxLike(const LHS &L, const RHS &R) {
902 return m_AnyOf(
903 m_MaxMinLike<ISD::UMAX, umax_pred_ty>(L, R),
904 m_MaxMinLike<ISD::SMAX, smax_pred_ty>(m_NonNegative(L), m_NonNegative(R)),
905 m_MaxMinLike<ISD::SMAX, smax_pred_ty>(m_Negative(L), m_Negative(R)));
906}
907
908template <typename LHS, typename RHS>
909inline BinaryOpc_match<LHS, RHS> m_UDiv(const LHS &L, const RHS &R) {
910 return BinaryOpc_match<LHS, RHS>(ISD::UDIV, L, R);
911}
912template <typename LHS, typename RHS>
913inline BinaryOpc_match<LHS, RHS> m_SDiv(const LHS &L, const RHS &R) {
914 return BinaryOpc_match<LHS, RHS>(ISD::SDIV, L, R);
915}
916
917template <typename LHS, typename RHS>
918inline BinaryOpc_match<LHS, RHS> m_URem(const LHS &L, const RHS &R) {
919 return BinaryOpc_match<LHS, RHS>(ISD::UREM, L, R);
920}
921template <typename LHS, typename RHS>
922inline BinaryOpc_match<LHS, RHS> m_SRem(const LHS &L, const RHS &R) {
923 return BinaryOpc_match<LHS, RHS>(ISD::SREM, L, R);
924}
925
926template <typename LHS, typename RHS>
927inline BinaryOpc_match<LHS, RHS> m_Shl(const LHS &L, const RHS &R) {
928 return BinaryOpc_match<LHS, RHS>(ISD::SHL, L, R);
929}
930
931template <typename LHS, typename RHS>
932inline BinaryOpc_match<LHS, RHS> m_Sra(const LHS &L, const RHS &R) {
933 return BinaryOpc_match<LHS, RHS>(ISD::SRA, L, R);
934}
935template <typename LHS, typename RHS>
936inline BinaryOpc_match<LHS, RHS> m_Srl(const LHS &L, const RHS &R) {
937 return BinaryOpc_match<LHS, RHS>(ISD::SRL, L, R);
938}
939template <typename LHS, typename RHS>
940inline auto m_ExactSr(const LHS &L, const RHS &R) {
941 return m_AnyOf(BinaryOpc_match<LHS, RHS>(ISD::SRA, L, R, SDNodeFlags::Exact),
942 BinaryOpc_match<LHS, RHS>(ISD::SRL, L, R, SDNodeFlags::Exact));
943}
944
945template <typename LHS, typename RHS>
946inline BinaryOpc_match<LHS, RHS> m_Rotl(const LHS &L, const RHS &R) {
947 return BinaryOpc_match<LHS, RHS>(ISD::ROTL, L, R);
948}
949
950template <typename LHS, typename RHS>
951inline BinaryOpc_match<LHS, RHS> m_Rotr(const LHS &L, const RHS &R) {
952 return BinaryOpc_match<LHS, RHS>(ISD::ROTR, L, R);
953}
954
955template <typename LHS, typename RHS>
956inline BinaryOpc_match<LHS, RHS, true> m_Clmul(const LHS &L, const RHS &R) {
957 return BinaryOpc_match<LHS, RHS, true>(ISD::CLMUL, L, R);
958}
959
960template <typename LHS, typename RHS>
961inline BinaryOpc_match<LHS, RHS, true> m_FAdd(const LHS &L, const RHS &R) {
962 return BinaryOpc_match<LHS, RHS, true>(ISD::FADD, L, R);
963}
964
965template <typename LHS, typename RHS>
966inline BinaryOpc_match<LHS, RHS> m_FSub(const LHS &L, const RHS &R) {
967 return BinaryOpc_match<LHS, RHS>(ISD::FSUB, L, R);
968}
969
970template <typename LHS, typename RHS>
971inline BinaryOpc_match<LHS, RHS, true> m_FMul(const LHS &L, const RHS &R) {
972 return BinaryOpc_match<LHS, RHS, true>(ISD::FMUL, L, R);
973}
974
975template <typename LHS, typename RHS>
976inline BinaryOpc_match<LHS, RHS> m_FDiv(const LHS &L, const RHS &R) {
977 return BinaryOpc_match<LHS, RHS>(ISD::FDIV, L, R);
978}
979
980template <typename LHS, typename RHS>
981inline BinaryOpc_match<LHS, RHS> m_FRem(const LHS &L, const RHS &R) {
982 return BinaryOpc_match<LHS, RHS>(ISD::FREM, L, R);
983}
984
985template <typename V1_t, typename V2_t>
986inline BinaryOpc_match<V1_t, V2_t> m_Shuffle(const V1_t &v1, const V2_t &v2) {
987 return BinaryOpc_match<V1_t, V2_t>(ISD::VECTOR_SHUFFLE, v1, v2);
988}
989
990template <typename V1_t, typename V2_t, typename Mask_t>
991inline SDShuffle_match<V1_t, V2_t, Mask_t>
992m_Shuffle(const V1_t &v1, const V2_t &v2, const Mask_t &mask) {
993 return SDShuffle_match<V1_t, V2_t, Mask_t>(v1, v2, mask);
994}
995
996template <typename LHS, typename RHS>
997inline BinaryOpc_match<LHS, RHS> m_ExtractElt(const LHS &Vec, const RHS &Idx) {
998 return BinaryOpc_match<LHS, RHS>(ISD::EXTRACT_VECTOR_ELT, Vec, Idx);
999}
1000
1001template <typename LHS, typename RHS>
1002inline BinaryOpc_match<LHS, RHS> m_ExtractSubvector(const LHS &Vec,
1003 const RHS &Idx) {
1004 return BinaryOpc_match<LHS, RHS>(ISD::EXTRACT_SUBVECTOR, Vec, Idx);
1005}
1006
1007// === Unary operations ===
1008template <typename Opnd_P, bool ExcludeChain = false> struct UnaryOpc_match {
1009 unsigned Opcode;
1010 Opnd_P Opnd;
1011 SDNodeFlags Flags;
1012 UnaryOpc_match(unsigned Opc, const Opnd_P &Op,
1013 SDNodeFlags Flgs = SDNodeFlags())
1014 : Opcode(Opc), Opnd(Op), Flags(Flgs) {}
1015
1016 template <typename MatchContext>
1017 bool match(const MatchContext &Ctx, SDValue N) {
1018 if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
1019 EffectiveOperands<ExcludeChain> EO(N, Ctx);
1020 assert(EO.Size == 1);
1021 if (!Opnd.match(Ctx, N->getOperand(Num: EO.FirstIndex)))
1022 return false;
1023
1024 return (Flags & N->getFlags()) == Flags;
1025 }
1026
1027 return false;
1028 }
1029};
1030
1031template <typename Opnd>
1032inline UnaryOpc_match<Opnd> m_UnaryOp(unsigned Opc, const Opnd &Op) {
1033 return UnaryOpc_match<Opnd>(Opc, Op);
1034}
1035template <typename Opnd>
1036inline UnaryOpc_match<Opnd, true> m_ChainedUnaryOp(unsigned Opc,
1037 const Opnd &Op) {
1038 return UnaryOpc_match<Opnd, true>(Opc, Op);
1039}
1040
1041template <typename Opnd> inline UnaryOpc_match<Opnd> m_BitCast(const Opnd &Op) {
1042 return UnaryOpc_match<Opnd>(ISD::BITCAST, Op);
1043}
1044
1045template <typename Opnd>
1046inline UnaryOpc_match<Opnd> m_BSwap(const Opnd &Op) {
1047 return UnaryOpc_match<Opnd>(ISD::BSWAP, Op);
1048}
1049
1050template <typename Opnd>
1051inline UnaryOpc_match<Opnd> m_BitReverse(const Opnd &Op) {
1052 return UnaryOpc_match<Opnd>(ISD::BITREVERSE, Op);
1053}
1054
1055template <typename Opnd> inline UnaryOpc_match<Opnd> m_ZExt(const Opnd &Op) {
1056 return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op);
1057}
1058
1059template <typename Opnd>
1060inline UnaryOpc_match<Opnd> m_NNegZExt(const Opnd &Op) {
1061 return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op, SDNodeFlags::NonNeg);
1062}
1063
1064template <typename Opnd> inline auto m_SExt(const Opnd &Op) {
1065 return UnaryOpc_match<Opnd>(ISD::SIGN_EXTEND, Op);
1066}
1067
1068template <typename Opnd> inline UnaryOpc_match<Opnd> m_AnyExt(const Opnd &Op) {
1069 return UnaryOpc_match<Opnd>(ISD::ANY_EXTEND, Op);
1070}
1071
1072template <typename Opnd> inline UnaryOpc_match<Opnd> m_Trunc(const Opnd &Op) {
1073 return UnaryOpc_match<Opnd>(ISD::TRUNCATE, Op);
1074}
1075
1076template <typename Opnd> inline UnaryOpc_match<Opnd> m_Abs(const Opnd &Op) {
1077 return UnaryOpc_match<Opnd>(ISD::ABS, Op);
1078}
1079
1080template <typename Opnd> inline UnaryOpc_match<Opnd> m_FAbs(const Opnd &Op) {
1081 return UnaryOpc_match<Opnd>(ISD::FABS, Op);
1082}
1083
1084/// Match a zext or identity
1085/// Allows to peek through optional extensions
1086template <typename Opnd> inline auto m_ZExtOrSelf(const Opnd &Op) {
1087 return m_AnyOf(m_ZExt(Op), Op);
1088}
1089
1090/// Match a sext or identity
1091/// Allows to peek through optional extensions
1092template <typename Opnd> inline auto m_SExtOrSelf(const Opnd &Op) {
1093 return m_AnyOf(m_SExt(Op), Op);
1094}
1095
1096template <typename Opnd> inline auto m_SExtLike(const Opnd &Op) {
1097 return m_AnyOf(m_SExt(Op), m_NNegZExt(Op));
1098}
1099
1100/// Match a aext or identity
1101/// Allows to peek through optional extensions
1102template <typename Opnd>
1103inline Or<UnaryOpc_match<Opnd>, Opnd> m_AExtOrSelf(const Opnd &Op) {
1104 return Or<UnaryOpc_match<Opnd>, Opnd>(m_AnyExt(Op), Op);
1105}
1106
1107/// Match a trunc or identity
1108/// Allows to peek through optional truncations
1109template <typename Opnd>
1110inline Or<UnaryOpc_match<Opnd>, Opnd> m_TruncOrSelf(const Opnd &Op) {
1111 return Or<UnaryOpc_match<Opnd>, Opnd>(m_Trunc(Op), Op);
1112}
1113
1114template <typename Opnd> inline UnaryOpc_match<Opnd> m_VScale(const Opnd &Op) {
1115 return UnaryOpc_match<Opnd>(ISD::VSCALE, Op);
1116}
1117
1118template <typename Opnd> inline UnaryOpc_match<Opnd> m_FPToUI(const Opnd &Op) {
1119 return UnaryOpc_match<Opnd>(ISD::FP_TO_UINT, Op);
1120}
1121
1122template <typename Opnd> inline UnaryOpc_match<Opnd> m_FPToSI(const Opnd &Op) {
1123 return UnaryOpc_match<Opnd>(ISD::FP_TO_SINT, Op);
1124}
1125
1126template <typename Opnd> inline UnaryOpc_match<Opnd> m_Ctpop(const Opnd &Op) {
1127 return UnaryOpc_match<Opnd>(ISD::CTPOP, Op);
1128}
1129
1130template <typename Opnd> inline UnaryOpc_match<Opnd> m_Ctlz(const Opnd &Op) {
1131 return UnaryOpc_match<Opnd>(ISD::CTLZ, Op);
1132}
1133
1134template <typename Opnd> inline UnaryOpc_match<Opnd> m_Cttz(const Opnd &Op) {
1135 return UnaryOpc_match<Opnd>(ISD::CTTZ, Op);
1136}
1137
1138template <typename Opnd> inline UnaryOpc_match<Opnd> m_FNeg(const Opnd &Op) {
1139 return UnaryOpc_match<Opnd>(ISD::FNEG, Op);
1140}
1141
1142// === Constants ===
1143struct ConstantInt_match {
1144 APInt *BindVal;
1145
1146 explicit ConstantInt_match(APInt *V) : BindVal(V) {}
1147
1148 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
1149 // The logics here are similar to that in
1150 // SelectionDAG::isConstantIntBuildVectorOrConstantInt, but the latter also
1151 // treats GlobalAddressSDNode as a constant, which is difficult to turn into
1152 // APInt.
1153 if (auto *C = dyn_cast_or_null<ConstantSDNode>(Val: N.getNode())) {
1154 if (BindVal)
1155 *BindVal = C->getAPIntValue();
1156 return true;
1157 }
1158
1159 APInt Discard;
1160 return ISD::isConstantSplatVector(N: N.getNode(),
1161 SplatValue&: BindVal ? *BindVal : Discard);
1162 }
1163};
1164
1165template <typename T> struct Constant64_match {
1166 static_assert(sizeof(T) == 8, "T must be 64 bits wide");
1167
1168 T &BindVal;
1169
1170 explicit Constant64_match(T &V) : BindVal(V) {}
1171
1172 template <typename MatchContext>
1173 bool match(const MatchContext &Ctx, SDValue N) {
1174 APInt V;
1175 if (!ConstantInt_match(&V).match(Ctx, N))
1176 return false;
1177
1178 if constexpr (std::is_signed_v<T>) {
1179 if (std::optional<int64_t> TrySExt = V.trySExtValue()) {
1180 BindVal = *TrySExt;
1181 return true;
1182 }
1183 }
1184
1185 if constexpr (std::is_unsigned_v<T>) {
1186 if (std::optional<uint64_t> TryZExt = V.tryZExtValue()) {
1187 BindVal = *TryZExt;
1188 return true;
1189 }
1190 }
1191
1192 return false;
1193 }
1194};
1195
1196/// Match any integer constants or splat of an integer constant.
1197inline ConstantInt_match m_ConstInt() { return ConstantInt_match(nullptr); }
1198/// Match any integer constants or splat of an integer constant; return the
1199/// specific constant or constant splat value.
1200inline ConstantInt_match m_ConstInt(APInt &V) { return ConstantInt_match(&V); }
1201/// Match any integer constants or splat of an integer constant that can fit in
1202/// 64 bits; return the specific constant or constant splat value, zero-extended
1203/// to 64 bits.
1204inline Constant64_match<uint64_t> m_ConstInt(uint64_t &V) {
1205 return Constant64_match<uint64_t>(V);
1206}
1207/// Match any integer constants or splat of an integer constant that can fit in
1208/// 64 bits; return the specific constant or constant splat value, sign-extended
1209/// to 64 bits.
1210inline Constant64_match<int64_t> m_ConstInt(int64_t &V) {
1211 return Constant64_match<int64_t>(V);
1212}
1213
1214struct SpecificInt_match {
1215 APInt IntVal;
1216
1217 explicit SpecificInt_match(APInt APV) : IntVal(std::move(APV)) {}
1218
1219 template <typename MatchContext>
1220 bool match(const MatchContext &Ctx, SDValue N) {
1221 APInt ConstInt;
1222 if (sd_context_match(N, Ctx, m_ConstInt(V&: ConstInt)))
1223 return APInt::isSameValue(I1: IntVal, I2: ConstInt);
1224 return false;
1225 }
1226};
1227
1228/// Match a specific integer constant or constant splat value.
1229inline SpecificInt_match m_SpecificInt(APInt V) {
1230 return SpecificInt_match(std::move(V));
1231}
1232inline SpecificInt_match m_SpecificInt(uint64_t V) {
1233 return SpecificInt_match(APInt(64, V));
1234}
1235
1236struct SpecificFP_match {
1237 APFloat Val;
1238
1239 explicit SpecificFP_match(APFloat V) : Val(V) {}
1240
1241 template <typename MatchContext>
1242 bool match(const MatchContext &Ctx, SDValue V) {
1243 if (const auto *CFP = dyn_cast<ConstantFPSDNode>(Val: V.getNode()))
1244 return CFP->isExactlyValue(V: Val);
1245 if (ConstantFPSDNode *C = isConstOrConstSplatFP(N: V, /*AllowUndefs=*/AllowUndefs: true))
1246 return C->getValueAPF().compare(RHS: Val) == APFloat::cmpEqual;
1247 return false;
1248 }
1249};
1250
1251/// Match a specific float constant.
1252inline SpecificFP_match m_SpecificFP(APFloat V) { return SpecificFP_match(V); }
1253
1254inline SpecificFP_match m_SpecificFP(double V) {
1255 return SpecificFP_match(APFloat(V));
1256}
1257
1258struct Negative_match {
1259 template <typename MatchContext>
1260 bool match(const MatchContext &Ctx, SDValue N) {
1261 const SelectionDAG *DAG = Ctx.getDAG();
1262 return DAG && DAG->computeKnownBits(Op: N).isNegative();
1263 }
1264};
1265
1266struct NonNegative_match {
1267 template <typename MatchContext>
1268 bool match(const MatchContext &Ctx, SDValue N) {
1269 const SelectionDAG *DAG = Ctx.getDAG();
1270 return DAG && DAG->computeKnownBits(Op: N).isNonNegative();
1271 }
1272};
1273
1274struct StrictlyPositive_match {
1275 template <typename MatchContext>
1276 bool match(const MatchContext &Ctx, SDValue N) {
1277 const SelectionDAG *DAG = Ctx.getDAG();
1278 return DAG && DAG->computeKnownBits(Op: N).isStrictlyPositive();
1279 }
1280};
1281
1282struct NonPositive_match {
1283 template <typename MatchContext>
1284 bool match(const MatchContext &Ctx, SDValue N) {
1285 const SelectionDAG *DAG = Ctx.getDAG();
1286 return DAG && DAG->computeKnownBits(Op: N).isNonPositive();
1287 }
1288};
1289
1290struct NonZero_match {
1291 template <typename MatchContext>
1292 bool match(const MatchContext &Ctx, SDValue N) {
1293 const SelectionDAG *DAG = Ctx.getDAG();
1294 return DAG && DAG->computeKnownBits(Op: N).isNonZero();
1295 }
1296};
1297
1298struct Zero_match {
1299 bool AllowUndefs;
1300
1301 explicit Zero_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
1302
1303 template <typename MatchContext>
1304 bool match(const MatchContext &, SDValue N) const {
1305 return isZeroOrZeroSplat(N, AllowUndefs);
1306 }
1307};
1308
1309struct Ones_match {
1310 bool AllowUndefs;
1311
1312 Ones_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
1313
1314 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
1315 return isOnesOrOnesSplat(N, AllowUndefs);
1316 }
1317};
1318
1319struct AllOnes_match {
1320 bool AllowUndefs;
1321
1322 AllOnes_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
1323
1324 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
1325 return isAllOnesOrAllOnesSplat(V: N, AllowUndefs);
1326 }
1327};
1328
1329inline Negative_match m_Negative() { return Negative_match(); }
1330template <typename Pattern> inline auto m_Negative(const Pattern &P) {
1331 return m_AllOf(m_Negative(), P);
1332}
1333inline NonNegative_match m_NonNegative() { return NonNegative_match(); }
1334template <typename Pattern> inline auto m_NonNegative(const Pattern &P) {
1335 return m_AllOf(m_NonNegative(), P);
1336}
1337inline StrictlyPositive_match m_StrictlyPositive() {
1338 return StrictlyPositive_match();
1339}
1340template <typename Pattern> inline auto m_StrictlyPositive(const Pattern &P) {
1341 return m_AllOf(m_StrictlyPositive(), P);
1342}
1343inline NonPositive_match m_NonPositive() { return NonPositive_match(); }
1344template <typename Pattern> inline auto m_NonPositive(const Pattern &P) {
1345 return m_AllOf(m_NonPositive(), P);
1346}
1347inline NonZero_match m_NonZero() { return NonZero_match(); }
1348template <typename Pattern> inline auto m_NonZero(const Pattern &P) {
1349 return m_AllOf(m_NonZero(), P);
1350}
1351inline Ones_match m_One(bool AllowUndefs = false) {
1352 return Ones_match(AllowUndefs);
1353}
1354inline Zero_match m_Zero(bool AllowUndefs = false) {
1355 return Zero_match(AllowUndefs);
1356}
1357inline AllOnes_match m_AllOnes(bool AllowUndefs = false) {
1358 return AllOnes_match(AllowUndefs);
1359}
1360
1361/// Match true boolean value based on the information provided by
1362/// TargetLowering.
1363inline auto m_True() {
1364 return TLI_pred_match{
1365 [](const TargetLowering &TLI, SDValue N) {
1366 APInt ConstVal;
1367 if (sd_match(N, P: m_ConstInt(V&: ConstVal)))
1368 switch (TLI.getBooleanContents(Type: N.getValueType())) {
1369 case TargetLowering::ZeroOrOneBooleanContent:
1370 return ConstVal.isOne();
1371 case TargetLowering::ZeroOrNegativeOneBooleanContent:
1372 return ConstVal.isAllOnes();
1373 case TargetLowering::UndefinedBooleanContent:
1374 return (ConstVal & 0x01) == 1;
1375 }
1376
1377 return false;
1378 },
1379 m_Value()};
1380}
1381/// Match false boolean value based on the information provided by
1382/// TargetLowering.
1383inline auto m_False() {
1384 return TLI_pred_match{
1385 [](const TargetLowering &TLI, SDValue N) {
1386 APInt ConstVal;
1387 if (sd_match(N, P: m_ConstInt(V&: ConstVal)))
1388 switch (TLI.getBooleanContents(Type: N.getValueType())) {
1389 case TargetLowering::ZeroOrOneBooleanContent:
1390 case TargetLowering::ZeroOrNegativeOneBooleanContent:
1391 return ConstVal.isZero();
1392 case TargetLowering::UndefinedBooleanContent:
1393 return (ConstVal & 0x01) == 0;
1394 }
1395
1396 return false;
1397 },
1398 m_Value()};
1399}
1400
1401struct CondCode_match {
1402 std::optional<ISD::CondCode> CCToMatch;
1403 ISD::CondCode *BindCC = nullptr;
1404
1405 explicit CondCode_match(ISD::CondCode CC) : CCToMatch(CC) {}
1406
1407 explicit CondCode_match(ISD::CondCode *CC) : BindCC(CC) {}
1408
1409 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
1410 if (auto *CC = dyn_cast<CondCodeSDNode>(Val: N.getNode())) {
1411 if (CCToMatch && *CCToMatch != CC->get())
1412 return false;
1413
1414 if (BindCC)
1415 *BindCC = CC->get();
1416 return true;
1417 }
1418
1419 return false;
1420 }
1421};
1422
1423/// Match any conditional code SDNode.
1424inline CondCode_match m_CondCode() { return CondCode_match(nullptr); }
1425/// Match any conditional code SDNode and return its ISD::CondCode value.
1426inline CondCode_match m_CondCode(ISD::CondCode &CC) {
1427 return CondCode_match(&CC);
1428}
1429/// Match a conditional code SDNode with a specific ISD::CondCode.
1430inline CondCode_match m_SpecificCondCode(ISD::CondCode CC) {
1431 return CondCode_match(CC);
1432}
1433
1434/// Match a negate as a sub(0, v)
1435template <typename ValTy>
1436inline BinaryOpc_match<Zero_match, ValTy, false> m_Neg(const ValTy &V) {
1437 return m_Sub(m_Zero(), V);
1438}
1439
1440/// Match a Not as a xor(v, -1) or xor(-1, v)
1441template <typename ValTy>
1442inline BinaryOpc_match<ValTy, AllOnes_match, true> m_Not(const ValTy &V) {
1443 return m_Xor(V, m_AllOnes());
1444}
1445
1446template <unsigned IntrinsicId, typename... OpndPreds>
1447inline auto m_IntrinsicWOChain(const OpndPreds &...Opnds) {
1448 return m_Node(ISD::INTRINSIC_WO_CHAIN, m_SpecificInt(V: IntrinsicId), Opnds...);
1449}
1450
1451struct SpecificNeg_match {
1452 SDValue V;
1453
1454 explicit SpecificNeg_match(SDValue V) : V(V) {}
1455
1456 template <typename MatchContext>
1457 bool match(const MatchContext &Ctx, SDValue N) {
1458 if (sd_context_match(N, Ctx, m_Neg(V: m_Specific(N: V))))
1459 return true;
1460
1461 return ISD::matchBinaryPredicate(
1462 LHS: V, RHS: N, Match: [](ConstantSDNode *LHS, ConstantSDNode *RHS) {
1463 return LHS->getAPIntValue() == -RHS->getAPIntValue();
1464 });
1465 }
1466};
1467
1468/// Match a negation of a specific value V, either as sub(0, V) or as
1469/// constant(s) that are the negation of V's constant(s).
1470inline SpecificNeg_match m_SpecificNeg(SDValue V) {
1471 return SpecificNeg_match(V);
1472}
1473
1474template <typename... PatternTs> struct ReassociatableOpc_match {
1475 unsigned Opcode;
1476 std::tuple<PatternTs...> Patterns;
1477 constexpr static size_t NumPatterns =
1478 std::tuple_size_v<std::tuple<PatternTs...>>;
1479
1480 SDNodeFlags Flags;
1481
1482 ReassociatableOpc_match(unsigned Opcode, const PatternTs &...Patterns)
1483 : Opcode(Opcode), Patterns(Patterns...) {}
1484
1485 ReassociatableOpc_match(unsigned Opcode, SDNodeFlags Flags,
1486 const PatternTs &...Patterns)
1487 : Opcode(Opcode), Patterns(Patterns...), Flags(Flags) {}
1488
1489 template <typename MatchContext>
1490 bool match(const MatchContext &Ctx, SDValue N) {
1491 std::array<SDValue, NumPatterns> Leaves;
1492 size_t LeavesIdx = 0;
1493 if (!(collectLeaves(V: N, Leaves, LeafIdx&: LeavesIdx) && (LeavesIdx == NumPatterns)))
1494 return false;
1495
1496 Bitset<NumPatterns> Used;
1497 return std::apply(
1498 [&](auto &...P) -> bool {
1499 return reassociatableMatchHelper(Ctx, Leaves, Used, P...);
1500 },
1501 Patterns);
1502 }
1503
1504 bool collectLeaves(SDValue V, std::array<SDValue, NumPatterns> &Leaves,
1505 std::size_t &LeafIdx) {
1506 if (V->getOpcode() == Opcode && (Flags & V->getFlags()) == Flags) {
1507 for (size_t I = 0, N = V->getNumOperands(); I < N; I++)
1508 if ((LeafIdx == NumPatterns) ||
1509 !collectLeaves(V: V->getOperand(Num: I), Leaves, LeafIdx))
1510 return false;
1511 } else {
1512 Leaves[LeafIdx] = V;
1513 LeafIdx++;
1514 }
1515 return true;
1516 }
1517
1518 // Searchs for a matching leaf for every sub-pattern.
1519 template <typename MatchContext, typename PatternHd, typename... PatternTl>
1520 [[nodiscard]] inline bool
1521 reassociatableMatchHelper(const MatchContext &Ctx, ArrayRef<SDValue> Leaves,
1522 Bitset<NumPatterns> &Used, PatternHd &HeadPattern,
1523 PatternTl &...TailPatterns) {
1524 for (size_t Match = 0, N = Used.size(); Match < N; Match++) {
1525 if (Used[Match] || !(sd_context_match(Leaves[Match], Ctx, HeadPattern)))
1526 continue;
1527 Used.set(Match);
1528 if (reassociatableMatchHelper(Ctx, Leaves, Used, TailPatterns...))
1529 return true;
1530 Used.reset(Match);
1531 }
1532 return false;
1533 }
1534
1535 template <typename MatchContext>
1536 [[nodiscard]] inline bool
1537 reassociatableMatchHelper(const MatchContext &Ctx, ArrayRef<SDValue> Leaves,
1538 Bitset<NumPatterns> &Used) {
1539 return true;
1540 }
1541};
1542
1543template <typename... PatternTs>
1544inline ReassociatableOpc_match<PatternTs...>
1545m_ReassociatableAdd(const PatternTs &...Patterns) {
1546 return ReassociatableOpc_match<PatternTs...>(ISD::ADD, Patterns...);
1547}
1548
1549template <typename... PatternTs>
1550inline ReassociatableOpc_match<PatternTs...>
1551m_ReassociatableOr(const PatternTs &...Patterns) {
1552 return ReassociatableOpc_match<PatternTs...>(ISD::OR, Patterns...);
1553}
1554
1555template <typename... PatternTs>
1556inline ReassociatableOpc_match<PatternTs...>
1557m_ReassociatableAnd(const PatternTs &...Patterns) {
1558 return ReassociatableOpc_match<PatternTs...>(ISD::AND, Patterns...);
1559}
1560
1561template <typename... PatternTs>
1562inline ReassociatableOpc_match<PatternTs...>
1563m_ReassociatableMul(const PatternTs &...Patterns) {
1564 return ReassociatableOpc_match<PatternTs...>(ISD::MUL, Patterns...);
1565}
1566
1567template <typename... PatternTs>
1568inline ReassociatableOpc_match<PatternTs...>
1569m_ReassociatableNSWAdd(const PatternTs &...Patterns) {
1570 return ReassociatableOpc_match<PatternTs...>(
1571 ISD::ADD, SDNodeFlags::NoSignedWrap, Patterns...);
1572}
1573
1574template <typename... PatternTs>
1575inline ReassociatableOpc_match<PatternTs...>
1576m_ReassociatableNUWAdd(const PatternTs &...Patterns) {
1577 return ReassociatableOpc_match<PatternTs...>(
1578 ISD::ADD, SDNodeFlags::NoUnsignedWrap, Patterns...);
1579}
1580
1581} // namespace SDPatternMatch
1582} // namespace llvm
1583#endif
1584