1 | //===- ScalarEvolutionNormalization.cpp - See below -----------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements utilities for working with "normalized" expressions. |
10 | // See the comments at the top of ScalarEvolutionNormalization.h for details. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "llvm/Analysis/ScalarEvolutionNormalization.h" |
15 | #include "llvm/Analysis/LoopInfo.h" |
16 | #include "llvm/Analysis/ScalarEvolution.h" |
17 | #include "llvm/Analysis/ScalarEvolutionExpressions.h" |
18 | using namespace llvm; |
19 | |
20 | /// TransformKind - Different types of transformations that |
21 | /// TransformForPostIncUse can do. |
22 | enum TransformKind { |
23 | /// Normalize - Normalize according to the given loops. |
24 | Normalize, |
25 | /// Denormalize - Perform the inverse transform on the expression with the |
26 | /// given loop set. |
27 | Denormalize |
28 | }; |
29 | |
30 | namespace { |
31 | struct NormalizeDenormalizeRewriter |
32 | : public SCEVRewriteVisitor<NormalizeDenormalizeRewriter> { |
33 | const TransformKind Kind; |
34 | |
35 | // NB! Pred is a function_ref. Storing it here is okay only because |
36 | // we're careful about the lifetime of NormalizeDenormalizeRewriter. |
37 | const NormalizePredTy Pred; |
38 | |
39 | NormalizeDenormalizeRewriter(TransformKind Kind, NormalizePredTy Pred, |
40 | ScalarEvolution &SE) |
41 | : SCEVRewriteVisitor<NormalizeDenormalizeRewriter>(SE), Kind(Kind), |
42 | Pred(Pred) {} |
43 | const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr); |
44 | }; |
45 | } // namespace |
46 | |
47 | const SCEV * |
48 | NormalizeDenormalizeRewriter::visitAddRecExpr(const SCEVAddRecExpr *AR) { |
49 | SmallVector<const SCEV *, 8> Operands; |
50 | |
51 | transform(Range: AR->operands(), d_first: std::back_inserter(x&: Operands), |
52 | F: [&](const SCEV *Op) { return visit(S: Op); }); |
53 | |
54 | if (!Pred(AR)) |
55 | return SE.getAddRecExpr(Operands, L: AR->getLoop(), Flags: SCEV::FlagAnyWrap); |
56 | |
57 | // Normalization and denormalization are fancy names for decrementing and |
58 | // incrementing a SCEV expression with respect to a set of loops. Since |
59 | // Pred(AR) has returned true, we know we need to normalize or denormalize AR |
60 | // with respect to its loop. |
61 | |
62 | if (Kind == Denormalize) { |
63 | // Denormalization / "partial increment" is essentially the same as \c |
64 | // SCEVAddRecExpr::getPostIncExpr. Here we use an explicit loop to make the |
65 | // symmetry with Normalization clear. |
66 | for (int i = 0, e = Operands.size() - 1; i < e; i++) |
67 | Operands[i] = SE.getAddExpr(LHS: Operands[i], RHS: Operands[i + 1]); |
68 | } else { |
69 | assert(Kind == Normalize && "Only two possibilities!" ); |
70 | |
71 | // Normalization / "partial decrement" is a bit more subtle. Since |
72 | // incrementing a SCEV expression (in general) changes the step of the SCEV |
73 | // expression as well, we cannot use the step of the current expression. |
74 | // Instead, we have to use the step of the very expression we're trying to |
75 | // compute! |
76 | // |
77 | // We solve the issue by recursively building up the result, starting from |
78 | // the "least significant" operand in the add recurrence: |
79 | // |
80 | // Base case: |
81 | // Single operand add recurrence. It's its own normalization. |
82 | // |
83 | // N-operand case: |
84 | // {S_{N-1},+,S_{N-2},+,...,+,S_0} = S |
85 | // |
86 | // Since the step recurrence of S is {S_{N-2},+,...,+,S_0}, we know its |
87 | // normalization by induction. We subtract the normalized step |
88 | // recurrence from S_{N-1} to get the normalization of S. |
89 | |
90 | for (int i = Operands.size() - 2; i >= 0; i--) |
91 | Operands[i] = SE.getMinusSCEV(LHS: Operands[i], RHS: Operands[i + 1]); |
92 | } |
93 | |
94 | return SE.getAddRecExpr(Operands, L: AR->getLoop(), Flags: SCEV::FlagAnyWrap); |
95 | } |
96 | |
97 | const SCEV *llvm::normalizeForPostIncUse(const SCEV *S, |
98 | const PostIncLoopSet &Loops, |
99 | ScalarEvolution &SE, |
100 | bool CheckInvertible) { |
101 | if (Loops.empty()) |
102 | return S; |
103 | auto Pred = [&](const SCEVAddRecExpr *AR) { |
104 | return Loops.count(Ptr: AR->getLoop()); |
105 | }; |
106 | const SCEV *Normalized = |
107 | NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S); |
108 | const SCEV *Denormalized = denormalizeForPostIncUse(S: Normalized, Loops, SE); |
109 | // If the normalized expression isn't invertible. |
110 | if (CheckInvertible && Denormalized != S) |
111 | return nullptr; |
112 | return Normalized; |
113 | } |
114 | |
115 | const SCEV *llvm::normalizeForPostIncUseIf(const SCEV *S, NormalizePredTy Pred, |
116 | ScalarEvolution &SE) { |
117 | return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S); |
118 | } |
119 | |
120 | const SCEV *llvm::denormalizeForPostIncUse(const SCEV *S, |
121 | const PostIncLoopSet &Loops, |
122 | ScalarEvolution &SE) { |
123 | if (Loops.empty()) |
124 | return S; |
125 | auto Pred = [&](const SCEVAddRecExpr *AR) { |
126 | return Loops.count(Ptr: AR->getLoop()); |
127 | }; |
128 | return NormalizeDenormalizeRewriter(Denormalize, Pred, SE).visit(S); |
129 | } |
130 | |