| 1 | //===----------------------- AlignmentFromAssumptions.cpp -----------------===// |
| 2 | // Set Load/Store Alignments From Assumptions |
| 3 | // |
| 4 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 5 | // See https://llvm.org/LICENSE.txt for license information. |
| 6 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 7 | // |
| 8 | //===----------------------------------------------------------------------===// |
| 9 | // |
| 10 | // This file implements a ScalarEvolution-based transformation to set |
| 11 | // the alignments of load, stores and memory intrinsics based on the truth |
| 12 | // expressions of assume intrinsics. The primary motivation is to handle |
| 13 | // complex alignment assumptions that apply to vector loads and stores that |
| 14 | // appear after vectorization and unrolling. |
| 15 | // |
| 16 | //===----------------------------------------------------------------------===// |
| 17 | |
| 18 | #include "llvm/Transforms/Scalar/AlignmentFromAssumptions.h" |
| 19 | #include "llvm/ADT/SmallPtrSet.h" |
| 20 | #include "llvm/ADT/Statistic.h" |
| 21 | #include "llvm/Analysis/AliasAnalysis.h" |
| 22 | #include "llvm/Analysis/AssumptionCache.h" |
| 23 | #include "llvm/Analysis/GlobalsModRef.h" |
| 24 | #include "llvm/Analysis/LoopInfo.h" |
| 25 | #include "llvm/Analysis/ScalarEvolutionExpressions.h" |
| 26 | #include "llvm/Analysis/ValueTracking.h" |
| 27 | #include "llvm/IR/Dominators.h" |
| 28 | #include "llvm/IR/Instruction.h" |
| 29 | #include "llvm/IR/Instructions.h" |
| 30 | #include "llvm/IR/IntrinsicInst.h" |
| 31 | #include "llvm/Support/Debug.h" |
| 32 | #include "llvm/Support/raw_ostream.h" |
| 33 | |
| 34 | #define DEBUG_TYPE "alignment-from-assumptions" |
| 35 | using namespace llvm; |
| 36 | |
| 37 | STATISTIC(NumLoadAlignChanged, |
| 38 | "Number of loads changed by alignment assumptions" ); |
| 39 | STATISTIC(NumStoreAlignChanged, |
| 40 | "Number of stores changed by alignment assumptions" ); |
| 41 | STATISTIC(NumMemIntAlignChanged, |
| 42 | "Number of memory intrinsics changed by alignment assumptions" ); |
| 43 | |
| 44 | // Given an expression for the (constant) alignment, AlignSCEV, and an |
| 45 | // expression for the displacement between a pointer and the aligned address, |
| 46 | // DiffSCEV, compute the alignment of the displaced pointer if it can be reduced |
| 47 | // to a constant. Using SCEV to compute alignment handles the case where |
| 48 | // DiffSCEV is a recurrence with constant start such that the aligned offset |
| 49 | // is constant. e.g. {16,+,32} % 32 -> 16. |
| 50 | static MaybeAlign getNewAlignmentDiff(const SCEV *DiffSCEV, |
| 51 | const SCEV *AlignSCEV, |
| 52 | ScalarEvolution *SE) { |
| 53 | // DiffUnits = Diff % int64_t(Alignment) |
| 54 | const SCEV *DiffUnitsSCEV = SE->getURemExpr(LHS: DiffSCEV, RHS: AlignSCEV); |
| 55 | |
| 56 | LLVM_DEBUG(dbgs() << "\talignment relative to " << *AlignSCEV << " is " |
| 57 | << *DiffUnitsSCEV << " (diff: " << *DiffSCEV << ")\n" ); |
| 58 | |
| 59 | if (const SCEVConstant *ConstDUSCEV = |
| 60 | dyn_cast<SCEVConstant>(Val: DiffUnitsSCEV)) { |
| 61 | int64_t DiffUnits = ConstDUSCEV->getValue()->getSExtValue(); |
| 62 | |
| 63 | // If the displacement is an exact multiple of the alignment, then the |
| 64 | // displaced pointer has the same alignment as the aligned pointer, so |
| 65 | // return the alignment value. |
| 66 | if (!DiffUnits) |
| 67 | return cast<SCEVConstant>(Val: AlignSCEV)->getValue()->getAlignValue(); |
| 68 | |
| 69 | // If the displacement is not an exact multiple, but the remainder is a |
| 70 | // constant, then return this remainder (but only if it is a power of 2). |
| 71 | uint64_t DiffUnitsAbs = std::abs(i: DiffUnits); |
| 72 | if (isPowerOf2_64(Value: DiffUnitsAbs)) |
| 73 | return Align(DiffUnitsAbs); |
| 74 | } |
| 75 | |
| 76 | return std::nullopt; |
| 77 | } |
| 78 | |
| 79 | // There is an address given by an offset OffSCEV from AASCEV which has an |
| 80 | // alignment AlignSCEV. Use that information, if possible, to compute a new |
| 81 | // alignment for Ptr. |
| 82 | static Align getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV, |
| 83 | const SCEV *OffSCEV, Value *Ptr, |
| 84 | ScalarEvolution *SE) { |
| 85 | const SCEV *PtrSCEV = SE->getSCEV(V: Ptr); |
| 86 | |
| 87 | const SCEV *DiffSCEV = SE->getMinusSCEV(LHS: PtrSCEV, RHS: AASCEV); |
| 88 | if (isa<SCEVCouldNotCompute>(Val: DiffSCEV)) |
| 89 | return Align(1); |
| 90 | |
| 91 | // On 32-bit platforms, DiffSCEV might now have type i32 -- we've always |
| 92 | // sign-extended OffSCEV to i64, so make sure they agree again. |
| 93 | DiffSCEV = SE->getNoopOrSignExtend(V: DiffSCEV, Ty: OffSCEV->getType()); |
| 94 | |
| 95 | // What we really want to know is the overall offset to the aligned |
| 96 | // address. This address is displaced by the provided offset. |
| 97 | DiffSCEV = SE->getAddExpr(LHS: DiffSCEV, RHS: OffSCEV); |
| 98 | |
| 99 | LLVM_DEBUG(dbgs() << "AFI: alignment of " << *Ptr << " relative to " |
| 100 | << *AlignSCEV << " and offset " << *OffSCEV |
| 101 | << " using diff " << *DiffSCEV << "\n" ); |
| 102 | |
| 103 | if (MaybeAlign NewAlignment = getNewAlignmentDiff(DiffSCEV, AlignSCEV, SE)) { |
| 104 | LLVM_DEBUG(dbgs() << "\tnew alignment: " << DebugStr(NewAlignment) << "\n" ); |
| 105 | return *NewAlignment; |
| 106 | } |
| 107 | |
| 108 | if (const SCEVAddRecExpr *DiffARSCEV = dyn_cast<SCEVAddRecExpr>(Val: DiffSCEV)) { |
| 109 | // The relative offset to the alignment assumption did not yield a constant, |
| 110 | // but we should try harder: if we assume that a is 32-byte aligned, then in |
| 111 | // for (i = 0; i < 1024; i += 4) r += a[i]; not all of the loads from a are |
| 112 | // 32-byte aligned, but instead alternate between 32 and 16-byte alignment. |
| 113 | // As a result, the new alignment will not be a constant, but can still |
| 114 | // be improved over the default (of 4) to 16. |
| 115 | |
| 116 | const SCEV *DiffStartSCEV = DiffARSCEV->getStart(); |
| 117 | const SCEV *DiffIncSCEV = DiffARSCEV->getStepRecurrence(SE&: *SE); |
| 118 | |
| 119 | LLVM_DEBUG(dbgs() << "\ttrying start/inc alignment using start " |
| 120 | << *DiffStartSCEV << " and inc " << *DiffIncSCEV << "\n" ); |
| 121 | |
| 122 | // Now compute the new alignment using the displacement to the value in the |
| 123 | // first iteration, and also the alignment using the per-iteration delta. |
| 124 | // If these are the same, then use that answer. Otherwise, use the smaller |
| 125 | // one, but only if it divides the larger one. |
| 126 | MaybeAlign NewAlignment = getNewAlignmentDiff(DiffSCEV: DiffStartSCEV, AlignSCEV, SE); |
| 127 | MaybeAlign NewIncAlignment = |
| 128 | getNewAlignmentDiff(DiffSCEV: DiffIncSCEV, AlignSCEV, SE); |
| 129 | |
| 130 | LLVM_DEBUG(dbgs() << "\tnew start alignment: " << DebugStr(NewAlignment) |
| 131 | << "\n" ); |
| 132 | LLVM_DEBUG(dbgs() << "\tnew inc alignment: " << DebugStr(NewIncAlignment) |
| 133 | << "\n" ); |
| 134 | |
| 135 | if (!NewAlignment || !NewIncAlignment) |
| 136 | return Align(1); |
| 137 | |
| 138 | const Align NewAlign = *NewAlignment; |
| 139 | const Align NewIncAlign = *NewIncAlignment; |
| 140 | if (NewAlign > NewIncAlign) { |
| 141 | LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " |
| 142 | << DebugStr(NewIncAlign) << "\n" ); |
| 143 | return NewIncAlign; |
| 144 | } |
| 145 | if (NewIncAlign > NewAlign) { |
| 146 | LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << DebugStr(NewAlign) |
| 147 | << "\n" ); |
| 148 | return NewAlign; |
| 149 | } |
| 150 | assert(NewIncAlign == NewAlign); |
| 151 | LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << DebugStr(NewAlign) |
| 152 | << "\n" ); |
| 153 | return NewAlign; |
| 154 | } |
| 155 | |
| 156 | return Align(1); |
| 157 | } |
| 158 | |
| 159 | bool AlignmentFromAssumptionsPass::(CallInst *I, |
| 160 | unsigned Idx, |
| 161 | Value *&AAPtr, |
| 162 | const SCEV *&AlignSCEV, |
| 163 | const SCEV *&OffSCEV) { |
| 164 | Type *Int64Ty = Type::getInt64Ty(C&: I->getContext()); |
| 165 | OperandBundleUse AlignOB = I->getOperandBundleAt(Index: Idx); |
| 166 | if (AlignOB.getTagName() != "align" ) |
| 167 | return false; |
| 168 | assert(AlignOB.Inputs.size() >= 2); |
| 169 | AAPtr = AlignOB.Inputs[0].get(); |
| 170 | // TODO: Consider accumulating the offset to the base. |
| 171 | AAPtr = AAPtr->stripPointerCastsSameRepresentation(); |
| 172 | AlignSCEV = SE->getSCEV(V: AlignOB.Inputs[1].get()); |
| 173 | AlignSCEV = SE->getTruncateOrZeroExtend(V: AlignSCEV, Ty: Int64Ty); |
| 174 | if (!isa<SCEVConstant>(Val: AlignSCEV)) |
| 175 | // Added to suppress a crash because consumer doesn't expect non-constant |
| 176 | // alignments in the assume bundle. TODO: Consider generalizing caller. |
| 177 | return false; |
| 178 | if (!cast<SCEVConstant>(Val: AlignSCEV)->getAPInt().isPowerOf2()) |
| 179 | // Only power of two alignments are supported. |
| 180 | return false; |
| 181 | if (AlignOB.Inputs.size() == 3) |
| 182 | OffSCEV = SE->getSCEV(V: AlignOB.Inputs[2].get()); |
| 183 | else |
| 184 | OffSCEV = SE->getZero(Ty: Int64Ty); |
| 185 | OffSCEV = SE->getTruncateOrZeroExtend(V: OffSCEV, Ty: Int64Ty); |
| 186 | return true; |
| 187 | } |
| 188 | |
| 189 | bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall, |
| 190 | unsigned Idx) { |
| 191 | Value *AAPtr; |
| 192 | const SCEV *AlignSCEV, *OffSCEV; |
| 193 | if (!extractAlignmentInfo(I: ACall, Idx, AAPtr, AlignSCEV, OffSCEV)) |
| 194 | return false; |
| 195 | |
| 196 | // Skip ConstantPointerNull and UndefValue. Assumptions on these shouldn't |
| 197 | // affect other users. |
| 198 | if (isa<ConstantData>(Val: AAPtr)) |
| 199 | return false; |
| 200 | |
| 201 | const SCEV *AASCEV = SE->getSCEV(V: AAPtr); |
| 202 | |
| 203 | // Apply the assumption to all other users of the specified pointer. |
| 204 | SmallPtrSet<Instruction *, 32> Visited; |
| 205 | SmallVector<Instruction*, 16> WorkList; |
| 206 | for (User *J : AAPtr->users()) { |
| 207 | if (J == ACall) |
| 208 | continue; |
| 209 | |
| 210 | if (Instruction *K = dyn_cast<Instruction>(Val: J)) |
| 211 | if (K->getFunction() == ACall->getFunction()) |
| 212 | WorkList.push_back(Elt: K); |
| 213 | } |
| 214 | |
| 215 | while (!WorkList.empty()) { |
| 216 | Instruction *J = WorkList.pop_back_val(); |
| 217 | if (LoadInst *LI = dyn_cast<LoadInst>(Val: J)) { |
| 218 | if (!isValidAssumeForContext(I: ACall, CxtI: J, DT)) |
| 219 | continue; |
| 220 | Align NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, |
| 221 | Ptr: LI->getPointerOperand(), SE); |
| 222 | if (NewAlignment > LI->getAlign()) { |
| 223 | LI->setAlignment(NewAlignment); |
| 224 | ++NumLoadAlignChanged; |
| 225 | } |
| 226 | } else if (StoreInst *SI = dyn_cast<StoreInst>(Val: J)) { |
| 227 | if (!isValidAssumeForContext(I: ACall, CxtI: J, DT)) |
| 228 | continue; |
| 229 | Align NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, |
| 230 | Ptr: SI->getPointerOperand(), SE); |
| 231 | if (NewAlignment > SI->getAlign()) { |
| 232 | SI->setAlignment(NewAlignment); |
| 233 | ++NumStoreAlignChanged; |
| 234 | } |
| 235 | } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(Val: J)) { |
| 236 | if (!isValidAssumeForContext(I: ACall, CxtI: J, DT)) |
| 237 | continue; |
| 238 | Align NewDestAlignment = |
| 239 | getNewAlignment(AASCEV, AlignSCEV, OffSCEV, Ptr: MI->getDest(), SE); |
| 240 | |
| 241 | LLVM_DEBUG(dbgs() << "\tmem inst: " << DebugStr(NewDestAlignment) |
| 242 | << "\n" ;); |
| 243 | if (NewDestAlignment > *MI->getDestAlign()) { |
| 244 | MI->setDestAlignment(NewDestAlignment); |
| 245 | ++NumMemIntAlignChanged; |
| 246 | } |
| 247 | |
| 248 | // For memory transfers, there is also a source alignment that |
| 249 | // can be set. |
| 250 | if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(Val: MI)) { |
| 251 | Align NewSrcAlignment = |
| 252 | getNewAlignment(AASCEV, AlignSCEV, OffSCEV, Ptr: MTI->getSource(), SE); |
| 253 | |
| 254 | LLVM_DEBUG(dbgs() << "\tmem trans: " << DebugStr(NewSrcAlignment) |
| 255 | << "\n" ;); |
| 256 | |
| 257 | if (NewSrcAlignment > *MTI->getSourceAlign()) { |
| 258 | MTI->setSourceAlignment(NewSrcAlignment); |
| 259 | ++NumMemIntAlignChanged; |
| 260 | } |
| 261 | } |
| 262 | } |
| 263 | |
| 264 | // Now that we've updated that use of the pointer, look for other uses of |
| 265 | // the pointer to update. |
| 266 | Visited.insert(Ptr: J); |
| 267 | if (isa<GetElementPtrInst>(Val: J) || isa<PHINode>(Val: J)) |
| 268 | for (auto &U : J->uses()) { |
| 269 | if (U->getType()->isPointerTy()) { |
| 270 | Instruction *K = cast<Instruction>(Val: U.getUser()); |
| 271 | StoreInst *SI = dyn_cast<StoreInst>(Val: K); |
| 272 | if (SI && SI->getPointerOperandIndex() != U.getOperandNo()) |
| 273 | continue; |
| 274 | if (!Visited.count(Ptr: K)) |
| 275 | WorkList.push_back(Elt: K); |
| 276 | } |
| 277 | } |
| 278 | } |
| 279 | |
| 280 | return true; |
| 281 | } |
| 282 | |
| 283 | bool AlignmentFromAssumptionsPass::runImpl(Function &F, AssumptionCache &AC, |
| 284 | ScalarEvolution *SE_, |
| 285 | DominatorTree *DT_) { |
| 286 | SE = SE_; |
| 287 | DT = DT_; |
| 288 | |
| 289 | bool Changed = false; |
| 290 | for (auto &AssumeVH : AC.assumptions()) |
| 291 | if (AssumeVH) { |
| 292 | CallInst *Call = cast<CallInst>(Val&: AssumeVH); |
| 293 | for (unsigned Idx = 0; Idx < Call->getNumOperandBundles(); Idx++) |
| 294 | Changed |= processAssumption(ACall: Call, Idx); |
| 295 | } |
| 296 | |
| 297 | return Changed; |
| 298 | } |
| 299 | |
| 300 | PreservedAnalyses |
| 301 | AlignmentFromAssumptionsPass::run(Function &F, FunctionAnalysisManager &AM) { |
| 302 | |
| 303 | AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(IR&: F); |
| 304 | ScalarEvolution &SE = AM.getResult<ScalarEvolutionAnalysis>(IR&: F); |
| 305 | DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(IR&: F); |
| 306 | if (!runImpl(F, AC, SE_: &SE, DT_: &DT)) |
| 307 | return PreservedAnalyses::all(); |
| 308 | |
| 309 | PreservedAnalyses PA; |
| 310 | PA.preserveSet<CFGAnalyses>(); |
| 311 | PA.preserve<ScalarEvolutionAnalysis>(); |
| 312 | return PA; |
| 313 | } |
| 314 | |