1//==--------------- llvm/CodeGen/SDPatternMatch.h ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8/// \file
9/// Contains matchers for matching SelectionDAG nodes and values.
10///
11//===----------------------------------------------------------------------===//
12
13#ifndef LLVM_CODEGEN_SDPATTERNMATCH_H
14#define LLVM_CODEGEN_SDPATTERNMATCH_H
15
16#include "llvm/ADT/APInt.h"
17#include "llvm/ADT/ArrayRef.h"
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/ADT/SmallBitVector.h"
20#include "llvm/CodeGen/SelectionDAG.h"
21#include "llvm/CodeGen/SelectionDAGNodes.h"
22#include "llvm/CodeGen/TargetLowering.h"
23#include "llvm/Support/KnownBits.h"
24
25#include <type_traits>
26
27namespace llvm {
28namespace SDPatternMatch {
29
30/// MatchContext can repurpose existing patterns to behave differently under
31/// a certain context. For instance, `m_SpecificOpc(ISD::ADD)` matches plain ADD
32/// nodes in normal circumstances, but matches VP_ADD nodes under a custom
33/// VPMatchContext. This design is meant to facilitate code / pattern reusing.
34class BasicMatchContext {
35 const SelectionDAG *DAG;
36 const TargetLowering *TLI;
37
38public:
39 explicit BasicMatchContext(const SelectionDAG *DAG)
40 : DAG(DAG), TLI(DAG ? &DAG->getTargetLoweringInfo() : nullptr) {}
41
42 explicit BasicMatchContext(const TargetLowering *TLI)
43 : DAG(nullptr), TLI(TLI) {}
44
45 // A valid MatchContext has to implement the following functions.
46
47 const SelectionDAG *getDAG() const { return DAG; }
48
49 const TargetLowering *getTLI() const { return TLI; }
50
51 /// Return true if N effectively has opcode Opcode.
52 bool match(SDValue N, unsigned Opcode) const {
53 return N->getOpcode() == Opcode;
54 }
55
56 unsigned getNumOperands(SDValue N) const { return N->getNumOperands(); }
57};
58
59template <typename Pattern, typename MatchContext>
60[[nodiscard]] bool sd_context_match(SDValue N, const MatchContext &Ctx,
61 Pattern &&P) {
62 return P.match(Ctx, N);
63}
64
65template <typename Pattern, typename MatchContext>
66[[nodiscard]] bool sd_context_match(SDNode *N, const MatchContext &Ctx,
67 Pattern &&P) {
68 return sd_context_match(SDValue(N, 0), Ctx, P);
69}
70
71template <typename Pattern>
72[[nodiscard]] bool sd_match(SDNode *N, const SelectionDAG *DAG, Pattern &&P) {
73 return sd_context_match(N, BasicMatchContext(DAG), P);
74}
75
76template <typename Pattern>
77[[nodiscard]] bool sd_match(SDValue N, const SelectionDAG *DAG, Pattern &&P) {
78 return sd_context_match(N, BasicMatchContext(DAG), P);
79}
80
81template <typename Pattern>
82[[nodiscard]] bool sd_match(SDNode *N, Pattern &&P) {
83 return sd_match(N, nullptr, P);
84}
85
86template <typename Pattern>
87[[nodiscard]] bool sd_match(SDValue N, Pattern &&P) {
88 return sd_match(N, nullptr, P);
89}
90
91// === Utilities ===
92struct Value_match {
93 SDValue MatchVal;
94
95 Value_match() = default;
96
97 explicit Value_match(SDValue Match) : MatchVal(Match) {}
98
99 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
100 if (MatchVal)
101 return MatchVal == N;
102 return N.getNode();
103 }
104};
105
106/// Match any valid SDValue.
107inline Value_match m_Value() { return Value_match(); }
108
109inline Value_match m_Specific(SDValue N) {
110 assert(N);
111 return Value_match(N);
112}
113
114template <unsigned ResNo, typename Pattern> struct Result_match {
115 Pattern P;
116
117 explicit Result_match(const Pattern &P) : P(P) {}
118
119 template <typename MatchContext>
120 bool match(const MatchContext &Ctx, SDValue N) {
121 return N.getResNo() == ResNo && P.match(Ctx, N);
122 }
123};
124
125/// Match only if the SDValue is a certain result at ResNo.
126template <unsigned ResNo, typename Pattern>
127inline Result_match<ResNo, Pattern> m_Result(const Pattern &P) {
128 return Result_match<ResNo, Pattern>(P);
129}
130
131struct DeferredValue_match {
132 SDValue &MatchVal;
133
134 explicit DeferredValue_match(SDValue &Match) : MatchVal(Match) {}
135
136 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
137 return N == MatchVal;
138 }
139};
140
141/// Similar to m_Specific, but the specific value to match is determined by
142/// another sub-pattern in the same sd_match() expression. For instance,
143/// We cannot match `(add V, V)` with `m_Add(m_Value(X), m_Specific(X))` since
144/// `X` is not initialized at the time it got copied into `m_Specific`. Instead,
145/// we should use `m_Add(m_Value(X), m_Deferred(X))`.
146inline DeferredValue_match m_Deferred(SDValue &V) {
147 return DeferredValue_match(V);
148}
149
150struct Opcode_match {
151 unsigned Opcode;
152
153 explicit Opcode_match(unsigned Opc) : Opcode(Opc) {}
154
155 template <typename MatchContext>
156 bool match(const MatchContext &Ctx, SDValue N) {
157 return Ctx.match(N, Opcode);
158 }
159};
160
161// === Patterns combinators ===
162template <typename... Preds> struct And {
163 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
164 return true;
165 }
166};
167
168template <typename Pred, typename... Preds>
169struct And<Pred, Preds...> : And<Preds...> {
170 Pred P;
171 And(const Pred &p, const Preds &...preds) : And<Preds...>(preds...), P(p) {}
172
173 template <typename MatchContext>
174 bool match(const MatchContext &Ctx, SDValue N) {
175 return P.match(Ctx, N) && And<Preds...>::match(Ctx, N);
176 }
177};
178
179template <typename... Preds> struct Or {
180 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
181 return false;
182 }
183};
184
185template <typename Pred, typename... Preds>
186struct Or<Pred, Preds...> : Or<Preds...> {
187 Pred P;
188 Or(const Pred &p, const Preds &...preds) : Or<Preds...>(preds...), P(p) {}
189
190 template <typename MatchContext>
191 bool match(const MatchContext &Ctx, SDValue N) {
192 return P.match(Ctx, N) || Or<Preds...>::match(Ctx, N);
193 }
194};
195
196template <typename Pred> struct Not {
197 Pred P;
198
199 explicit Not(const Pred &P) : P(P) {}
200
201 template <typename MatchContext>
202 bool match(const MatchContext &Ctx, SDValue N) {
203 return !P.match(Ctx, N);
204 }
205};
206// Explicit deduction guide.
207template <typename Pred> Not(const Pred &P) -> Not<Pred>;
208
209/// Match if the inner pattern does NOT match.
210template <typename Pred> inline Not<Pred> m_Unless(const Pred &P) {
211 return Not{P};
212}
213
214template <typename... Preds> And<Preds...> m_AllOf(const Preds &...preds) {
215 return And<Preds...>(preds...);
216}
217
218template <typename... Preds> Or<Preds...> m_AnyOf(const Preds &...preds) {
219 return Or<Preds...>(preds...);
220}
221
222template <typename... Preds> auto m_NoneOf(const Preds &...preds) {
223 return m_Unless(m_AnyOf(preds...));
224}
225
226inline Opcode_match m_SpecificOpc(unsigned Opcode) {
227 return Opcode_match(Opcode);
228}
229
230inline auto m_Undef() {
231 return m_AnyOf(preds: Opcode_match(ISD::UNDEF), preds: Opcode_match(ISD::POISON));
232}
233
234inline Opcode_match m_Poison() { return Opcode_match(ISD::POISON); }
235
236template <unsigned NumUses, typename Pattern> struct NUses_match {
237 Pattern P;
238
239 explicit NUses_match(const Pattern &P) : P(P) {}
240
241 template <typename MatchContext>
242 bool match(const MatchContext &Ctx, SDValue N) {
243 // SDNode::hasNUsesOfValue is pretty expensive when the SDNode produces
244 // multiple results, hence we check the subsequent pattern here before
245 // checking the number of value users.
246 return P.match(Ctx, N) && N->hasNUsesOfValue(NUses: NumUses, Value: N.getResNo());
247 }
248};
249
250template <typename Pattern>
251inline NUses_match<1, Pattern> m_OneUse(const Pattern &P) {
252 return NUses_match<1, Pattern>(P);
253}
254template <unsigned N, typename Pattern>
255inline NUses_match<N, Pattern> m_NUses(const Pattern &P) {
256 return NUses_match<N, Pattern>(P);
257}
258
259inline NUses_match<1, Value_match> m_OneUse() {
260 return NUses_match<1, Value_match>(m_Value());
261}
262template <unsigned N> inline NUses_match<N, Value_match> m_NUses() {
263 return NUses_match<N, Value_match>(m_Value());
264}
265
266template <typename PredPattern> struct Value_bind {
267 SDValue &BindVal;
268 PredPattern Pred;
269
270 Value_bind(SDValue &N, const PredPattern &P) : BindVal(N), Pred(P) {}
271
272 template <typename MatchContext>
273 bool match(const MatchContext &Ctx, SDValue N) {
274 if (!Pred.match(Ctx, N))
275 return false;
276
277 BindVal = N;
278 return true;
279 }
280};
281
282inline auto m_Value(SDValue &N) {
283 return Value_bind<Value_match>(N, m_Value());
284}
285/// Conditionally bind an SDValue based on the predicate.
286template <typename PredPattern>
287inline auto m_Value(SDValue &N, const PredPattern &P) {
288 return Value_bind<PredPattern>(N, P);
289}
290
291template <typename Pattern, typename PredFuncT> struct TLI_pred_match {
292 Pattern P;
293 PredFuncT PredFunc;
294
295 TLI_pred_match(const PredFuncT &Pred, const Pattern &P)
296 : P(P), PredFunc(Pred) {}
297
298 template <typename MatchContext>
299 bool match(const MatchContext &Ctx, SDValue N) {
300 assert(Ctx.getTLI() && "TargetLowering is required for this pattern.");
301 return PredFunc(*Ctx.getTLI(), N) && P.match(Ctx, N);
302 }
303};
304
305// Explicit deduction guide.
306template <typename PredFuncT, typename Pattern>
307TLI_pred_match(const PredFuncT &Pred, const Pattern &P)
308 -> TLI_pred_match<Pattern, PredFuncT>;
309
310/// Match legal SDNodes based on the information provided by TargetLowering.
311template <typename Pattern> inline auto m_LegalOp(const Pattern &P) {
312 return TLI_pred_match{[](const TargetLowering &TLI, SDValue N) {
313 return TLI.isOperationLegal(Op: N->getOpcode(),
314 VT: N.getValueType());
315 },
316 P};
317}
318
319/// Switch to a different MatchContext for subsequent patterns.
320template <typename NewMatchContext, typename Pattern> struct SwitchContext {
321 const NewMatchContext &Ctx;
322 Pattern P;
323
324 template <typename OrigMatchContext>
325 bool match(const OrigMatchContext &, SDValue N) {
326 return P.match(Ctx, N);
327 }
328};
329
330template <typename MatchContext, typename Pattern>
331inline SwitchContext<MatchContext, Pattern> m_Context(const MatchContext &Ctx,
332 Pattern &&P) {
333 return SwitchContext<MatchContext, Pattern>{Ctx, std::move(P)};
334}
335
336// === Value type ===
337
338template <typename Pattern> struct ValueType_bind {
339 EVT &BindVT;
340 Pattern P;
341
342 explicit ValueType_bind(EVT &Bind, const Pattern &P) : BindVT(Bind), P(P) {}
343
344 template <typename MatchContext>
345 bool match(const MatchContext &Ctx, SDValue N) {
346 BindVT = N.getValueType();
347 return P.match(Ctx, N);
348 }
349};
350
351template <typename Pattern>
352ValueType_bind(const Pattern &P) -> ValueType_bind<Pattern>;
353
354/// Retreive the ValueType of the current SDValue.
355inline auto m_VT(EVT &VT) { return ValueType_bind(VT, m_Value()); }
356
357template <typename Pattern> inline auto m_VT(EVT &VT, const Pattern &P) {
358 return ValueType_bind(VT, P);
359}
360
361template <typename Pattern, typename PredFuncT> struct ValueType_match {
362 PredFuncT PredFunc;
363 Pattern P;
364
365 ValueType_match(const PredFuncT &Pred, const Pattern &P)
366 : PredFunc(Pred), P(P) {}
367
368 template <typename MatchContext>
369 bool match(const MatchContext &Ctx, SDValue N) {
370 return PredFunc(N.getValueType()) && P.match(Ctx, N);
371 }
372};
373
374// Explicit deduction guide.
375template <typename PredFuncT, typename Pattern>
376ValueType_match(const PredFuncT &Pred, const Pattern &P)
377 -> ValueType_match<Pattern, PredFuncT>;
378
379/// Match a specific ValueType.
380template <typename Pattern>
381inline auto m_SpecificVT(EVT RefVT, const Pattern &P) {
382 return ValueType_match{[=](EVT VT) { return VT == RefVT; }, P};
383}
384inline auto m_SpecificVT(EVT RefVT) {
385 return ValueType_match{[=](EVT VT) { return VT == RefVT; }, m_Value()};
386}
387
388inline auto m_Glue() { return m_SpecificVT(RefVT: MVT::Glue); }
389inline auto m_OtherVT() { return m_SpecificVT(RefVT: MVT::Other); }
390
391/// Match a scalar ValueType.
392template <typename Pattern>
393inline auto m_SpecificScalarVT(EVT RefVT, const Pattern &P) {
394 return ValueType_match{[=](EVT VT) { return VT.getScalarType() == RefVT; },
395 P};
396}
397inline auto m_SpecificScalarVT(EVT RefVT) {
398 return ValueType_match{[=](EVT VT) { return VT.getScalarType() == RefVT; },
399 m_Value()};
400}
401
402/// Match a vector ValueType.
403template <typename Pattern>
404inline auto m_SpecificVectorElementVT(EVT RefVT, const Pattern &P) {
405 return ValueType_match{[=](EVT VT) {
406 return VT.isVector() &&
407 VT.getVectorElementType() == RefVT;
408 },
409 P};
410}
411inline auto m_SpecificVectorElementVT(EVT RefVT) {
412 return ValueType_match{[=](EVT VT) {
413 return VT.isVector() &&
414 VT.getVectorElementType() == RefVT;
415 },
416 m_Value()};
417}
418
419/// Match any integer ValueTypes.
420template <typename Pattern> inline auto m_IntegerVT(const Pattern &P) {
421 return ValueType_match{[](EVT VT) { return VT.isInteger(); }, P};
422}
423inline auto m_IntegerVT() {
424 return ValueType_match{[](EVT VT) { return VT.isInteger(); }, m_Value()};
425}
426
427/// Match any floating point ValueTypes.
428template <typename Pattern> inline auto m_FloatingPointVT(const Pattern &P) {
429 return ValueType_match{[](EVT VT) { return VT.isFloatingPoint(); }, P};
430}
431inline auto m_FloatingPointVT() {
432 return ValueType_match{[](EVT VT) { return VT.isFloatingPoint(); },
433 m_Value()};
434}
435
436/// Match any vector ValueTypes.
437template <typename Pattern> inline auto m_VectorVT(const Pattern &P) {
438 return ValueType_match{[](EVT VT) { return VT.isVector(); }, P};
439}
440inline auto m_VectorVT() {
441 return ValueType_match{[](EVT VT) { return VT.isVector(); }, m_Value()};
442}
443
444/// Match fixed-length vector ValueTypes.
445template <typename Pattern> inline auto m_FixedVectorVT(const Pattern &P) {
446 return ValueType_match{[](EVT VT) { return VT.isFixedLengthVector(); }, P};
447}
448inline auto m_FixedVectorVT() {
449 return ValueType_match{[](EVT VT) { return VT.isFixedLengthVector(); },
450 m_Value()};
451}
452
453/// Match scalable vector ValueTypes.
454template <typename Pattern> inline auto m_ScalableVectorVT(const Pattern &P) {
455 return ValueType_match{[](EVT VT) { return VT.isScalableVector(); }, P};
456}
457inline auto m_ScalableVectorVT() {
458 return ValueType_match{[](EVT VT) { return VT.isScalableVector(); },
459 m_Value()};
460}
461
462/// Match legal ValueTypes based on the information provided by TargetLowering.
463template <typename Pattern> inline auto m_LegalType(const Pattern &P) {
464 return TLI_pred_match{[](const TargetLowering &TLI, SDValue N) {
465 return TLI.isTypeLegal(VT: N.getValueType());
466 },
467 P};
468}
469
470// === Generic node matching ===
471template <unsigned OpIdx, typename... OpndPreds> struct Operands_match {
472 template <typename MatchContext>
473 bool match(const MatchContext &Ctx, SDValue N) {
474 // Returns false if there are more operands than predicates;
475 // Ignores the last two operands if both the Context and the Node are VP
476 return Ctx.getNumOperands(N) == OpIdx;
477 }
478};
479
480template <unsigned OpIdx, typename OpndPred, typename... OpndPreds>
481struct Operands_match<OpIdx, OpndPred, OpndPreds...>
482 : Operands_match<OpIdx + 1, OpndPreds...> {
483 OpndPred P;
484
485 Operands_match(const OpndPred &p, const OpndPreds &...preds)
486 : Operands_match<OpIdx + 1, OpndPreds...>(preds...), P(p) {}
487
488 template <typename MatchContext>
489 bool match(const MatchContext &Ctx, SDValue N) {
490 if (OpIdx < N->getNumOperands())
491 return P.match(Ctx, N->getOperand(Num: OpIdx)) &&
492 Operands_match<OpIdx + 1, OpndPreds...>::match(Ctx, N);
493
494 // This is the case where there are more predicates than operands.
495 return false;
496 }
497};
498
499template <typename... OpndPreds>
500auto m_Node(unsigned Opcode, const OpndPreds &...preds) {
501 return m_AllOf(m_SpecificOpc(Opcode),
502 Operands_match<0, OpndPreds...>(preds...));
503}
504
505/// Provide number of operands that are not chain or glue, as well as the first
506/// index of such operand.
507template <bool ExcludeChain> struct EffectiveOperands {
508 unsigned Size = 0;
509 unsigned FirstIndex = 0;
510
511 template <typename MatchContext>
512 explicit EffectiveOperands(SDValue N, const MatchContext &Ctx) {
513 const unsigned TotalNumOps = Ctx.getNumOperands(N);
514 FirstIndex = TotalNumOps;
515 for (unsigned I = 0; I < TotalNumOps; ++I) {
516 // Count the number of non-chain and non-glue nodes (we ignore chain
517 // and glue by default) and retreive the operand index offset.
518 EVT VT = N->getOperand(Num: I).getValueType();
519 if (VT != MVT::Glue && VT != MVT::Other) {
520 ++Size;
521 if (FirstIndex == TotalNumOps)
522 FirstIndex = I;
523 }
524 }
525 }
526};
527
528template <> struct EffectiveOperands<false> {
529 unsigned Size = 0;
530 unsigned FirstIndex = 0;
531
532 template <typename MatchContext>
533 explicit EffectiveOperands(SDValue N, const MatchContext &Ctx)
534 : Size(Ctx.getNumOperands(N)) {}
535};
536
537// === Ternary operations ===
538template <typename T0_P, typename T1_P, typename T2_P, bool Commutable = false,
539 bool ExcludeChain = false>
540struct TernaryOpc_match {
541 unsigned Opcode;
542 T0_P Op0;
543 T1_P Op1;
544 T2_P Op2;
545
546 TernaryOpc_match(unsigned Opc, const T0_P &Op0, const T1_P &Op1,
547 const T2_P &Op2)
548 : Opcode(Opc), Op0(Op0), Op1(Op1), Op2(Op2) {}
549
550 template <typename MatchContext>
551 bool match(const MatchContext &Ctx, SDValue N) {
552 if (sd_context_match(N, Ctx, m_SpecificOpc(Opcode))) {
553 EffectiveOperands<ExcludeChain> EO(N, Ctx);
554 assert(EO.Size == 3);
555 return ((Op0.match(Ctx, N->getOperand(Num: EO.FirstIndex)) &&
556 Op1.match(Ctx, N->getOperand(Num: EO.FirstIndex + 1))) ||
557 (Commutable && Op0.match(Ctx, N->getOperand(Num: EO.FirstIndex + 1)) &&
558 Op1.match(Ctx, N->getOperand(Num: EO.FirstIndex)))) &&
559 Op2.match(Ctx, N->getOperand(Num: EO.FirstIndex + 2));
560 }
561
562 return false;
563 }
564};
565
566template <typename T0_P, typename T1_P, typename T2_P>
567inline TernaryOpc_match<T0_P, T1_P, T2_P>
568m_SetCC(const T0_P &LHS, const T1_P &RHS, const T2_P &CC) {
569 return TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::SETCC, LHS, RHS, CC);
570}
571
572template <typename T0_P, typename T1_P, typename T2_P>
573inline TernaryOpc_match<T0_P, T1_P, T2_P, true, false>
574m_c_SetCC(const T0_P &LHS, const T1_P &RHS, const T2_P &CC) {
575 return TernaryOpc_match<T0_P, T1_P, T2_P, true, false>(ISD::SETCC, LHS, RHS,
576 CC);
577}
578
579template <typename T0_P, typename T1_P, typename T2_P>
580inline TernaryOpc_match<T0_P, T1_P, T2_P>
581m_Select(const T0_P &Cond, const T1_P &T, const T2_P &F) {
582 return TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::SELECT, Cond, T, F);
583}
584
585template <typename T0_P, typename T1_P, typename T2_P>
586inline TernaryOpc_match<T0_P, T1_P, T2_P>
587m_VSelect(const T0_P &Cond, const T1_P &T, const T2_P &F) {
588 return TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::VSELECT, Cond, T, F);
589}
590
591template <typename T0_P, typename T1_P, typename T2_P>
592inline auto m_SelectLike(const T0_P &Cond, const T1_P &T, const T2_P &F) {
593 return m_AnyOf(m_Select(Cond, T, F), m_VSelect(Cond, T, F));
594}
595
596template <typename T0_P, typename T1_P, typename T2_P>
597inline Result_match<0, TernaryOpc_match<T0_P, T1_P, T2_P>>
598m_Load(const T0_P &Ch, const T1_P &Ptr, const T2_P &Offset) {
599 return m_Result<0>(
600 TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::LOAD, Ch, Ptr, Offset));
601}
602
603template <typename T0_P, typename T1_P, typename T2_P>
604inline TernaryOpc_match<T0_P, T1_P, T2_P>
605m_InsertElt(const T0_P &Vec, const T1_P &Val, const T2_P &Idx) {
606 return TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::INSERT_VECTOR_ELT, Vec, Val,
607 Idx);
608}
609
610template <typename LHS, typename RHS, typename IDX>
611inline TernaryOpc_match<LHS, RHS, IDX>
612m_InsertSubvector(const LHS &Base, const RHS &Sub, const IDX &Idx) {
613 return TernaryOpc_match<LHS, RHS, IDX>(ISD::INSERT_SUBVECTOR, Base, Sub, Idx);
614}
615
616template <typename T0_P, typename T1_P, typename T2_P>
617inline TernaryOpc_match<T0_P, T1_P, T2_P>
618m_TernaryOp(unsigned Opc, const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
619 return TernaryOpc_match<T0_P, T1_P, T2_P>(Opc, Op0, Op1, Op2);
620}
621
622template <typename T0_P, typename T1_P, typename T2_P>
623inline TernaryOpc_match<T0_P, T1_P, T2_P, true>
624m_c_TernaryOp(unsigned Opc, const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
625 return TernaryOpc_match<T0_P, T1_P, T2_P, true>(Opc, Op0, Op1, Op2);
626}
627
628template <typename LTy, typename RTy, typename TTy, typename FTy, typename CCTy>
629inline auto m_SelectCC(const LTy &L, const RTy &R, const TTy &T, const FTy &F,
630 const CCTy &CC) {
631 return m_Node(ISD::SELECT_CC, L, R, T, F, CC);
632}
633
634template <typename LTy, typename RTy, typename TTy, typename FTy, typename CCTy>
635inline auto m_SelectCCLike(const LTy &L, const RTy &R, const TTy &T,
636 const FTy &F, const CCTy &CC) {
637 return m_AnyOf(m_Select(m_SetCC(L, R, CC), T, F), m_SelectCC(L, R, T, F, CC));
638}
639
640// === Binary operations ===
641template <typename LHS_P, typename RHS_P, bool Commutable = false,
642 bool ExcludeChain = false>
643struct BinaryOpc_match {
644 unsigned Opcode;
645 LHS_P LHS;
646 RHS_P RHS;
647 SDNodeFlags Flags;
648 BinaryOpc_match(unsigned Opc, const LHS_P &L, const RHS_P &R,
649 SDNodeFlags Flgs = SDNodeFlags())
650 : Opcode(Opc), LHS(L), RHS(R), Flags(Flgs) {}
651
652 template <typename MatchContext>
653 bool match(const MatchContext &Ctx, SDValue N) {
654 if (sd_context_match(N, Ctx, m_SpecificOpc(Opcode))) {
655 EffectiveOperands<ExcludeChain> EO(N, Ctx);
656 assert(EO.Size == 2);
657 if (!((LHS.match(Ctx, N->getOperand(Num: EO.FirstIndex)) &&
658 RHS.match(Ctx, N->getOperand(Num: EO.FirstIndex + 1))) ||
659 (Commutable && LHS.match(Ctx, N->getOperand(Num: EO.FirstIndex + 1)) &&
660 RHS.match(Ctx, N->getOperand(Num: EO.FirstIndex)))))
661 return false;
662
663 return (Flags & N->getFlags()) == Flags;
664 }
665
666 return false;
667 }
668};
669
670/// Matching while capturing mask
671template <typename T0, typename T1, typename T2> struct SDShuffle_match {
672 T0 Op1;
673 T1 Op2;
674 T2 Mask;
675
676 SDShuffle_match(const T0 &Op1, const T1 &Op2, const T2 &Mask)
677 : Op1(Op1), Op2(Op2), Mask(Mask) {}
678
679 template <typename MatchContext>
680 bool match(const MatchContext &Ctx, SDValue N) {
681 if (auto *I = dyn_cast<ShuffleVectorSDNode>(Val&: N)) {
682 return Op1.match(Ctx, I->getOperand(Num: 0)) &&
683 Op2.match(Ctx, I->getOperand(Num: 1)) && Mask.match(I->getMask());
684 }
685 return false;
686 }
687};
688struct m_Mask {
689 ArrayRef<int> &MaskRef;
690 m_Mask(ArrayRef<int> &MaskRef) : MaskRef(MaskRef) {}
691 bool match(ArrayRef<int> Mask) {
692 MaskRef = Mask;
693 return true;
694 }
695};
696
697struct m_SpecificMask {
698 ArrayRef<int> MaskRef;
699 m_SpecificMask(ArrayRef<int> MaskRef) : MaskRef(MaskRef) {}
700 bool match(ArrayRef<int> Mask) { return MaskRef == Mask; }
701};
702
703template <typename LHS_P, typename RHS_P, typename Pred_t,
704 bool Commutable = false, bool ExcludeChain = false>
705struct MaxMin_match {
706 using PredType = Pred_t;
707 LHS_P LHS;
708 RHS_P RHS;
709
710 MaxMin_match(const LHS_P &L, const RHS_P &R) : LHS(L), RHS(R) {}
711
712 template <typename MatchContext>
713 bool match(const MatchContext &Ctx, SDValue N) {
714 auto MatchMinMax = [&](SDValue L, SDValue R, SDValue TrueValue,
715 SDValue FalseValue, ISD::CondCode CC) {
716 if ((TrueValue != L || FalseValue != R) &&
717 (TrueValue != R || FalseValue != L))
718 return false;
719
720 ISD::CondCode Cond =
721 TrueValue == L ? CC : getSetCCInverse(Operation: CC, Type: L.getValueType());
722 if (!Pred_t::match(Cond))
723 return false;
724
725 return (LHS.match(Ctx, L) && RHS.match(Ctx, R)) ||
726 (Commutable && LHS.match(Ctx, R) && RHS.match(Ctx, L));
727 };
728
729 if (sd_context_match(N, Ctx, m_SpecificOpc(Opcode: ISD::SELECT)) ||
730 sd_context_match(N, Ctx, m_SpecificOpc(Opcode: ISD::VSELECT))) {
731 EffectiveOperands<ExcludeChain> EO_SELECT(N, Ctx);
732 assert(EO_SELECT.Size == 3);
733 SDValue Cond = N->getOperand(Num: EO_SELECT.FirstIndex);
734 SDValue TrueValue = N->getOperand(Num: EO_SELECT.FirstIndex + 1);
735 SDValue FalseValue = N->getOperand(Num: EO_SELECT.FirstIndex + 2);
736
737 if (sd_context_match(Cond, Ctx, m_SpecificOpc(Opcode: ISD::SETCC))) {
738 EffectiveOperands<ExcludeChain> EO_SETCC(Cond, Ctx);
739 assert(EO_SETCC.Size == 3);
740 SDValue L = Cond->getOperand(Num: EO_SETCC.FirstIndex);
741 SDValue R = Cond->getOperand(Num: EO_SETCC.FirstIndex + 1);
742 auto *CondNode =
743 cast<CondCodeSDNode>(Cond->getOperand(Num: EO_SETCC.FirstIndex + 2));
744 return MatchMinMax(L, R, TrueValue, FalseValue, CondNode->get());
745 }
746 }
747
748 if (sd_context_match(N, Ctx, m_SpecificOpc(Opcode: ISD::SELECT_CC))) {
749 EffectiveOperands<ExcludeChain> EO_SELECT(N, Ctx);
750 assert(EO_SELECT.Size == 5);
751 SDValue L = N->getOperand(Num: EO_SELECT.FirstIndex);
752 SDValue R = N->getOperand(Num: EO_SELECT.FirstIndex + 1);
753 SDValue TrueValue = N->getOperand(Num: EO_SELECT.FirstIndex + 2);
754 SDValue FalseValue = N->getOperand(Num: EO_SELECT.FirstIndex + 3);
755 auto *CondNode =
756 cast<CondCodeSDNode>(N->getOperand(Num: EO_SELECT.FirstIndex + 4));
757 return MatchMinMax(L, R, TrueValue, FalseValue, CondNode->get());
758 }
759
760 return false;
761 }
762};
763
764// Helper class for identifying signed max predicates.
765struct smax_pred_ty {
766 static bool match(ISD::CondCode Cond) {
767 return Cond == ISD::CondCode::SETGT || Cond == ISD::CondCode::SETGE;
768 }
769};
770
771// Helper class for identifying unsigned max predicates.
772struct umax_pred_ty {
773 static bool match(ISD::CondCode Cond) {
774 return Cond == ISD::CondCode::SETUGT || Cond == ISD::CondCode::SETUGE;
775 }
776};
777
778// Helper class for identifying signed min predicates.
779struct smin_pred_ty {
780 static bool match(ISD::CondCode Cond) {
781 return Cond == ISD::CondCode::SETLT || Cond == ISD::CondCode::SETLE;
782 }
783};
784
785// Helper class for identifying unsigned min predicates.
786struct umin_pred_ty {
787 static bool match(ISD::CondCode Cond) {
788 return Cond == ISD::CondCode::SETULT || Cond == ISD::CondCode::SETULE;
789 }
790};
791
792template <typename LHS, typename RHS>
793inline BinaryOpc_match<LHS, RHS> m_BinOp(unsigned Opc, const LHS &L,
794 const RHS &R,
795 SDNodeFlags Flgs = SDNodeFlags()) {
796 return BinaryOpc_match<LHS, RHS>(Opc, L, R, Flgs);
797}
798template <typename LHS, typename RHS>
799inline BinaryOpc_match<LHS, RHS, true>
800m_c_BinOp(unsigned Opc, const LHS &L, const RHS &R,
801 SDNodeFlags Flgs = SDNodeFlags()) {
802 return BinaryOpc_match<LHS, RHS, true>(Opc, L, R, Flgs);
803}
804
805template <typename LHS, typename RHS>
806inline BinaryOpc_match<LHS, RHS, false, true>
807m_ChainedBinOp(unsigned Opc, const LHS &L, const RHS &R) {
808 return BinaryOpc_match<LHS, RHS, false, true>(Opc, L, R);
809}
810template <typename LHS, typename RHS>
811inline BinaryOpc_match<LHS, RHS, true, true>
812m_c_ChainedBinOp(unsigned Opc, const LHS &L, const RHS &R) {
813 return BinaryOpc_match<LHS, RHS, true, true>(Opc, L, R);
814}
815
816// Common binary operations
817template <typename LHS, typename RHS>
818inline BinaryOpc_match<LHS, RHS, true> m_Add(const LHS &L, const RHS &R) {
819 return BinaryOpc_match<LHS, RHS, true>(ISD::ADD, L, R);
820}
821
822template <typename LHS, typename RHS>
823inline BinaryOpc_match<LHS, RHS> m_Sub(const LHS &L, const RHS &R) {
824 return BinaryOpc_match<LHS, RHS>(ISD::SUB, L, R);
825}
826
827template <typename LHS, typename RHS>
828inline BinaryOpc_match<LHS, RHS, true> m_Mul(const LHS &L, const RHS &R) {
829 return BinaryOpc_match<LHS, RHS, true>(ISD::MUL, L, R);
830}
831
832template <typename LHS, typename RHS>
833inline BinaryOpc_match<LHS, RHS, true> m_And(const LHS &L, const RHS &R) {
834 return BinaryOpc_match<LHS, RHS, true>(ISD::AND, L, R);
835}
836
837template <typename LHS, typename RHS>
838inline BinaryOpc_match<LHS, RHS, true> m_Or(const LHS &L, const RHS &R) {
839 return BinaryOpc_match<LHS, RHS, true>(ISD::OR, L, R);
840}
841
842template <typename LHS, typename RHS>
843inline BinaryOpc_match<LHS, RHS, true> m_DisjointOr(const LHS &L,
844 const RHS &R) {
845 return BinaryOpc_match<LHS, RHS, true>(ISD::OR, L, R, SDNodeFlags::Disjoint);
846}
847
848template <typename LHS, typename RHS>
849inline auto m_AddLike(const LHS &L, const RHS &R) {
850 return m_AnyOf(m_Add(L, R), m_DisjointOr(L, R));
851}
852
853template <typename LHS, typename RHS>
854inline BinaryOpc_match<LHS, RHS, true> m_Xor(const LHS &L, const RHS &R) {
855 return BinaryOpc_match<LHS, RHS, true>(ISD::XOR, L, R);
856}
857
858template <typename LHS, typename RHS>
859inline auto m_BitwiseLogic(const LHS &L, const RHS &R) {
860 return m_AnyOf(m_And(L, R), m_Or(L, R), m_Xor(L, R));
861}
862
863template <unsigned Opc, typename Pred, typename LHS, typename RHS>
864inline auto m_MaxMinLike(const LHS &L, const RHS &R) {
865 return m_AnyOf(BinaryOpc_match<LHS, RHS, true>(Opc, L, R),
866 MaxMin_match<LHS, RHS, Pred, true>(L, R));
867}
868
869template <typename LHS, typename RHS>
870inline BinaryOpc_match<LHS, RHS, true> m_SMin(const LHS &L, const RHS &R) {
871 return BinaryOpc_match<LHS, RHS, true>(ISD::SMIN, L, R);
872}
873
874template <typename LHS, typename RHS>
875inline auto m_SMinLike(const LHS &L, const RHS &R) {
876 return m_AnyOf(
877 m_MaxMinLike<ISD::SMIN, smin_pred_ty>(L, R),
878 m_MaxMinLike<ISD::UMIN, umin_pred_ty>(m_NonNegative(L), m_NonNegative(R)),
879 m_MaxMinLike<ISD::UMIN, umin_pred_ty>(m_Negative(L), m_Negative(R)));
880}
881
882template <typename LHS, typename RHS>
883inline BinaryOpc_match<LHS, RHS, true> m_SMax(const LHS &L, const RHS &R) {
884 return BinaryOpc_match<LHS, RHS, true>(ISD::SMAX, L, R);
885}
886
887template <typename LHS, typename RHS>
888inline auto m_SMaxLike(const LHS &L, const RHS &R) {
889 return m_AnyOf(
890 m_MaxMinLike<ISD::SMAX, smax_pred_ty>(L, R),
891 m_MaxMinLike<ISD::UMAX, umax_pred_ty>(m_NonNegative(L), m_NonNegative(R)),
892 m_MaxMinLike<ISD::UMAX, umax_pred_ty>(m_Negative(L), m_Negative(R)));
893}
894
895template <typename LHS, typename RHS>
896inline BinaryOpc_match<LHS, RHS, true> m_UMin(const LHS &L, const RHS &R) {
897 return BinaryOpc_match<LHS, RHS, true>(ISD::UMIN, L, R);
898}
899
900template <typename LHS, typename RHS>
901inline auto m_UMinLike(const LHS &L, const RHS &R) {
902 return m_AnyOf(
903 m_MaxMinLike<ISD::UMIN, umin_pred_ty>(L, R),
904 m_MaxMinLike<ISD::SMIN, smin_pred_ty>(m_NonNegative(L), m_NonNegative(R)),
905 m_MaxMinLike<ISD::SMIN, smin_pred_ty>(m_Negative(L), m_Negative(R)));
906}
907
908template <typename LHS, typename RHS>
909inline BinaryOpc_match<LHS, RHS, true> m_UMax(const LHS &L, const RHS &R) {
910 return BinaryOpc_match<LHS, RHS, true>(ISD::UMAX, L, R);
911}
912
913template <typename LHS, typename RHS>
914inline auto m_UMaxLike(const LHS &L, const RHS &R) {
915 return m_AnyOf(
916 m_MaxMinLike<ISD::UMAX, umax_pred_ty>(L, R),
917 m_MaxMinLike<ISD::SMAX, smax_pred_ty>(m_NonNegative(L), m_NonNegative(R)),
918 m_MaxMinLike<ISD::SMAX, smax_pred_ty>(m_Negative(L), m_Negative(R)));
919}
920
921template <typename LHS, typename RHS>
922inline BinaryOpc_match<LHS, RHS> m_UDiv(const LHS &L, const RHS &R) {
923 return BinaryOpc_match<LHS, RHS>(ISD::UDIV, L, R);
924}
925template <typename LHS, typename RHS>
926inline BinaryOpc_match<LHS, RHS> m_SDiv(const LHS &L, const RHS &R) {
927 return BinaryOpc_match<LHS, RHS>(ISD::SDIV, L, R);
928}
929
930template <typename LHS, typename RHS>
931inline BinaryOpc_match<LHS, RHS> m_URem(const LHS &L, const RHS &R) {
932 return BinaryOpc_match<LHS, RHS>(ISD::UREM, L, R);
933}
934template <typename LHS, typename RHS>
935inline BinaryOpc_match<LHS, RHS> m_SRem(const LHS &L, const RHS &R) {
936 return BinaryOpc_match<LHS, RHS>(ISD::SREM, L, R);
937}
938
939template <typename LHS, typename RHS>
940inline BinaryOpc_match<LHS, RHS> m_Shl(const LHS &L, const RHS &R) {
941 return BinaryOpc_match<LHS, RHS>(ISD::SHL, L, R);
942}
943
944template <typename LHS, typename RHS>
945inline BinaryOpc_match<LHS, RHS> m_Sra(const LHS &L, const RHS &R) {
946 return BinaryOpc_match<LHS, RHS>(ISD::SRA, L, R);
947}
948template <typename LHS, typename RHS>
949inline BinaryOpc_match<LHS, RHS> m_Srl(const LHS &L, const RHS &R) {
950 return BinaryOpc_match<LHS, RHS>(ISD::SRL, L, R);
951}
952template <typename LHS, typename RHS>
953inline auto m_ExactSr(const LHS &L, const RHS &R) {
954 return m_AnyOf(BinaryOpc_match<LHS, RHS>(ISD::SRA, L, R, SDNodeFlags::Exact),
955 BinaryOpc_match<LHS, RHS>(ISD::SRL, L, R, SDNodeFlags::Exact));
956}
957
958template <typename LHS, typename RHS>
959inline BinaryOpc_match<LHS, RHS> m_Rotl(const LHS &L, const RHS &R) {
960 return BinaryOpc_match<LHS, RHS>(ISD::ROTL, L, R);
961}
962
963template <typename LHS, typename RHS>
964inline BinaryOpc_match<LHS, RHS> m_Rotr(const LHS &L, const RHS &R) {
965 return BinaryOpc_match<LHS, RHS>(ISD::ROTR, L, R);
966}
967
968template <typename LHS, typename RHS>
969inline BinaryOpc_match<LHS, RHS, true> m_Clmul(const LHS &L, const RHS &R) {
970 return BinaryOpc_match<LHS, RHS, true>(ISD::CLMUL, L, R);
971}
972
973template <typename LHS, typename RHS>
974inline BinaryOpc_match<LHS, RHS, true> m_FAdd(const LHS &L, const RHS &R) {
975 return BinaryOpc_match<LHS, RHS, true>(ISD::FADD, L, R);
976}
977
978template <typename LHS, typename RHS>
979inline BinaryOpc_match<LHS, RHS> m_FSub(const LHS &L, const RHS &R) {
980 return BinaryOpc_match<LHS, RHS>(ISD::FSUB, L, R);
981}
982
983template <typename LHS, typename RHS>
984inline BinaryOpc_match<LHS, RHS, true> m_FMul(const LHS &L, const RHS &R) {
985 return BinaryOpc_match<LHS, RHS, true>(ISD::FMUL, L, R);
986}
987
988template <typename LHS, typename RHS>
989inline BinaryOpc_match<LHS, RHS> m_FDiv(const LHS &L, const RHS &R) {
990 return BinaryOpc_match<LHS, RHS>(ISD::FDIV, L, R);
991}
992
993template <typename LHS, typename RHS>
994inline BinaryOpc_match<LHS, RHS> m_FRem(const LHS &L, const RHS &R) {
995 return BinaryOpc_match<LHS, RHS>(ISD::FREM, L, R);
996}
997
998template <typename V1_t, typename V2_t>
999inline BinaryOpc_match<V1_t, V2_t> m_Shuffle(const V1_t &v1, const V2_t &v2) {
1000 return BinaryOpc_match<V1_t, V2_t>(ISD::VECTOR_SHUFFLE, v1, v2);
1001}
1002
1003template <typename V1_t, typename V2_t, typename Mask_t>
1004inline SDShuffle_match<V1_t, V2_t, Mask_t>
1005m_Shuffle(const V1_t &v1, const V2_t &v2, const Mask_t &mask) {
1006 return SDShuffle_match<V1_t, V2_t, Mask_t>(v1, v2, mask);
1007}
1008
1009template <typename LHS, typename RHS>
1010inline BinaryOpc_match<LHS, RHS> m_ExtractElt(const LHS &Vec, const RHS &Idx) {
1011 return BinaryOpc_match<LHS, RHS>(ISD::EXTRACT_VECTOR_ELT, Vec, Idx);
1012}
1013
1014template <typename LHS, typename RHS>
1015inline BinaryOpc_match<LHS, RHS> m_ExtractSubvector(const LHS &Vec,
1016 const RHS &Idx) {
1017 return BinaryOpc_match<LHS, RHS>(ISD::EXTRACT_SUBVECTOR, Vec, Idx);
1018}
1019
1020// === Unary operations ===
1021template <typename Opnd_P, bool ExcludeChain = false> struct UnaryOpc_match {
1022 unsigned Opcode;
1023 Opnd_P Opnd;
1024 SDNodeFlags Flags;
1025 UnaryOpc_match(unsigned Opc, const Opnd_P &Op,
1026 SDNodeFlags Flgs = SDNodeFlags())
1027 : Opcode(Opc), Opnd(Op), Flags(Flgs) {}
1028
1029 template <typename MatchContext>
1030 bool match(const MatchContext &Ctx, SDValue N) {
1031 if (sd_context_match(N, Ctx, m_SpecificOpc(Opcode))) {
1032 EffectiveOperands<ExcludeChain> EO(N, Ctx);
1033 assert(EO.Size == 1);
1034 if (!Opnd.match(Ctx, N->getOperand(Num: EO.FirstIndex)))
1035 return false;
1036
1037 return (Flags & N->getFlags()) == Flags;
1038 }
1039
1040 return false;
1041 }
1042};
1043
1044template <typename Opnd>
1045inline UnaryOpc_match<Opnd> m_UnaryOp(unsigned Opc, const Opnd &Op) {
1046 return UnaryOpc_match<Opnd>(Opc, Op);
1047}
1048template <typename Opnd>
1049inline UnaryOpc_match<Opnd, true> m_ChainedUnaryOp(unsigned Opc,
1050 const Opnd &Op) {
1051 return UnaryOpc_match<Opnd, true>(Opc, Op);
1052}
1053
1054template <typename Opnd> inline UnaryOpc_match<Opnd> m_BitCast(const Opnd &Op) {
1055 return UnaryOpc_match<Opnd>(ISD::BITCAST, Op);
1056}
1057
1058template <typename Opnd>
1059inline UnaryOpc_match<Opnd> m_BSwap(const Opnd &Op) {
1060 return UnaryOpc_match<Opnd>(ISD::BSWAP, Op);
1061}
1062
1063template <typename Opnd>
1064inline UnaryOpc_match<Opnd> m_BitReverse(const Opnd &Op) {
1065 return UnaryOpc_match<Opnd>(ISD::BITREVERSE, Op);
1066}
1067
1068template <typename Opnd> inline UnaryOpc_match<Opnd> m_ZExt(const Opnd &Op) {
1069 return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op);
1070}
1071
1072template <typename Opnd>
1073inline UnaryOpc_match<Opnd> m_NNegZExt(const Opnd &Op) {
1074 return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op, SDNodeFlags::NonNeg);
1075}
1076
1077template <typename Opnd> inline auto m_SExt(const Opnd &Op) {
1078 return UnaryOpc_match<Opnd>(ISD::SIGN_EXTEND, Op);
1079}
1080
1081template <typename Opnd> inline UnaryOpc_match<Opnd> m_AnyExt(const Opnd &Op) {
1082 return UnaryOpc_match<Opnd>(ISD::ANY_EXTEND, Op);
1083}
1084
1085template <typename Opnd> inline UnaryOpc_match<Opnd> m_Trunc(const Opnd &Op) {
1086 return UnaryOpc_match<Opnd>(ISD::TRUNCATE, Op);
1087}
1088
1089template <typename Opnd> inline UnaryOpc_match<Opnd> m_Abs(const Opnd &Op) {
1090 return UnaryOpc_match<Opnd>(ISD::ABS, Op);
1091}
1092
1093template <typename Opnd> inline UnaryOpc_match<Opnd> m_FAbs(const Opnd &Op) {
1094 return UnaryOpc_match<Opnd>(ISD::FABS, Op);
1095}
1096
1097/// Match a zext or identity
1098/// Allows to peek through optional extensions
1099template <typename Opnd> inline auto m_ZExtOrSelf(const Opnd &Op) {
1100 return m_AnyOf(m_ZExt(Op), Op);
1101}
1102
1103/// Match a sext or identity
1104/// Allows to peek through optional extensions
1105template <typename Opnd> inline auto m_SExtOrSelf(const Opnd &Op) {
1106 return m_AnyOf(m_SExt(Op), Op);
1107}
1108
1109template <typename Opnd> inline auto m_SExtLike(const Opnd &Op) {
1110 return m_AnyOf(m_SExt(Op), m_NNegZExt(Op));
1111}
1112
1113/// Match a aext or identity
1114/// Allows to peek through optional extensions
1115template <typename Opnd>
1116inline Or<UnaryOpc_match<Opnd>, Opnd> m_AExtOrSelf(const Opnd &Op) {
1117 return Or<UnaryOpc_match<Opnd>, Opnd>(m_AnyExt(Op), Op);
1118}
1119
1120/// Match a trunc or identity
1121/// Allows to peek through optional truncations
1122template <typename Opnd>
1123inline Or<UnaryOpc_match<Opnd>, Opnd> m_TruncOrSelf(const Opnd &Op) {
1124 return Or<UnaryOpc_match<Opnd>, Opnd>(m_Trunc(Op), Op);
1125}
1126
1127template <typename Opnd> inline UnaryOpc_match<Opnd> m_VScale(const Opnd &Op) {
1128 return UnaryOpc_match<Opnd>(ISD::VSCALE, Op);
1129}
1130
1131template <typename Opnd> inline UnaryOpc_match<Opnd> m_FPToUI(const Opnd &Op) {
1132 return UnaryOpc_match<Opnd>(ISD::FP_TO_UINT, Op);
1133}
1134
1135template <typename Opnd> inline UnaryOpc_match<Opnd> m_FPToSI(const Opnd &Op) {
1136 return UnaryOpc_match<Opnd>(ISD::FP_TO_SINT, Op);
1137}
1138
1139template <typename Opnd> inline UnaryOpc_match<Opnd> m_Ctpop(const Opnd &Op) {
1140 return UnaryOpc_match<Opnd>(ISD::CTPOP, Op);
1141}
1142
1143template <typename Opnd> inline UnaryOpc_match<Opnd> m_Ctlz(const Opnd &Op) {
1144 return UnaryOpc_match<Opnd>(ISD::CTLZ, Op);
1145}
1146
1147template <typename Opnd> inline UnaryOpc_match<Opnd> m_Cttz(const Opnd &Op) {
1148 return UnaryOpc_match<Opnd>(ISD::CTTZ, Op);
1149}
1150
1151template <typename Opnd> inline UnaryOpc_match<Opnd> m_FNeg(const Opnd &Op) {
1152 return UnaryOpc_match<Opnd>(ISD::FNEG, Op);
1153}
1154
1155// === Constants ===
1156struct ConstantInt_match {
1157 APInt *BindVal;
1158
1159 explicit ConstantInt_match(APInt *V) : BindVal(V) {}
1160
1161 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
1162 // The logics here are similar to that in
1163 // SelectionDAG::isConstantIntBuildVectorOrConstantInt, but the latter also
1164 // treats GlobalAddressSDNode as a constant, which is difficult to turn into
1165 // APInt.
1166 if (auto *C = dyn_cast_or_null<ConstantSDNode>(Val: N.getNode())) {
1167 if (BindVal)
1168 *BindVal = C->getAPIntValue();
1169 return true;
1170 }
1171
1172 APInt Discard;
1173 return ISD::isConstantSplatVector(N: N.getNode(),
1174 SplatValue&: BindVal ? *BindVal : Discard);
1175 }
1176};
1177
1178template <typename T> struct Constant64_match {
1179 static_assert(sizeof(T) == 8, "T must be 64 bits wide");
1180
1181 T &BindVal;
1182
1183 explicit Constant64_match(T &V) : BindVal(V) {}
1184
1185 template <typename MatchContext>
1186 bool match(const MatchContext &Ctx, SDValue N) {
1187 APInt V;
1188 if (!ConstantInt_match(&V).match(Ctx, N))
1189 return false;
1190
1191 if constexpr (std::is_signed_v<T>) {
1192 if (std::optional<int64_t> TrySExt = V.trySExtValue()) {
1193 BindVal = *TrySExt;
1194 return true;
1195 }
1196 }
1197
1198 if constexpr (std::is_unsigned_v<T>) {
1199 if (std::optional<uint64_t> TryZExt = V.tryZExtValue()) {
1200 BindVal = *TryZExt;
1201 return true;
1202 }
1203 }
1204
1205 return false;
1206 }
1207};
1208
1209/// Match any integer constants or splat of an integer constant.
1210inline ConstantInt_match m_ConstInt() { return ConstantInt_match(nullptr); }
1211/// Match any integer constants or splat of an integer constant; return the
1212/// specific constant or constant splat value.
1213inline ConstantInt_match m_ConstInt(APInt &V) { return ConstantInt_match(&V); }
1214/// Match any integer constants or splat of an integer constant that can fit in
1215/// 64 bits; return the specific constant or constant splat value, zero-extended
1216/// to 64 bits.
1217inline Constant64_match<uint64_t> m_ConstInt(uint64_t &V) {
1218 return Constant64_match<uint64_t>(V);
1219}
1220/// Match any integer constants or splat of an integer constant that can fit in
1221/// 64 bits; return the specific constant or constant splat value, sign-extended
1222/// to 64 bits.
1223inline Constant64_match<int64_t> m_ConstInt(int64_t &V) {
1224 return Constant64_match<int64_t>(V);
1225}
1226
1227struct SpecificInt_match {
1228 APInt IntVal;
1229
1230 explicit SpecificInt_match(APInt APV) : IntVal(std::move(APV)) {}
1231
1232 template <typename MatchContext>
1233 bool match(const MatchContext &Ctx, SDValue N) {
1234 APInt ConstInt;
1235 if (sd_context_match(N, Ctx, m_ConstInt(V&: ConstInt)))
1236 return APInt::isSameValue(I1: IntVal, I2: ConstInt);
1237 return false;
1238 }
1239};
1240
1241/// Match a specific integer constant or constant splat value.
1242inline SpecificInt_match m_SpecificInt(APInt V) {
1243 return SpecificInt_match(std::move(V));
1244}
1245inline SpecificInt_match m_SpecificInt(uint64_t V) {
1246 return SpecificInt_match(APInt(64, V));
1247}
1248
1249struct SpecificFP_match {
1250 APFloat Val;
1251
1252 explicit SpecificFP_match(APFloat V) : Val(V) {}
1253
1254 template <typename MatchContext>
1255 bool match(const MatchContext &Ctx, SDValue V) {
1256 if (const auto *CFP = dyn_cast<ConstantFPSDNode>(Val: V.getNode()))
1257 return CFP->isExactlyValue(V: Val);
1258 if (ConstantFPSDNode *C = isConstOrConstSplatFP(N: V, /*AllowUndefs=*/AllowUndefs: true))
1259 return C->getValueAPF().compare(RHS: Val) == APFloat::cmpEqual;
1260 return false;
1261 }
1262};
1263
1264/// Match a specific float constant.
1265inline SpecificFP_match m_SpecificFP(APFloat V) { return SpecificFP_match(V); }
1266
1267inline SpecificFP_match m_SpecificFP(double V) {
1268 return SpecificFP_match(APFloat(V));
1269}
1270
1271struct Negative_match {
1272 template <typename MatchContext>
1273 bool match(const MatchContext &Ctx, SDValue N) {
1274 const SelectionDAG *DAG = Ctx.getDAG();
1275 return DAG && DAG->computeKnownBits(Op: N).isNegative();
1276 }
1277};
1278
1279struct NonNegative_match {
1280 template <typename MatchContext>
1281 bool match(const MatchContext &Ctx, SDValue N) {
1282 const SelectionDAG *DAG = Ctx.getDAG();
1283 return DAG && DAG->computeKnownBits(Op: N).isNonNegative();
1284 }
1285};
1286
1287struct StrictlyPositive_match {
1288 template <typename MatchContext>
1289 bool match(const MatchContext &Ctx, SDValue N) {
1290 const SelectionDAG *DAG = Ctx.getDAG();
1291 return DAG && DAG->computeKnownBits(Op: N).isStrictlyPositive();
1292 }
1293};
1294
1295struct NonPositive_match {
1296 template <typename MatchContext>
1297 bool match(const MatchContext &Ctx, SDValue N) {
1298 const SelectionDAG *DAG = Ctx.getDAG();
1299 return DAG && DAG->computeKnownBits(Op: N).isNonPositive();
1300 }
1301};
1302
1303struct NonZero_match {
1304 template <typename MatchContext>
1305 bool match(const MatchContext &Ctx, SDValue N) {
1306 const SelectionDAG *DAG = Ctx.getDAG();
1307 return DAG && DAG->computeKnownBits(Op: N).isNonZero();
1308 }
1309};
1310
1311struct Zero_match {
1312 bool AllowUndefs;
1313
1314 explicit Zero_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
1315
1316 template <typename MatchContext>
1317 bool match(const MatchContext &, SDValue N) const {
1318 return isZeroOrZeroSplat(N, AllowUndefs);
1319 }
1320};
1321
1322struct Ones_match {
1323 bool AllowUndefs;
1324
1325 Ones_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
1326
1327 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
1328 return isOnesOrOnesSplat(N, AllowUndefs);
1329 }
1330};
1331
1332struct AllOnes_match {
1333 bool AllowUndefs;
1334
1335 AllOnes_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
1336
1337 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
1338 return isAllOnesOrAllOnesSplat(V: N, AllowUndefs);
1339 }
1340};
1341
1342inline Negative_match m_Negative() { return Negative_match(); }
1343template <typename Pattern> inline auto m_Negative(const Pattern &P) {
1344 return m_AllOf(m_Negative(), P);
1345}
1346inline NonNegative_match m_NonNegative() { return NonNegative_match(); }
1347template <typename Pattern> inline auto m_NonNegative(const Pattern &P) {
1348 return m_AllOf(m_NonNegative(), P);
1349}
1350inline StrictlyPositive_match m_StrictlyPositive() {
1351 return StrictlyPositive_match();
1352}
1353template <typename Pattern> inline auto m_StrictlyPositive(const Pattern &P) {
1354 return m_AllOf(m_StrictlyPositive(), P);
1355}
1356inline NonPositive_match m_NonPositive() { return NonPositive_match(); }
1357template <typename Pattern> inline auto m_NonPositive(const Pattern &P) {
1358 return m_AllOf(m_NonPositive(), P);
1359}
1360inline NonZero_match m_NonZero() { return NonZero_match(); }
1361template <typename Pattern> inline auto m_NonZero(const Pattern &P) {
1362 return m_AllOf(m_NonZero(), P);
1363}
1364inline Ones_match m_One(bool AllowUndefs = false) {
1365 return Ones_match(AllowUndefs);
1366}
1367inline Zero_match m_Zero(bool AllowUndefs = false) {
1368 return Zero_match(AllowUndefs);
1369}
1370inline AllOnes_match m_AllOnes(bool AllowUndefs = false) {
1371 return AllOnes_match(AllowUndefs);
1372}
1373
1374/// Match true boolean value based on the information provided by
1375/// TargetLowering.
1376inline auto m_True() {
1377 return TLI_pred_match{
1378 [](const TargetLowering &TLI, SDValue N) {
1379 APInt ConstVal;
1380 if (sd_match(N, P: m_ConstInt(V&: ConstVal)))
1381 switch (TLI.getBooleanContents(Type: N.getValueType())) {
1382 case TargetLowering::ZeroOrOneBooleanContent:
1383 return ConstVal.isOne();
1384 case TargetLowering::ZeroOrNegativeOneBooleanContent:
1385 return ConstVal.isAllOnes();
1386 case TargetLowering::UndefinedBooleanContent:
1387 return (ConstVal & 0x01) == 1;
1388 }
1389
1390 return false;
1391 },
1392 m_Value()};
1393}
1394/// Match false boolean value based on the information provided by
1395/// TargetLowering.
1396inline auto m_False() {
1397 return TLI_pred_match{
1398 [](const TargetLowering &TLI, SDValue N) {
1399 APInt ConstVal;
1400 if (sd_match(N, P: m_ConstInt(V&: ConstVal)))
1401 switch (TLI.getBooleanContents(Type: N.getValueType())) {
1402 case TargetLowering::ZeroOrOneBooleanContent:
1403 case TargetLowering::ZeroOrNegativeOneBooleanContent:
1404 return ConstVal.isZero();
1405 case TargetLowering::UndefinedBooleanContent:
1406 return (ConstVal & 0x01) == 0;
1407 }
1408
1409 return false;
1410 },
1411 m_Value()};
1412}
1413
1414struct CondCode_match {
1415 std::optional<ISD::CondCode> CCToMatch;
1416 ISD::CondCode *BindCC = nullptr;
1417
1418 explicit CondCode_match(ISD::CondCode CC) : CCToMatch(CC) {}
1419
1420 explicit CondCode_match(ISD::CondCode *CC) : BindCC(CC) {}
1421
1422 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
1423 if (auto *CC = dyn_cast<CondCodeSDNode>(Val: N.getNode())) {
1424 if (CCToMatch && *CCToMatch != CC->get())
1425 return false;
1426
1427 if (BindCC)
1428 *BindCC = CC->get();
1429 return true;
1430 }
1431
1432 return false;
1433 }
1434};
1435
1436/// Match any conditional code SDNode.
1437inline CondCode_match m_CondCode() { return CondCode_match(nullptr); }
1438/// Match any conditional code SDNode and return its ISD::CondCode value.
1439inline CondCode_match m_CondCode(ISD::CondCode &CC) {
1440 return CondCode_match(&CC);
1441}
1442/// Match a conditional code SDNode with a specific ISD::CondCode.
1443inline CondCode_match m_SpecificCondCode(ISD::CondCode CC) {
1444 return CondCode_match(CC);
1445}
1446
1447/// Match a negate as a sub(0, v)
1448template <typename ValTy>
1449inline BinaryOpc_match<Zero_match, ValTy, false> m_Neg(const ValTy &V) {
1450 return m_Sub(m_Zero(), V);
1451}
1452
1453/// Match a Not as a xor(v, -1) or xor(-1, v)
1454template <typename ValTy>
1455inline BinaryOpc_match<ValTy, AllOnes_match, true> m_Not(const ValTy &V) {
1456 return m_Xor(V, m_AllOnes());
1457}
1458
1459template <unsigned IntrinsicId, typename... OpndPreds>
1460inline auto m_IntrinsicWOChain(const OpndPreds &...Opnds) {
1461 return m_Node(ISD::INTRINSIC_WO_CHAIN, m_SpecificInt(V: IntrinsicId), Opnds...);
1462}
1463
1464struct SpecificNeg_match {
1465 SDValue V;
1466
1467 explicit SpecificNeg_match(SDValue V) : V(V) {}
1468
1469 template <typename MatchContext>
1470 bool match(const MatchContext &Ctx, SDValue N) {
1471 if (sd_context_match(N, Ctx, m_Neg(V: m_Specific(N: V))))
1472 return true;
1473
1474 return ISD::matchBinaryPredicate(
1475 LHS: V, RHS: N, Match: [](ConstantSDNode *LHS, ConstantSDNode *RHS) {
1476 return LHS->getAPIntValue() == -RHS->getAPIntValue();
1477 });
1478 }
1479};
1480
1481/// Match a negation of a specific value V, either as sub(0, V) or as
1482/// constant(s) that are the negation of V's constant(s).
1483inline SpecificNeg_match m_SpecificNeg(SDValue V) {
1484 return SpecificNeg_match(V);
1485}
1486
1487template <typename... PatternTs> struct ReassociatableOpc_match {
1488 unsigned Opcode;
1489 std::tuple<PatternTs...> Patterns;
1490 constexpr static size_t NumPatterns =
1491 std::tuple_size_v<std::tuple<PatternTs...>>;
1492
1493 SDNodeFlags Flags;
1494
1495 ReassociatableOpc_match(unsigned Opcode, const PatternTs &...Patterns)
1496 : Opcode(Opcode), Patterns(Patterns...) {}
1497
1498 ReassociatableOpc_match(unsigned Opcode, SDNodeFlags Flags,
1499 const PatternTs &...Patterns)
1500 : Opcode(Opcode), Patterns(Patterns...), Flags(Flags) {}
1501
1502 template <typename MatchContext>
1503 bool match(const MatchContext &Ctx, SDValue N) {
1504 std::array<SDValue, NumPatterns> Leaves;
1505 size_t LeavesIdx = 0;
1506 if (!(collectLeaves(V: N, Leaves, LeafIdx&: LeavesIdx) && (LeavesIdx == NumPatterns)))
1507 return false;
1508
1509 Bitset<NumPatterns> Used;
1510 return std::apply(
1511 [&](auto &...P) -> bool {
1512 return reassociatableMatchHelper(Ctx, Leaves, Used, P...);
1513 },
1514 Patterns);
1515 }
1516
1517 bool collectLeaves(SDValue V, std::array<SDValue, NumPatterns> &Leaves,
1518 std::size_t &LeafIdx) {
1519 if (V->getOpcode() == Opcode && (Flags & V->getFlags()) == Flags) {
1520 for (size_t I = 0, N = V->getNumOperands(); I < N; I++)
1521 if ((LeafIdx == NumPatterns) ||
1522 !collectLeaves(V: V->getOperand(Num: I), Leaves, LeafIdx))
1523 return false;
1524 } else {
1525 Leaves[LeafIdx] = V;
1526 LeafIdx++;
1527 }
1528 return true;
1529 }
1530
1531 // Searchs for a matching leaf for every sub-pattern.
1532 template <typename MatchContext, typename PatternHd, typename... PatternTl>
1533 [[nodiscard]] inline bool
1534 reassociatableMatchHelper(const MatchContext &Ctx, ArrayRef<SDValue> Leaves,
1535 Bitset<NumPatterns> &Used, PatternHd &HeadPattern,
1536 PatternTl &...TailPatterns) {
1537 for (size_t Match = 0, N = Used.size(); Match < N; Match++) {
1538 if (Used[Match] || !(sd_context_match(Leaves[Match], Ctx, HeadPattern)))
1539 continue;
1540 Used.set(Match);
1541 if (reassociatableMatchHelper(Ctx, Leaves, Used, TailPatterns...))
1542 return true;
1543 Used.reset(Match);
1544 }
1545 return false;
1546 }
1547
1548 template <typename MatchContext>
1549 [[nodiscard]] inline bool
1550 reassociatableMatchHelper(const MatchContext &Ctx, ArrayRef<SDValue> Leaves,
1551 Bitset<NumPatterns> &Used) {
1552 return true;
1553 }
1554};
1555
1556template <typename... PatternTs>
1557inline ReassociatableOpc_match<PatternTs...>
1558m_ReassociatableAdd(const PatternTs &...Patterns) {
1559 return ReassociatableOpc_match<PatternTs...>(ISD::ADD, Patterns...);
1560}
1561
1562template <typename... PatternTs>
1563inline ReassociatableOpc_match<PatternTs...>
1564m_ReassociatableOr(const PatternTs &...Patterns) {
1565 return ReassociatableOpc_match<PatternTs...>(ISD::OR, Patterns...);
1566}
1567
1568template <typename... PatternTs>
1569inline ReassociatableOpc_match<PatternTs...>
1570m_ReassociatableAnd(const PatternTs &...Patterns) {
1571 return ReassociatableOpc_match<PatternTs...>(ISD::AND, Patterns...);
1572}
1573
1574template <typename... PatternTs>
1575inline ReassociatableOpc_match<PatternTs...>
1576m_ReassociatableMul(const PatternTs &...Patterns) {
1577 return ReassociatableOpc_match<PatternTs...>(ISD::MUL, Patterns...);
1578}
1579
1580template <typename... PatternTs>
1581inline ReassociatableOpc_match<PatternTs...>
1582m_ReassociatableNSWAdd(const PatternTs &...Patterns) {
1583 return ReassociatableOpc_match<PatternTs...>(
1584 ISD::ADD, SDNodeFlags::NoSignedWrap, Patterns...);
1585}
1586
1587template <typename... PatternTs>
1588inline ReassociatableOpc_match<PatternTs...>
1589m_ReassociatableNUWAdd(const PatternTs &...Patterns) {
1590 return ReassociatableOpc_match<PatternTs...>(
1591 ISD::ADD, SDNodeFlags::NoUnsignedWrap, Patterns...);
1592}
1593
1594} // namespace SDPatternMatch
1595} // namespace llvm
1596#endif
1597