1//===- AssumptionCache.cpp - Cache finding @llvm.assume calls -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains a pass that keeps track of @llvm.assume intrinsics in
10// the functions of a module.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/Analysis/AssumptionCache.h"
15#include "llvm/ADT/STLExtras.h"
16#include "llvm/ADT/SmallPtrSet.h"
17#include "llvm/ADT/SmallVector.h"
18#include "llvm/Analysis/AssumeBundleQueries.h"
19#include "llvm/Analysis/TargetTransformInfo.h"
20#include "llvm/Analysis/ValueTracking.h"
21#include "llvm/IR/BasicBlock.h"
22#include "llvm/IR/Function.h"
23#include "llvm/IR/InstrTypes.h"
24#include "llvm/IR/Instruction.h"
25#include "llvm/IR/Instructions.h"
26#include "llvm/IR/PassManager.h"
27#include "llvm/IR/PatternMatch.h"
28#include "llvm/InitializePasses.h"
29#include "llvm/Pass.h"
30#include "llvm/Support/Casting.h"
31#include "llvm/Support/CommandLine.h"
32#include "llvm/Support/ErrorHandling.h"
33#include "llvm/Support/raw_ostream.h"
34#include <cassert>
35
36using namespace llvm;
37using namespace llvm::PatternMatch;
38
39static cl::opt<bool>
40 VerifyAssumptionCache("verify-assumption-cache", cl::Hidden,
41 cl::desc("Enable verification of assumption cache"),
42 cl::init(Val: false));
43
44SmallVector<AssumptionCache::ResultElem, 1> &
45AssumptionCache::getOrInsertAffectedValues(Value *V) {
46 // Try using find_as first to avoid creating extra value handles just for the
47 // purpose of doing the lookup.
48 auto AVI = AffectedValues.find_as(Val: V);
49 if (AVI != AffectedValues.end())
50 return AVI->second;
51
52 return AffectedValues[AffectedValueCallbackVH(V, this)];
53}
54
55void AssumptionCache::findValuesAffectedByOperandBundle(
56 OperandBundleUse Bundle, function_ref<void(Value *)> InsertAffected) {
57 auto AddAffectedVal = [&](Value *V) {
58 if (isa<Argument, GlobalValue, Instruction>(Val: V))
59 InsertAffected(V);
60 };
61
62 if (Bundle.getTagName() == "separate_storage") {
63 assert(Bundle.Inputs.size() == 2 && "separate_storage must have two args");
64 AddAffectedVal(getUnderlyingObject(V: Bundle.Inputs[0]));
65 AddAffectedVal(getUnderlyingObject(V: Bundle.Inputs[1]));
66 } else if (Bundle.Inputs.size() > ABA_WasOn &&
67 Bundle.getTagName() != IgnoreBundleTag)
68 AddAffectedVal(Bundle.Inputs[ABA_WasOn]);
69}
70
71static void
72findAffectedValues(CallBase *CI, TargetTransformInfo *TTI,
73 SmallVectorImpl<AssumptionCache::ResultElem> &Affected) {
74 // Note: This code must be kept in-sync with the code in
75 // computeKnownBitsFromAssume in ValueTracking.
76
77 auto InsertAffected = [&Affected](Value *V) {
78 Affected.push_back(Elt: {.Assume: V, .Index: AssumptionCache::ExprResultIdx});
79 };
80
81 auto AddAffectedVal = [&Affected](Value *V, unsigned Idx) {
82 if (isa<Argument>(Val: V) || isa<GlobalValue>(Val: V) || isa<Instruction>(Val: V)) {
83 Affected.push_back(Elt: {.Assume: V, .Index: Idx});
84 }
85 };
86
87 for (unsigned Idx = 0; Idx != CI->getNumOperandBundles(); Idx++)
88 AssumptionCache::findValuesAffectedByOperandBundle(
89 Bundle: CI->getOperandBundleAt(Index: Idx),
90 InsertAffected: [&](Value *V) { Affected.push_back(Elt: {.Assume: V, .Index: Idx}); });
91
92 Value *Cond = CI->getArgOperand(i: 0);
93 findValuesAffectedByCondition(Cond, /*IsAssume=*/true, InsertAffected);
94
95 if (TTI) {
96 const Value *Ptr;
97 unsigned AS;
98 std::tie(args&: Ptr, args&: AS) = TTI->getPredicatedAddrSpace(V: Cond);
99 if (Ptr)
100 AddAffectedVal(const_cast<Value *>(Ptr->stripInBoundsOffsets()),
101 AssumptionCache::ExprResultIdx);
102 }
103}
104
105void AssumptionCache::updateAffectedValues(AssumeInst *CI) {
106 SmallVector<AssumptionCache::ResultElem, 16> Affected;
107 findAffectedValues(CI, TTI, Affected);
108
109 for (auto &AV : Affected) {
110 auto &AVV = getOrInsertAffectedValues(V: AV.Assume);
111 if (llvm::none_of(Range&: AVV, P: [&](ResultElem &Elem) {
112 return Elem.Assume == CI && Elem.Index == AV.Index;
113 }))
114 AVV.push_back(Elt: {.Assume: CI, .Index: AV.Index});
115 }
116}
117
118void AssumptionCache::removeAffectedValues(AssumeInst *CI) {
119 SmallVector<AssumptionCache::ResultElem, 16> Affected;
120 findAffectedValues(CI, TTI, Affected);
121
122 // If a value appears more than once in an AssumeInst e.g., 'ptr %arg1' in:
123 // call void @llvm.assume(i1 true)
124 // [ "dereferenceable"(ptr %arg1, i64 1),
125 // "align"(ptr %arg1, i64 8) ]
126 // it will appear multiple times in Affected, but we may (depending on
127 // how the results in AffectedValues.find_as(AV.Assume) are ordered)
128 // nullify multiple instances of Elem.Assume during one iteration of the
129 // 'for (auto &AV : Affected)' loop below. The next iteration of that for
130 // loop may then find only a match to a different AssumeInst, resulting in
131 // an assertion failure. Avoid this by counting the number of expected
132 // matches.
133#ifndef NDEBUG
134 DenseMap<Value *, int> ExpectedMatches;
135 for (auto &AV : Affected)
136 if (AffectedValues.find_as(AV.Assume) != AffectedValues.end())
137 ExpectedMatches[AV.Assume]++;
138#endif
139
140 for (auto &AV : Affected) {
141 auto AVI = AffectedValues.find_as(Val: AV.Assume);
142 if (AVI == AffectedValues.end())
143 continue;
144 bool Found = false;
145 bool HasNonnull = false;
146 for (ResultElem &Elem : AVI->second) {
147 if (Elem.Assume == CI) {
148 Found = true;
149 Elem.Assume = nullptr;
150
151#ifndef NDEBUG
152 ExpectedMatches[AV.Assume]--;
153#endif
154 assert(ExpectedMatches[AV.Assume] >= 0);
155 // After ExpectedMatches[AV.Assume] == 0, we still need to iterate
156 // through this loop to determine the value of HasNonnull, to avoid
157 // prematurely calling AffectedValues.erase(AVI).
158 }
159 HasNonnull |= !!Elem.Assume;
160 if (HasNonnull && Found)
161 break;
162 }
163
164 assert(ExpectedMatches[AV.Assume] == 0 ||
165 Found && "already unregistered or incorrect cache state");
166
167 if (!HasNonnull)
168 AffectedValues.erase(I: AVI);
169 }
170
171 assert(
172 none_of(Affected, [&](auto &AV) { return ExpectedMatches[AV.Assume]; }) &&
173 "already unregistered or incorrect cache state");
174}
175
176void AssumptionCache::unregisterAssumption(AssumeInst *CI) {
177 removeAffectedValues(CI);
178 llvm::erase(C&: AssumeHandles, V: CI);
179}
180
181void AssumptionCache::replaceAssumption(WeakVH &Handle, AssumeInst *New) {
182 removeAffectedValues(CI: cast<AssumeInst>(Val&: Handle));
183 Handle = New;
184 updateAffectedValues(CI: New);
185}
186
187void AssumptionCache::AffectedValueCallbackVH::deleted() {
188 AC->AffectedValues.erase(Val: getValPtr());
189 // 'this' now dangles!
190}
191
192void AssumptionCache::transferAffectedValuesInCache(Value *OV, Value *NV) {
193 auto &NAVV = getOrInsertAffectedValues(V: NV);
194 auto AVI = AffectedValues.find(Val: OV);
195 if (AVI == AffectedValues.end())
196 return;
197
198 for (auto &A : AVI->second)
199 if (!llvm::is_contained(Range&: NAVV, Element: A))
200 NAVV.push_back(Elt: A);
201 AffectedValues.erase(Val: OV);
202}
203
204void AssumptionCache::AffectedValueCallbackVH::allUsesReplacedWith(Value *NV) {
205 if (!isa<Instruction>(Val: NV) && !isa<Argument>(Val: NV))
206 return;
207
208 // Any assumptions that affected this value now affect the new value.
209
210 AC->transferAffectedValuesInCache(OV: getValPtr(), NV);
211 // 'this' now might dangle! If the AffectedValues map was resized to add an
212 // entry for NV then this object might have been destroyed in favor of some
213 // copy in the grown map.
214}
215
216void AssumptionCache::scanFunction() {
217 assert(!Scanned && "Tried to scan the function twice!");
218 assert(AssumeHandles.empty() && "Already have assumes when scanning!");
219
220 // Go through all instructions in all blocks, add all calls to @llvm.assume
221 // to this cache.
222 for (BasicBlock &B : F)
223 for (Instruction &I : B)
224 if (isa<AssumeInst>(Val: &I))
225 AssumeHandles.push_back(Elt: &I);
226
227 // Mark the scan as complete.
228 Scanned = true;
229
230 // Update affected values.
231 for (auto &A : AssumeHandles)
232 updateAffectedValues(CI: cast<AssumeInst>(Val&: A));
233}
234
235void AssumptionCache::registerAssumption(AssumeInst *CI) {
236 // If we haven't scanned the function yet, just drop this assumption. It will
237 // be found when we scan later.
238 if (!Scanned)
239 return;
240
241 AssumeHandles.push_back(Elt: CI);
242
243#ifndef NDEBUG
244 assert(CI->getParent() &&
245 "Cannot register @llvm.assume call not in a basic block");
246 assert(&F == CI->getParent()->getParent() &&
247 "Cannot register @llvm.assume call not in this function");
248
249 // We expect the number of assumptions to be small, so in an asserts build
250 // check that we don't accumulate duplicates and that all assumptions point
251 // to the same function.
252 SmallPtrSet<Value *, 16> AssumptionSet;
253 for (auto &VH : AssumeHandles) {
254 if (!VH)
255 continue;
256
257 assert(&F == cast<Instruction>(VH)->getParent()->getParent() &&
258 "Cached assumption not inside this function!");
259 assert(match(cast<CallInst>(VH), m_Intrinsic<Intrinsic::assume>()) &&
260 "Cached something other than a call to @llvm.assume!");
261 assert(AssumptionSet.insert(VH).second &&
262 "Cache contains multiple copies of a call!");
263 }
264#endif
265
266 updateAffectedValues(CI);
267}
268
269AssumptionCache AssumptionAnalysis::run(Function &F,
270 FunctionAnalysisManager &FAM) {
271 auto &TTI = FAM.getResult<TargetIRAnalysis>(IR&: F);
272 return AssumptionCache(F, &TTI);
273}
274
275AnalysisKey AssumptionAnalysis::Key;
276
277PreservedAnalyses AssumptionPrinterPass::run(Function &F,
278 FunctionAnalysisManager &AM) {
279 AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(IR&: F);
280
281 OS << "Cached assumptions for function: " << F.getName() << "\n";
282 for (auto &VH : AC.assumptions())
283 if (VH)
284 OS << " " << *cast<CallInst>(Val&: VH)->getArgOperand(i: 0) << "\n";
285
286 return PreservedAnalyses::all();
287}
288
289void AssumptionCacheTracker::FunctionCallbackVH::deleted() {
290 auto I = ACT->AssumptionCaches.find_as(Val: cast<Function>(Val: getValPtr()));
291 if (I != ACT->AssumptionCaches.end())
292 ACT->AssumptionCaches.erase(I);
293 // 'this' now dangles!
294}
295
296AssumptionCache &AssumptionCacheTracker::getAssumptionCache(Function &F) {
297 // We probe the function map twice to try and avoid creating a value handle
298 // around the function in common cases. This makes insertion a bit slower,
299 // but if we have to insert we're going to scan the whole function so that
300 // shouldn't matter.
301 auto I = AssumptionCaches.find_as(Val: &F);
302 if (I != AssumptionCaches.end())
303 return *I->second;
304
305 auto *TTIWP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>();
306 auto *TTI = TTIWP ? &TTIWP->getTTI(F) : nullptr;
307
308 // Ok, build a new cache by scanning the function, insert it and the value
309 // handle into our map, and return the newly populated cache.
310 auto IP = AssumptionCaches.insert(KV: std::make_pair(
311 x: FunctionCallbackVH(&F, this), y: std::make_unique<AssumptionCache>(args&: F, args&: TTI)));
312 assert(IP.second && "Scanning function already in the map?");
313 return *IP.first->second;
314}
315
316AssumptionCache *AssumptionCacheTracker::lookupAssumptionCache(Function &F) {
317 auto I = AssumptionCaches.find_as(Val: &F);
318 if (I != AssumptionCaches.end())
319 return I->second.get();
320 return nullptr;
321}
322
323void AssumptionCacheTracker::verifyAnalysis() const {
324 // FIXME: In the long term the verifier should not be controllable with a
325 // flag. We should either fix all passes to correctly update the assumption
326 // cache and enable the verifier unconditionally or somehow arrange for the
327 // assumption list to be updated automatically by passes.
328 if (!VerifyAssumptionCache)
329 return;
330
331 SmallPtrSet<const CallInst *, 4> AssumptionSet;
332 for (const auto &I : AssumptionCaches) {
333 for (auto &VH : I.second->assumptions())
334 if (VH)
335 AssumptionSet.insert(Ptr: cast<CallInst>(Val&: VH));
336
337 for (const BasicBlock &B : cast<Function>(Val&: *I.first))
338 for (const Instruction &II : B)
339 if (match(V: &II, P: m_Intrinsic<Intrinsic::assume>()) &&
340 !AssumptionSet.count(Ptr: cast<CallInst>(Val: &II)))
341 report_fatal_error(reason: "Assumption in scanned function not in cache");
342 }
343}
344
345AssumptionCacheTracker::AssumptionCacheTracker() : ImmutablePass(ID) {}
346
347AssumptionCacheTracker::~AssumptionCacheTracker() = default;
348
349char AssumptionCacheTracker::ID = 0;
350
351INITIALIZE_PASS(AssumptionCacheTracker, "assumption-cache-tracker",
352 "Assumption Cache Tracker", false, true)
353