1//===- CmpInstAnalysis.cpp - Utils to help fold compares ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file holds routines to help analyse compare instructions
10// and fold them into constants or other compare instructions
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/Analysis/CmpInstAnalysis.h"
15#include "llvm/IR/Constants.h"
16#include "llvm/IR/Instructions.h"
17#include "llvm/IR/PatternMatch.h"
18
19using namespace llvm;
20
21unsigned llvm::getICmpCode(CmpInst::Predicate Pred) {
22 switch (Pred) {
23 // False -> 0
24 case ICmpInst::ICMP_UGT: return 1; // 001
25 case ICmpInst::ICMP_SGT: return 1; // 001
26 case ICmpInst::ICMP_EQ: return 2; // 010
27 case ICmpInst::ICMP_UGE: return 3; // 011
28 case ICmpInst::ICMP_SGE: return 3; // 011
29 case ICmpInst::ICMP_ULT: return 4; // 100
30 case ICmpInst::ICMP_SLT: return 4; // 100
31 case ICmpInst::ICMP_NE: return 5; // 101
32 case ICmpInst::ICMP_ULE: return 6; // 110
33 case ICmpInst::ICMP_SLE: return 6; // 110
34 // True -> 7
35 default:
36 llvm_unreachable("Invalid ICmp predicate!");
37 }
38}
39
40Constant *llvm::getPredForICmpCode(unsigned Code, bool Sign, Type *OpTy,
41 CmpInst::Predicate &Pred) {
42 switch (Code) {
43 default: llvm_unreachable("Illegal ICmp code!");
44 case 0: // False.
45 return ConstantInt::get(Ty: CmpInst::makeCmpResultType(opnd_type: OpTy), V: 0);
46 case 1: Pred = Sign ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT; break;
47 case 2: Pred = ICmpInst::ICMP_EQ; break;
48 case 3: Pred = Sign ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE; break;
49 case 4: Pred = Sign ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; break;
50 case 5: Pred = ICmpInst::ICMP_NE; break;
51 case 6: Pred = Sign ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE; break;
52 case 7: // True.
53 return ConstantInt::get(Ty: CmpInst::makeCmpResultType(opnd_type: OpTy), V: 1);
54 }
55 return nullptr;
56}
57
58bool llvm::predicatesFoldable(ICmpInst::Predicate P1, ICmpInst::Predicate P2) {
59 return (CmpInst::isSigned(predicate: P1) == CmpInst::isSigned(predicate: P2)) ||
60 (CmpInst::isSigned(predicate: P1) && ICmpInst::isEquality(P: P2)) ||
61 (CmpInst::isSigned(predicate: P2) && ICmpInst::isEquality(P: P1));
62}
63
64Constant *llvm::getPredForFCmpCode(unsigned Code, Type *OpTy,
65 CmpInst::Predicate &Pred) {
66 Pred = static_cast<FCmpInst::Predicate>(Code);
67 assert(FCmpInst::FCMP_FALSE <= Pred && Pred <= FCmpInst::FCMP_TRUE &&
68 "Unexpected FCmp predicate!");
69 if (Pred == FCmpInst::FCMP_FALSE)
70 return ConstantInt::get(Ty: CmpInst::makeCmpResultType(opnd_type: OpTy), V: 0);
71 if (Pred == FCmpInst::FCMP_TRUE)
72 return ConstantInt::get(Ty: CmpInst::makeCmpResultType(opnd_type: OpTy), V: 1);
73 return nullptr;
74}
75
76std::optional<DecomposedBitTest>
77llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
78 bool LookThruTrunc, bool AllowNonZeroC,
79 bool DecomposeAnd) {
80 using namespace PatternMatch;
81
82 const APInt *OrigC;
83 if ((ICmpInst::isEquality(P: Pred) && !DecomposeAnd) ||
84 !match(V: RHS, P: m_APIntAllowPoison(Res&: OrigC)))
85 return std::nullopt;
86
87 bool Inverted = false;
88 if (ICmpInst::isGT(P: Pred) || ICmpInst::isGE(P: Pred)) {
89 Inverted = true;
90 Pred = ICmpInst::getInversePredicate(pred: Pred);
91 }
92
93 APInt C = *OrigC;
94 if (ICmpInst::isLE(P: Pred)) {
95 if (ICmpInst::isSigned(predicate: Pred) ? C.isMaxSignedValue() : C.isMaxValue())
96 return std::nullopt;
97 ++C;
98 Pred = ICmpInst::getStrictPredicate(pred: Pred);
99 }
100
101 DecomposedBitTest Result;
102 switch (Pred) {
103 default:
104 llvm_unreachable("Unexpected predicate");
105 case ICmpInst::ICMP_SLT: {
106 // X < 0 is equivalent to (X & SignMask) != 0.
107 if (C.isZero()) {
108 Result.Mask = APInt::getSignMask(BitWidth: C.getBitWidth());
109 Result.C = APInt::getZero(numBits: C.getBitWidth());
110 Result.Pred = ICmpInst::ICMP_NE;
111 break;
112 }
113
114 APInt FlippedSign = C ^ APInt::getSignMask(BitWidth: C.getBitWidth());
115 if (FlippedSign.isPowerOf2()) {
116 // X s< 10000100 is equivalent to (X & 11111100 == 10000000)
117 Result.Mask = -FlippedSign;
118 Result.C = APInt::getSignMask(BitWidth: C.getBitWidth());
119 Result.Pred = ICmpInst::ICMP_EQ;
120 break;
121 }
122
123 if (FlippedSign.isNegatedPowerOf2()) {
124 // X s< 01111100 is equivalent to (X & 11111100 != 01111100)
125 Result.Mask = FlippedSign;
126 Result.C = C;
127 Result.Pred = ICmpInst::ICMP_NE;
128 break;
129 }
130
131 return std::nullopt;
132 }
133 case ICmpInst::ICMP_ULT: {
134 // X <u 2^n is equivalent to (X & ~(2^n-1)) == 0.
135 if (C.isPowerOf2()) {
136 Result.Mask = -C;
137 Result.C = APInt::getZero(numBits: C.getBitWidth());
138 Result.Pred = ICmpInst::ICMP_EQ;
139 break;
140 }
141
142 // X u< 11111100 is equivalent to (X & 11111100 != 11111100)
143 if (C.isNegatedPowerOf2()) {
144 Result.Mask = C;
145 Result.C = C;
146 Result.Pred = ICmpInst::ICMP_NE;
147 break;
148 }
149
150 return std::nullopt;
151 }
152 case ICmpInst::ICMP_EQ:
153 case ICmpInst::ICMP_NE: {
154 assert(DecomposeAnd);
155 const APInt *AndC;
156 Value *AndVal;
157 if (match(V: LHS, P: m_And(L: m_Value(V&: AndVal), R: m_APIntAllowPoison(Res&: AndC)))) {
158 LHS = AndVal;
159 Result.Mask = *AndC;
160 Result.C = C;
161 Result.Pred = Pred;
162 break;
163 }
164
165 return std::nullopt;
166 }
167 }
168
169 if (!AllowNonZeroC && !Result.C.isZero())
170 return std::nullopt;
171
172 if (Inverted)
173 Result.Pred = ICmpInst::getInversePredicate(pred: Result.Pred);
174
175 Value *X;
176 if (LookThruTrunc && match(V: LHS, P: m_Trunc(Op: m_Value(V&: X)))) {
177 Result.X = X;
178 Result.Mask = Result.Mask.zext(width: X->getType()->getScalarSizeInBits());
179 Result.C = Result.C.zext(width: X->getType()->getScalarSizeInBits());
180 } else {
181 Result.X = LHS;
182 }
183
184 return Result;
185}
186
187std::optional<DecomposedBitTest> llvm::decomposeBitTest(Value *Cond,
188 bool LookThruTrunc,
189 bool AllowNonZeroC,
190 bool DecomposeAnd) {
191 using namespace PatternMatch;
192 if (auto *ICmp = dyn_cast<ICmpInst>(Val: Cond)) {
193 // Don't allow pointers. Splat vectors are fine.
194 if (!ICmp->getOperand(i_nocapture: 0)->getType()->isIntOrIntVectorTy())
195 return std::nullopt;
196 return decomposeBitTestICmp(LHS: ICmp->getOperand(i_nocapture: 0), RHS: ICmp->getOperand(i_nocapture: 1),
197 Pred: ICmp->getPredicate(), LookThruTrunc,
198 AllowNonZeroC, DecomposeAnd);
199 }
200 Value *X;
201 if (Cond->getType()->isIntOrIntVectorTy(BitWidth: 1) &&
202 (match(V: Cond, P: m_Trunc(Op: m_Value(V&: X))) ||
203 match(V: Cond, P: m_Not(V: m_Trunc(Op: m_Value(V&: X)))))) {
204 DecomposedBitTest Result;
205 Result.X = X;
206 unsigned BitWidth = X->getType()->getScalarSizeInBits();
207 Result.Mask = APInt(BitWidth, 1);
208 Result.C = APInt::getZero(numBits: BitWidth);
209 Result.Pred = isa<TruncInst>(Val: Cond) ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ;
210
211 return Result;
212 }
213
214 return std::nullopt;
215}
216