1//===- VPlanPatternMatch.h - Match on VPValues and recipes ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file provides a simple and efficient mechanism for performing general
10// tree-based pattern matches on the VPlan values and recipes, based on
11// LLVM's IR pattern matchers.
12//
13//===----------------------------------------------------------------------===//
14
15#ifndef LLVM_TRANSFORM_VECTORIZE_VPLANPATTERNMATCH_H
16#define LLVM_TRANSFORM_VECTORIZE_VPLANPATTERNMATCH_H
17
18#include "VPlan.h"
19
20namespace llvm::VPlanPatternMatch {
21
22template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) {
23 return P.match(V);
24}
25
26/// A match functor that can be used as a UnaryPredicate in functional
27/// algorithms like all_of.
28template <typename Val, typename Pattern> auto match_fn(const Pattern &P) {
29 return bind_back<match<Val, Pattern>>(P);
30}
31
32template <typename Pattern> bool match(VPUser *U, const Pattern &P) {
33 auto *R = dyn_cast<VPRecipeBase>(Val: U);
34 return R && match(R, P);
35}
36
37/// Match functor for VPUser.
38template <typename Pattern> auto match_fn(const Pattern &P) {
39 return bind_back<match<Pattern>>(P);
40}
41
42template <typename Pattern> bool match(VPSingleDefRecipe *R, const Pattern &P) {
43 return P.match(static_cast<const VPRecipeBase *>(R));
44}
45
46template <typename... Classes> struct class_match {
47 template <typename ITy> bool match(ITy *V) const {
48 return isa<Classes...>(V);
49 }
50};
51
52/// Match an arbitrary VPValue and ignore it.
53inline class_match<VPValue> m_VPValue() { return class_match<VPValue>(); }
54
55template <typename Class> struct bind_ty {
56 Class *&VR;
57
58 bind_ty(Class *&V) : VR(V) {}
59
60 template <typename ITy> bool match(ITy *V) const {
61 if (auto *CV = dyn_cast<Class>(V)) {
62 VR = CV;
63 return true;
64 }
65 return false;
66 }
67};
68
69/// Match a specified VPValue.
70struct specificval_ty {
71 const VPValue *Val;
72
73 specificval_ty(const VPValue *V) : Val(V) {}
74
75 bool match(VPValue *VPV) const { return VPV == Val; }
76};
77
78inline specificval_ty m_Specific(const VPValue *VPV) { return VPV; }
79
80/// Stores a reference to the VPValue *, not the VPValue * itself,
81/// thus can be used in commutative matchers.
82struct deferredval_ty {
83 VPValue *const &Val;
84
85 deferredval_ty(VPValue *const &V) : Val(V) {}
86
87 bool match(VPValue *const V) const { return V == Val; }
88};
89
90/// Like m_Specific(), but works if the specific value to match is determined
91/// as part of the same match() expression. For example:
92/// m_Mul(m_VPValue(X), m_Specific(X)) is incorrect, because m_Specific() will
93/// bind X before the pattern match starts.
94/// m_Mul(m_VPValue(X), m_Deferred(X)) is correct, and will check against
95/// whichever value m_VPValue(X) populated.
96inline deferredval_ty m_Deferred(VPValue *const &V) { return V; }
97
98/// Match an integer constant or vector of constants if Pred::isValue returns
99/// true for the APInt. \p BitWidth optionally specifies the bitwidth the
100/// matched constant must have. If it is 0, the matched constant can have any
101/// bitwidth.
102template <typename Pred, unsigned BitWidth = 0> struct int_pred_ty {
103 Pred P;
104
105 int_pred_ty(Pred P) : P(std::move(P)) {}
106 int_pred_ty() : P() {}
107
108 bool match(VPValue *VPV) const {
109 auto *VPI = dyn_cast<VPInstruction>(Val: VPV);
110 if (VPI && VPI->getOpcode() == VPInstruction::Broadcast)
111 VPV = VPI->getOperand(N: 0);
112 auto *CI = dyn_cast<VPConstantInt>(Val: VPV);
113 if (!CI)
114 return false;
115
116 if (BitWidth != 0 && CI->getBitWidth() != BitWidth)
117 return false;
118 return P.isValue(CI->getAPInt());
119 }
120};
121
122/// Match a specified integer value or vector of all elements of that
123/// value. \p BitWidth optionally specifies the bitwidth the matched constant
124/// must have. If it is 0, the matched constant can have any bitwidth.
125struct is_specific_int {
126 APInt Val;
127
128 is_specific_int(APInt Val) : Val(std::move(Val)) {}
129
130 bool isValue(const APInt &C) const { return APInt::isSameValue(I1: Val, I2: C); }
131};
132
133template <unsigned Bitwidth = 0>
134using specific_intval = int_pred_ty<is_specific_int, Bitwidth>;
135
136inline specific_intval<0> m_SpecificInt(uint64_t V) {
137 return specific_intval<0>(is_specific_int(APInt(64, V)));
138}
139
140inline specific_intval<1> m_False() {
141 return specific_intval<1>(is_specific_int(APInt(64, 0)));
142}
143
144inline specific_intval<1> m_True() {
145 return specific_intval<1>(is_specific_int(APInt(64, 1)));
146}
147
148struct is_all_ones {
149 bool isValue(const APInt &C) const { return C.isAllOnes(); }
150};
151
152/// Match an integer or vector with all bits set.
153/// For vectors, this includes constants with undefined elements.
154inline int_pred_ty<is_all_ones> m_AllOnes() {
155 return int_pred_ty<is_all_ones>();
156}
157
158struct is_zero_int {
159 bool isValue(const APInt &C) const { return C.isZero(); }
160};
161
162struct is_one {
163 bool isValue(const APInt &C) const { return C.isOne(); }
164};
165
166/// Match an integer 0 or a vector with all elements equal to 0.
167/// For vectors, this includes constants with undefined elements.
168inline int_pred_ty<is_zero_int> m_ZeroInt() {
169 return int_pred_ty<is_zero_int>();
170}
171
172/// Match an integer 1 or a vector with all elements equal to 1.
173/// For vectors, this includes constants with undefined elements.
174inline int_pred_ty<is_one> m_One() { return int_pred_ty<is_one>(); }
175
176struct bind_apint {
177 const APInt *&Res;
178
179 bind_apint(const APInt *&Res) : Res(Res) {}
180
181 bool match(VPValue *VPV) const {
182 auto *CI = dyn_cast<VPConstantInt>(Val: VPV);
183 if (!CI)
184 return false;
185 Res = &CI->getAPInt();
186 return true;
187 }
188};
189
190inline bind_apint m_APInt(const APInt *&C) { return C; }
191
192struct bind_const_int {
193 uint64_t &Res;
194
195 bind_const_int(uint64_t &Res) : Res(Res) {}
196
197 bool match(VPValue *VPV) const {
198 const APInt *APConst;
199 if (!bind_apint(APConst).match(VPV))
200 return false;
201 if (auto C = APConst->tryZExtValue()) {
202 Res = *C;
203 return true;
204 }
205 return false;
206 }
207};
208
209/// Match a plain integer constant no wider than 64-bits, capturing it if we
210/// match.
211inline bind_const_int m_ConstantInt(uint64_t &C) { return C; }
212
213/// Matching combinators
214template <typename LTy, typename RTy> struct match_combine_or {
215 LTy L;
216 RTy R;
217
218 match_combine_or(const LTy &Left, const RTy &Right) : L(Left), R(Right) {}
219
220 template <typename ITy> bool match(ITy *V) const {
221 return L.match(V) || R.match(V);
222 }
223};
224
225template <typename LTy, typename RTy> struct match_combine_and {
226 LTy L;
227 RTy R;
228
229 match_combine_and(const LTy &Left, const RTy &Right) : L(Left), R(Right) {}
230
231 template <typename ITy> bool match(ITy *V) const {
232 return L.match(V) && R.match(V);
233 }
234};
235
236/// Combine two pattern matchers matching L || R
237template <typename LTy, typename RTy>
238inline match_combine_or<LTy, RTy> m_CombineOr(const LTy &L, const RTy &R) {
239 return match_combine_or<LTy, RTy>(L, R);
240}
241
242/// Combine two pattern matchers matching L && R
243template <typename LTy, typename RTy>
244inline match_combine_and<LTy, RTy> m_CombineAnd(const LTy &L, const RTy &R) {
245 return match_combine_and<LTy, RTy>(L, R);
246}
247
248/// Match a VPValue, capturing it if we match.
249inline bind_ty<VPValue> m_VPValue(VPValue *&V) { return V; }
250
251/// Match a VPIRValue.
252inline bind_ty<VPIRValue> m_VPIRValue(VPIRValue *&V) { return V; }
253
254/// Match a VPInstruction, capturing if we match.
255inline bind_ty<VPInstruction> m_VPInstruction(VPInstruction *&V) { return V; }
256
257template <typename Ops_t, unsigned Opcode, bool Commutative,
258 typename... RecipeTys>
259struct Recipe_match {
260 Ops_t Ops;
261
262 template <typename... OpTy> Recipe_match(OpTy... Ops) : Ops(Ops...) {
263 static_assert(std::tuple_size<Ops_t>::value == sizeof...(Ops) &&
264 "number of operands in constructor doesn't match Ops_t");
265 static_assert((!Commutative || std::tuple_size<Ops_t>::value == 2) &&
266 "only binary ops can be commutative");
267 }
268
269 bool match(const VPValue *V) const {
270 auto *DefR = V->getDefiningRecipe();
271 return DefR && match(DefR);
272 }
273
274 bool match(const VPSingleDefRecipe *R) const {
275 return match(static_cast<const VPRecipeBase *>(R));
276 }
277
278 bool match(const VPRecipeBase *R) const {
279 if (std::tuple_size_v<Ops_t> == 0) {
280 auto *VPI = dyn_cast<VPInstruction>(Val: R);
281 return VPI && VPI->getOpcode() == Opcode;
282 }
283
284 if ((!matchRecipeAndOpcode<RecipeTys>(R) && ...))
285 return false;
286
287 if (R->getNumOperands() != std::tuple_size_v<Ops_t>) {
288 [[maybe_unused]] auto *RepR = dyn_cast<VPReplicateRecipe>(Val: R);
289 assert(((isa<VPInstruction>(R) &&
290 VPInstruction::getNumOperandsForOpcode(Opcode) == -1u) ||
291 (RepR && std::tuple_size_v<Ops_t> ==
292 RepR->getNumOperands() - RepR->isPredicated())) &&
293 "non-variadic recipe with matched opcode does not have the "
294 "expected number of operands");
295 return false;
296 }
297
298 auto IdxSeq = std::make_index_sequence<std::tuple_size<Ops_t>::value>();
299 if (all_of_tuple_elements(IdxSeq, [R](auto Op, unsigned Idx) {
300 return Op.match(R->getOperand(N: Idx));
301 }))
302 return true;
303
304 return Commutative &&
305 all_of_tuple_elements(IdxSeq, [R](auto Op, unsigned Idx) {
306 return Op.match(R->getOperand(N: R->getNumOperands() - Idx - 1));
307 });
308 }
309
310private:
311 template <typename RecipeTy>
312 static bool matchRecipeAndOpcode(const VPRecipeBase *R) {
313 auto *DefR = dyn_cast<RecipeTy>(R);
314 // Check for recipes that do not have opcodes.
315 if constexpr (std::is_same_v<RecipeTy, VPScalarIVStepsRecipe> ||
316 std::is_same_v<RecipeTy, VPCanonicalIVPHIRecipe> ||
317 std::is_same_v<RecipeTy, VPDerivedIVRecipe> ||
318 std::is_same_v<RecipeTy, VPVectorEndPointerRecipe>)
319 return DefR;
320 else
321 return DefR && DefR->getOpcode() == Opcode;
322 }
323
324 /// Helper to check if predicate \p P holds on all tuple elements in Ops using
325 /// the provided index sequence.
326 template <typename Fn, std::size_t... Is>
327 bool all_of_tuple_elements(std::index_sequence<Is...>, Fn P) const {
328 return (P(std::get<Is>(Ops), Is) && ...);
329 }
330};
331
332template <unsigned Opcode, typename... OpTys>
333using AllRecipe_match =
334 Recipe_match<std::tuple<OpTys...>, Opcode, /*Commutative*/ false,
335 VPWidenRecipe, VPReplicateRecipe, VPWidenCastRecipe,
336 VPInstruction>;
337
338template <unsigned Opcode, typename... OpTys>
339using AllRecipe_commutative_match =
340 Recipe_match<std::tuple<OpTys...>, Opcode, /*Commutative*/ true,
341 VPWidenRecipe, VPReplicateRecipe, VPInstruction>;
342
343template <unsigned Opcode, typename... OpTys>
344using VPInstruction_match = Recipe_match<std::tuple<OpTys...>, Opcode,
345 /*Commutative*/ false, VPInstruction>;
346
347template <unsigned Opcode, typename... OpTys>
348inline VPInstruction_match<Opcode, OpTys...>
349m_VPInstruction(const OpTys &...Ops) {
350 return VPInstruction_match<Opcode, OpTys...>(Ops...);
351}
352
353/// BuildVector is matches only its opcode, w/o matching its operands as the
354/// number of operands is not fixed.
355inline VPInstruction_match<VPInstruction::BuildVector> m_BuildVector() {
356 return m_VPInstruction<VPInstruction::BuildVector>();
357}
358
359template <typename Op0_t>
360inline VPInstruction_match<Instruction::Freeze, Op0_t>
361m_Freeze(const Op0_t &Op0) {
362 return m_VPInstruction<Instruction::Freeze>(Op0);
363}
364
365inline VPInstruction_match<VPInstruction::BranchOnCond> m_BranchOnCond() {
366 return m_VPInstruction<VPInstruction::BranchOnCond>();
367}
368
369template <typename Op0_t>
370inline VPInstruction_match<VPInstruction::BranchOnCond, Op0_t>
371m_BranchOnCond(const Op0_t &Op0) {
372 return m_VPInstruction<VPInstruction::BranchOnCond>(Op0);
373}
374
375inline VPInstruction_match<VPInstruction::BranchOnTwoConds>
376m_BranchOnTwoConds() {
377 return m_VPInstruction<VPInstruction::BranchOnTwoConds>();
378}
379
380template <typename Op0_t, typename Op1_t>
381inline VPInstruction_match<VPInstruction::BranchOnTwoConds, Op0_t, Op1_t>
382m_BranchOnTwoConds(const Op0_t &Op0, const Op1_t &Op1) {
383 return m_VPInstruction<VPInstruction::BranchOnTwoConds>(Op0, Op1);
384}
385
386template <typename Op0_t>
387inline VPInstruction_match<VPInstruction::Broadcast, Op0_t>
388m_Broadcast(const Op0_t &Op0) {
389 return m_VPInstruction<VPInstruction::Broadcast>(Op0);
390}
391
392template <typename Op0_t>
393inline VPInstruction_match<VPInstruction::ExplicitVectorLength, Op0_t>
394m_EVL(const Op0_t &Op0) {
395 return m_VPInstruction<VPInstruction::ExplicitVectorLength>(Op0);
396}
397
398template <typename Op0_t>
399inline VPInstruction_match<VPInstruction::ExtractLastLane, Op0_t>
400m_ExtractLastLane(const Op0_t &Op0) {
401 return m_VPInstruction<VPInstruction::ExtractLastLane>(Op0);
402}
403
404template <typename Op0_t, typename Op1_t>
405inline VPInstruction_match<Instruction::ExtractElement, Op0_t, Op1_t>
406m_ExtractElement(const Op0_t &Op0, const Op1_t &Op1) {
407 return m_VPInstruction<Instruction::ExtractElement>(Op0, Op1);
408}
409
410template <typename Op0_t, typename Op1_t>
411inline VPInstruction_match<VPInstruction::ExtractLane, Op0_t, Op1_t>
412m_ExtractLane(const Op0_t &Op0, const Op1_t &Op1) {
413 return m_VPInstruction<VPInstruction::ExtractLane>(Op0, Op1);
414}
415
416template <typename Op0_t>
417inline VPInstruction_match<VPInstruction::ExtractLastPart, Op0_t>
418m_ExtractLastPart(const Op0_t &Op0) {
419 return m_VPInstruction<VPInstruction::ExtractLastPart>(Op0);
420}
421
422template <typename Op0_t>
423inline VPInstruction_match<
424 VPInstruction::ExtractLastLane,
425 VPInstruction_match<VPInstruction::ExtractLastPart, Op0_t>>
426m_ExtractLastLaneOfLastPart(const Op0_t &Op0) {
427 return m_ExtractLastLane(m_ExtractLastPart(Op0));
428}
429
430template <typename Op0_t>
431inline VPInstruction_match<VPInstruction::ExtractPenultimateElement, Op0_t>
432m_ExtractPenultimateElement(const Op0_t &Op0) {
433 return m_VPInstruction<VPInstruction::ExtractPenultimateElement>(Op0);
434}
435
436template <typename Op0_t, typename Op1_t, typename Op2_t>
437inline VPInstruction_match<VPInstruction::ActiveLaneMask, Op0_t, Op1_t, Op2_t>
438m_ActiveLaneMask(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) {
439 return m_VPInstruction<VPInstruction::ActiveLaneMask>(Op0, Op1, Op2);
440}
441
442inline VPInstruction_match<VPInstruction::BranchOnCount> m_BranchOnCount() {
443 return m_VPInstruction<VPInstruction::BranchOnCount>();
444}
445
446template <typename Op0_t, typename Op1_t>
447inline VPInstruction_match<VPInstruction::BranchOnCount, Op0_t, Op1_t>
448m_BranchOnCount(const Op0_t &Op0, const Op1_t &Op1) {
449 return m_VPInstruction<VPInstruction::BranchOnCount>(Op0, Op1);
450}
451
452inline VPInstruction_match<VPInstruction::AnyOf> m_AnyOf() {
453 return m_VPInstruction<VPInstruction::AnyOf>();
454}
455
456template <typename Op0_t>
457inline VPInstruction_match<VPInstruction::AnyOf, Op0_t>
458m_AnyOf(const Op0_t &Op0) {
459 return m_VPInstruction<VPInstruction::AnyOf>(Op0);
460}
461
462template <typename Op0_t>
463inline VPInstruction_match<VPInstruction::FirstActiveLane, Op0_t>
464m_FirstActiveLane(const Op0_t &Op0) {
465 return m_VPInstruction<VPInstruction::FirstActiveLane>(Op0);
466}
467
468template <typename Op0_t>
469inline VPInstruction_match<VPInstruction::LastActiveLane, Op0_t>
470m_LastActiveLane(const Op0_t &Op0) {
471 return m_VPInstruction<VPInstruction::LastActiveLane>(Op0);
472}
473
474template <typename Op0_t>
475inline VPInstruction_match<VPInstruction::ComputeReductionResult, Op0_t>
476m_ComputeReductionResult(const Op0_t &Op0) {
477 return m_VPInstruction<VPInstruction::ComputeReductionResult>(Op0);
478}
479
480/// Match FindIV result pattern:
481/// select(icmp ne ComputeReductionResult(ReducedIV), Sentinel),
482/// ComputeReductionResult(ReducedIV), Start.
483template <typename Op0_t, typename Op1_t>
484inline bool matchFindIVResult(VPInstruction *VPI, Op0_t ReducedIV, Op1_t Start) {
485 return match(VPI, m_Select(m_SpecificICmp(ICmpInst::ICMP_NE,
486 m_ComputeReductionResult(ReducedIV),
487 m_VPValue()),
488 m_ComputeReductionResult(ReducedIV), Start));
489}
490
491template <typename Op0_t, typename Op1_t, typename Op2_t>
492inline VPInstruction_match<VPInstruction::ComputeAnyOfResult, Op0_t, Op1_t,
493 Op2_t>
494m_ComputeAnyOfResult(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) {
495 return m_VPInstruction<VPInstruction::ComputeAnyOfResult>(Op0, Op1, Op2);
496}
497
498template <typename Op0_t>
499inline VPInstruction_match<VPInstruction::Reverse, Op0_t>
500m_Reverse(const Op0_t &Op0) {
501 return m_VPInstruction<VPInstruction::Reverse>(Op0);
502}
503
504inline VPInstruction_match<VPInstruction::StepVector> m_StepVector() {
505 return m_VPInstruction<VPInstruction::StepVector>();
506}
507
508template <unsigned Opcode, typename Op0_t>
509inline AllRecipe_match<Opcode, Op0_t> m_Unary(const Op0_t &Op0) {
510 return AllRecipe_match<Opcode, Op0_t>(Op0);
511}
512
513template <typename Op0_t>
514inline AllRecipe_match<Instruction::Trunc, Op0_t> m_Trunc(const Op0_t &Op0) {
515 return m_Unary<Instruction::Trunc, Op0_t>(Op0);
516}
517
518template <typename Op0_t>
519inline match_combine_or<AllRecipe_match<Instruction::Trunc, Op0_t>, Op0_t>
520m_TruncOrSelf(const Op0_t &Op0) {
521 return m_CombineOr(m_Trunc(Op0), Op0);
522}
523
524template <typename Op0_t>
525inline AllRecipe_match<Instruction::ZExt, Op0_t> m_ZExt(const Op0_t &Op0) {
526 return m_Unary<Instruction::ZExt, Op0_t>(Op0);
527}
528
529template <typename Op0_t>
530inline AllRecipe_match<Instruction::SExt, Op0_t> m_SExt(const Op0_t &Op0) {
531 return m_Unary<Instruction::SExt, Op0_t>(Op0);
532}
533
534template <typename Op0_t>
535inline AllRecipe_match<Instruction::FPExt, Op0_t> m_FPExt(const Op0_t &Op0) {
536 return m_Unary<Instruction::FPExt, Op0_t>(Op0);
537}
538
539template <typename Op0_t>
540inline match_combine_or<AllRecipe_match<Instruction::ZExt, Op0_t>,
541 AllRecipe_match<Instruction::SExt, Op0_t>>
542m_ZExtOrSExt(const Op0_t &Op0) {
543 return m_CombineOr(m_ZExt(Op0), m_SExt(Op0));
544}
545
546template <typename Op0_t>
547inline match_combine_or<AllRecipe_match<Instruction::ZExt, Op0_t>, Op0_t>
548m_ZExtOrSelf(const Op0_t &Op0) {
549 return m_CombineOr(m_ZExt(Op0), Op0);
550}
551
552template <unsigned Opcode, typename Op0_t, typename Op1_t>
553inline AllRecipe_match<Opcode, Op0_t, Op1_t> m_Binary(const Op0_t &Op0,
554 const Op1_t &Op1) {
555 return AllRecipe_match<Opcode, Op0_t, Op1_t>(Op0, Op1);
556}
557
558template <unsigned Opcode, typename Op0_t, typename Op1_t>
559inline AllRecipe_commutative_match<Opcode, Op0_t, Op1_t>
560m_c_Binary(const Op0_t &Op0, const Op1_t &Op1) {
561 return AllRecipe_commutative_match<Opcode, Op0_t, Op1_t>(Op0, Op1);
562}
563
564template <typename Op0_t, typename Op1_t>
565inline AllRecipe_match<Instruction::Add, Op0_t, Op1_t> m_Add(const Op0_t &Op0,
566 const Op1_t &Op1) {
567 return m_Binary<Instruction::Add, Op0_t, Op1_t>(Op0, Op1);
568}
569
570template <typename Op0_t, typename Op1_t>
571inline AllRecipe_commutative_match<Instruction::Add, Op0_t, Op1_t>
572m_c_Add(const Op0_t &Op0, const Op1_t &Op1) {
573 return m_c_Binary<Instruction::Add, Op0_t, Op1_t>(Op0, Op1);
574}
575
576template <typename Op0_t, typename Op1_t>
577inline AllRecipe_match<Instruction::Sub, Op0_t, Op1_t> m_Sub(const Op0_t &Op0,
578 const Op1_t &Op1) {
579 return m_Binary<Instruction::Sub, Op0_t, Op1_t>(Op0, Op1);
580}
581
582template <typename Op0_t, typename Op1_t>
583inline AllRecipe_match<Instruction::Mul, Op0_t, Op1_t> m_Mul(const Op0_t &Op0,
584 const Op1_t &Op1) {
585 return m_Binary<Instruction::Mul, Op0_t, Op1_t>(Op0, Op1);
586}
587
588template <typename Op0_t, typename Op1_t>
589inline AllRecipe_commutative_match<Instruction::Mul, Op0_t, Op1_t>
590m_c_Mul(const Op0_t &Op0, const Op1_t &Op1) {
591 return m_c_Binary<Instruction::Mul, Op0_t, Op1_t>(Op0, Op1);
592}
593
594template <typename Op0_t, typename Op1_t>
595inline AllRecipe_match<Instruction::FMul, Op0_t, Op1_t>
596m_FMul(const Op0_t &Op0, const Op1_t &Op1) {
597 return m_Binary<Instruction::FMul, Op0_t, Op1_t>(Op0, Op1);
598}
599
600template <typename Op0_t, typename Op1_t>
601inline AllRecipe_match<Instruction::FAdd, Op0_t, Op1_t>
602m_FAdd(const Op0_t &Op0, const Op1_t &Op1) {
603 return m_Binary<Instruction::FAdd, Op0_t, Op1_t>(Op0, Op1);
604}
605
606template <typename Op0_t, typename Op1_t>
607inline AllRecipe_commutative_match<Instruction::FAdd, Op0_t, Op1_t>
608m_c_FAdd(const Op0_t &Op0, const Op1_t &Op1) {
609 return m_c_Binary<Instruction::FAdd, Op0_t, Op1_t>(Op0, Op1);
610}
611
612template <typename Op0_t, typename Op1_t>
613inline AllRecipe_match<Instruction::UDiv, Op0_t, Op1_t>
614m_UDiv(const Op0_t &Op0, const Op1_t &Op1) {
615 return m_Binary<Instruction::UDiv, Op0_t, Op1_t>(Op0, Op1);
616}
617
618/// Match a binary AND operation.
619template <typename Op0_t, typename Op1_t>
620inline AllRecipe_commutative_match<Instruction::And, Op0_t, Op1_t>
621m_c_BinaryAnd(const Op0_t &Op0, const Op1_t &Op1) {
622 return m_c_Binary<Instruction::And, Op0_t, Op1_t>(Op0, Op1);
623}
624
625/// Match a binary OR operation. Note that while conceptually the operands can
626/// be matched commutatively, \p Commutative defaults to false in line with the
627/// IR-based pattern matching infrastructure. Use m_c_BinaryOr for a commutative
628/// version of the matcher.
629template <typename Op0_t, typename Op1_t>
630inline AllRecipe_match<Instruction::Or, Op0_t, Op1_t>
631m_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) {
632 return m_Binary<Instruction::Or, Op0_t, Op1_t>(Op0, Op1);
633}
634
635template <typename Op0_t, typename Op1_t>
636inline AllRecipe_commutative_match<Instruction::Or, Op0_t, Op1_t>
637m_c_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) {
638 return m_c_Binary<Instruction::Or, Op0_t, Op1_t>(Op0, Op1);
639}
640
641/// Cmp_match is a variant of BinaryRecipe_match that also binds the comparison
642/// predicate. Opcodes must either be Instruction::ICmp or Instruction::FCmp, or
643/// both.
644template <typename Op0_t, typename Op1_t, unsigned... Opcodes>
645struct Cmp_match {
646 static_assert((sizeof...(Opcodes) == 1 || sizeof...(Opcodes) == 2) &&
647 "Expected one or two opcodes");
648 static_assert(
649 ((Opcodes == Instruction::ICmp || Opcodes == Instruction::FCmp) && ...) &&
650 "Expected a compare instruction opcode");
651
652 CmpPredicate *Predicate = nullptr;
653 Op0_t Op0;
654 Op1_t Op1;
655
656 Cmp_match(CmpPredicate &Pred, const Op0_t &Op0, const Op1_t &Op1)
657 : Predicate(&Pred), Op0(Op0), Op1(Op1) {}
658 Cmp_match(const Op0_t &Op0, const Op1_t &Op1) : Op0(Op0), Op1(Op1) {}
659
660 bool match(const VPValue *V) const {
661 auto *DefR = V->getDefiningRecipe();
662 return DefR && match(DefR);
663 }
664
665 bool match(const VPRecipeBase *V) const {
666 if ((m_Binary<Opcodes>(Op0, Op1).match(V) || ...)) {
667 if (Predicate)
668 *Predicate = cast<VPRecipeWithIRFlags>(Val: V)->getPredicate();
669 return true;
670 }
671 return false;
672 }
673};
674
675/// SpecificCmp_match is a variant of Cmp_match that matches the comparison
676/// predicate, instead of binding it.
677template <typename Op0_t, typename Op1_t, unsigned... Opcodes>
678struct SpecificCmp_match {
679 const CmpPredicate Predicate;
680 Op0_t Op0;
681 Op1_t Op1;
682
683 SpecificCmp_match(CmpPredicate Pred, const Op0_t &LHS, const Op1_t &RHS)
684 : Predicate(Pred), Op0(LHS), Op1(RHS) {}
685
686 bool match(const VPValue *V) const {
687 auto *DefR = V->getDefiningRecipe();
688 return DefR && match(DefR);
689 }
690
691 bool match(const VPRecipeBase *V) const {
692 CmpPredicate CurrentPred;
693 return Cmp_match<Op0_t, Op1_t, Opcodes...>(CurrentPred, Op0, Op1)
694 .match(V) &&
695 CmpPredicate::getMatching(A: CurrentPred, B: Predicate);
696 }
697};
698
699template <typename Op0_t, typename Op1_t>
700inline Cmp_match<Op0_t, Op1_t, Instruction::ICmp> m_ICmp(const Op0_t &Op0,
701 const Op1_t &Op1) {
702 return Cmp_match<Op0_t, Op1_t, Instruction::ICmp>(Op0, Op1);
703}
704
705template <typename Op0_t, typename Op1_t>
706inline Cmp_match<Op0_t, Op1_t, Instruction::ICmp>
707m_ICmp(CmpPredicate &Pred, const Op0_t &Op0, const Op1_t &Op1) {
708 return Cmp_match<Op0_t, Op1_t, Instruction::ICmp>(Pred, Op0, Op1);
709}
710
711template <typename Op0_t, typename Op1_t>
712inline SpecificCmp_match<Op0_t, Op1_t, Instruction::ICmp>
713m_SpecificICmp(CmpPredicate MatchPred, const Op0_t &Op0, const Op1_t &Op1) {
714 return SpecificCmp_match<Op0_t, Op1_t, Instruction::ICmp>(MatchPred, Op0,
715 Op1);
716}
717
718template <typename Op0_t, typename Op1_t>
719inline Cmp_match<Op0_t, Op1_t, Instruction::ICmp, Instruction::FCmp>
720m_Cmp(const Op0_t &Op0, const Op1_t &Op1) {
721 return Cmp_match<Op0_t, Op1_t, Instruction::ICmp, Instruction::FCmp>(Op0,
722 Op1);
723}
724
725template <typename Op0_t, typename Op1_t>
726inline Cmp_match<Op0_t, Op1_t, Instruction::ICmp, Instruction::FCmp>
727m_Cmp(CmpPredicate &Pred, const Op0_t &Op0, const Op1_t &Op1) {
728 return Cmp_match<Op0_t, Op1_t, Instruction::ICmp, Instruction::FCmp>(
729 Pred, Op0, Op1);
730}
731
732template <typename Op0_t, typename Op1_t>
733inline SpecificCmp_match<Op0_t, Op1_t, Instruction::ICmp, Instruction::FCmp>
734m_SpecificCmp(CmpPredicate MatchPred, const Op0_t &Op0, const Op1_t &Op1) {
735 return SpecificCmp_match<Op0_t, Op1_t, Instruction::ICmp, Instruction::FCmp>(
736 MatchPred, Op0, Op1);
737}
738
739template <typename Op0_t, typename Op1_t>
740using GEPLikeRecipe_match = match_combine_or<
741 Recipe_match<std::tuple<Op0_t, Op1_t>, Instruction::GetElementPtr,
742 /*Commutative*/ false, VPReplicateRecipe, VPWidenGEPRecipe>,
743 match_combine_or<
744 VPInstruction_match<VPInstruction::PtrAdd, Op0_t, Op1_t>,
745 VPInstruction_match<VPInstruction::WidePtrAdd, Op0_t, Op1_t>>>;
746
747template <typename Op0_t, typename Op1_t>
748inline GEPLikeRecipe_match<Op0_t, Op1_t> m_GetElementPtr(const Op0_t &Op0,
749 const Op1_t &Op1) {
750 return m_CombineOr(
751 Recipe_match<std::tuple<Op0_t, Op1_t>, Instruction::GetElementPtr,
752 /*Commutative*/ false, VPReplicateRecipe, VPWidenGEPRecipe>(
753 Op0, Op1),
754 m_CombineOr(
755 VPInstruction_match<VPInstruction::PtrAdd, Op0_t, Op1_t>(Op0, Op1),
756 VPInstruction_match<VPInstruction::WidePtrAdd, Op0_t, Op1_t>(Op0,
757 Op1)));
758}
759
760template <typename Op0_t, typename Op1_t, typename Op2_t>
761inline AllRecipe_match<Instruction::Select, Op0_t, Op1_t, Op2_t>
762m_Select(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) {
763 return AllRecipe_match<Instruction::Select, Op0_t, Op1_t, Op2_t>(
764 {Op0, Op1, Op2});
765}
766
767template <typename Op0_t>
768inline match_combine_or<VPInstruction_match<VPInstruction::Not, Op0_t>,
769 AllRecipe_commutative_match<
770 Instruction::Xor, int_pred_ty<is_all_ones>, Op0_t>>
771m_Not(const Op0_t &Op0) {
772 return m_CombineOr(m_VPInstruction<VPInstruction::Not>(Op0),
773 m_c_Binary<Instruction::Xor>(m_AllOnes(), Op0));
774}
775
776template <typename Op0_t, typename Op1_t>
777inline match_combine_or<
778 VPInstruction_match<VPInstruction::LogicalAnd, Op0_t, Op1_t>,
779 AllRecipe_match<Instruction::Select, Op0_t, Op1_t, specific_intval<1>>>
780m_LogicalAnd(const Op0_t &Op0, const Op1_t &Op1) {
781 return m_CombineOr(
782 m_VPInstruction<VPInstruction::LogicalAnd, Op0_t, Op1_t>(Op0, Op1),
783 m_Select(Op0, Op1, m_False()));
784}
785
786template <typename Op0_t, typename Op1_t>
787inline AllRecipe_match<Instruction::Select, Op0_t, specific_intval<1>, Op1_t>
788m_LogicalOr(const Op0_t &Op0, const Op1_t &Op1) {
789 return m_Select(Op0, m_True(), Op1);
790}
791
792template <typename Op0_t, typename Op1_t, typename Op2_t>
793using VPScalarIVSteps_match = Recipe_match<std::tuple<Op0_t, Op1_t, Op2_t>, 0,
794 false, VPScalarIVStepsRecipe>;
795
796template <typename Op0_t, typename Op1_t, typename Op2_t>
797inline VPScalarIVSteps_match<Op0_t, Op1_t, Op2_t>
798m_ScalarIVSteps(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) {
799 return VPScalarIVSteps_match<Op0_t, Op1_t, Op2_t>({Op0, Op1, Op2});
800}
801
802template <typename Op0_t, typename Op1_t, typename Op2_t>
803using VPDerivedIV_match =
804 Recipe_match<std::tuple<Op0_t, Op1_t, Op2_t>, 0, false, VPDerivedIVRecipe>;
805
806template <typename Op0_t, typename Op1_t, typename Op2_t>
807inline VPDerivedIV_match<Op0_t, Op1_t, Op2_t>
808m_DerivedIV(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) {
809 return VPDerivedIV_match<Op0_t, Op1_t, Op2_t>({Op0, Op1, Op2});
810}
811
812template <typename Addr_t, typename Mask_t> struct Load_match {
813 Addr_t Addr;
814 Mask_t Mask;
815
816 Load_match(Addr_t Addr, Mask_t Mask) : Addr(Addr), Mask(Mask) {}
817
818 template <typename OpTy> bool match(const OpTy *V) const {
819 auto *Load = dyn_cast<VPWidenLoadRecipe>(V);
820 if (!Load || !Addr.match(Load->getAddr()) || !Load->isMasked() ||
821 !Mask.match(Load->getMask()))
822 return false;
823 return true;
824 }
825};
826
827/// Match a (possibly reversed) masked load.
828template <typename Addr_t, typename Mask_t>
829inline Load_match<Addr_t, Mask_t> m_MaskedLoad(const Addr_t &Addr,
830 const Mask_t &Mask) {
831 return Load_match<Addr_t, Mask_t>(Addr, Mask);
832}
833
834template <typename Addr_t, typename Val_t, typename Mask_t> struct Store_match {
835 Addr_t Addr;
836 Val_t Val;
837 Mask_t Mask;
838
839 Store_match(Addr_t Addr, Val_t Val, Mask_t Mask)
840 : Addr(Addr), Val(Val), Mask(Mask) {}
841
842 template <typename OpTy> bool match(const OpTy *V) const {
843 auto *Store = dyn_cast<VPWidenStoreRecipe>(V);
844 if (!Store || !Addr.match(Store->getAddr()) ||
845 !Val.match(Store->getStoredValue()) || !Store->isMasked() ||
846 !Mask.match(Store->getMask()))
847 return false;
848 return true;
849 }
850};
851
852/// Match a (possibly reversed) masked store.
853template <typename Addr_t, typename Val_t, typename Mask_t>
854inline Store_match<Addr_t, Val_t, Mask_t>
855m_MaskedStore(const Addr_t &Addr, const Val_t &Val, const Mask_t &Mask) {
856 return Store_match<Addr_t, Val_t, Mask_t>(Addr, Val, Mask);
857}
858
859template <typename Op0_t, typename Op1_t>
860using VectorEndPointerRecipe_match =
861 Recipe_match<std::tuple<Op0_t, Op1_t>, 0,
862 /*Commutative*/ false, VPVectorEndPointerRecipe>;
863
864template <typename Op0_t, typename Op1_t>
865VectorEndPointerRecipe_match<Op0_t, Op1_t> m_VecEndPtr(const Op0_t &Op0,
866 const Op1_t &Op1) {
867 return VectorEndPointerRecipe_match<Op0_t, Op1_t>(Op0, Op1);
868}
869
870/// Match a call argument at a given argument index.
871template <typename Opnd_t> struct Argument_match {
872 /// Call argument index to match.
873 unsigned OpI;
874 Opnd_t Val;
875
876 Argument_match(unsigned OpIdx, const Opnd_t &V) : OpI(OpIdx), Val(V) {}
877
878 template <typename OpTy> bool match(OpTy *V) const {
879 if (const auto *R = dyn_cast<VPWidenIntrinsicRecipe>(V))
880 return Val.match(R->getOperand(OpI));
881 if (const auto *R = dyn_cast<VPWidenCallRecipe>(V))
882 return Val.match(R->getOperand(OpI));
883 if (const auto *R = dyn_cast<VPReplicateRecipe>(V))
884 if (R->getOpcode() == Instruction::Call)
885 return Val.match(R->getOperand(OpI));
886 if (const auto *R = dyn_cast<VPInstruction>(V))
887 if (R->getOpcode() == Instruction::Call)
888 return Val.match(R->getOperand(OpI));
889 return false;
890 }
891};
892
893/// Match a call argument.
894template <unsigned OpI, typename Opnd_t>
895inline Argument_match<Opnd_t> m_Argument(const Opnd_t &Op) {
896 return Argument_match<Opnd_t>(OpI, Op);
897}
898
899/// Intrinsic matchers.
900struct IntrinsicID_match {
901 unsigned ID;
902
903 IntrinsicID_match(Intrinsic::ID IntrID) : ID(IntrID) {}
904
905 template <typename OpTy> bool match(OpTy *V) const {
906 if (const auto *R = dyn_cast<VPWidenIntrinsicRecipe>(V))
907 return R->getVectorIntrinsicID() == ID;
908 if (const auto *R = dyn_cast<VPWidenCallRecipe>(V))
909 return R->getCalledScalarFunction()->getIntrinsicID() == ID;
910
911 auto MatchCalleeIntrinsic = [&](VPValue *CalleeOp) {
912 if (!isa<VPIRValue>(Val: CalleeOp))
913 return false;
914 auto *F = cast<Function>(Val: CalleeOp->getLiveInIRValue());
915 return F->getIntrinsicID() == ID;
916 };
917 if (const auto *R = dyn_cast<VPReplicateRecipe>(V))
918 if (R->getOpcode() == Instruction::Call) {
919 // The mask is always the last operand if predicated.
920 return MatchCalleeIntrinsic(
921 R->getOperand(R->getNumOperands() - 1 - R->isPredicated()));
922 }
923 if (const auto *R = dyn_cast<VPInstruction>(V))
924 if (R->getOpcode() == Instruction::Call)
925 return MatchCalleeIntrinsic(R->getOperand(R->getNumOperands() - 1));
926 return false;
927 }
928};
929
930/// Intrinsic matches are combinations of ID matchers, and argument
931/// matchers. Higher arity matcher are defined recursively in terms of and-ing
932/// them with lower arity matchers. Here's some convenient typedefs for up to
933/// several arguments, and more can be added as needed
934template <typename T0 = void, typename T1 = void, typename T2 = void,
935 typename T3 = void>
936struct m_Intrinsic_Ty;
937template <typename T0> struct m_Intrinsic_Ty<T0> {
938 using Ty = match_combine_and<IntrinsicID_match, Argument_match<T0>>;
939};
940template <typename T0, typename T1> struct m_Intrinsic_Ty<T0, T1> {
941 using Ty =
942 match_combine_and<typename m_Intrinsic_Ty<T0>::Ty, Argument_match<T1>>;
943};
944template <typename T0, typename T1, typename T2>
945struct m_Intrinsic_Ty<T0, T1, T2> {
946 using Ty = match_combine_and<typename m_Intrinsic_Ty<T0, T1>::Ty,
947 Argument_match<T2>>;
948};
949template <typename T0, typename T1, typename T2, typename T3>
950struct m_Intrinsic_Ty {
951 using Ty = match_combine_and<typename m_Intrinsic_Ty<T0, T1, T2>::Ty,
952 Argument_match<T3>>;
953};
954
955/// Match intrinsic calls like this:
956/// m_Intrinsic<Intrinsic::fabs>(m_VPValue(X), ...)
957template <Intrinsic::ID IntrID> inline IntrinsicID_match m_Intrinsic() {
958 return IntrinsicID_match(IntrID);
959}
960
961/// Match intrinsic calls with a runtime intrinsic ID.
962inline IntrinsicID_match m_Intrinsic(Intrinsic::ID IntrID) {
963 return IntrinsicID_match(IntrID);
964}
965
966template <Intrinsic::ID IntrID, typename T0>
967inline typename m_Intrinsic_Ty<T0>::Ty m_Intrinsic(const T0 &Op0) {
968 return m_CombineAnd(m_Intrinsic<IntrID>(), m_Argument<0>(Op0));
969}
970
971template <Intrinsic::ID IntrID, typename T0, typename T1>
972inline typename m_Intrinsic_Ty<T0, T1>::Ty m_Intrinsic(const T0 &Op0,
973 const T1 &Op1) {
974 return m_CombineAnd(m_Intrinsic<IntrID>(Op0), m_Argument<1>(Op1));
975}
976
977template <Intrinsic::ID IntrID, typename T0, typename T1, typename T2>
978inline typename m_Intrinsic_Ty<T0, T1, T2>::Ty
979m_Intrinsic(const T0 &Op0, const T1 &Op1, const T2 &Op2) {
980 return m_CombineAnd(m_Intrinsic<IntrID>(Op0, Op1), m_Argument<2>(Op2));
981}
982
983template <Intrinsic::ID IntrID, typename T0, typename T1, typename T2,
984 typename T3>
985inline typename m_Intrinsic_Ty<T0, T1, T2, T3>::Ty
986m_Intrinsic(const T0 &Op0, const T1 &Op1, const T2 &Op2, const T3 &Op3) {
987 return m_CombineAnd(m_Intrinsic<IntrID>(Op0, Op1, Op2), m_Argument<3>(Op3));
988}
989
990inline auto m_LiveIn() { return class_match<VPIRValue, VPSymbolicValue>(); }
991
992/// Match a GEP recipe (VPWidenGEPRecipe, VPInstruction, or VPReplicateRecipe)
993/// and bind the source element type and operands.
994struct GetElementPtr_match {
995 Type *&SourceElementType;
996 ArrayRef<VPValue *> &Operands;
997
998 GetElementPtr_match(Type *&SourceElementType, ArrayRef<VPValue *> &Operands)
999 : SourceElementType(SourceElementType), Operands(Operands) {}
1000
1001 template <typename ITy> bool match(ITy *V) const {
1002 return matchRecipeAndBind<VPWidenGEPRecipe>(V) ||
1003 matchRecipeAndBind<VPInstruction>(V) ||
1004 matchRecipeAndBind<VPReplicateRecipe>(V);
1005 }
1006
1007private:
1008 template <typename RecipeTy> bool matchRecipeAndBind(const VPValue *V) const {
1009 auto *DefR = dyn_cast<RecipeTy>(V);
1010 if (!DefR)
1011 return false;
1012
1013 if constexpr (std::is_same_v<RecipeTy, VPWidenGEPRecipe>) {
1014 SourceElementType = DefR->getSourceElementType();
1015 } else if (DefR->getOpcode() == Instruction::GetElementPtr) {
1016 SourceElementType = cast<GetElementPtrInst>(DefR->getUnderlyingInstr())
1017 ->getSourceElementType();
1018 } else if constexpr (std::is_same_v<RecipeTy, VPInstruction>) {
1019 if (DefR->getOpcode() == VPInstruction::PtrAdd) {
1020 // PtrAdd is a byte-offset GEP with i8 element type.
1021 LLVMContext &Ctx = DefR->getParent()->getPlan()->getContext();
1022 SourceElementType = Type::getInt8Ty(C&: Ctx);
1023 } else {
1024 return false;
1025 }
1026 } else {
1027 return false;
1028 }
1029
1030 Operands = ArrayRef<VPValue *>(DefR->op_begin(), DefR->op_end());
1031 return true;
1032 }
1033};
1034
1035/// Match a GEP recipe with any number of operands and bind source element type
1036/// and operands.
1037inline GetElementPtr_match m_GetElementPtr(Type *&SourceElementType,
1038 ArrayRef<VPValue *> &Operands) {
1039 return GetElementPtr_match(SourceElementType, Operands);
1040}
1041
1042template <typename SubPattern_t> struct OneUse_match {
1043 SubPattern_t SubPattern;
1044
1045 OneUse_match(const SubPattern_t &SP) : SubPattern(SP) {}
1046
1047 template <typename OpTy> bool match(OpTy *V) {
1048 return V->hasOneUse() && SubPattern.match(V);
1049 }
1050};
1051
1052template <typename T> inline OneUse_match<T> m_OneUse(const T &SubPattern) {
1053 return SubPattern;
1054}
1055
1056} // namespace llvm::VPlanPatternMatch
1057
1058#endif
1059