1//==--------------- llvm/CodeGen/SDPatternMatch.h ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8/// \file
9/// Contains matchers for matching SelectionDAG nodes and values.
10///
11//===----------------------------------------------------------------------===//
12
13#ifndef LLVM_CODEGEN_SDPATTERNMATCH_H
14#define LLVM_CODEGEN_SDPATTERNMATCH_H
15
16#include "llvm/ADT/APInt.h"
17#include "llvm/ADT/ArrayRef.h"
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/ADT/SmallBitVector.h"
20#include "llvm/CodeGen/SelectionDAG.h"
21#include "llvm/CodeGen/SelectionDAGNodes.h"
22#include "llvm/CodeGen/TargetLowering.h"
23#include "llvm/Support/KnownBits.h"
24
25namespace llvm {
26namespace SDPatternMatch {
27
28/// MatchContext can repurpose existing patterns to behave differently under
29/// a certain context. For instance, `m_Opc(ISD::ADD)` matches plain ADD nodes
30/// in normal circumstances, but matches VP_ADD nodes under a custom
31/// VPMatchContext. This design is meant to facilitate code / pattern reusing.
32class BasicMatchContext {
33 const SelectionDAG *DAG;
34 const TargetLowering *TLI;
35
36public:
37 explicit BasicMatchContext(const SelectionDAG *DAG)
38 : DAG(DAG), TLI(DAG ? &DAG->getTargetLoweringInfo() : nullptr) {}
39
40 explicit BasicMatchContext(const TargetLowering *TLI)
41 : DAG(nullptr), TLI(TLI) {}
42
43 // A valid MatchContext has to implement the following functions.
44
45 const SelectionDAG *getDAG() const { return DAG; }
46
47 const TargetLowering *getTLI() const { return TLI; }
48
49 /// Return true if N effectively has opcode Opcode.
50 bool match(SDValue N, unsigned Opcode) const {
51 return N->getOpcode() == Opcode;
52 }
53
54 unsigned getNumOperands(SDValue N) const { return N->getNumOperands(); }
55};
56
57template <typename Pattern, typename MatchContext>
58[[nodiscard]] bool sd_context_match(SDValue N, const MatchContext &Ctx,
59 Pattern &&P) {
60 return P.match(Ctx, N);
61}
62
63template <typename Pattern, typename MatchContext>
64[[nodiscard]] bool sd_context_match(SDNode *N, const MatchContext &Ctx,
65 Pattern &&P) {
66 return sd_context_match(SDValue(N, 0), Ctx, P);
67}
68
69template <typename Pattern>
70[[nodiscard]] bool sd_match(SDNode *N, const SelectionDAG *DAG, Pattern &&P) {
71 return sd_context_match(N, BasicMatchContext(DAG), P);
72}
73
74template <typename Pattern>
75[[nodiscard]] bool sd_match(SDValue N, const SelectionDAG *DAG, Pattern &&P) {
76 return sd_context_match(N, BasicMatchContext(DAG), P);
77}
78
79template <typename Pattern>
80[[nodiscard]] bool sd_match(SDNode *N, Pattern &&P) {
81 return sd_match(N, nullptr, P);
82}
83
84template <typename Pattern>
85[[nodiscard]] bool sd_match(SDValue N, Pattern &&P) {
86 return sd_match(N, nullptr, P);
87}
88
89// === Utilities ===
90struct Value_match {
91 SDValue MatchVal;
92
93 Value_match() = default;
94
95 explicit Value_match(SDValue Match) : MatchVal(Match) {}
96
97 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
98 if (MatchVal)
99 return MatchVal == N;
100 return N.getNode();
101 }
102};
103
104/// Match any valid SDValue.
105inline Value_match m_Value() { return Value_match(); }
106
107inline Value_match m_Specific(SDValue N) {
108 assert(N);
109 return Value_match(N);
110}
111
112template <unsigned ResNo, typename Pattern> struct Result_match {
113 Pattern P;
114
115 explicit Result_match(const Pattern &P) : P(P) {}
116
117 template <typename MatchContext>
118 bool match(const MatchContext &Ctx, SDValue N) {
119 return N.getResNo() == ResNo && P.match(Ctx, N);
120 }
121};
122
123/// Match only if the SDValue is a certain result at ResNo.
124template <unsigned ResNo, typename Pattern>
125inline Result_match<ResNo, Pattern> m_Result(const Pattern &P) {
126 return Result_match<ResNo, Pattern>(P);
127}
128
129struct DeferredValue_match {
130 SDValue &MatchVal;
131
132 explicit DeferredValue_match(SDValue &Match) : MatchVal(Match) {}
133
134 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
135 return N == MatchVal;
136 }
137};
138
139/// Similar to m_Specific, but the specific value to match is determined by
140/// another sub-pattern in the same sd_match() expression. For instance,
141/// We cannot match `(add V, V)` with `m_Add(m_Value(X), m_Specific(X))` since
142/// `X` is not initialized at the time it got copied into `m_Specific`. Instead,
143/// we should use `m_Add(m_Value(X), m_Deferred(X))`.
144inline DeferredValue_match m_Deferred(SDValue &V) {
145 return DeferredValue_match(V);
146}
147
148struct Opcode_match {
149 unsigned Opcode;
150
151 explicit Opcode_match(unsigned Opc) : Opcode(Opc) {}
152
153 template <typename MatchContext>
154 bool match(const MatchContext &Ctx, SDValue N) {
155 return Ctx.match(N, Opcode);
156 }
157};
158
159// === Patterns combinators ===
160template <typename... Preds> struct And {
161 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
162 return true;
163 }
164};
165
166template <typename Pred, typename... Preds>
167struct And<Pred, Preds...> : And<Preds...> {
168 Pred P;
169 And(const Pred &p, const Preds &...preds) : And<Preds...>(preds...), P(p) {}
170
171 template <typename MatchContext>
172 bool match(const MatchContext &Ctx, SDValue N) {
173 return P.match(Ctx, N) && And<Preds...>::match(Ctx, N);
174 }
175};
176
177template <typename... Preds> struct Or {
178 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
179 return false;
180 }
181};
182
183template <typename Pred, typename... Preds>
184struct Or<Pred, Preds...> : Or<Preds...> {
185 Pred P;
186 Or(const Pred &p, const Preds &...preds) : Or<Preds...>(preds...), P(p) {}
187
188 template <typename MatchContext>
189 bool match(const MatchContext &Ctx, SDValue N) {
190 return P.match(Ctx, N) || Or<Preds...>::match(Ctx, N);
191 }
192};
193
194template <typename Pred> struct Not {
195 Pred P;
196
197 explicit Not(const Pred &P) : P(P) {}
198
199 template <typename MatchContext>
200 bool match(const MatchContext &Ctx, SDValue N) {
201 return !P.match(Ctx, N);
202 }
203};
204// Explicit deduction guide.
205template <typename Pred> Not(const Pred &P) -> Not<Pred>;
206
207/// Match if the inner pattern does NOT match.
208template <typename Pred> inline Not<Pred> m_Unless(const Pred &P) {
209 return Not{P};
210}
211
212template <typename... Preds> And<Preds...> m_AllOf(const Preds &...preds) {
213 return And<Preds...>(preds...);
214}
215
216template <typename... Preds> Or<Preds...> m_AnyOf(const Preds &...preds) {
217 return Or<Preds...>(preds...);
218}
219
220template <typename... Preds> auto m_NoneOf(const Preds &...preds) {
221 return m_Unless(m_AnyOf(preds...));
222}
223
224inline Opcode_match m_Opc(unsigned Opcode) { return Opcode_match(Opcode); }
225
226inline auto m_Undef() {
227 return m_AnyOf(preds: Opcode_match(ISD::UNDEF), preds: Opcode_match(ISD::POISON));
228}
229
230inline Opcode_match m_Poison() { return Opcode_match(ISD::POISON); }
231
232template <unsigned NumUses, typename Pattern> struct NUses_match {
233 Pattern P;
234
235 explicit NUses_match(const Pattern &P) : P(P) {}
236
237 template <typename MatchContext>
238 bool match(const MatchContext &Ctx, SDValue N) {
239 // SDNode::hasNUsesOfValue is pretty expensive when the SDNode produces
240 // multiple results, hence we check the subsequent pattern here before
241 // checking the number of value users.
242 return P.match(Ctx, N) && N->hasNUsesOfValue(NUses: NumUses, Value: N.getResNo());
243 }
244};
245
246template <typename Pattern>
247inline NUses_match<1, Pattern> m_OneUse(const Pattern &P) {
248 return NUses_match<1, Pattern>(P);
249}
250template <unsigned N, typename Pattern>
251inline NUses_match<N, Pattern> m_NUses(const Pattern &P) {
252 return NUses_match<N, Pattern>(P);
253}
254
255inline NUses_match<1, Value_match> m_OneUse() {
256 return NUses_match<1, Value_match>(m_Value());
257}
258template <unsigned N> inline NUses_match<N, Value_match> m_NUses() {
259 return NUses_match<N, Value_match>(m_Value());
260}
261
262struct Value_bind {
263 SDValue &BindVal;
264
265 explicit Value_bind(SDValue &N) : BindVal(N) {}
266
267 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
268 BindVal = N;
269 return true;
270 }
271};
272
273inline Value_bind m_Value(SDValue &N) { return Value_bind(N); }
274
275template <typename Pattern, typename PredFuncT> struct TLI_pred_match {
276 Pattern P;
277 PredFuncT PredFunc;
278
279 TLI_pred_match(const PredFuncT &Pred, const Pattern &P)
280 : P(P), PredFunc(Pred) {}
281
282 template <typename MatchContext>
283 bool match(const MatchContext &Ctx, SDValue N) {
284 assert(Ctx.getTLI() && "TargetLowering is required for this pattern.");
285 return PredFunc(*Ctx.getTLI(), N) && P.match(Ctx, N);
286 }
287};
288
289// Explicit deduction guide.
290template <typename PredFuncT, typename Pattern>
291TLI_pred_match(const PredFuncT &Pred, const Pattern &P)
292 -> TLI_pred_match<Pattern, PredFuncT>;
293
294/// Match legal SDNodes based on the information provided by TargetLowering.
295template <typename Pattern> inline auto m_LegalOp(const Pattern &P) {
296 return TLI_pred_match{[](const TargetLowering &TLI, SDValue N) {
297 return TLI.isOperationLegal(Op: N->getOpcode(),
298 VT: N.getValueType());
299 },
300 P};
301}
302
303/// Switch to a different MatchContext for subsequent patterns.
304template <typename NewMatchContext, typename Pattern> struct SwitchContext {
305 const NewMatchContext &Ctx;
306 Pattern P;
307
308 template <typename OrigMatchContext>
309 bool match(const OrigMatchContext &, SDValue N) {
310 return P.match(Ctx, N);
311 }
312};
313
314template <typename MatchContext, typename Pattern>
315inline SwitchContext<MatchContext, Pattern> m_Context(const MatchContext &Ctx,
316 Pattern &&P) {
317 return SwitchContext<MatchContext, Pattern>{Ctx, std::move(P)};
318}
319
320// === Value type ===
321struct ValueType_bind {
322 EVT &BindVT;
323
324 explicit ValueType_bind(EVT &Bind) : BindVT(Bind) {}
325
326 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
327 BindVT = N.getValueType();
328 return true;
329 }
330};
331
332/// Retreive the ValueType of the current SDValue.
333inline ValueType_bind m_VT(EVT &VT) { return ValueType_bind(VT); }
334
335template <typename Pattern, typename PredFuncT> struct ValueType_match {
336 PredFuncT PredFunc;
337 Pattern P;
338
339 ValueType_match(const PredFuncT &Pred, const Pattern &P)
340 : PredFunc(Pred), P(P) {}
341
342 template <typename MatchContext>
343 bool match(const MatchContext &Ctx, SDValue N) {
344 return PredFunc(N.getValueType()) && P.match(Ctx, N);
345 }
346};
347
348// Explicit deduction guide.
349template <typename PredFuncT, typename Pattern>
350ValueType_match(const PredFuncT &Pred, const Pattern &P)
351 -> ValueType_match<Pattern, PredFuncT>;
352
353/// Match a specific ValueType.
354template <typename Pattern>
355inline auto m_SpecificVT(EVT RefVT, const Pattern &P) {
356 return ValueType_match{[=](EVT VT) { return VT == RefVT; }, P};
357}
358inline auto m_SpecificVT(EVT RefVT) {
359 return ValueType_match{[=](EVT VT) { return VT == RefVT; }, m_Value()};
360}
361
362inline auto m_Glue() { return m_SpecificVT(RefVT: MVT::Glue); }
363inline auto m_OtherVT() { return m_SpecificVT(RefVT: MVT::Other); }
364
365/// Match a scalar ValueType.
366template <typename Pattern>
367inline auto m_SpecificScalarVT(EVT RefVT, const Pattern &P) {
368 return ValueType_match{[=](EVT VT) { return VT.getScalarType() == RefVT; },
369 P};
370}
371inline auto m_SpecificScalarVT(EVT RefVT) {
372 return ValueType_match{[=](EVT VT) { return VT.getScalarType() == RefVT; },
373 m_Value()};
374}
375
376/// Match a vector ValueType.
377template <typename Pattern>
378inline auto m_SpecificVectorElementVT(EVT RefVT, const Pattern &P) {
379 return ValueType_match{[=](EVT VT) {
380 return VT.isVector() &&
381 VT.getVectorElementType() == RefVT;
382 },
383 P};
384}
385inline auto m_SpecificVectorElementVT(EVT RefVT) {
386 return ValueType_match{[=](EVT VT) {
387 return VT.isVector() &&
388 VT.getVectorElementType() == RefVT;
389 },
390 m_Value()};
391}
392
393/// Match any integer ValueTypes.
394template <typename Pattern> inline auto m_IntegerVT(const Pattern &P) {
395 return ValueType_match{[](EVT VT) { return VT.isInteger(); }, P};
396}
397inline auto m_IntegerVT() {
398 return ValueType_match{[](EVT VT) { return VT.isInteger(); }, m_Value()};
399}
400
401/// Match any floating point ValueTypes.
402template <typename Pattern> inline auto m_FloatingPointVT(const Pattern &P) {
403 return ValueType_match{[](EVT VT) { return VT.isFloatingPoint(); }, P};
404}
405inline auto m_FloatingPointVT() {
406 return ValueType_match{[](EVT VT) { return VT.isFloatingPoint(); },
407 m_Value()};
408}
409
410/// Match any vector ValueTypes.
411template <typename Pattern> inline auto m_VectorVT(const Pattern &P) {
412 return ValueType_match{[](EVT VT) { return VT.isVector(); }, P};
413}
414inline auto m_VectorVT() {
415 return ValueType_match{[](EVT VT) { return VT.isVector(); }, m_Value()};
416}
417
418/// Match fixed-length vector ValueTypes.
419template <typename Pattern> inline auto m_FixedVectorVT(const Pattern &P) {
420 return ValueType_match{[](EVT VT) { return VT.isFixedLengthVector(); }, P};
421}
422inline auto m_FixedVectorVT() {
423 return ValueType_match{[](EVT VT) { return VT.isFixedLengthVector(); },
424 m_Value()};
425}
426
427/// Match scalable vector ValueTypes.
428template <typename Pattern> inline auto m_ScalableVectorVT(const Pattern &P) {
429 return ValueType_match{[](EVT VT) { return VT.isScalableVector(); }, P};
430}
431inline auto m_ScalableVectorVT() {
432 return ValueType_match{[](EVT VT) { return VT.isScalableVector(); },
433 m_Value()};
434}
435
436/// Match legal ValueTypes based on the information provided by TargetLowering.
437template <typename Pattern> inline auto m_LegalType(const Pattern &P) {
438 return TLI_pred_match{[](const TargetLowering &TLI, SDValue N) {
439 return TLI.isTypeLegal(VT: N.getValueType());
440 },
441 P};
442}
443
444// === Generic node matching ===
445template <unsigned OpIdx, typename... OpndPreds> struct Operands_match {
446 template <typename MatchContext>
447 bool match(const MatchContext &Ctx, SDValue N) {
448 // Returns false if there are more operands than predicates;
449 // Ignores the last two operands if both the Context and the Node are VP
450 return Ctx.getNumOperands(N) == OpIdx;
451 }
452};
453
454template <unsigned OpIdx, typename OpndPred, typename... OpndPreds>
455struct Operands_match<OpIdx, OpndPred, OpndPreds...>
456 : Operands_match<OpIdx + 1, OpndPreds...> {
457 OpndPred P;
458
459 Operands_match(const OpndPred &p, const OpndPreds &...preds)
460 : Operands_match<OpIdx + 1, OpndPreds...>(preds...), P(p) {}
461
462 template <typename MatchContext>
463 bool match(const MatchContext &Ctx, SDValue N) {
464 if (OpIdx < N->getNumOperands())
465 return P.match(Ctx, N->getOperand(Num: OpIdx)) &&
466 Operands_match<OpIdx + 1, OpndPreds...>::match(Ctx, N);
467
468 // This is the case where there are more predicates than operands.
469 return false;
470 }
471};
472
473template <typename... OpndPreds>
474auto m_Node(unsigned Opcode, const OpndPreds &...preds) {
475 return m_AllOf(m_Opc(Opcode), Operands_match<0, OpndPreds...>(preds...));
476}
477
478/// Provide number of operands that are not chain or glue, as well as the first
479/// index of such operand.
480template <bool ExcludeChain> struct EffectiveOperands {
481 unsigned Size = 0;
482 unsigned FirstIndex = 0;
483
484 template <typename MatchContext>
485 explicit EffectiveOperands(SDValue N, const MatchContext &Ctx) {
486 const unsigned TotalNumOps = Ctx.getNumOperands(N);
487 FirstIndex = TotalNumOps;
488 for (unsigned I = 0; I < TotalNumOps; ++I) {
489 // Count the number of non-chain and non-glue nodes (we ignore chain
490 // and glue by default) and retreive the operand index offset.
491 EVT VT = N->getOperand(Num: I).getValueType();
492 if (VT != MVT::Glue && VT != MVT::Other) {
493 ++Size;
494 if (FirstIndex == TotalNumOps)
495 FirstIndex = I;
496 }
497 }
498 }
499};
500
501template <> struct EffectiveOperands<false> {
502 unsigned Size = 0;
503 unsigned FirstIndex = 0;
504
505 template <typename MatchContext>
506 explicit EffectiveOperands(SDValue N, const MatchContext &Ctx)
507 : Size(Ctx.getNumOperands(N)) {}
508};
509
510// === Ternary operations ===
511template <typename T0_P, typename T1_P, typename T2_P, bool Commutable = false,
512 bool ExcludeChain = false>
513struct TernaryOpc_match {
514 unsigned Opcode;
515 T0_P Op0;
516 T1_P Op1;
517 T2_P Op2;
518
519 TernaryOpc_match(unsigned Opc, const T0_P &Op0, const T1_P &Op1,
520 const T2_P &Op2)
521 : Opcode(Opc), Op0(Op0), Op1(Op1), Op2(Op2) {}
522
523 template <typename MatchContext>
524 bool match(const MatchContext &Ctx, SDValue N) {
525 if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
526 EffectiveOperands<ExcludeChain> EO(N, Ctx);
527 assert(EO.Size == 3);
528 return ((Op0.match(Ctx, N->getOperand(Num: EO.FirstIndex)) &&
529 Op1.match(Ctx, N->getOperand(Num: EO.FirstIndex + 1))) ||
530 (Commutable && Op0.match(Ctx, N->getOperand(Num: EO.FirstIndex + 1)) &&
531 Op1.match(Ctx, N->getOperand(Num: EO.FirstIndex)))) &&
532 Op2.match(Ctx, N->getOperand(Num: EO.FirstIndex + 2));
533 }
534
535 return false;
536 }
537};
538
539template <typename T0_P, typename T1_P, typename T2_P>
540inline TernaryOpc_match<T0_P, T1_P, T2_P>
541m_SetCC(const T0_P &LHS, const T1_P &RHS, const T2_P &CC) {
542 return TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::SETCC, LHS, RHS, CC);
543}
544
545template <typename T0_P, typename T1_P, typename T2_P>
546inline TernaryOpc_match<T0_P, T1_P, T2_P, true, false>
547m_c_SetCC(const T0_P &LHS, const T1_P &RHS, const T2_P &CC) {
548 return TernaryOpc_match<T0_P, T1_P, T2_P, true, false>(ISD::SETCC, LHS, RHS,
549 CC);
550}
551
552template <typename T0_P, typename T1_P, typename T2_P>
553inline TernaryOpc_match<T0_P, T1_P, T2_P>
554m_Select(const T0_P &Cond, const T1_P &T, const T2_P &F) {
555 return TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::SELECT, Cond, T, F);
556}
557
558template <typename T0_P, typename T1_P, typename T2_P>
559inline TernaryOpc_match<T0_P, T1_P, T2_P>
560m_VSelect(const T0_P &Cond, const T1_P &T, const T2_P &F) {
561 return TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::VSELECT, Cond, T, F);
562}
563
564template <typename T0_P, typename T1_P, typename T2_P>
565inline auto m_SelectLike(const T0_P &Cond, const T1_P &T, const T2_P &F) {
566 return m_AnyOf(m_Select(Cond, T, F), m_VSelect(Cond, T, F));
567}
568
569template <typename T0_P, typename T1_P, typename T2_P>
570inline Result_match<0, TernaryOpc_match<T0_P, T1_P, T2_P>>
571m_Load(const T0_P &Ch, const T1_P &Ptr, const T2_P &Offset) {
572 return m_Result<0>(
573 TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::LOAD, Ch, Ptr, Offset));
574}
575
576template <typename T0_P, typename T1_P, typename T2_P>
577inline TernaryOpc_match<T0_P, T1_P, T2_P>
578m_InsertElt(const T0_P &Vec, const T1_P &Val, const T2_P &Idx) {
579 return TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::INSERT_VECTOR_ELT, Vec, Val,
580 Idx);
581}
582
583template <typename LHS, typename RHS, typename IDX>
584inline TernaryOpc_match<LHS, RHS, IDX>
585m_InsertSubvector(const LHS &Base, const RHS &Sub, const IDX &Idx) {
586 return TernaryOpc_match<LHS, RHS, IDX>(ISD::INSERT_SUBVECTOR, Base, Sub, Idx);
587}
588
589template <typename T0_P, typename T1_P, typename T2_P>
590inline TernaryOpc_match<T0_P, T1_P, T2_P>
591m_TernaryOp(unsigned Opc, const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
592 return TernaryOpc_match<T0_P, T1_P, T2_P>(Opc, Op0, Op1, Op2);
593}
594
595template <typename T0_P, typename T1_P, typename T2_P>
596inline TernaryOpc_match<T0_P, T1_P, T2_P, true>
597m_c_TernaryOp(unsigned Opc, const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
598 return TernaryOpc_match<T0_P, T1_P, T2_P, true>(Opc, Op0, Op1, Op2);
599}
600
601template <typename LTy, typename RTy, typename TTy, typename FTy, typename CCTy>
602inline auto m_SelectCC(const LTy &L, const RTy &R, const TTy &T, const FTy &F,
603 const CCTy &CC) {
604 return m_Node(ISD::SELECT_CC, L, R, T, F, CC);
605}
606
607template <typename LTy, typename RTy, typename TTy, typename FTy, typename CCTy>
608inline auto m_SelectCCLike(const LTy &L, const RTy &R, const TTy &T,
609 const FTy &F, const CCTy &CC) {
610 return m_AnyOf(m_Select(m_SetCC(L, R, CC), T, F), m_SelectCC(L, R, T, F, CC));
611}
612
613// === Binary operations ===
614template <typename LHS_P, typename RHS_P, bool Commutable = false,
615 bool ExcludeChain = false>
616struct BinaryOpc_match {
617 unsigned Opcode;
618 LHS_P LHS;
619 RHS_P RHS;
620 SDNodeFlags Flags;
621 BinaryOpc_match(unsigned Opc, const LHS_P &L, const RHS_P &R,
622 SDNodeFlags Flgs = SDNodeFlags())
623 : Opcode(Opc), LHS(L), RHS(R), Flags(Flgs) {}
624
625 template <typename MatchContext>
626 bool match(const MatchContext &Ctx, SDValue N) {
627 if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
628 EffectiveOperands<ExcludeChain> EO(N, Ctx);
629 assert(EO.Size == 2);
630 if (!((LHS.match(Ctx, N->getOperand(Num: EO.FirstIndex)) &&
631 RHS.match(Ctx, N->getOperand(Num: EO.FirstIndex + 1))) ||
632 (Commutable && LHS.match(Ctx, N->getOperand(Num: EO.FirstIndex + 1)) &&
633 RHS.match(Ctx, N->getOperand(Num: EO.FirstIndex)))))
634 return false;
635
636 return (Flags & N->getFlags()) == Flags;
637 }
638
639 return false;
640 }
641};
642
643/// Matching while capturing mask
644template <typename T0, typename T1, typename T2> struct SDShuffle_match {
645 T0 Op1;
646 T1 Op2;
647 T2 Mask;
648
649 SDShuffle_match(const T0 &Op1, const T1 &Op2, const T2 &Mask)
650 : Op1(Op1), Op2(Op2), Mask(Mask) {}
651
652 template <typename MatchContext>
653 bool match(const MatchContext &Ctx, SDValue N) {
654 if (auto *I = dyn_cast<ShuffleVectorSDNode>(Val&: N)) {
655 return Op1.match(Ctx, I->getOperand(Num: 0)) &&
656 Op2.match(Ctx, I->getOperand(Num: 1)) && Mask.match(I->getMask());
657 }
658 return false;
659 }
660};
661struct m_Mask {
662 ArrayRef<int> &MaskRef;
663 m_Mask(ArrayRef<int> &MaskRef) : MaskRef(MaskRef) {}
664 bool match(ArrayRef<int> Mask) {
665 MaskRef = Mask;
666 return true;
667 }
668};
669
670struct m_SpecificMask {
671 ArrayRef<int> MaskRef;
672 m_SpecificMask(ArrayRef<int> MaskRef) : MaskRef(MaskRef) {}
673 bool match(ArrayRef<int> Mask) { return MaskRef == Mask; }
674};
675
676template <typename LHS_P, typename RHS_P, typename Pred_t,
677 bool Commutable = false, bool ExcludeChain = false>
678struct MaxMin_match {
679 using PredType = Pred_t;
680 LHS_P LHS;
681 RHS_P RHS;
682
683 MaxMin_match(const LHS_P &L, const RHS_P &R) : LHS(L), RHS(R) {}
684
685 template <typename MatchContext>
686 bool match(const MatchContext &Ctx, SDValue N) {
687 auto MatchMinMax = [&](SDValue L, SDValue R, SDValue TrueValue,
688 SDValue FalseValue, ISD::CondCode CC) {
689 if ((TrueValue != L || FalseValue != R) &&
690 (TrueValue != R || FalseValue != L))
691 return false;
692
693 ISD::CondCode Cond =
694 TrueValue == L ? CC : getSetCCInverse(Operation: CC, Type: L.getValueType());
695 if (!Pred_t::match(Cond))
696 return false;
697
698 return (LHS.match(Ctx, L) && RHS.match(Ctx, R)) ||
699 (Commutable && LHS.match(Ctx, R) && RHS.match(Ctx, L));
700 };
701
702 if (sd_context_match(N, Ctx, m_Opc(Opcode: ISD::SELECT)) ||
703 sd_context_match(N, Ctx, m_Opc(Opcode: ISD::VSELECT))) {
704 EffectiveOperands<ExcludeChain> EO_SELECT(N, Ctx);
705 assert(EO_SELECT.Size == 3);
706 SDValue Cond = N->getOperand(Num: EO_SELECT.FirstIndex);
707 SDValue TrueValue = N->getOperand(Num: EO_SELECT.FirstIndex + 1);
708 SDValue FalseValue = N->getOperand(Num: EO_SELECT.FirstIndex + 2);
709
710 if (sd_context_match(Cond, Ctx, m_Opc(Opcode: ISD::SETCC))) {
711 EffectiveOperands<ExcludeChain> EO_SETCC(Cond, Ctx);
712 assert(EO_SETCC.Size == 3);
713 SDValue L = Cond->getOperand(Num: EO_SETCC.FirstIndex);
714 SDValue R = Cond->getOperand(Num: EO_SETCC.FirstIndex + 1);
715 auto *CondNode =
716 cast<CondCodeSDNode>(Cond->getOperand(Num: EO_SETCC.FirstIndex + 2));
717 return MatchMinMax(L, R, TrueValue, FalseValue, CondNode->get());
718 }
719 }
720
721 if (sd_context_match(N, Ctx, m_Opc(Opcode: ISD::SELECT_CC))) {
722 EffectiveOperands<ExcludeChain> EO_SELECT(N, Ctx);
723 assert(EO_SELECT.Size == 5);
724 SDValue L = N->getOperand(Num: EO_SELECT.FirstIndex);
725 SDValue R = N->getOperand(Num: EO_SELECT.FirstIndex + 1);
726 SDValue TrueValue = N->getOperand(Num: EO_SELECT.FirstIndex + 2);
727 SDValue FalseValue = N->getOperand(Num: EO_SELECT.FirstIndex + 3);
728 auto *CondNode =
729 cast<CondCodeSDNode>(N->getOperand(Num: EO_SELECT.FirstIndex + 4));
730 return MatchMinMax(L, R, TrueValue, FalseValue, CondNode->get());
731 }
732
733 return false;
734 }
735};
736
737// Helper class for identifying signed max predicates.
738struct smax_pred_ty {
739 static bool match(ISD::CondCode Cond) {
740 return Cond == ISD::CondCode::SETGT || Cond == ISD::CondCode::SETGE;
741 }
742};
743
744// Helper class for identifying unsigned max predicates.
745struct umax_pred_ty {
746 static bool match(ISD::CondCode Cond) {
747 return Cond == ISD::CondCode::SETUGT || Cond == ISD::CondCode::SETUGE;
748 }
749};
750
751// Helper class for identifying signed min predicates.
752struct smin_pred_ty {
753 static bool match(ISD::CondCode Cond) {
754 return Cond == ISD::CondCode::SETLT || Cond == ISD::CondCode::SETLE;
755 }
756};
757
758// Helper class for identifying unsigned min predicates.
759struct umin_pred_ty {
760 static bool match(ISD::CondCode Cond) {
761 return Cond == ISD::CondCode::SETULT || Cond == ISD::CondCode::SETULE;
762 }
763};
764
765template <typename LHS, typename RHS>
766inline BinaryOpc_match<LHS, RHS> m_BinOp(unsigned Opc, const LHS &L,
767 const RHS &R,
768 SDNodeFlags Flgs = SDNodeFlags()) {
769 return BinaryOpc_match<LHS, RHS>(Opc, L, R, Flgs);
770}
771template <typename LHS, typename RHS>
772inline BinaryOpc_match<LHS, RHS, true>
773m_c_BinOp(unsigned Opc, const LHS &L, const RHS &R,
774 SDNodeFlags Flgs = SDNodeFlags()) {
775 return BinaryOpc_match<LHS, RHS, true>(Opc, L, R, Flgs);
776}
777
778template <typename LHS, typename RHS>
779inline BinaryOpc_match<LHS, RHS, false, true>
780m_ChainedBinOp(unsigned Opc, const LHS &L, const RHS &R) {
781 return BinaryOpc_match<LHS, RHS, false, true>(Opc, L, R);
782}
783template <typename LHS, typename RHS>
784inline BinaryOpc_match<LHS, RHS, true, true>
785m_c_ChainedBinOp(unsigned Opc, const LHS &L, const RHS &R) {
786 return BinaryOpc_match<LHS, RHS, true, true>(Opc, L, R);
787}
788
789// Common binary operations
790template <typename LHS, typename RHS>
791inline BinaryOpc_match<LHS, RHS, true> m_Add(const LHS &L, const RHS &R) {
792 return BinaryOpc_match<LHS, RHS, true>(ISD::ADD, L, R);
793}
794
795template <typename LHS, typename RHS>
796inline BinaryOpc_match<LHS, RHS> m_Sub(const LHS &L, const RHS &R) {
797 return BinaryOpc_match<LHS, RHS>(ISD::SUB, L, R);
798}
799
800template <typename LHS, typename RHS>
801inline BinaryOpc_match<LHS, RHS, true> m_Mul(const LHS &L, const RHS &R) {
802 return BinaryOpc_match<LHS, RHS, true>(ISD::MUL, L, R);
803}
804
805template <typename LHS, typename RHS>
806inline BinaryOpc_match<LHS, RHS, true> m_And(const LHS &L, const RHS &R) {
807 return BinaryOpc_match<LHS, RHS, true>(ISD::AND, L, R);
808}
809
810template <typename LHS, typename RHS>
811inline BinaryOpc_match<LHS, RHS, true> m_Or(const LHS &L, const RHS &R) {
812 return BinaryOpc_match<LHS, RHS, true>(ISD::OR, L, R);
813}
814
815template <typename LHS, typename RHS>
816inline BinaryOpc_match<LHS, RHS, true> m_DisjointOr(const LHS &L,
817 const RHS &R) {
818 return BinaryOpc_match<LHS, RHS, true>(ISD::OR, L, R, SDNodeFlags::Disjoint);
819}
820
821template <typename LHS, typename RHS>
822inline auto m_AddLike(const LHS &L, const RHS &R) {
823 return m_AnyOf(m_Add(L, R), m_DisjointOr(L, R));
824}
825
826template <typename LHS, typename RHS>
827inline BinaryOpc_match<LHS, RHS, true> m_Xor(const LHS &L, const RHS &R) {
828 return BinaryOpc_match<LHS, RHS, true>(ISD::XOR, L, R);
829}
830
831template <typename LHS, typename RHS>
832inline auto m_BitwiseLogic(const LHS &L, const RHS &R) {
833 return m_AnyOf(m_And(L, R), m_Or(L, R), m_Xor(L, R));
834}
835
836template <unsigned Opc, typename Pred, typename LHS, typename RHS>
837inline auto m_MaxMinLike(const LHS &L, const RHS &R) {
838 return m_AnyOf(BinaryOpc_match<LHS, RHS, true>(Opc, L, R),
839 MaxMin_match<LHS, RHS, Pred, true>(L, R));
840}
841
842template <typename LHS, typename RHS>
843inline BinaryOpc_match<LHS, RHS, true> m_SMin(const LHS &L, const RHS &R) {
844 return BinaryOpc_match<LHS, RHS, true>(ISD::SMIN, L, R);
845}
846
847template <typename LHS, typename RHS>
848inline auto m_SMinLike(const LHS &L, const RHS &R) {
849 return m_AnyOf(
850 m_MaxMinLike<ISD::SMIN, smin_pred_ty>(L, R),
851 m_MaxMinLike<ISD::UMIN, umin_pred_ty>(m_NonNegative(L), m_NonNegative(R)),
852 m_MaxMinLike<ISD::UMIN, umin_pred_ty>(m_Negative(L), m_Negative(R)));
853}
854
855template <typename LHS, typename RHS>
856inline BinaryOpc_match<LHS, RHS, true> m_SMax(const LHS &L, const RHS &R) {
857 return BinaryOpc_match<LHS, RHS, true>(ISD::SMAX, L, R);
858}
859
860template <typename LHS, typename RHS>
861inline auto m_SMaxLike(const LHS &L, const RHS &R) {
862 return m_AnyOf(
863 m_MaxMinLike<ISD::SMAX, smax_pred_ty>(L, R),
864 m_MaxMinLike<ISD::UMAX, umax_pred_ty>(m_NonNegative(L), m_NonNegative(R)),
865 m_MaxMinLike<ISD::UMAX, umax_pred_ty>(m_Negative(L), m_Negative(R)));
866}
867
868template <typename LHS, typename RHS>
869inline BinaryOpc_match<LHS, RHS, true> m_UMin(const LHS &L, const RHS &R) {
870 return BinaryOpc_match<LHS, RHS, true>(ISD::UMIN, L, R);
871}
872
873template <typename LHS, typename RHS>
874inline auto m_UMinLike(const LHS &L, const RHS &R) {
875 return m_AnyOf(
876 m_MaxMinLike<ISD::UMIN, umin_pred_ty>(L, R),
877 m_MaxMinLike<ISD::SMIN, smin_pred_ty>(m_NonNegative(L), m_NonNegative(R)),
878 m_MaxMinLike<ISD::SMIN, smin_pred_ty>(m_Negative(L), m_Negative(R)));
879}
880
881template <typename LHS, typename RHS>
882inline BinaryOpc_match<LHS, RHS, true> m_UMax(const LHS &L, const RHS &R) {
883 return BinaryOpc_match<LHS, RHS, true>(ISD::UMAX, L, R);
884}
885
886template <typename LHS, typename RHS>
887inline auto m_UMaxLike(const LHS &L, const RHS &R) {
888 return m_AnyOf(
889 m_MaxMinLike<ISD::UMAX, umax_pred_ty>(L, R),
890 m_MaxMinLike<ISD::SMAX, smax_pred_ty>(m_NonNegative(L), m_NonNegative(R)),
891 m_MaxMinLike<ISD::SMAX, smax_pred_ty>(m_Negative(L), m_Negative(R)));
892}
893
894template <typename LHS, typename RHS>
895inline BinaryOpc_match<LHS, RHS> m_UDiv(const LHS &L, const RHS &R) {
896 return BinaryOpc_match<LHS, RHS>(ISD::UDIV, L, R);
897}
898template <typename LHS, typename RHS>
899inline BinaryOpc_match<LHS, RHS> m_SDiv(const LHS &L, const RHS &R) {
900 return BinaryOpc_match<LHS, RHS>(ISD::SDIV, L, R);
901}
902
903template <typename LHS, typename RHS>
904inline BinaryOpc_match<LHS, RHS> m_URem(const LHS &L, const RHS &R) {
905 return BinaryOpc_match<LHS, RHS>(ISD::UREM, L, R);
906}
907template <typename LHS, typename RHS>
908inline BinaryOpc_match<LHS, RHS> m_SRem(const LHS &L, const RHS &R) {
909 return BinaryOpc_match<LHS, RHS>(ISD::SREM, L, R);
910}
911
912template <typename LHS, typename RHS>
913inline BinaryOpc_match<LHS, RHS> m_Shl(const LHS &L, const RHS &R) {
914 return BinaryOpc_match<LHS, RHS>(ISD::SHL, L, R);
915}
916
917template <typename LHS, typename RHS>
918inline BinaryOpc_match<LHS, RHS> m_Sra(const LHS &L, const RHS &R) {
919 return BinaryOpc_match<LHS, RHS>(ISD::SRA, L, R);
920}
921template <typename LHS, typename RHS>
922inline BinaryOpc_match<LHS, RHS> m_Srl(const LHS &L, const RHS &R) {
923 return BinaryOpc_match<LHS, RHS>(ISD::SRL, L, R);
924}
925template <typename LHS, typename RHS>
926inline auto m_ExactSr(const LHS &L, const RHS &R) {
927 return m_AnyOf(BinaryOpc_match<LHS, RHS>(ISD::SRA, L, R, SDNodeFlags::Exact),
928 BinaryOpc_match<LHS, RHS>(ISD::SRL, L, R, SDNodeFlags::Exact));
929}
930
931template <typename LHS, typename RHS>
932inline BinaryOpc_match<LHS, RHS> m_Rotl(const LHS &L, const RHS &R) {
933 return BinaryOpc_match<LHS, RHS>(ISD::ROTL, L, R);
934}
935
936template <typename LHS, typename RHS>
937inline BinaryOpc_match<LHS, RHS> m_Rotr(const LHS &L, const RHS &R) {
938 return BinaryOpc_match<LHS, RHS>(ISD::ROTR, L, R);
939}
940
941template <typename LHS, typename RHS>
942inline BinaryOpc_match<LHS, RHS, true> m_Clmul(const LHS &L, const RHS &R) {
943 return BinaryOpc_match<LHS, RHS, true>(ISD::CLMUL, L, R);
944}
945
946template <typename LHS, typename RHS>
947inline BinaryOpc_match<LHS, RHS, true> m_FAdd(const LHS &L, const RHS &R) {
948 return BinaryOpc_match<LHS, RHS, true>(ISD::FADD, L, R);
949}
950
951template <typename LHS, typename RHS>
952inline BinaryOpc_match<LHS, RHS> m_FSub(const LHS &L, const RHS &R) {
953 return BinaryOpc_match<LHS, RHS>(ISD::FSUB, L, R);
954}
955
956template <typename LHS, typename RHS>
957inline BinaryOpc_match<LHS, RHS, true> m_FMul(const LHS &L, const RHS &R) {
958 return BinaryOpc_match<LHS, RHS, true>(ISD::FMUL, L, R);
959}
960
961template <typename LHS, typename RHS>
962inline BinaryOpc_match<LHS, RHS> m_FDiv(const LHS &L, const RHS &R) {
963 return BinaryOpc_match<LHS, RHS>(ISD::FDIV, L, R);
964}
965
966template <typename LHS, typename RHS>
967inline BinaryOpc_match<LHS, RHS> m_FRem(const LHS &L, const RHS &R) {
968 return BinaryOpc_match<LHS, RHS>(ISD::FREM, L, R);
969}
970
971template <typename V1_t, typename V2_t>
972inline BinaryOpc_match<V1_t, V2_t> m_Shuffle(const V1_t &v1, const V2_t &v2) {
973 return BinaryOpc_match<V1_t, V2_t>(ISD::VECTOR_SHUFFLE, v1, v2);
974}
975
976template <typename V1_t, typename V2_t, typename Mask_t>
977inline SDShuffle_match<V1_t, V2_t, Mask_t>
978m_Shuffle(const V1_t &v1, const V2_t &v2, const Mask_t &mask) {
979 return SDShuffle_match<V1_t, V2_t, Mask_t>(v1, v2, mask);
980}
981
982template <typename LHS, typename RHS>
983inline BinaryOpc_match<LHS, RHS> m_ExtractElt(const LHS &Vec, const RHS &Idx) {
984 return BinaryOpc_match<LHS, RHS>(ISD::EXTRACT_VECTOR_ELT, Vec, Idx);
985}
986
987template <typename LHS, typename RHS>
988inline BinaryOpc_match<LHS, RHS> m_ExtractSubvector(const LHS &Vec,
989 const RHS &Idx) {
990 return BinaryOpc_match<LHS, RHS>(ISD::EXTRACT_SUBVECTOR, Vec, Idx);
991}
992
993// === Unary operations ===
994template <typename Opnd_P, bool ExcludeChain = false> struct UnaryOpc_match {
995 unsigned Opcode;
996 Opnd_P Opnd;
997 SDNodeFlags Flags;
998 UnaryOpc_match(unsigned Opc, const Opnd_P &Op,
999 SDNodeFlags Flgs = SDNodeFlags())
1000 : Opcode(Opc), Opnd(Op), Flags(Flgs) {}
1001
1002 template <typename MatchContext>
1003 bool match(const MatchContext &Ctx, SDValue N) {
1004 if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
1005 EffectiveOperands<ExcludeChain> EO(N, Ctx);
1006 assert(EO.Size == 1);
1007 if (!Opnd.match(Ctx, N->getOperand(Num: EO.FirstIndex)))
1008 return false;
1009
1010 return (Flags & N->getFlags()) == Flags;
1011 }
1012
1013 return false;
1014 }
1015};
1016
1017template <typename Opnd>
1018inline UnaryOpc_match<Opnd> m_UnaryOp(unsigned Opc, const Opnd &Op) {
1019 return UnaryOpc_match<Opnd>(Opc, Op);
1020}
1021template <typename Opnd>
1022inline UnaryOpc_match<Opnd, true> m_ChainedUnaryOp(unsigned Opc,
1023 const Opnd &Op) {
1024 return UnaryOpc_match<Opnd, true>(Opc, Op);
1025}
1026
1027template <typename Opnd> inline UnaryOpc_match<Opnd> m_BitCast(const Opnd &Op) {
1028 return UnaryOpc_match<Opnd>(ISD::BITCAST, Op);
1029}
1030
1031template <typename Opnd>
1032inline UnaryOpc_match<Opnd> m_BSwap(const Opnd &Op) {
1033 return UnaryOpc_match<Opnd>(ISD::BSWAP, Op);
1034}
1035
1036template <typename Opnd>
1037inline UnaryOpc_match<Opnd> m_BitReverse(const Opnd &Op) {
1038 return UnaryOpc_match<Opnd>(ISD::BITREVERSE, Op);
1039}
1040
1041template <typename Opnd> inline UnaryOpc_match<Opnd> m_ZExt(const Opnd &Op) {
1042 return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op);
1043}
1044
1045template <typename Opnd>
1046inline UnaryOpc_match<Opnd> m_NNegZExt(const Opnd &Op) {
1047 return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op, SDNodeFlags::NonNeg);
1048}
1049
1050template <typename Opnd> inline auto m_SExt(const Opnd &Op) {
1051 return UnaryOpc_match<Opnd>(ISD::SIGN_EXTEND, Op);
1052}
1053
1054template <typename Opnd> inline UnaryOpc_match<Opnd> m_AnyExt(const Opnd &Op) {
1055 return UnaryOpc_match<Opnd>(ISD::ANY_EXTEND, Op);
1056}
1057
1058template <typename Opnd> inline UnaryOpc_match<Opnd> m_Trunc(const Opnd &Op) {
1059 return UnaryOpc_match<Opnd>(ISD::TRUNCATE, Op);
1060}
1061
1062template <typename Opnd> inline UnaryOpc_match<Opnd> m_Abs(const Opnd &Op) {
1063 return UnaryOpc_match<Opnd>(ISD::ABS, Op);
1064}
1065
1066template <typename Opnd> inline UnaryOpc_match<Opnd> m_FAbs(const Opnd &Op) {
1067 return UnaryOpc_match<Opnd>(ISD::FABS, Op);
1068}
1069
1070/// Match a zext or identity
1071/// Allows to peek through optional extensions
1072template <typename Opnd> inline auto m_ZExtOrSelf(const Opnd &Op) {
1073 return m_AnyOf(m_ZExt(Op), Op);
1074}
1075
1076/// Match a sext or identity
1077/// Allows to peek through optional extensions
1078template <typename Opnd> inline auto m_SExtOrSelf(const Opnd &Op) {
1079 return m_AnyOf(m_SExt(Op), Op);
1080}
1081
1082template <typename Opnd> inline auto m_SExtLike(const Opnd &Op) {
1083 return m_AnyOf(m_SExt(Op), m_NNegZExt(Op));
1084}
1085
1086/// Match a aext or identity
1087/// Allows to peek through optional extensions
1088template <typename Opnd>
1089inline Or<UnaryOpc_match<Opnd>, Opnd> m_AExtOrSelf(const Opnd &Op) {
1090 return Or<UnaryOpc_match<Opnd>, Opnd>(m_AnyExt(Op), Op);
1091}
1092
1093/// Match a trunc or identity
1094/// Allows to peek through optional truncations
1095template <typename Opnd>
1096inline Or<UnaryOpc_match<Opnd>, Opnd> m_TruncOrSelf(const Opnd &Op) {
1097 return Or<UnaryOpc_match<Opnd>, Opnd>(m_Trunc(Op), Op);
1098}
1099
1100template <typename Opnd> inline UnaryOpc_match<Opnd> m_VScale(const Opnd &Op) {
1101 return UnaryOpc_match<Opnd>(ISD::VSCALE, Op);
1102}
1103
1104template <typename Opnd> inline UnaryOpc_match<Opnd> m_FPToUI(const Opnd &Op) {
1105 return UnaryOpc_match<Opnd>(ISD::FP_TO_UINT, Op);
1106}
1107
1108template <typename Opnd> inline UnaryOpc_match<Opnd> m_FPToSI(const Opnd &Op) {
1109 return UnaryOpc_match<Opnd>(ISD::FP_TO_SINT, Op);
1110}
1111
1112template <typename Opnd> inline UnaryOpc_match<Opnd> m_Ctpop(const Opnd &Op) {
1113 return UnaryOpc_match<Opnd>(ISD::CTPOP, Op);
1114}
1115
1116template <typename Opnd> inline UnaryOpc_match<Opnd> m_Ctlz(const Opnd &Op) {
1117 return UnaryOpc_match<Opnd>(ISD::CTLZ, Op);
1118}
1119
1120template <typename Opnd> inline UnaryOpc_match<Opnd> m_Cttz(const Opnd &Op) {
1121 return UnaryOpc_match<Opnd>(ISD::CTTZ, Op);
1122}
1123
1124template <typename Opnd> inline UnaryOpc_match<Opnd> m_FNeg(const Opnd &Op) {
1125 return UnaryOpc_match<Opnd>(ISD::FNEG, Op);
1126}
1127
1128// === Constants ===
1129struct ConstantInt_match {
1130 APInt *BindVal;
1131
1132 explicit ConstantInt_match(APInt *V) : BindVal(V) {}
1133
1134 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
1135 // The logics here are similar to that in
1136 // SelectionDAG::isConstantIntBuildVectorOrConstantInt, but the latter also
1137 // treats GlobalAddressSDNode as a constant, which is difficult to turn into
1138 // APInt.
1139 if (auto *C = dyn_cast_or_null<ConstantSDNode>(Val: N.getNode())) {
1140 if (BindVal)
1141 *BindVal = C->getAPIntValue();
1142 return true;
1143 }
1144
1145 APInt Discard;
1146 return ISD::isConstantSplatVector(N: N.getNode(),
1147 SplatValue&: BindVal ? *BindVal : Discard);
1148 }
1149};
1150/// Match any integer constants or splat of an integer constant.
1151inline ConstantInt_match m_ConstInt() { return ConstantInt_match(nullptr); }
1152/// Match any integer constants or splat of an integer constant; return the
1153/// specific constant or constant splat value.
1154inline ConstantInt_match m_ConstInt(APInt &V) { return ConstantInt_match(&V); }
1155
1156struct SpecificInt_match {
1157 APInt IntVal;
1158
1159 explicit SpecificInt_match(APInt APV) : IntVal(std::move(APV)) {}
1160
1161 template <typename MatchContext>
1162 bool match(const MatchContext &Ctx, SDValue N) {
1163 APInt ConstInt;
1164 if (sd_context_match(N, Ctx, m_ConstInt(V&: ConstInt)))
1165 return APInt::isSameValue(I1: IntVal, I2: ConstInt);
1166 return false;
1167 }
1168};
1169
1170/// Match a specific integer constant or constant splat value.
1171inline SpecificInt_match m_SpecificInt(APInt V) {
1172 return SpecificInt_match(std::move(V));
1173}
1174inline SpecificInt_match m_SpecificInt(uint64_t V) {
1175 return SpecificInt_match(APInt(64, V));
1176}
1177
1178struct SpecificFP_match {
1179 APFloat Val;
1180
1181 explicit SpecificFP_match(APFloat V) : Val(V) {}
1182
1183 template <typename MatchContext>
1184 bool match(const MatchContext &Ctx, SDValue V) {
1185 if (const auto *CFP = dyn_cast<ConstantFPSDNode>(Val: V.getNode()))
1186 return CFP->isExactlyValue(V: Val);
1187 if (ConstantFPSDNode *C = isConstOrConstSplatFP(N: V, /*AllowUndefs=*/AllowUndefs: true))
1188 return C->getValueAPF().compare(RHS: Val) == APFloat::cmpEqual;
1189 return false;
1190 }
1191};
1192
1193/// Match a specific float constant.
1194inline SpecificFP_match m_SpecificFP(APFloat V) { return SpecificFP_match(V); }
1195
1196inline SpecificFP_match m_SpecificFP(double V) {
1197 return SpecificFP_match(APFloat(V));
1198}
1199
1200struct Negative_match {
1201 template <typename MatchContext>
1202 bool match(const MatchContext &Ctx, SDValue N) {
1203 const SelectionDAG *DAG = Ctx.getDAG();
1204 return DAG && DAG->computeKnownBits(Op: N).isNegative();
1205 }
1206};
1207
1208struct NonNegative_match {
1209 template <typename MatchContext>
1210 bool match(const MatchContext &Ctx, SDValue N) {
1211 const SelectionDAG *DAG = Ctx.getDAG();
1212 return DAG && DAG->computeKnownBits(Op: N).isNonNegative();
1213 }
1214};
1215
1216struct StrictlyPositive_match {
1217 template <typename MatchContext>
1218 bool match(const MatchContext &Ctx, SDValue N) {
1219 const SelectionDAG *DAG = Ctx.getDAG();
1220 return DAG && DAG->computeKnownBits(Op: N).isStrictlyPositive();
1221 }
1222};
1223
1224struct NonPositive_match {
1225 template <typename MatchContext>
1226 bool match(const MatchContext &Ctx, SDValue N) {
1227 const SelectionDAG *DAG = Ctx.getDAG();
1228 return DAG && DAG->computeKnownBits(Op: N).isNonPositive();
1229 }
1230};
1231
1232struct NonZero_match {
1233 template <typename MatchContext>
1234 bool match(const MatchContext &Ctx, SDValue N) {
1235 const SelectionDAG *DAG = Ctx.getDAG();
1236 return DAG && DAG->computeKnownBits(Op: N).isNonZero();
1237 }
1238};
1239
1240struct Zero_match {
1241 bool AllowUndefs;
1242
1243 explicit Zero_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
1244
1245 template <typename MatchContext>
1246 bool match(const MatchContext &, SDValue N) const {
1247 return isZeroOrZeroSplat(N, AllowUndefs);
1248 }
1249};
1250
1251struct Ones_match {
1252 bool AllowUndefs;
1253
1254 Ones_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
1255
1256 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
1257 return isOnesOrOnesSplat(N, AllowUndefs);
1258 }
1259};
1260
1261struct AllOnes_match {
1262 bool AllowUndefs;
1263
1264 AllOnes_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
1265
1266 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
1267 return isAllOnesOrAllOnesSplat(V: N, AllowUndefs);
1268 }
1269};
1270
1271inline Negative_match m_Negative() { return Negative_match(); }
1272template <typename Pattern> inline auto m_Negative(const Pattern &P) {
1273 return m_AllOf(m_Negative(), P);
1274}
1275inline NonNegative_match m_NonNegative() { return NonNegative_match(); }
1276template <typename Pattern> inline auto m_NonNegative(const Pattern &P) {
1277 return m_AllOf(m_NonNegative(), P);
1278}
1279inline StrictlyPositive_match m_StrictlyPositive() {
1280 return StrictlyPositive_match();
1281}
1282template <typename Pattern> inline auto m_StrictlyPositive(const Pattern &P) {
1283 return m_AllOf(m_StrictlyPositive(), P);
1284}
1285inline NonPositive_match m_NonPositive() { return NonPositive_match(); }
1286template <typename Pattern> inline auto m_NonPositive(const Pattern &P) {
1287 return m_AllOf(m_NonPositive(), P);
1288}
1289inline NonZero_match m_NonZero() { return NonZero_match(); }
1290template <typename Pattern> inline auto m_NonZero(const Pattern &P) {
1291 return m_AllOf(m_NonZero(), P);
1292}
1293inline Ones_match m_One(bool AllowUndefs = false) {
1294 return Ones_match(AllowUndefs);
1295}
1296inline Zero_match m_Zero(bool AllowUndefs = false) {
1297 return Zero_match(AllowUndefs);
1298}
1299inline AllOnes_match m_AllOnes(bool AllowUndefs = false) {
1300 return AllOnes_match(AllowUndefs);
1301}
1302
1303/// Match true boolean value based on the information provided by
1304/// TargetLowering.
1305inline auto m_True() {
1306 return TLI_pred_match{
1307 [](const TargetLowering &TLI, SDValue N) {
1308 APInt ConstVal;
1309 if (sd_match(N, P: m_ConstInt(V&: ConstVal)))
1310 switch (TLI.getBooleanContents(Type: N.getValueType())) {
1311 case TargetLowering::ZeroOrOneBooleanContent:
1312 return ConstVal.isOne();
1313 case TargetLowering::ZeroOrNegativeOneBooleanContent:
1314 return ConstVal.isAllOnes();
1315 case TargetLowering::UndefinedBooleanContent:
1316 return (ConstVal & 0x01) == 1;
1317 }
1318
1319 return false;
1320 },
1321 m_Value()};
1322}
1323/// Match false boolean value based on the information provided by
1324/// TargetLowering.
1325inline auto m_False() {
1326 return TLI_pred_match{
1327 [](const TargetLowering &TLI, SDValue N) {
1328 APInt ConstVal;
1329 if (sd_match(N, P: m_ConstInt(V&: ConstVal)))
1330 switch (TLI.getBooleanContents(Type: N.getValueType())) {
1331 case TargetLowering::ZeroOrOneBooleanContent:
1332 case TargetLowering::ZeroOrNegativeOneBooleanContent:
1333 return ConstVal.isZero();
1334 case TargetLowering::UndefinedBooleanContent:
1335 return (ConstVal & 0x01) == 0;
1336 }
1337
1338 return false;
1339 },
1340 m_Value()};
1341}
1342
1343struct CondCode_match {
1344 std::optional<ISD::CondCode> CCToMatch;
1345 ISD::CondCode *BindCC = nullptr;
1346
1347 explicit CondCode_match(ISD::CondCode CC) : CCToMatch(CC) {}
1348
1349 explicit CondCode_match(ISD::CondCode *CC) : BindCC(CC) {}
1350
1351 template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
1352 if (auto *CC = dyn_cast<CondCodeSDNode>(Val: N.getNode())) {
1353 if (CCToMatch && *CCToMatch != CC->get())
1354 return false;
1355
1356 if (BindCC)
1357 *BindCC = CC->get();
1358 return true;
1359 }
1360
1361 return false;
1362 }
1363};
1364
1365/// Match any conditional code SDNode.
1366inline CondCode_match m_CondCode() { return CondCode_match(nullptr); }
1367/// Match any conditional code SDNode and return its ISD::CondCode value.
1368inline CondCode_match m_CondCode(ISD::CondCode &CC) {
1369 return CondCode_match(&CC);
1370}
1371/// Match a conditional code SDNode with a specific ISD::CondCode.
1372inline CondCode_match m_SpecificCondCode(ISD::CondCode CC) {
1373 return CondCode_match(CC);
1374}
1375
1376/// Match a negate as a sub(0, v)
1377template <typename ValTy>
1378inline BinaryOpc_match<Zero_match, ValTy, false> m_Neg(const ValTy &V) {
1379 return m_Sub(m_Zero(), V);
1380}
1381
1382/// Match a Not as a xor(v, -1) or xor(-1, v)
1383template <typename ValTy>
1384inline BinaryOpc_match<ValTy, AllOnes_match, true> m_Not(const ValTy &V) {
1385 return m_Xor(V, m_AllOnes());
1386}
1387
1388template <unsigned IntrinsicId, typename... OpndPreds>
1389inline auto m_IntrinsicWOChain(const OpndPreds &...Opnds) {
1390 return m_Node(ISD::INTRINSIC_WO_CHAIN, m_SpecificInt(V: IntrinsicId), Opnds...);
1391}
1392
1393struct SpecificNeg_match {
1394 SDValue V;
1395
1396 explicit SpecificNeg_match(SDValue V) : V(V) {}
1397
1398 template <typename MatchContext>
1399 bool match(const MatchContext &Ctx, SDValue N) {
1400 if (sd_context_match(N, Ctx, m_Neg(V: m_Specific(N: V))))
1401 return true;
1402
1403 return ISD::matchBinaryPredicate(
1404 LHS: V, RHS: N, Match: [](ConstantSDNode *LHS, ConstantSDNode *RHS) {
1405 return LHS->getAPIntValue() == -RHS->getAPIntValue();
1406 });
1407 }
1408};
1409
1410/// Match a negation of a specific value V, either as sub(0, V) or as
1411/// constant(s) that are the negation of V's constant(s).
1412inline SpecificNeg_match m_SpecificNeg(SDValue V) {
1413 return SpecificNeg_match(V);
1414}
1415
1416template <typename... PatternTs> struct ReassociatableOpc_match {
1417 unsigned Opcode;
1418 std::tuple<PatternTs...> Patterns;
1419 constexpr static size_t NumPatterns =
1420 std::tuple_size_v<std::tuple<PatternTs...>>;
1421
1422 SDNodeFlags Flags;
1423
1424 ReassociatableOpc_match(unsigned Opcode, const PatternTs &...Patterns)
1425 : Opcode(Opcode), Patterns(Patterns...) {}
1426
1427 ReassociatableOpc_match(unsigned Opcode, SDNodeFlags Flags,
1428 const PatternTs &...Patterns)
1429 : Opcode(Opcode), Patterns(Patterns...), Flags(Flags) {}
1430
1431 template <typename MatchContext>
1432 bool match(const MatchContext &Ctx, SDValue N) {
1433 std::array<SDValue, NumPatterns> Leaves;
1434 size_t LeavesIdx = 0;
1435 if (!(collectLeaves(V: N, Leaves, LeafIdx&: LeavesIdx) && (LeavesIdx == NumPatterns)))
1436 return false;
1437
1438 Bitset<NumPatterns> Used;
1439 return std::apply(
1440 [&](auto &...P) -> bool {
1441 return reassociatableMatchHelper(Ctx, Leaves, Used, P...);
1442 },
1443 Patterns);
1444 }
1445
1446 bool collectLeaves(SDValue V, std::array<SDValue, NumPatterns> &Leaves,
1447 std::size_t &LeafIdx) {
1448 if (V->getOpcode() == Opcode && (Flags & V->getFlags()) == Flags) {
1449 for (size_t I = 0, N = V->getNumOperands(); I < N; I++)
1450 if ((LeafIdx == NumPatterns) ||
1451 !collectLeaves(V: V->getOperand(Num: I), Leaves, LeafIdx))
1452 return false;
1453 } else {
1454 Leaves[LeafIdx] = V;
1455 LeafIdx++;
1456 }
1457 return true;
1458 }
1459
1460 // Searchs for a matching leaf for every sub-pattern.
1461 template <typename MatchContext, typename PatternHd, typename... PatternTl>
1462 [[nodiscard]] inline bool
1463 reassociatableMatchHelper(const MatchContext &Ctx, ArrayRef<SDValue> Leaves,
1464 Bitset<NumPatterns> &Used, PatternHd &HeadPattern,
1465 PatternTl &...TailPatterns) {
1466 for (size_t Match = 0, N = Used.size(); Match < N; Match++) {
1467 if (Used[Match] || !(sd_context_match(Leaves[Match], Ctx, HeadPattern)))
1468 continue;
1469 Used.set(Match);
1470 if (reassociatableMatchHelper(Ctx, Leaves, Used, TailPatterns...))
1471 return true;
1472 Used.reset(Match);
1473 }
1474 return false;
1475 }
1476
1477 template <typename MatchContext>
1478 [[nodiscard]] inline bool
1479 reassociatableMatchHelper(const MatchContext &Ctx, ArrayRef<SDValue> Leaves,
1480 Bitset<NumPatterns> &Used) {
1481 return true;
1482 }
1483};
1484
1485template <typename... PatternTs>
1486inline ReassociatableOpc_match<PatternTs...>
1487m_ReassociatableAdd(const PatternTs &...Patterns) {
1488 return ReassociatableOpc_match<PatternTs...>(ISD::ADD, Patterns...);
1489}
1490
1491template <typename... PatternTs>
1492inline ReassociatableOpc_match<PatternTs...>
1493m_ReassociatableOr(const PatternTs &...Patterns) {
1494 return ReassociatableOpc_match<PatternTs...>(ISD::OR, Patterns...);
1495}
1496
1497template <typename... PatternTs>
1498inline ReassociatableOpc_match<PatternTs...>
1499m_ReassociatableAnd(const PatternTs &...Patterns) {
1500 return ReassociatableOpc_match<PatternTs...>(ISD::AND, Patterns...);
1501}
1502
1503template <typename... PatternTs>
1504inline ReassociatableOpc_match<PatternTs...>
1505m_ReassociatableMul(const PatternTs &...Patterns) {
1506 return ReassociatableOpc_match<PatternTs...>(ISD::MUL, Patterns...);
1507}
1508
1509template <typename... PatternTs>
1510inline ReassociatableOpc_match<PatternTs...>
1511m_ReassociatableNSWAdd(const PatternTs &...Patterns) {
1512 return ReassociatableOpc_match<PatternTs...>(
1513 ISD::ADD, SDNodeFlags::NoSignedWrap, Patterns...);
1514}
1515
1516template <typename... PatternTs>
1517inline ReassociatableOpc_match<PatternTs...>
1518m_ReassociatableNUWAdd(const PatternTs &...Patterns) {
1519 return ReassociatableOpc_match<PatternTs...>(
1520 ISD::ADD, SDNodeFlags::NoUnsignedWrap, Patterns...);
1521}
1522
1523} // namespace SDPatternMatch
1524} // namespace llvm
1525#endif
1526