1//===- InferAlignment.cpp -------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Infer alignment for load, stores and other memory operations based on
10// trailing zero known bits information.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/Transforms/Scalar/InferAlignment.h"
15#include "llvm/ADT/APInt.h"
16#include "llvm/ADT/STLFunctionalExtras.h"
17#include "llvm/ADT/ScopedHashTable.h"
18#include "llvm/Analysis/AssumptionCache.h"
19#include "llvm/Analysis/ValueTracking.h"
20#include "llvm/IR/Instruction.h"
21#include "llvm/IR/Instructions.h"
22#include "llvm/IR/IntrinsicInst.h"
23#include "llvm/IR/PatternMatch.h"
24#include "llvm/Support/KnownBits.h"
25#include "llvm/Transforms/Scalar.h"
26#include "llvm/Transforms/Utils/Local.h"
27
28using namespace llvm;
29using namespace llvm::PatternMatch;
30
31static bool tryToImproveAlign(
32 const DataLayout &DL, Instruction *I,
33 function_ref<Align(Value *PtrOp, Align OldAlign, Align PrefAlign)> Fn) {
34
35 if (auto *PtrOp = getLoadStorePointerOperand(V: I)) {
36 Align OldAlign = getLoadStoreAlignment(I);
37 Align PrefAlign = DL.getPrefTypeAlign(Ty: getLoadStoreType(I));
38
39 Align NewAlign = Fn(PtrOp, OldAlign, PrefAlign);
40 if (NewAlign > OldAlign) {
41 setLoadStoreAlignment(I, NewAlign);
42 return true;
43 }
44 }
45
46 Value *PtrOp;
47 const APInt *Const;
48 if (match(V: I, P: m_And(L: m_PtrToIntOrAddr(Op: m_Value(V&: PtrOp)), R: m_APInt(Res&: Const)))) {
49 Align ActualAlign = Fn(PtrOp, Align(1), Align(1));
50 if (Const->ult(RHS: ActualAlign.value())) {
51 I->replaceAllUsesWith(V: Constant::getNullValue(Ty: I->getType()));
52 return true;
53 }
54 if (Const->uge(
55 RHS: APInt::getBitsSetFrom(numBits: Const->getBitWidth(), loBit: Log2(A: ActualAlign)))) {
56 I->replaceAllUsesWith(V: I->getOperand(i: 0));
57 return true;
58 }
59 }
60 if (match(V: I, P: m_Trunc(Op: m_PtrToIntOrAddr(Op: m_Value(V&: PtrOp))))) {
61 Align ActualAlign = Fn(PtrOp, Align(1), Align(1));
62 if (Log2(A: ActualAlign) >= I->getType()->getScalarSizeInBits()) {
63 I->replaceAllUsesWith(V: Constant::getNullValue(Ty: I->getType()));
64 return true;
65 }
66 }
67
68 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: I);
69 if (!II)
70 return false;
71
72 // TODO: Handle more memory intrinsics.
73 switch (II->getIntrinsicID()) {
74 case Intrinsic::masked_load:
75 case Intrinsic::masked_store: {
76 unsigned PtrOpIdx = II->getIntrinsicID() == Intrinsic::masked_load ? 0 : 1;
77 Value *PtrOp = II->getArgOperand(i: PtrOpIdx);
78 Type *Type = II->getIntrinsicID() == Intrinsic::masked_load
79 ? II->getType()
80 : II->getArgOperand(i: 0)->getType();
81
82 Align OldAlign = II->getParamAlign(ArgNo: PtrOpIdx).valueOrOne();
83 Align PrefAlign = DL.getPrefTypeAlign(Ty: Type);
84 Align NewAlign = Fn(PtrOp, OldAlign, PrefAlign);
85 if (NewAlign <= OldAlign)
86 return false;
87
88 II->addParamAttr(ArgNo: PtrOpIdx,
89 Attr: Attribute::getWithAlignment(Context&: II->getContext(), Alignment: NewAlign));
90 return true;
91 }
92 default:
93 return false;
94 }
95}
96
97using ScopedHT =
98 ScopedHashTable<Value *, Align, DenseMapInfo<Value *>, BumpPtrAllocator>;
99struct AlignmentScope {
100 // If BB is nullptr, the BB is processed.
101 BasicBlock *BB;
102 DomTreeNode::const_iterator Iter;
103 DomTreeNode::const_iterator End;
104 ScopedHT::ScopeTy Scope;
105
106 AlignmentScope(DomTreeNode *N, ScopedHT &Table)
107 : BB(N->getBlock()), Iter(N->begin()), End(N->end()), Scope(Table) {}
108};
109
110bool inferAlignment(Function &F, AssumptionCache &AC, DominatorTree &DT) {
111 const DataLayout &DL = F.getDataLayout();
112 bool Changed = false;
113
114 // Enforce preferred type alignment if possible. We do this as a separate
115 // pass first, because it may improve the alignments we infer below.
116 for (BasicBlock &BB : F) {
117 for (Instruction &I : BB) {
118 Changed |= tryToImproveAlign(
119 DL, I: &I, Fn: [&](Value *PtrOp, Align OldAlign, Align PrefAlign) {
120 if (PrefAlign > OldAlign)
121 return std::max(a: OldAlign,
122 b: tryEnforceAlignment(V: PtrOp, PrefAlign, DL));
123 return OldAlign;
124 });
125 }
126 }
127
128 // Compute alignment from known bits.
129 auto InferFromKnownBits = [&](Instruction &I, Value *PtrOp) {
130 KnownBits Known = computeKnownBits(V: PtrOp, DL, AC: &AC, CxtI: &I, DT: &DT);
131 unsigned TrailZ =
132 std::min(a: Known.countMinTrailingZeros(), b: +Value::MaxAlignmentExponent);
133 return Align(1ull << std::min(a: Known.getBitWidth() - 1, b: TrailZ));
134 };
135
136 // Propagate alignment between loads and stores that originate from the
137 // same base pointer.
138 ScopedHT BestBasePointerAligns;
139 auto InferFromBasePointer = [&](Value *PtrOp, Align LoadStoreAlign) {
140 APInt OffsetFromBase(DL.getIndexTypeSizeInBits(Ty: PtrOp->getType()), 0);
141 PtrOp = PtrOp->stripAndAccumulateConstantOffsets(DL, Offset&: OffsetFromBase, AllowNonInbounds: true);
142 // Derive the base pointer alignment from the load/store alignment
143 // and the offset from the base pointer.
144 Align BasePointerAlign =
145 commonAlignment(A: LoadStoreAlign, Offset: OffsetFromBase.getLimitedValue());
146
147 if (auto BestAlign = BestBasePointerAligns.lookup(Key: PtrOp);
148 BestAlign != Align()) {
149 // If the stored base pointer alignment is better than the
150 // base pointer alignment we derived, we may be able to use it
151 // to improve the load/store alignment. If not, store the
152 // improved base pointer alignment for future iterations.
153 if (BestAlign > BasePointerAlign) {
154 Align BetterLoadStoreAlign =
155 commonAlignment(A: BestAlign, Offset: OffsetFromBase.getLimitedValue());
156 return BetterLoadStoreAlign;
157 }
158 }
159
160 BestBasePointerAligns.insert(Key: PtrOp, Val: BasePointerAlign);
161 return LoadStoreAlign;
162 };
163
164 // AlignmentScope is unmovable.
165 std::list<AlignmentScope> Stack;
166 Stack.emplace_back(args: DT.getRootNode(), args&: BestBasePointerAligns);
167 while (!Stack.empty()) {
168 AlignmentScope &Top = Stack.back();
169 if (Top.BB) {
170 for (Instruction &I : *Top.BB) {
171 Changed |= tryToImproveAlign(
172 DL, I: &I, Fn: [&](Value *PtrOp, Align OldAlign, Align PrefAlign) {
173 return std::max(a: InferFromKnownBits(I, PtrOp),
174 b: InferFromBasePointer(PtrOp, OldAlign));
175 });
176 }
177 Top.BB = nullptr;
178 }
179
180 if (Top.Iter != Top.End)
181 Stack.emplace_back(args: *Top.Iter++, args&: BestBasePointerAligns);
182 else
183 Stack.pop_back();
184 }
185
186 return Changed;
187}
188
189PreservedAnalyses InferAlignmentPass::run(Function &F,
190 FunctionAnalysisManager &AM) {
191 AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(IR&: F);
192 DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(IR&: F);
193 inferAlignment(F, AC, DT);
194 // Changes to alignment shouldn't invalidated analyses.
195 return PreservedAnalyses::all();
196}
197