1 | //===----------------------- AlignmentFromAssumptions.cpp -----------------===// |
2 | // Set Load/Store Alignments From Assumptions |
3 | // |
4 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
5 | // See https://llvm.org/LICENSE.txt for license information. |
6 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
7 | // |
8 | //===----------------------------------------------------------------------===// |
9 | // |
10 | // This file implements a ScalarEvolution-based transformation to set |
11 | // the alignments of load, stores and memory intrinsics based on the truth |
12 | // expressions of assume intrinsics. The primary motivation is to handle |
13 | // complex alignment assumptions that apply to vector loads and stores that |
14 | // appear after vectorization and unrolling. |
15 | // |
16 | //===----------------------------------------------------------------------===// |
17 | |
18 | #include "llvm/Transforms/Scalar/AlignmentFromAssumptions.h" |
19 | #include "llvm/ADT/SmallPtrSet.h" |
20 | #include "llvm/ADT/Statistic.h" |
21 | #include "llvm/Analysis/AliasAnalysis.h" |
22 | #include "llvm/Analysis/AssumptionCache.h" |
23 | #include "llvm/Analysis/GlobalsModRef.h" |
24 | #include "llvm/Analysis/LoopInfo.h" |
25 | #include "llvm/Analysis/ScalarEvolutionExpressions.h" |
26 | #include "llvm/Analysis/ValueTracking.h" |
27 | #include "llvm/IR/Dominators.h" |
28 | #include "llvm/IR/Instruction.h" |
29 | #include "llvm/IR/Instructions.h" |
30 | #include "llvm/IR/IntrinsicInst.h" |
31 | #include "llvm/Support/Debug.h" |
32 | #include "llvm/Support/raw_ostream.h" |
33 | |
34 | #define DEBUG_TYPE "alignment-from-assumptions" |
35 | using namespace llvm; |
36 | |
37 | STATISTIC(NumLoadAlignChanged, |
38 | "Number of loads changed by alignment assumptions" ); |
39 | STATISTIC(NumStoreAlignChanged, |
40 | "Number of stores changed by alignment assumptions" ); |
41 | STATISTIC(NumMemIntAlignChanged, |
42 | "Number of memory intrinsics changed by alignment assumptions" ); |
43 | |
44 | // Given an expression for the (constant) alignment, AlignSCEV, and an |
45 | // expression for the displacement between a pointer and the aligned address, |
46 | // DiffSCEV, compute the alignment of the displaced pointer if it can be reduced |
47 | // to a constant. Using SCEV to compute alignment handles the case where |
48 | // DiffSCEV is a recurrence with constant start such that the aligned offset |
49 | // is constant. e.g. {16,+,32} % 32 -> 16. |
50 | static MaybeAlign getNewAlignmentDiff(const SCEV *DiffSCEV, |
51 | const SCEV *AlignSCEV, |
52 | ScalarEvolution *SE) { |
53 | // DiffUnits = Diff % int64_t(Alignment) |
54 | const SCEV *DiffUnitsSCEV = SE->getURemExpr(LHS: DiffSCEV, RHS: AlignSCEV); |
55 | |
56 | LLVM_DEBUG(dbgs() << "\talignment relative to " << *AlignSCEV << " is " |
57 | << *DiffUnitsSCEV << " (diff: " << *DiffSCEV << ")\n" ); |
58 | |
59 | if (const SCEVConstant *ConstDUSCEV = |
60 | dyn_cast<SCEVConstant>(Val: DiffUnitsSCEV)) { |
61 | int64_t DiffUnits = ConstDUSCEV->getValue()->getSExtValue(); |
62 | |
63 | // If the displacement is an exact multiple of the alignment, then the |
64 | // displaced pointer has the same alignment as the aligned pointer, so |
65 | // return the alignment value. |
66 | if (!DiffUnits) |
67 | return cast<SCEVConstant>(Val: AlignSCEV)->getValue()->getAlignValue(); |
68 | |
69 | // If the displacement is not an exact multiple, but the remainder is a |
70 | // constant, then return this remainder (but only if it is a power of 2). |
71 | uint64_t DiffUnitsAbs = std::abs(i: DiffUnits); |
72 | if (isPowerOf2_64(Value: DiffUnitsAbs)) |
73 | return Align(DiffUnitsAbs); |
74 | } |
75 | |
76 | return std::nullopt; |
77 | } |
78 | |
79 | // There is an address given by an offset OffSCEV from AASCEV which has an |
80 | // alignment AlignSCEV. Use that information, if possible, to compute a new |
81 | // alignment for Ptr. |
82 | static Align getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV, |
83 | const SCEV *OffSCEV, Value *Ptr, |
84 | ScalarEvolution *SE) { |
85 | const SCEV *PtrSCEV = SE->getSCEV(V: Ptr); |
86 | |
87 | const SCEV *DiffSCEV = SE->getMinusSCEV(LHS: PtrSCEV, RHS: AASCEV); |
88 | if (isa<SCEVCouldNotCompute>(Val: DiffSCEV)) |
89 | return Align(1); |
90 | |
91 | // On 32-bit platforms, DiffSCEV might now have type i32 -- we've always |
92 | // sign-extended OffSCEV to i64, so make sure they agree again. |
93 | DiffSCEV = SE->getNoopOrSignExtend(V: DiffSCEV, Ty: OffSCEV->getType()); |
94 | |
95 | // What we really want to know is the overall offset to the aligned |
96 | // address. This address is displaced by the provided offset. |
97 | DiffSCEV = SE->getAddExpr(LHS: DiffSCEV, RHS: OffSCEV); |
98 | |
99 | LLVM_DEBUG(dbgs() << "AFI: alignment of " << *Ptr << " relative to " |
100 | << *AlignSCEV << " and offset " << *OffSCEV |
101 | << " using diff " << *DiffSCEV << "\n" ); |
102 | |
103 | if (MaybeAlign NewAlignment = getNewAlignmentDiff(DiffSCEV, AlignSCEV, SE)) { |
104 | LLVM_DEBUG(dbgs() << "\tnew alignment: " << DebugStr(NewAlignment) << "\n" ); |
105 | return *NewAlignment; |
106 | } |
107 | |
108 | if (const SCEVAddRecExpr *DiffARSCEV = dyn_cast<SCEVAddRecExpr>(Val: DiffSCEV)) { |
109 | // The relative offset to the alignment assumption did not yield a constant, |
110 | // but we should try harder: if we assume that a is 32-byte aligned, then in |
111 | // for (i = 0; i < 1024; i += 4) r += a[i]; not all of the loads from a are |
112 | // 32-byte aligned, but instead alternate between 32 and 16-byte alignment. |
113 | // As a result, the new alignment will not be a constant, but can still |
114 | // be improved over the default (of 4) to 16. |
115 | |
116 | const SCEV *DiffStartSCEV = DiffARSCEV->getStart(); |
117 | const SCEV *DiffIncSCEV = DiffARSCEV->getStepRecurrence(SE&: *SE); |
118 | |
119 | LLVM_DEBUG(dbgs() << "\ttrying start/inc alignment using start " |
120 | << *DiffStartSCEV << " and inc " << *DiffIncSCEV << "\n" ); |
121 | |
122 | // Now compute the new alignment using the displacement to the value in the |
123 | // first iteration, and also the alignment using the per-iteration delta. |
124 | // If these are the same, then use that answer. Otherwise, use the smaller |
125 | // one, but only if it divides the larger one. |
126 | MaybeAlign NewAlignment = getNewAlignmentDiff(DiffSCEV: DiffStartSCEV, AlignSCEV, SE); |
127 | MaybeAlign NewIncAlignment = |
128 | getNewAlignmentDiff(DiffSCEV: DiffIncSCEV, AlignSCEV, SE); |
129 | |
130 | LLVM_DEBUG(dbgs() << "\tnew start alignment: " << DebugStr(NewAlignment) |
131 | << "\n" ); |
132 | LLVM_DEBUG(dbgs() << "\tnew inc alignment: " << DebugStr(NewIncAlignment) |
133 | << "\n" ); |
134 | |
135 | if (!NewAlignment || !NewIncAlignment) |
136 | return Align(1); |
137 | |
138 | const Align NewAlign = *NewAlignment; |
139 | const Align NewIncAlign = *NewIncAlignment; |
140 | if (NewAlign > NewIncAlign) { |
141 | LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " |
142 | << DebugStr(NewIncAlign) << "\n" ); |
143 | return NewIncAlign; |
144 | } |
145 | if (NewIncAlign > NewAlign) { |
146 | LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << DebugStr(NewAlign) |
147 | << "\n" ); |
148 | return NewAlign; |
149 | } |
150 | assert(NewIncAlign == NewAlign); |
151 | LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << DebugStr(NewAlign) |
152 | << "\n" ); |
153 | return NewAlign; |
154 | } |
155 | |
156 | return Align(1); |
157 | } |
158 | |
159 | bool AlignmentFromAssumptionsPass::(CallInst *I, |
160 | unsigned Idx, |
161 | Value *&AAPtr, |
162 | const SCEV *&AlignSCEV, |
163 | const SCEV *&OffSCEV) { |
164 | Type *Int64Ty = Type::getInt64Ty(C&: I->getContext()); |
165 | OperandBundleUse AlignOB = I->getOperandBundleAt(Index: Idx); |
166 | if (AlignOB.getTagName() != "align" ) |
167 | return false; |
168 | assert(AlignOB.Inputs.size() >= 2); |
169 | AAPtr = AlignOB.Inputs[0].get(); |
170 | // TODO: Consider accumulating the offset to the base. |
171 | AAPtr = AAPtr->stripPointerCastsSameRepresentation(); |
172 | AlignSCEV = SE->getSCEV(V: AlignOB.Inputs[1].get()); |
173 | AlignSCEV = SE->getTruncateOrZeroExtend(V: AlignSCEV, Ty: Int64Ty); |
174 | if (!isa<SCEVConstant>(Val: AlignSCEV)) |
175 | // Added to suppress a crash because consumer doesn't expect non-constant |
176 | // alignments in the assume bundle. TODO: Consider generalizing caller. |
177 | return false; |
178 | if (!cast<SCEVConstant>(Val: AlignSCEV)->getAPInt().isPowerOf2()) |
179 | // Only power of two alignments are supported. |
180 | return false; |
181 | if (AlignOB.Inputs.size() == 3) |
182 | OffSCEV = SE->getSCEV(V: AlignOB.Inputs[2].get()); |
183 | else |
184 | OffSCEV = SE->getZero(Ty: Int64Ty); |
185 | OffSCEV = SE->getTruncateOrZeroExtend(V: OffSCEV, Ty: Int64Ty); |
186 | return true; |
187 | } |
188 | |
189 | bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall, |
190 | unsigned Idx) { |
191 | Value *AAPtr; |
192 | const SCEV *AlignSCEV, *OffSCEV; |
193 | if (!extractAlignmentInfo(I: ACall, Idx, AAPtr, AlignSCEV, OffSCEV)) |
194 | return false; |
195 | |
196 | // Skip ConstantPointerNull and UndefValue. Assumptions on these shouldn't |
197 | // affect other users. |
198 | if (isa<ConstantData>(Val: AAPtr)) |
199 | return false; |
200 | |
201 | const SCEV *AASCEV = SE->getSCEV(V: AAPtr); |
202 | |
203 | // Apply the assumption to all other users of the specified pointer. |
204 | SmallPtrSet<Instruction *, 32> Visited; |
205 | SmallVector<Instruction*, 16> WorkList; |
206 | for (User *J : AAPtr->users()) { |
207 | if (J == ACall) |
208 | continue; |
209 | |
210 | if (Instruction *K = dyn_cast<Instruction>(Val: J)) |
211 | WorkList.push_back(Elt: K); |
212 | } |
213 | |
214 | while (!WorkList.empty()) { |
215 | Instruction *J = WorkList.pop_back_val(); |
216 | if (LoadInst *LI = dyn_cast<LoadInst>(Val: J)) { |
217 | if (!isValidAssumeForContext(I: ACall, CxtI: J, DT)) |
218 | continue; |
219 | Align NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, |
220 | Ptr: LI->getPointerOperand(), SE); |
221 | if (NewAlignment > LI->getAlign()) { |
222 | LI->setAlignment(NewAlignment); |
223 | ++NumLoadAlignChanged; |
224 | } |
225 | } else if (StoreInst *SI = dyn_cast<StoreInst>(Val: J)) { |
226 | if (!isValidAssumeForContext(I: ACall, CxtI: J, DT)) |
227 | continue; |
228 | Align NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, |
229 | Ptr: SI->getPointerOperand(), SE); |
230 | if (NewAlignment > SI->getAlign()) { |
231 | SI->setAlignment(NewAlignment); |
232 | ++NumStoreAlignChanged; |
233 | } |
234 | } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(Val: J)) { |
235 | if (!isValidAssumeForContext(I: ACall, CxtI: J, DT)) |
236 | continue; |
237 | Align NewDestAlignment = |
238 | getNewAlignment(AASCEV, AlignSCEV, OffSCEV, Ptr: MI->getDest(), SE); |
239 | |
240 | LLVM_DEBUG(dbgs() << "\tmem inst: " << DebugStr(NewDestAlignment) |
241 | << "\n" ;); |
242 | if (NewDestAlignment > *MI->getDestAlign()) { |
243 | MI->setDestAlignment(NewDestAlignment); |
244 | ++NumMemIntAlignChanged; |
245 | } |
246 | |
247 | // For memory transfers, there is also a source alignment that |
248 | // can be set. |
249 | if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(Val: MI)) { |
250 | Align NewSrcAlignment = |
251 | getNewAlignment(AASCEV, AlignSCEV, OffSCEV, Ptr: MTI->getSource(), SE); |
252 | |
253 | LLVM_DEBUG(dbgs() << "\tmem trans: " << DebugStr(NewSrcAlignment) |
254 | << "\n" ;); |
255 | |
256 | if (NewSrcAlignment > *MTI->getSourceAlign()) { |
257 | MTI->setSourceAlignment(NewSrcAlignment); |
258 | ++NumMemIntAlignChanged; |
259 | } |
260 | } |
261 | } |
262 | |
263 | // Now that we've updated that use of the pointer, look for other uses of |
264 | // the pointer to update. |
265 | Visited.insert(Ptr: J); |
266 | if (isa<GetElementPtrInst>(Val: J) || isa<PHINode>(Val: J)) |
267 | for (auto &U : J->uses()) { |
268 | if (U->getType()->isPointerTy()) { |
269 | Instruction *K = cast<Instruction>(Val: U.getUser()); |
270 | StoreInst *SI = dyn_cast<StoreInst>(Val: K); |
271 | if (SI && SI->getPointerOperandIndex() != U.getOperandNo()) |
272 | continue; |
273 | if (!Visited.count(Ptr: K)) |
274 | WorkList.push_back(Elt: K); |
275 | } |
276 | } |
277 | } |
278 | |
279 | return true; |
280 | } |
281 | |
282 | bool AlignmentFromAssumptionsPass::runImpl(Function &F, AssumptionCache &AC, |
283 | ScalarEvolution *SE_, |
284 | DominatorTree *DT_) { |
285 | SE = SE_; |
286 | DT = DT_; |
287 | |
288 | bool Changed = false; |
289 | for (auto &AssumeVH : AC.assumptions()) |
290 | if (AssumeVH) { |
291 | CallInst *Call = cast<CallInst>(Val&: AssumeVH); |
292 | for (unsigned Idx = 0; Idx < Call->getNumOperandBundles(); Idx++) |
293 | Changed |= processAssumption(ACall: Call, Idx); |
294 | } |
295 | |
296 | return Changed; |
297 | } |
298 | |
299 | PreservedAnalyses |
300 | AlignmentFromAssumptionsPass::run(Function &F, FunctionAnalysisManager &AM) { |
301 | |
302 | AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(IR&: F); |
303 | ScalarEvolution &SE = AM.getResult<ScalarEvolutionAnalysis>(IR&: F); |
304 | DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(IR&: F); |
305 | if (!runImpl(F, AC, SE_: &SE, DT_: &DT)) |
306 | return PreservedAnalyses::all(); |
307 | |
308 | PreservedAnalyses PA; |
309 | PA.preserveSet<CFGAnalyses>(); |
310 | PA.preserve<ScalarEvolutionAnalysis>(); |
311 | return PA; |
312 | } |
313 | |