1//===- AssumptionCache.cpp - Cache finding @llvm.assume calls -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains a pass that keeps track of @llvm.assume intrinsics in
10// the functions of a module.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/Analysis/AssumptionCache.h"
15#include "llvm/ADT/STLExtras.h"
16#include "llvm/ADT/SmallPtrSet.h"
17#include "llvm/ADT/SmallVector.h"
18#include "llvm/Analysis/AssumeBundleQueries.h"
19#include "llvm/Analysis/TargetTransformInfo.h"
20#include "llvm/Analysis/ValueTracking.h"
21#include "llvm/IR/BasicBlock.h"
22#include "llvm/IR/Function.h"
23#include "llvm/IR/InstrTypes.h"
24#include "llvm/IR/Instruction.h"
25#include "llvm/IR/Instructions.h"
26#include "llvm/IR/PassManager.h"
27#include "llvm/IR/PatternMatch.h"
28#include "llvm/InitializePasses.h"
29#include "llvm/Pass.h"
30#include "llvm/Support/Casting.h"
31#include "llvm/Support/CommandLine.h"
32#include "llvm/Support/ErrorHandling.h"
33#include "llvm/Support/raw_ostream.h"
34#include <cassert>
35
36using namespace llvm;
37using namespace llvm::PatternMatch;
38
39static cl::opt<bool>
40 VerifyAssumptionCache("verify-assumption-cache", cl::Hidden,
41 cl::desc("Enable verification of assumption cache"),
42 cl::init(Val: false));
43
44SmallVector<AssumptionCache::ResultElem, 1> &
45AssumptionCache::getOrInsertAffectedValues(Value *V) {
46 // Try using find_as first to avoid creating extra value handles just for the
47 // purpose of doing the lookup.
48 auto AVI = AffectedValues.find_as(Val: V);
49 if (AVI != AffectedValues.end())
50 return AVI->second;
51
52 return AffectedValues[AffectedValueCallbackVH(V, this)];
53}
54
55void AssumptionCache::findValuesAffectedByOperandBundle(
56 OperandBundleUse Bundle, function_ref<void(Value *)> InsertAffected) {
57 auto AddAffectedVal = [&](Value *V) {
58 if (isa<Argument, GlobalValue, Instruction>(Val: V))
59 InsertAffected(V);
60 };
61
62 if (Bundle.getTagName() == "separate_storage") {
63 assert(Bundle.Inputs.size() == 2 && "separate_storage must have two args");
64 AddAffectedVal(getUnderlyingObject(V: Bundle.Inputs[0]));
65 AddAffectedVal(getUnderlyingObject(V: Bundle.Inputs[1]));
66 } else if (Bundle.Inputs.size() > ABA_WasOn &&
67 Bundle.getTagName() != IgnoreBundleTag)
68 AddAffectedVal(Bundle.Inputs[ABA_WasOn]);
69}
70
71static void
72findAffectedValues(CallBase *CI, TargetTransformInfo *TTI,
73 SmallVectorImpl<AssumptionCache::ResultElem> &Affected) {
74 // Note: This code must be kept in-sync with the code in
75 // computeKnownBitsFromAssume in ValueTracking.
76
77 auto InsertAffected = [&Affected](Value *V) {
78 Affected.push_back(Elt: {.Assume: V, .Index: AssumptionCache::ExprResultIdx});
79 };
80
81 auto AddAffectedVal = [&Affected](Value *V, unsigned Idx) {
82 if (isa<Argument>(Val: V) || isa<GlobalValue>(Val: V) || isa<Instruction>(Val: V)) {
83 Affected.push_back(Elt: {.Assume: V, .Index: Idx});
84 }
85 };
86
87 for (unsigned Idx = 0; Idx != CI->getNumOperandBundles(); Idx++)
88 AssumptionCache::findValuesAffectedByOperandBundle(
89 Bundle: CI->getOperandBundleAt(Index: Idx),
90 InsertAffected: [&](Value *V) { Affected.push_back(Elt: {.Assume: V, .Index: Idx}); });
91
92 Value *Cond = CI->getArgOperand(i: 0);
93 findValuesAffectedByCondition(Cond, /*IsAssume=*/true, InsertAffected);
94
95 if (TTI) {
96 const Value *Ptr;
97 unsigned AS;
98 std::tie(args&: Ptr, args&: AS) = TTI->getPredicatedAddrSpace(V: Cond);
99 if (Ptr)
100 AddAffectedVal(const_cast<Value *>(Ptr->stripInBoundsOffsets()),
101 AssumptionCache::ExprResultIdx);
102 }
103}
104
105void AssumptionCache::updateAffectedValues(AssumeInst *CI) {
106 SmallVector<AssumptionCache::ResultElem, 16> Affected;
107 findAffectedValues(CI, TTI, Affected);
108
109 for (auto &AV : Affected) {
110 auto &AVV = getOrInsertAffectedValues(V: AV.Assume);
111 if (llvm::none_of(Range&: AVV, P: [&](ResultElem &Elem) {
112 return Elem.Assume == CI && Elem.Index == AV.Index;
113 }))
114 AVV.push_back(Elt: {.Assume: CI, .Index: AV.Index});
115 }
116}
117
118void AssumptionCache::unregisterAssumption(AssumeInst *CI) {
119 SmallVector<AssumptionCache::ResultElem, 16> Affected;
120 findAffectedValues(CI, TTI, Affected);
121
122 for (auto &AV : Affected) {
123 auto AVI = AffectedValues.find_as(Val: AV.Assume);
124 if (AVI == AffectedValues.end())
125 continue;
126 bool Found = false;
127 bool HasNonnull = false;
128 for (ResultElem &Elem : AVI->second) {
129 if (Elem.Assume == CI) {
130 Found = true;
131 Elem.Assume = nullptr;
132 }
133 HasNonnull |= !!Elem.Assume;
134 if (HasNonnull && Found)
135 break;
136 }
137 assert(Found && "already unregistered or incorrect cache state");
138 if (!HasNonnull)
139 AffectedValues.erase(I: AVI);
140 }
141
142 llvm::erase(C&: AssumeHandles, V: CI);
143}
144
145void AssumptionCache::AffectedValueCallbackVH::deleted() {
146 AC->AffectedValues.erase(Val: getValPtr());
147 // 'this' now dangles!
148}
149
150void AssumptionCache::transferAffectedValuesInCache(Value *OV, Value *NV) {
151 auto &NAVV = getOrInsertAffectedValues(V: NV);
152 auto AVI = AffectedValues.find(Val: OV);
153 if (AVI == AffectedValues.end())
154 return;
155
156 for (auto &A : AVI->second)
157 if (!llvm::is_contained(Range&: NAVV, Element: A))
158 NAVV.push_back(Elt: A);
159 AffectedValues.erase(Val: OV);
160}
161
162void AssumptionCache::AffectedValueCallbackVH::allUsesReplacedWith(Value *NV) {
163 if (!isa<Instruction>(Val: NV) && !isa<Argument>(Val: NV))
164 return;
165
166 // Any assumptions that affected this value now affect the new value.
167
168 AC->transferAffectedValuesInCache(OV: getValPtr(), NV);
169 // 'this' now might dangle! If the AffectedValues map was resized to add an
170 // entry for NV then this object might have been destroyed in favor of some
171 // copy in the grown map.
172}
173
174void AssumptionCache::scanFunction() {
175 assert(!Scanned && "Tried to scan the function twice!");
176 assert(AssumeHandles.empty() && "Already have assumes when scanning!");
177
178 // Go through all instructions in all blocks, add all calls to @llvm.assume
179 // to this cache.
180 for (BasicBlock &B : F)
181 for (Instruction &I : B)
182 if (isa<AssumeInst>(Val: &I))
183 AssumeHandles.push_back(Elt: &I);
184
185 // Mark the scan as complete.
186 Scanned = true;
187
188 // Update affected values.
189 for (auto &A : AssumeHandles)
190 updateAffectedValues(CI: cast<AssumeInst>(Val&: A));
191}
192
193void AssumptionCache::registerAssumption(AssumeInst *CI) {
194 // If we haven't scanned the function yet, just drop this assumption. It will
195 // be found when we scan later.
196 if (!Scanned)
197 return;
198
199 AssumeHandles.push_back(Elt: CI);
200
201#ifndef NDEBUG
202 assert(CI->getParent() &&
203 "Cannot register @llvm.assume call not in a basic block");
204 assert(&F == CI->getParent()->getParent() &&
205 "Cannot register @llvm.assume call not in this function");
206
207 // We expect the number of assumptions to be small, so in an asserts build
208 // check that we don't accumulate duplicates and that all assumptions point
209 // to the same function.
210 SmallPtrSet<Value *, 16> AssumptionSet;
211 for (auto &VH : AssumeHandles) {
212 if (!VH)
213 continue;
214
215 assert(&F == cast<Instruction>(VH)->getParent()->getParent() &&
216 "Cached assumption not inside this function!");
217 assert(match(cast<CallInst>(VH), m_Intrinsic<Intrinsic::assume>()) &&
218 "Cached something other than a call to @llvm.assume!");
219 assert(AssumptionSet.insert(VH).second &&
220 "Cache contains multiple copies of a call!");
221 }
222#endif
223
224 updateAffectedValues(CI);
225}
226
227AssumptionCache AssumptionAnalysis::run(Function &F,
228 FunctionAnalysisManager &FAM) {
229 auto &TTI = FAM.getResult<TargetIRAnalysis>(IR&: F);
230 return AssumptionCache(F, &TTI);
231}
232
233AnalysisKey AssumptionAnalysis::Key;
234
235PreservedAnalyses AssumptionPrinterPass::run(Function &F,
236 FunctionAnalysisManager &AM) {
237 AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(IR&: F);
238
239 OS << "Cached assumptions for function: " << F.getName() << "\n";
240 for (auto &VH : AC.assumptions())
241 if (VH)
242 OS << " " << *cast<CallInst>(Val&: VH)->getArgOperand(i: 0) << "\n";
243
244 return PreservedAnalyses::all();
245}
246
247void AssumptionCacheTracker::FunctionCallbackVH::deleted() {
248 auto I = ACT->AssumptionCaches.find_as(Val: cast<Function>(Val: getValPtr()));
249 if (I != ACT->AssumptionCaches.end())
250 ACT->AssumptionCaches.erase(I);
251 // 'this' now dangles!
252}
253
254AssumptionCache &AssumptionCacheTracker::getAssumptionCache(Function &F) {
255 // We probe the function map twice to try and avoid creating a value handle
256 // around the function in common cases. This makes insertion a bit slower,
257 // but if we have to insert we're going to scan the whole function so that
258 // shouldn't matter.
259 auto I = AssumptionCaches.find_as(Val: &F);
260 if (I != AssumptionCaches.end())
261 return *I->second;
262
263 auto *TTIWP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>();
264 auto *TTI = TTIWP ? &TTIWP->getTTI(F) : nullptr;
265
266 // Ok, build a new cache by scanning the function, insert it and the value
267 // handle into our map, and return the newly populated cache.
268 auto IP = AssumptionCaches.insert(KV: std::make_pair(
269 x: FunctionCallbackVH(&F, this), y: std::make_unique<AssumptionCache>(args&: F, args&: TTI)));
270 assert(IP.second && "Scanning function already in the map?");
271 return *IP.first->second;
272}
273
274AssumptionCache *AssumptionCacheTracker::lookupAssumptionCache(Function &F) {
275 auto I = AssumptionCaches.find_as(Val: &F);
276 if (I != AssumptionCaches.end())
277 return I->second.get();
278 return nullptr;
279}
280
281void AssumptionCacheTracker::verifyAnalysis() const {
282 // FIXME: In the long term the verifier should not be controllable with a
283 // flag. We should either fix all passes to correctly update the assumption
284 // cache and enable the verifier unconditionally or somehow arrange for the
285 // assumption list to be updated automatically by passes.
286 if (!VerifyAssumptionCache)
287 return;
288
289 SmallPtrSet<const CallInst *, 4> AssumptionSet;
290 for (const auto &I : AssumptionCaches) {
291 for (auto &VH : I.second->assumptions())
292 if (VH)
293 AssumptionSet.insert(Ptr: cast<CallInst>(Val&: VH));
294
295 for (const BasicBlock &B : cast<Function>(Val&: *I.first))
296 for (const Instruction &II : B)
297 if (match(V: &II, P: m_Intrinsic<Intrinsic::assume>()) &&
298 !AssumptionSet.count(Ptr: cast<CallInst>(Val: &II)))
299 report_fatal_error(reason: "Assumption in scanned function not in cache");
300 }
301}
302
303AssumptionCacheTracker::AssumptionCacheTracker() : ImmutablePass(ID) {}
304
305AssumptionCacheTracker::~AssumptionCacheTracker() = default;
306
307char AssumptionCacheTracker::ID = 0;
308
309INITIALIZE_PASS(AssumptionCacheTracker, "assumption-cache-tracker",
310 "Assumption Cache Tracker", false, true)
311