1//===- ScalarEvolutionDivision.h - See below --------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the class that knows how to divide SCEV's.
10//
11//===----------------------------------------------------------------------===//
12
13#include "llvm/Analysis/ScalarEvolutionDivision.h"
14#include "llvm/ADT/APInt.h"
15#include "llvm/ADT/DenseMap.h"
16#include "llvm/ADT/SmallVector.h"
17#include "llvm/Analysis/ScalarEvolution.h"
18#include "llvm/IR/InstIterator.h"
19#include "llvm/IR/Instructions.h"
20#include "llvm/Support/Casting.h"
21#include <cassert>
22#include <cstdint>
23
24#define DEBUG_TYPE "scev-division"
25
26namespace llvm {
27class Type;
28} // namespace llvm
29
30using namespace llvm;
31
32static inline int sizeOfSCEV(const SCEV *S) {
33 struct FindSCEVSize {
34 int Size = 0;
35
36 FindSCEVSize() = default;
37
38 bool follow(const SCEV *S) {
39 ++Size;
40 // Keep looking at all operands of S.
41 return true;
42 }
43
44 bool isDone() const { return false; }
45 };
46
47 FindSCEVSize F;
48 SCEVTraversal<FindSCEVSize> ST(F);
49 ST.visitAll(Root: S);
50 return F.Size;
51}
52
53// Computes the Quotient and Remainder of the division of Numerator by
54// Denominator.
55void SCEVDivision::divide(ScalarEvolution &SE, const SCEV *Numerator,
56 const SCEV *Denominator, const SCEV **Quotient,
57 const SCEV **Remainder) {
58 assert(Numerator && Denominator && "Uninitialized SCEV");
59
60 SCEVDivision D(SE, Numerator, Denominator);
61
62 // Check for the trivial case here to avoid having to check for it in the
63 // rest of the code.
64 if (Numerator == Denominator) {
65 *Quotient = D.One;
66 *Remainder = D.Zero;
67 return;
68 }
69
70 if (Numerator->isZero()) {
71 *Quotient = D.Zero;
72 *Remainder = D.Zero;
73 return;
74 }
75
76 // A simple case when N/1. The quotient is N.
77 if (Denominator->isOne()) {
78 *Quotient = Numerator;
79 *Remainder = D.Zero;
80 return;
81 }
82
83 // Split the Denominator when it is a product.
84 if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Val: Denominator)) {
85 const SCEV *Q, *R;
86 *Quotient = Numerator;
87 for (const SCEV *Op : T->operands()) {
88 divide(SE, Numerator: *Quotient, Denominator: Op, Quotient: &Q, Remainder: &R);
89 *Quotient = Q;
90
91 // Bail out when the Numerator is not divisible by one of the terms of
92 // the Denominator.
93 if (!R->isZero()) {
94 *Quotient = D.Zero;
95 *Remainder = Numerator;
96 return;
97 }
98 }
99 *Remainder = D.Zero;
100 return;
101 }
102
103 D.visit(S: Numerator);
104 *Quotient = D.Quotient;
105 *Remainder = D.Remainder;
106}
107
108void SCEVDivision::visitConstant(const SCEVConstant *Numerator) {
109 if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Val: Denominator)) {
110 APInt NumeratorVal = Numerator->getAPInt();
111 APInt DenominatorVal = D->getAPInt();
112 uint32_t NumeratorBW = NumeratorVal.getBitWidth();
113 uint32_t DenominatorBW = DenominatorVal.getBitWidth();
114
115 if (NumeratorBW > DenominatorBW)
116 DenominatorVal = DenominatorVal.sext(width: NumeratorBW);
117 else if (NumeratorBW < DenominatorBW)
118 NumeratorVal = NumeratorVal.sext(width: DenominatorBW);
119
120 APInt QuotientVal(NumeratorVal.getBitWidth(), 0);
121 APInt RemainderVal(NumeratorVal.getBitWidth(), 0);
122 APInt::sdivrem(LHS: NumeratorVal, RHS: DenominatorVal, Quotient&: QuotientVal, Remainder&: RemainderVal);
123 Quotient = SE.getConstant(Val: QuotientVal);
124 Remainder = SE.getConstant(Val: RemainderVal);
125 return;
126 }
127}
128
129void SCEVDivision::visitVScale(const SCEVVScale *Numerator) {
130 return cannotDivide(Numerator);
131}
132
133void SCEVDivision::visitAddRecExpr(const SCEVAddRecExpr *Numerator) {
134 const SCEV *StartQ, *StartR, *StepQ, *StepR;
135 if (!Numerator->isAffine())
136 return cannotDivide(Numerator);
137 divide(SE, Numerator: Numerator->getStart(), Denominator, Quotient: &StartQ, Remainder: &StartR);
138 divide(SE, Numerator: Numerator->getStepRecurrence(SE), Denominator, Quotient: &StepQ, Remainder: &StepR);
139 // Bail out if the types do not match.
140 Type *Ty = Denominator->getType();
141 if (Ty != StartQ->getType() || Ty != StartR->getType() ||
142 Ty != StepQ->getType() || Ty != StepR->getType())
143 return cannotDivide(Numerator);
144
145 Quotient = SE.getAddRecExpr(Start: StartQ, Step: StepQ, L: Numerator->getLoop(),
146 Flags: SCEV::NoWrapFlags::FlagAnyWrap);
147 Remainder = SE.getAddRecExpr(Start: StartR, Step: StepR, L: Numerator->getLoop(),
148 Flags: SCEV::NoWrapFlags::FlagAnyWrap);
149}
150
151void SCEVDivision::visitAddExpr(const SCEVAddExpr *Numerator) {
152 SmallVector<const SCEV *, 2> Qs, Rs;
153 Type *Ty = Denominator->getType();
154
155 for (const SCEV *Op : Numerator->operands()) {
156 const SCEV *Q, *R;
157 divide(SE, Numerator: Op, Denominator, Quotient: &Q, Remainder: &R);
158
159 // Bail out if types do not match.
160 if (Ty != Q->getType() || Ty != R->getType())
161 return cannotDivide(Numerator);
162
163 Qs.push_back(Elt: Q);
164 Rs.push_back(Elt: R);
165 }
166
167 if (Qs.size() == 1) {
168 Quotient = Qs[0];
169 Remainder = Rs[0];
170 return;
171 }
172
173 Quotient = SE.getAddExpr(Ops&: Qs);
174 Remainder = SE.getAddExpr(Ops&: Rs);
175}
176
177void SCEVDivision::visitMulExpr(const SCEVMulExpr *Numerator) {
178 SmallVector<const SCEV *, 2> Qs;
179 Type *Ty = Denominator->getType();
180
181 bool FoundDenominatorTerm = false;
182 for (const SCEV *Op : Numerator->operands()) {
183 // Bail out if types do not match.
184 if (Ty != Op->getType())
185 return cannotDivide(Numerator);
186
187 if (FoundDenominatorTerm) {
188 Qs.push_back(Elt: Op);
189 continue;
190 }
191
192 // Check whether Denominator divides one of the product operands.
193 const SCEV *Q, *R;
194 divide(SE, Numerator: Op, Denominator, Quotient: &Q, Remainder: &R);
195 if (!R->isZero()) {
196 Qs.push_back(Elt: Op);
197 continue;
198 }
199
200 // Bail out if types do not match.
201 if (Ty != Q->getType())
202 return cannotDivide(Numerator);
203
204 FoundDenominatorTerm = true;
205 Qs.push_back(Elt: Q);
206 }
207
208 if (FoundDenominatorTerm) {
209 Remainder = Zero;
210 if (Qs.size() == 1)
211 Quotient = Qs[0];
212 else
213 Quotient = SE.getMulExpr(Ops&: Qs);
214 return;
215 }
216
217 if (!isa<SCEVUnknown>(Val: Denominator))
218 return cannotDivide(Numerator);
219
220 // The Remainder is obtained by replacing Denominator by 0 in Numerator.
221 ValueToSCEVMapTy RewriteMap;
222 RewriteMap[cast<SCEVUnknown>(Val: Denominator)->getValue()] = Zero;
223 Remainder = SCEVParameterRewriter::rewrite(Scev: Numerator, SE, Map&: RewriteMap);
224
225 if (Remainder->isZero()) {
226 // The Quotient is obtained by replacing Denominator by 1 in Numerator.
227 RewriteMap[cast<SCEVUnknown>(Val: Denominator)->getValue()] = One;
228 Quotient = SCEVParameterRewriter::rewrite(Scev: Numerator, SE, Map&: RewriteMap);
229 return;
230 }
231
232 // Quotient is (Numerator - Remainder) divided by Denominator.
233 const SCEV *Q, *R;
234 const SCEV *Diff = SE.getMinusSCEV(LHS: Numerator, RHS: Remainder);
235 // This SCEV does not seem to simplify: fail the division here.
236 if (sizeOfSCEV(S: Diff) > sizeOfSCEV(S: Numerator))
237 return cannotDivide(Numerator);
238 divide(SE, Numerator: Diff, Denominator, Quotient: &Q, Remainder: &R);
239 if (R != Zero)
240 return cannotDivide(Numerator);
241 Quotient = Q;
242}
243
244SCEVDivision::SCEVDivision(ScalarEvolution &S, const SCEV *Numerator,
245 const SCEV *Denominator)
246 : SE(S), Denominator(Denominator) {
247 Zero = SE.getZero(Ty: Denominator->getType());
248 One = SE.getOne(Ty: Denominator->getType());
249
250 // We generally do not know how to divide Expr by Denominator. We initialize
251 // the division to a "cannot divide" state to simplify the rest of the code.
252 cannotDivide(Numerator);
253}
254
255// Convenience function for giving up on the division. We set the quotient to
256// be equal to zero and the remainder to be equal to the numerator.
257void SCEVDivision::cannotDivide(const SCEV *Numerator) {
258 Quotient = Zero;
259 Remainder = Numerator;
260}
261
262void SCEVDivisionPrinterPass::runImpl(Function &F, ScalarEvolution &SE) {
263 OS << "Printing analysis 'Scalar Evolution Division' for function '"
264 << F.getName() << "':\n";
265 for (Instruction &Inst : instructions(F)) {
266 BinaryOperator *Div = dyn_cast<BinaryOperator>(Val: &Inst);
267 if (!Div || Div->getOpcode() != Instruction::SDiv)
268 continue;
269
270 const SCEV *Numerator = SE.getSCEV(V: Div->getOperand(i_nocapture: 0));
271 const SCEV *Denominator = SE.getSCEV(V: Div->getOperand(i_nocapture: 1));
272 const SCEV *Quotient, *Remainder;
273 SCEVDivision::divide(SE, Numerator, Denominator, Quotient: &Quotient, Remainder: &Remainder);
274
275 OS << "Instruction: " << *Div << "\n";
276 OS.indent(NumSpaces: 2) << "Numerator: " << *Numerator << "\n";
277 OS.indent(NumSpaces: 2) << "Denominator: " << *Denominator << "\n";
278 OS.indent(NumSpaces: 2) << "Quotient: " << *Quotient << "\n";
279 OS.indent(NumSpaces: 2) << "Remainder: " << *Remainder << "\n";
280 }
281}
282
283PreservedAnalyses SCEVDivisionPrinterPass::run(Function &F,
284 FunctionAnalysisManager &AM) {
285 ScalarEvolution &SE = AM.getResult<ScalarEvolutionAnalysis>(IR&: F);
286 runImpl(F, SE);
287 return PreservedAnalyses::all();
288}
289