1//===- WholeProgramDevirt.cpp - Whole program virtual call optimization ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements whole program optimization of virtual calls in cases
10// where we know (via !type metadata) that the list of callees is fixed. This
11// includes the following:
12// - Single implementation devirtualization: if a virtual call has a single
13// possible callee, replace all calls with a direct call to that callee.
14// - Virtual constant propagation: if the virtual function's return type is an
15// integer <=64 bits and all possible callees are readnone, for each class and
16// each list of constant arguments: evaluate the function, store the return
17// value alongside the virtual table, and rewrite each virtual call as a load
18// from the virtual table.
19// - Uniform return value optimization: if the conditions for virtual constant
20// propagation hold and each function returns the same constant value, replace
21// each virtual call with that constant.
22// - Unique return value optimization for i1 return values: if the conditions
23// for virtual constant propagation hold and a single vtable's function
24// returns 0, or a single vtable's function returns 1, replace each virtual
25// call with a comparison of the vptr against that vtable's address.
26//
27// This pass is intended to be used during the regular/thin and non-LTO
28// pipelines:
29//
30// During regular LTO, the pass determines the best optimization for each
31// virtual call and applies the resolutions directly to virtual calls that are
32// eligible for virtual call optimization (i.e. calls that use either of the
33// llvm.assume(llvm.type.test) or llvm.type.checked.load intrinsics).
34//
35// During hybrid Regular/ThinLTO, the pass operates in two phases:
36// - Export phase: this is run during the thin link over a single merged module
37// that contains all vtables with !type metadata that participate in the link.
38// The pass computes a resolution for each virtual call and stores it in the
39// type identifier summary.
40// - Import phase: this is run during the thin backends over the individual
41// modules. The pass applies the resolutions previously computed during the
42// import phase to each eligible virtual call.
43//
44// During ThinLTO, the pass operates in two phases:
45// - Export phase: this is run during the thin link over the index which
46// contains a summary of all vtables with !type metadata that participate in
47// the link. It computes a resolution for each virtual call and stores it in
48// the type identifier summary. Only single implementation devirtualization
49// is supported.
50// - Import phase: (same as with hybrid case above).
51//
52// During Speculative devirtualization mode -not restricted to LTO-:
53// - The pass applies speculative devirtualization without requiring any type of
54// visibility.
55// - Skips other features like virtual constant propagation, uniform return
56// value optimization, unique return value optimization and branch funnels as
57// they need LTO.
58// - This mode is enabled via 'devirtualize-speculatively' flag.
59//
60//===----------------------------------------------------------------------===//
61
62#include "llvm/Transforms/IPO/WholeProgramDevirt.h"
63#include "llvm/ADT/ArrayRef.h"
64#include "llvm/ADT/DenseMap.h"
65#include "llvm/ADT/DenseMapInfo.h"
66#include "llvm/ADT/DenseSet.h"
67#include "llvm/ADT/MapVector.h"
68#include "llvm/ADT/SmallVector.h"
69#include "llvm/ADT/Statistic.h"
70#include "llvm/Analysis/AssumptionCache.h"
71#include "llvm/Analysis/BasicAliasAnalysis.h"
72#include "llvm/Analysis/BlockFrequencyInfo.h"
73#include "llvm/Analysis/ModuleSummaryAnalysis.h"
74#include "llvm/Analysis/OptimizationRemarkEmitter.h"
75#include "llvm/Analysis/ProfileSummaryInfo.h"
76#include "llvm/Analysis/TypeMetadataUtils.h"
77#include "llvm/Bitcode/BitcodeReader.h"
78#include "llvm/Bitcode/BitcodeWriter.h"
79#include "llvm/IR/Constants.h"
80#include "llvm/IR/DataLayout.h"
81#include "llvm/IR/DebugLoc.h"
82#include "llvm/IR/DerivedTypes.h"
83#include "llvm/IR/DiagnosticInfo.h"
84#include "llvm/IR/Dominators.h"
85#include "llvm/IR/Function.h"
86#include "llvm/IR/GlobalAlias.h"
87#include "llvm/IR/GlobalVariable.h"
88#include "llvm/IR/IRBuilder.h"
89#include "llvm/IR/InstrTypes.h"
90#include "llvm/IR/Instruction.h"
91#include "llvm/IR/Instructions.h"
92#include "llvm/IR/Intrinsics.h"
93#include "llvm/IR/LLVMContext.h"
94#include "llvm/IR/MDBuilder.h"
95#include "llvm/IR/Metadata.h"
96#include "llvm/IR/Module.h"
97#include "llvm/IR/ModuleSummaryIndexYAML.h"
98#include "llvm/IR/PassManager.h"
99#include "llvm/IR/ProfDataUtils.h"
100#include "llvm/Support/Casting.h"
101#include "llvm/Support/CommandLine.h"
102#include "llvm/Support/DebugCounter.h"
103#include "llvm/Support/Errc.h"
104#include "llvm/Support/Error.h"
105#include "llvm/Support/FileSystem.h"
106#include "llvm/Support/GlobPattern.h"
107#include "llvm/Support/TimeProfiler.h"
108#include "llvm/TargetParser/Triple.h"
109#include "llvm/Transforms/IPO.h"
110#include "llvm/Transforms/IPO/FunctionAttrs.h"
111#include "llvm/Transforms/Utils/BasicBlockUtils.h"
112#include "llvm/Transforms/Utils/CallPromotionUtils.h"
113#include "llvm/Transforms/Utils/Evaluator.h"
114#include <algorithm>
115#include <cmath>
116#include <cstddef>
117#include <map>
118#include <set>
119#include <string>
120
121using namespace llvm;
122using namespace wholeprogramdevirt;
123
124#define DEBUG_TYPE "wholeprogramdevirt"
125
126STATISTIC(NumDevirtTargets, "Number of whole program devirtualization targets");
127STATISTIC(NumSingleImpl, "Number of single implementation devirtualizations");
128STATISTIC(NumBranchFunnel, "Number of branch funnels");
129STATISTIC(NumUniformRetVal, "Number of uniform return value optimizations");
130STATISTIC(NumUniqueRetVal, "Number of unique return value optimizations");
131STATISTIC(NumVirtConstProp1Bit,
132 "Number of 1 bit virtual constant propagations");
133STATISTIC(NumVirtConstProp, "Number of virtual constant propagations");
134DEBUG_COUNTER(CallsToDevirt, "calls-to-devirt",
135 "Controls how many calls should be devirtualized.");
136
137namespace llvm {
138
139static cl::opt<PassSummaryAction> ClSummaryAction(
140 "wholeprogramdevirt-summary-action",
141 cl::desc("What to do with the summary when running this pass"),
142 cl::values(clEnumValN(PassSummaryAction::None, "none", "Do nothing"),
143 clEnumValN(PassSummaryAction::Import, "import",
144 "Import typeid resolutions from summary and globals"),
145 clEnumValN(PassSummaryAction::Export, "export",
146 "Export typeid resolutions to summary and globals")),
147 cl::Hidden);
148
149static cl::opt<std::string> ClReadSummary(
150 "wholeprogramdevirt-read-summary",
151 cl::desc(
152 "Read summary from given bitcode or YAML file before running pass"),
153 cl::Hidden);
154
155static cl::opt<std::string> ClWriteSummary(
156 "wholeprogramdevirt-write-summary",
157 cl::desc("Write summary to given bitcode or YAML file after running pass. "
158 "Output file format is deduced from extension: *.bc means writing "
159 "bitcode, otherwise YAML"),
160 cl::Hidden);
161
162// TODO: This option eventually should support any public visibility vtables
163// with/out LTO.
164static cl::opt<bool> ClDevirtualizeSpeculatively(
165 "devirtualize-speculatively",
166 cl::desc("Enable speculative devirtualization optimization"),
167 cl::init(Val: false));
168
169static cl::opt<unsigned>
170 ClThreshold("wholeprogramdevirt-branch-funnel-threshold", cl::Hidden,
171 cl::init(Val: 10),
172 cl::desc("Maximum number of call targets per "
173 "call site to enable branch funnels"));
174
175static cl::opt<bool>
176 PrintSummaryDevirt("wholeprogramdevirt-print-index-based", cl::Hidden,
177 cl::desc("Print index-based devirtualization messages"));
178
179/// Provide a way to force enable whole program visibility in tests.
180/// This is needed to support legacy tests that don't contain
181/// !vcall_visibility metadata (the mere presense of type tests
182/// previously implied hidden visibility).
183static cl::opt<bool>
184 WholeProgramVisibility("whole-program-visibility", cl::Hidden,
185 cl::desc("Enable whole program visibility"));
186
187/// Provide a way to force disable whole program for debugging or workarounds,
188/// when enabled via the linker.
189static cl::opt<bool> DisableWholeProgramVisibility(
190 "disable-whole-program-visibility", cl::Hidden,
191 cl::desc("Disable whole program visibility (overrides enabling options)"));
192
193/// Provide way to prevent certain function from being devirtualized
194static cl::list<std::string>
195 SkipFunctionNames("wholeprogramdevirt-skip",
196 cl::desc("Prevent function(s) from being devirtualized"),
197 cl::Hidden, cl::CommaSeparated);
198
199extern cl::opt<bool> ProfcheckDisableMetadataFixes;
200
201} // end namespace llvm
202
203/// With Clang, a pure virtual class's deleting destructor is emitted as a
204/// `llvm.trap` intrinsic followed by an unreachable IR instruction. In the
205/// context of whole program devirtualization, the deleting destructor of a pure
206/// virtual class won't be invoked by the source code so safe to skip as a
207/// devirtualize target.
208///
209/// However, not all unreachable functions are safe to skip. In some cases, the
210/// program intends to run such functions and terminate, for instance, a unit
211/// test may run a death test. A non-test program might (or allowed to) invoke
212/// such functions to report failures (whether/when it's a good practice or not
213/// is a different topic).
214///
215/// This option is enabled to keep an unreachable function as a possible
216/// devirtualize target to conservatively keep the program behavior.
217///
218/// TODO: Make a pure virtual class's deleting destructor precisely identifiable
219/// in Clang's codegen for more devirtualization in LLVM.
220static cl::opt<bool> WholeProgramDevirtKeepUnreachableFunction(
221 "wholeprogramdevirt-keep-unreachable-function",
222 cl::desc("Regard unreachable functions as possible devirtualize targets."),
223 cl::Hidden, cl::init(Val: true));
224
225/// Mechanism to add runtime checking of devirtualization decisions, optionally
226/// trapping or falling back to indirect call on any that are not correct.
227/// Trapping mode is useful for debugging undefined behavior leading to failures
228/// with WPD. Fallback mode is useful for ensuring safety when whole program
229/// visibility may be compromised.
230enum WPDCheckMode { None, Trap, Fallback };
231static cl::opt<WPDCheckMode> DevirtCheckMode(
232 "wholeprogramdevirt-check", cl::Hidden,
233 cl::desc("Type of checking for incorrect devirtualizations"),
234 cl::values(clEnumValN(WPDCheckMode::None, "none", "No checking"),
235 clEnumValN(WPDCheckMode::Trap, "trap", "Trap when incorrect"),
236 clEnumValN(WPDCheckMode::Fallback, "fallback",
237 "Fallback to indirect when incorrect")));
238
239namespace {
240struct PatternList {
241 std::vector<GlobPattern> Patterns;
242 template <class T> void init(const T &StringList) {
243 for (const auto &S : StringList)
244 if (Expected<GlobPattern> Pat = GlobPattern::create(Pat: S))
245 Patterns.push_back(x: std::move(*Pat));
246 }
247 bool match(StringRef S) {
248 for (const GlobPattern &P : Patterns)
249 if (P.match(S))
250 return true;
251 return false;
252 }
253};
254} // namespace
255
256// Find the minimum offset that we may store a value of size Size bits at. If
257// IsAfter is set, look for an offset before the object, otherwise look for an
258// offset after the object.
259uint64_t
260wholeprogramdevirt::findLowestOffset(ArrayRef<VirtualCallTarget> Targets,
261 bool IsAfter, uint64_t Size) {
262 // Find a minimum offset taking into account only vtable sizes.
263 uint64_t MinByte = 0;
264 for (const VirtualCallTarget &Target : Targets) {
265 if (IsAfter)
266 MinByte = std::max(a: MinByte, b: Target.minAfterBytes());
267 else
268 MinByte = std::max(a: MinByte, b: Target.minBeforeBytes());
269 }
270
271 // Build a vector of arrays of bytes covering, for each target, a slice of the
272 // used region (see AccumBitVector::BytesUsed in
273 // llvm/Transforms/IPO/WholeProgramDevirt.h) starting at MinByte. Effectively,
274 // this aligns the used regions to start at MinByte.
275 //
276 // In this example, A, B and C are vtables, # is a byte already allocated for
277 // a virtual function pointer, AAAA... (etc.) are the used regions for the
278 // vtables and Offset(X) is the value computed for the Offset variable below
279 // for X.
280 //
281 // Offset(A)
282 // | |
283 // |MinByte
284 // A: ################AAAAAAAA|AAAAAAAA
285 // B: ########BBBBBBBBBBBBBBBB|BBBB
286 // C: ########################|CCCCCCCCCCCCCCCC
287 // | Offset(B) |
288 //
289 // This code produces the slices of A, B and C that appear after the divider
290 // at MinByte.
291 std::vector<ArrayRef<uint8_t>> Used;
292 for (const VirtualCallTarget &Target : Targets) {
293 ArrayRef<uint8_t> VTUsed = IsAfter ? Target.TM->Bits->After.BytesUsed
294 : Target.TM->Bits->Before.BytesUsed;
295 uint64_t Offset = IsAfter ? MinByte - Target.minAfterBytes()
296 : MinByte - Target.minBeforeBytes();
297
298 // Disregard used regions that are smaller than Offset. These are
299 // effectively all-free regions that do not need to be checked.
300 if (VTUsed.size() > Offset)
301 Used.push_back(x: VTUsed.slice(N: Offset));
302 }
303
304 if (Size == 1) {
305 // Find a free bit in each member of Used.
306 for (unsigned I = 0;; ++I) {
307 uint8_t BitsUsed = 0;
308 for (auto &&B : Used)
309 if (I < B.size())
310 BitsUsed |= B[I];
311 if (BitsUsed != 0xff)
312 return (MinByte + I) * 8 + llvm::countr_zero(Val: uint8_t(~BitsUsed));
313 }
314 } else {
315 // Find a free (Size/8) byte region in each member of Used.
316 // FIXME: see if alignment helps.
317 for (unsigned I = 0;; ++I) {
318 for (auto &&B : Used) {
319 unsigned Byte = 0;
320 while ((I + Byte) < B.size() && Byte < (Size / 8)) {
321 if (B[I + Byte])
322 goto NextI;
323 ++Byte;
324 }
325 }
326 // Rounding up ensures the constant is always stored at address we
327 // can directly load from without misalignment.
328 return alignTo(Value: (MinByte + I) * 8, Align: Size);
329 NextI:;
330 }
331 }
332}
333
334void wholeprogramdevirt::setBeforeReturnValues(
335 MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocBefore,
336 unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) {
337 if (BitWidth == 1)
338 OffsetByte = -(AllocBefore / 8 + 1);
339 else
340 OffsetByte = -((AllocBefore + 7) / 8 + (BitWidth + 7) / 8);
341 OffsetBit = AllocBefore % 8;
342
343 for (VirtualCallTarget &Target : Targets) {
344 if (BitWidth == 1)
345 Target.setBeforeBit(AllocBefore);
346 else
347 Target.setBeforeBytes(Pos: AllocBefore, Size: (BitWidth + 7) / 8);
348 }
349}
350
351void wholeprogramdevirt::setAfterReturnValues(
352 MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocAfter,
353 unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) {
354 if (BitWidth == 1)
355 OffsetByte = AllocAfter / 8;
356 else
357 OffsetByte = (AllocAfter + 7) / 8;
358 OffsetBit = AllocAfter % 8;
359
360 for (VirtualCallTarget &Target : Targets) {
361 if (BitWidth == 1)
362 Target.setAfterBit(AllocAfter);
363 else
364 Target.setAfterBytes(Pos: AllocAfter, Size: (BitWidth + 7) / 8);
365 }
366}
367
368VirtualCallTarget::VirtualCallTarget(GlobalValue *Fn, const TypeMemberInfo *TM)
369 : Fn(Fn), TM(TM),
370 IsBigEndian(Fn->getDataLayout().isBigEndian()),
371 WasDevirt(false) {}
372
373namespace {
374
375// A slot in a set of virtual tables. The TypeID identifies the set of virtual
376// tables, and the ByteOffset is the offset in bytes from the address point to
377// the virtual function pointer.
378struct VTableSlot {
379 Metadata *TypeID;
380 uint64_t ByteOffset;
381};
382
383} // end anonymous namespace
384
385template <> struct llvm::DenseMapInfo<VTableSlot> {
386 static VTableSlot getEmptyKey() {
387 return {.TypeID: DenseMapInfo<Metadata *>::getEmptyKey(),
388 .ByteOffset: DenseMapInfo<uint64_t>::getEmptyKey()};
389 }
390 static VTableSlot getTombstoneKey() {
391 return {.TypeID: DenseMapInfo<Metadata *>::getTombstoneKey(),
392 .ByteOffset: DenseMapInfo<uint64_t>::getTombstoneKey()};
393 }
394 static unsigned getHashValue(const VTableSlot &I) {
395 return DenseMapInfo<Metadata *>::getHashValue(PtrVal: I.TypeID) ^
396 DenseMapInfo<uint64_t>::getHashValue(Val: I.ByteOffset);
397 }
398 static bool isEqual(const VTableSlot &LHS,
399 const VTableSlot &RHS) {
400 return LHS.TypeID == RHS.TypeID && LHS.ByteOffset == RHS.ByteOffset;
401 }
402};
403
404template <> struct llvm::DenseMapInfo<VTableSlotSummary> {
405 static VTableSlotSummary getEmptyKey() {
406 return {.TypeID: DenseMapInfo<StringRef>::getEmptyKey(),
407 .ByteOffset: DenseMapInfo<uint64_t>::getEmptyKey()};
408 }
409 static VTableSlotSummary getTombstoneKey() {
410 return {.TypeID: DenseMapInfo<StringRef>::getTombstoneKey(),
411 .ByteOffset: DenseMapInfo<uint64_t>::getTombstoneKey()};
412 }
413 static unsigned getHashValue(const VTableSlotSummary &I) {
414 return DenseMapInfo<StringRef>::getHashValue(Val: I.TypeID) ^
415 DenseMapInfo<uint64_t>::getHashValue(Val: I.ByteOffset);
416 }
417 static bool isEqual(const VTableSlotSummary &LHS,
418 const VTableSlotSummary &RHS) {
419 return LHS.TypeID == RHS.TypeID && LHS.ByteOffset == RHS.ByteOffset;
420 }
421};
422
423// Returns true if the function must be unreachable based on ValueInfo.
424//
425// In particular, identifies a function as unreachable in the following
426// conditions
427// 1) All summaries are live.
428// 2) All function summaries indicate it's unreachable
429// 3) There is no non-function with the same GUID (which is rare)
430static bool mustBeUnreachableFunction(ValueInfo TheFnVI) {
431 if (WholeProgramDevirtKeepUnreachableFunction)
432 return false;
433
434 if ((!TheFnVI) || TheFnVI.getSummaryList().empty()) {
435 // Returns false if ValueInfo is absent, or the summary list is empty
436 // (e.g., function declarations).
437 return false;
438 }
439
440 for (const auto &Summary : TheFnVI.getSummaryList()) {
441 // Conservatively returns false if any non-live functions are seen.
442 // In general either all summaries should be live or all should be dead.
443 if (!Summary->isLive())
444 return false;
445 if (auto *FS = dyn_cast<FunctionSummary>(Val: Summary->getBaseObject())) {
446 if (!FS->fflags().MustBeUnreachable)
447 return false;
448 }
449 // Be conservative if a non-function has the same GUID (which is rare).
450 else
451 return false;
452 }
453 // All function summaries are live and all of them agree that the function is
454 // unreachble.
455 return true;
456}
457
458namespace {
459// A virtual call site. VTable is the loaded virtual table pointer, and CS is
460// the indirect virtual call.
461struct VirtualCallSite {
462 Value *VTable = nullptr;
463 CallBase &CB;
464
465 // If non-null, this field points to the associated unsafe use count stored in
466 // the DevirtModule::NumUnsafeUsesForTypeTest map below. See the description
467 // of that field for details.
468 unsigned *NumUnsafeUses = nullptr;
469
470 void
471 emitRemark(const StringRef OptName, const StringRef TargetName,
472 function_ref<OptimizationRemarkEmitter &(Function &)> OREGetter) {
473 Function *F = CB.getCaller();
474 DebugLoc DLoc = CB.getDebugLoc();
475 BasicBlock *Block = CB.getParent();
476
477 using namespace ore;
478 OREGetter(*F).emit(OptDiag: OptimizationRemark(DEBUG_TYPE, OptName, DLoc, Block)
479 << NV("Optimization", OptName)
480 << ": devirtualized a call to "
481 << NV("FunctionName", TargetName));
482 }
483
484 void replaceAndErase(
485 const StringRef OptName, const StringRef TargetName, bool RemarksEnabled,
486 function_ref<OptimizationRemarkEmitter &(Function &)> OREGetter,
487 Value *New) {
488 if (RemarksEnabled)
489 emitRemark(OptName, TargetName, OREGetter);
490 CB.replaceAllUsesWith(V: New);
491 if (auto *II = dyn_cast<InvokeInst>(Val: &CB)) {
492 UncondBrInst::Create(IfTrue: II->getNormalDest(), InsertBefore: CB.getIterator());
493 II->getUnwindDest()->removePredecessor(Pred: II->getParent());
494 }
495 CB.eraseFromParent();
496 // This use is no longer unsafe.
497 if (NumUnsafeUses)
498 --*NumUnsafeUses;
499 }
500};
501
502// Call site information collected for a specific VTableSlot and possibly a list
503// of constant integer arguments. The grouping by arguments is handled by the
504// VTableSlotInfo class.
505struct CallSiteInfo {
506 /// The set of call sites for this slot. Used during regular LTO and the
507 /// import phase of ThinLTO (as well as the export phase of ThinLTO for any
508 /// call sites that appear in the merged module itself); in each of these
509 /// cases we are directly operating on the call sites at the IR level.
510 std::vector<VirtualCallSite> CallSites;
511
512 /// Whether all call sites represented by this CallSiteInfo, including those
513 /// in summaries, have been devirtualized. This starts off as true because a
514 /// default constructed CallSiteInfo represents no call sites.
515 ///
516 /// If at the end of the pass there are still undevirtualized calls, we will
517 /// need to add a use of llvm.type.test to each of the function summaries in
518 /// the vector.
519 bool AllCallSitesDevirted = true;
520
521 // These fields are used during the export phase of ThinLTO and reflect
522 // information collected from function summaries.
523
524 /// CFI-specific: a vector containing the list of function summaries that use
525 /// the llvm.type.checked.load intrinsic and therefore will require
526 /// resolutions for llvm.type.test in order to implement CFI checks if
527 /// devirtualization was unsuccessful.
528 std::vector<FunctionSummary *> SummaryTypeCheckedLoadUsers;
529
530 /// A vector containing the list of function summaries that use
531 /// assume(llvm.type.test).
532 std::vector<FunctionSummary *> SummaryTypeTestAssumeUsers;
533
534 bool isExported() const {
535 return !SummaryTypeCheckedLoadUsers.empty() ||
536 !SummaryTypeTestAssumeUsers.empty();
537 }
538
539 void addSummaryTypeCheckedLoadUser(FunctionSummary *FS) {
540 SummaryTypeCheckedLoadUsers.push_back(x: FS);
541 AllCallSitesDevirted = false;
542 }
543
544 void addSummaryTypeTestAssumeUser(FunctionSummary *FS) {
545 SummaryTypeTestAssumeUsers.push_back(x: FS);
546 AllCallSitesDevirted = false;
547 }
548
549 void markDevirt() { AllCallSitesDevirted = true; }
550};
551
552// Call site information collected for a specific VTableSlot.
553struct VTableSlotInfo {
554 // The set of call sites which do not have all constant integer arguments
555 // (excluding "this").
556 CallSiteInfo CSInfo;
557
558 // The set of call sites with all constant integer arguments (excluding
559 // "this"), grouped by argument list.
560 std::map<std::vector<uint64_t>, CallSiteInfo> ConstCSInfo;
561
562 void addCallSite(Value *VTable, CallBase &CB, unsigned *NumUnsafeUses);
563
564private:
565 CallSiteInfo &findCallSiteInfo(CallBase &CB);
566};
567
568CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallBase &CB) {
569 std::vector<uint64_t> Args;
570 auto *CBType = dyn_cast<IntegerType>(Val: CB.getType());
571 if (!CBType || CBType->getBitWidth() > 64 || CB.arg_empty())
572 return CSInfo;
573 for (auto &&Arg : drop_begin(RangeOrContainer: CB.args())) {
574 auto *CI = dyn_cast<ConstantInt>(Val&: Arg);
575 if (!CI || CI->getBitWidth() > 64)
576 return CSInfo;
577 Args.push_back(x: CI->getZExtValue());
578 }
579 return ConstCSInfo[Args];
580}
581
582void VTableSlotInfo::addCallSite(Value *VTable, CallBase &CB,
583 unsigned *NumUnsafeUses) {
584 auto &CSI = findCallSiteInfo(CB);
585 CSI.AllCallSitesDevirted = false;
586 CSI.CallSites.push_back(x: {.VTable: VTable, .CB: CB, .NumUnsafeUses: NumUnsafeUses});
587}
588
589struct DevirtModule {
590 Module &M;
591 ModuleAnalysisManager &MAM;
592 FunctionAnalysisManager &FAM;
593
594 ModuleSummaryIndex *const ExportSummary;
595 const ModuleSummaryIndex *const ImportSummary;
596
597 IntegerType *const Int8Ty;
598 PointerType *const Int8PtrTy;
599 IntegerType *const Int32Ty;
600 IntegerType *const Int64Ty;
601 IntegerType *const IntPtrTy;
602 /// Sizeless array type, used for imported vtables. This provides a signal
603 /// to analyzers that these imports may alias, as they do for example
604 /// when multiple unique return values occur in the same vtable.
605 ArrayType *const Int8Arr0Ty;
606
607 const bool RemarksEnabled;
608 std::function<OptimizationRemarkEmitter &(Function &)> OREGetter;
609 MapVector<VTableSlot, VTableSlotInfo> CallSlots;
610
611 // Calls that have already been optimized. We may add a call to multiple
612 // VTableSlotInfos if vtable loads are coalesced and need to make sure not to
613 // optimize a call more than once.
614 SmallPtrSet<CallBase *, 8> OptimizedCalls;
615
616 // Store calls that had their ptrauth bundle removed. They are to be deleted
617 // at the end of the optimization.
618 SmallVector<CallBase *, 8> CallsWithPtrAuthBundleRemoved;
619
620 // This map keeps track of the number of "unsafe" uses of a loaded function
621 // pointer. The key is the associated llvm.type.test intrinsic call generated
622 // by this pass. An unsafe use is one that calls the loaded function pointer
623 // directly. Every time we eliminate an unsafe use (for example, by
624 // devirtualizing it or by applying virtual constant propagation), we
625 // decrement the value stored in this map. If a value reaches zero, we can
626 // eliminate the type check by RAUWing the associated llvm.type.test call with
627 // true.
628 std::map<CallInst *, unsigned> NumUnsafeUsesForTypeTest;
629 PatternList FunctionsToSkip;
630
631 const bool DevirtSpeculatively;
632 DevirtModule(Module &M, ModuleAnalysisManager &MAM,
633 ModuleSummaryIndex *ExportSummary,
634 const ModuleSummaryIndex *ImportSummary,
635 bool DevirtSpeculatively)
636 : M(M), MAM(MAM),
637 FAM(MAM.getResult<FunctionAnalysisManagerModuleProxy>(IR&: M).getManager()),
638 ExportSummary(ExportSummary), ImportSummary(ImportSummary),
639 Int8Ty(Type::getInt8Ty(C&: M.getContext())),
640 Int8PtrTy(PointerType::getUnqual(C&: M.getContext())),
641 Int32Ty(Type::getInt32Ty(C&: M.getContext())),
642 Int64Ty(Type::getInt64Ty(C&: M.getContext())),
643 IntPtrTy(M.getDataLayout().getIntPtrType(C&: M.getContext(), AddressSpace: 0)),
644 Int8Arr0Ty(ArrayType::get(ElementType: Type::getInt8Ty(C&: M.getContext()), NumElements: 0)),
645 RemarksEnabled(areRemarksEnabled()),
646 OREGetter([&](Function &F) -> OptimizationRemarkEmitter & {
647 return FAM.getResult<OptimizationRemarkEmitterAnalysis>(IR&: F);
648 }),
649 DevirtSpeculatively(DevirtSpeculatively) {
650 assert(!(ExportSummary && ImportSummary));
651 FunctionsToSkip.init(StringList: SkipFunctionNames);
652 }
653
654 bool areRemarksEnabled();
655
656 void
657 scanTypeTestUsers(Function *TypeTestFunc,
658 DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap);
659 void scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc);
660
661 void buildTypeIdentifierMap(
662 std::vector<VTableBits> &Bits,
663 DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap);
664
665 bool
666 tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot,
667 const std::set<TypeMemberInfo> &TypeMemberInfos,
668 uint64_t ByteOffset,
669 ModuleSummaryIndex *ExportSummary);
670
671 void applySingleImplDevirt(VTableSlotInfo &SlotInfo, Constant *TheFn,
672 bool &IsExported);
673 bool trySingleImplDevirt(ModuleSummaryIndex *ExportSummary,
674 MutableArrayRef<VirtualCallTarget> TargetsForSlot,
675 VTableSlotInfo &SlotInfo,
676 WholeProgramDevirtResolution *Res);
677
678 void applyICallBranchFunnel(VTableSlotInfo &SlotInfo, Function &JT,
679 bool &IsExported);
680 void tryICallBranchFunnel(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
681 VTableSlotInfo &SlotInfo,
682 WholeProgramDevirtResolution *Res, VTableSlot Slot);
683
684 bool tryEvaluateFunctionsWithArgs(
685 MutableArrayRef<VirtualCallTarget> TargetsForSlot,
686 ArrayRef<uint64_t> Args);
687
688 void applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
689 uint64_t TheRetVal);
690 bool tryUniformRetValOpt(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
691 CallSiteInfo &CSInfo,
692 WholeProgramDevirtResolution::ByArg *Res);
693
694 // Returns the global symbol name that is used to export information about the
695 // given vtable slot and list of arguments.
696 std::string getGlobalName(VTableSlot Slot, ArrayRef<uint64_t> Args,
697 StringRef Name);
698
699 bool shouldExportConstantsAsAbsoluteSymbols();
700
701 // This function is called during the export phase to create a symbol
702 // definition containing information about the given vtable slot and list of
703 // arguments.
704 void exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, StringRef Name,
705 Constant *C);
706 void exportConstant(VTableSlot Slot, ArrayRef<uint64_t> Args, StringRef Name,
707 uint32_t Const, uint32_t &Storage);
708
709 // This function is called during the import phase to create a reference to
710 // the symbol definition created during the export phase.
711 Constant *importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args,
712 StringRef Name);
713 Constant *importConstant(VTableSlot Slot, ArrayRef<uint64_t> Args,
714 StringRef Name, IntegerType *IntTy,
715 uint32_t Storage);
716
717 Constant *getMemberAddr(const TypeMemberInfo *M);
718
719 void applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, bool IsOne,
720 Constant *UniqueMemberAddr);
721 bool tryUniqueRetValOpt(unsigned BitWidth,
722 MutableArrayRef<VirtualCallTarget> TargetsForSlot,
723 CallSiteInfo &CSInfo,
724 WholeProgramDevirtResolution::ByArg *Res,
725 VTableSlot Slot, ArrayRef<uint64_t> Args);
726
727 void applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName,
728 Constant *Byte, Constant *Bit);
729 bool tryVirtualConstProp(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
730 VTableSlotInfo &SlotInfo,
731 WholeProgramDevirtResolution *Res, VTableSlot Slot);
732
733 void rebuildGlobal(VTableBits &B);
734
735 // Apply the summary resolution for Slot to all virtual calls in SlotInfo.
736 void importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo);
737
738 // If we were able to eliminate all unsafe uses for a type checked load,
739 // eliminate the associated type tests by replacing them with true.
740 void removeRedundantTypeTests();
741
742 bool run();
743
744 // Look up the corresponding ValueInfo entry of `TheFn` in `ExportSummary`.
745 //
746 // Caller guarantees that `ExportSummary` is not nullptr.
747 static ValueInfo lookUpFunctionValueInfo(Function *TheFn,
748 ModuleSummaryIndex *ExportSummary);
749
750 // Returns true if the function definition must be unreachable.
751 //
752 // Note if this helper function returns true, `F` is guaranteed
753 // to be unreachable; if it returns false, `F` might still
754 // be unreachable but not covered by this helper function.
755 //
756 // Implementation-wise, if function definition is present, IR is analyzed; if
757 // not, look up function flags from ExportSummary as a fallback.
758 static bool mustBeUnreachableFunction(Function *const F,
759 ModuleSummaryIndex *ExportSummary);
760
761 // Lower the module using the action and summary passed as command line
762 // arguments. For testing purposes only.
763 static bool runForTesting(Module &M, ModuleAnalysisManager &MAM,
764 bool DevirtSpeculatively);
765};
766
767struct DevirtIndex {
768 ModuleSummaryIndex &ExportSummary;
769 // The set in which to record GUIDs exported from their module by
770 // devirtualization, used by client to ensure they are not internalized.
771 std::set<GlobalValue::GUID> &ExportedGUIDs;
772 // A map in which to record the information necessary to locate the WPD
773 // resolution for local targets in case they are exported by cross module
774 // importing.
775 std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap;
776 // We have hardcoded the promoted and renamed function name in the WPD
777 // summary, so we need to ensure that they will be renamed. Note this and
778 // that adding the current names to this set ensures we continue to rename
779 // them.
780 DenseSet<StringRef> *ExternallyVisibleSymbolNamesPtr;
781
782 MapVector<VTableSlotSummary, VTableSlotInfo> CallSlots;
783
784 PatternList FunctionsToSkip;
785
786 DevirtIndex(
787 ModuleSummaryIndex &ExportSummary,
788 std::set<GlobalValue::GUID> &ExportedGUIDs,
789 std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap,
790 DenseSet<StringRef> *ExternallyVisibleSymbolNamesPtr)
791 : ExportSummary(ExportSummary), ExportedGUIDs(ExportedGUIDs),
792 LocalWPDTargetsMap(LocalWPDTargetsMap),
793 ExternallyVisibleSymbolNamesPtr(ExternallyVisibleSymbolNamesPtr) {
794 FunctionsToSkip.init(StringList: SkipFunctionNames);
795 }
796
797 bool tryFindVirtualCallTargets(std::vector<ValueInfo> &TargetsForSlot,
798 const TypeIdCompatibleVtableInfo TIdInfo,
799 uint64_t ByteOffset);
800
801 bool trySingleImplDevirt(MutableArrayRef<ValueInfo> TargetsForSlot,
802 VTableSlotSummary &SlotSummary,
803 VTableSlotInfo &SlotInfo,
804 WholeProgramDevirtResolution *Res,
805 std::set<ValueInfo> &DevirtTargets);
806
807 void run();
808};
809} // end anonymous namespace
810
811PreservedAnalyses WholeProgramDevirtPass::run(Module &M,
812 ModuleAnalysisManager &MAM) {
813 if (UseCommandLine) {
814 if (!DevirtModule::runForTesting(M, MAM, DevirtSpeculatively: ClDevirtualizeSpeculatively))
815 return PreservedAnalyses::all();
816 return PreservedAnalyses::none();
817 }
818
819 std::optional<ModuleSummaryIndex> Index;
820 if (!ExportSummary && !ImportSummary && DevirtSpeculatively) {
821 // Build the ExportSummary from the module.
822 assert(!ExportSummary &&
823 "ExportSummary is expected to be empty in non-LTO mode");
824 ProfileSummaryInfo PSI(M);
825 Index.emplace(args: buildModuleSummaryIndex(M, GetBFICallback: nullptr, PSI: &PSI));
826 ExportSummary = Index.has_value() ? &Index.value() : nullptr;
827 }
828 if (!DevirtModule(M, MAM, ExportSummary, ImportSummary, DevirtSpeculatively)
829 .run())
830 return PreservedAnalyses::all();
831 return PreservedAnalyses::none();
832}
833
834// Enable whole program visibility if enabled by client (e.g. linker) or
835// internal option, and not force disabled.
836bool llvm::hasWholeProgramVisibility(bool WholeProgramVisibilityEnabledInLTO) {
837 return (WholeProgramVisibilityEnabledInLTO || WholeProgramVisibility) &&
838 !DisableWholeProgramVisibility;
839}
840
841static bool
842typeIDVisibleToRegularObj(StringRef TypeID,
843 function_ref<bool(StringRef)> IsVisibleToRegularObj) {
844 // TypeID for member function pointer type is an internal construct
845 // and won't exist in IsVisibleToRegularObj. The full TypeID
846 // will be present and participate in invalidation.
847 if (TypeID.ends_with(Suffix: ".virtual"))
848 return false;
849
850 // TypeID that doesn't start with Itanium mangling (_ZTS) will be
851 // non-externally visible types which cannot interact with
852 // external native files. See CodeGenModule::CreateMetadataIdentifierImpl.
853 if (!TypeID.consume_front(Prefix: "_ZTS"))
854 return false;
855
856 // TypeID is keyed off the type name symbol (_ZTS). However, the native
857 // object may not contain this symbol if it does not contain a key
858 // function for the base type and thus only contains a reference to the
859 // type info (_ZTI). To catch this case we query using the type info
860 // symbol corresponding to the TypeID.
861 std::string TypeInfo = ("_ZTI" + TypeID).str();
862 return IsVisibleToRegularObj(TypeInfo);
863}
864
865static bool
866skipUpdateDueToValidation(GlobalVariable &GV,
867 function_ref<bool(StringRef)> IsVisibleToRegularObj) {
868 SmallVector<MDNode *, 2> Types;
869 GV.getMetadata(KindID: LLVMContext::MD_type, MDs&: Types);
870
871 for (auto *Type : Types)
872 if (auto *TypeID = dyn_cast<MDString>(Val: Type->getOperand(I: 1).get()))
873 return typeIDVisibleToRegularObj(TypeID: TypeID->getString(),
874 IsVisibleToRegularObj);
875
876 return false;
877}
878
879/// If whole program visibility asserted, then upgrade all public vcall
880/// visibility metadata on vtable definitions to linkage unit visibility in
881/// Module IR (for regular or hybrid LTO).
882void llvm::updateVCallVisibilityInModule(
883 Module &M, bool WholeProgramVisibilityEnabledInLTO,
884 const DenseSet<GlobalValue::GUID> &DynamicExportSymbols,
885 bool ValidateAllVtablesHaveTypeInfos,
886 function_ref<bool(StringRef)> IsVisibleToRegularObj) {
887 if (!hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO))
888 return;
889 for (GlobalVariable &GV : M.globals()) {
890 // Add linkage unit visibility to any variable with type metadata, which are
891 // the vtable definitions. We won't have an existing vcall_visibility
892 // metadata on vtable definitions with public visibility.
893 if (GV.hasMetadata(KindID: LLVMContext::MD_type) &&
894 GV.getVCallVisibility() == GlobalObject::VCallVisibilityPublic &&
895 // Don't upgrade the visibility for symbols exported to the dynamic
896 // linker, as we have no information on their eventual use.
897 !DynamicExportSymbols.count(V: GV.getGUID()) &&
898 // With validation enabled, we want to exclude symbols visible to
899 // regular objects. Local symbols will be in this group due to the
900 // current implementation but those with VCallVisibilityTranslationUnit
901 // will have already been marked in clang so are unaffected.
902 !(ValidateAllVtablesHaveTypeInfos &&
903 skipUpdateDueToValidation(GV, IsVisibleToRegularObj)))
904 GV.setVCallVisibilityMetadata(GlobalObject::VCallVisibilityLinkageUnit);
905 }
906}
907
908void llvm::updatePublicTypeTestCalls(Module &M,
909 bool WholeProgramVisibilityEnabledInLTO) {
910 llvm::TimeTraceScope timeScope("Update public type test calls");
911 Function *PublicTypeTestFunc =
912 Intrinsic::getDeclarationIfExists(M: &M, id: Intrinsic::public_type_test);
913 if (!PublicTypeTestFunc)
914 return;
915 if (hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO)) {
916 Function *TypeTestFunc =
917 Intrinsic::getOrInsertDeclaration(M: &M, id: Intrinsic::type_test);
918 for (Use &U : make_early_inc_range(Range: PublicTypeTestFunc->uses())) {
919 auto *CI = cast<CallInst>(Val: U.getUser());
920 auto *NewCI = CallInst::Create(
921 Func: TypeTestFunc, Args: {CI->getArgOperand(i: 0), CI->getArgOperand(i: 1)}, Bundles: {}, NameStr: "",
922 InsertBefore: CI->getIterator());
923 CI->replaceAllUsesWith(V: NewCI);
924 CI->eraseFromParent();
925 }
926 } else {
927 // TODO: Don't replace public type tests when speculative devirtualization
928 // gets enabled in LTO mode.
929 auto *True = ConstantInt::getTrue(Context&: M.getContext());
930 for (Use &U : make_early_inc_range(Range: PublicTypeTestFunc->uses())) {
931 auto *CI = cast<CallInst>(Val: U.getUser());
932 CI->replaceAllUsesWith(V: True);
933 CI->eraseFromParent();
934 }
935 }
936}
937
938/// Based on typeID string, get all associated vtable GUIDS that are
939/// visible to regular objects.
940void llvm::getVisibleToRegularObjVtableGUIDs(
941 ModuleSummaryIndex &Index,
942 DenseSet<GlobalValue::GUID> &VisibleToRegularObjSymbols,
943 function_ref<bool(StringRef)> IsVisibleToRegularObj) {
944 for (const auto &TypeID : Index.typeIdCompatibleVtableMap()) {
945 if (typeIDVisibleToRegularObj(TypeID: TypeID.first, IsVisibleToRegularObj))
946 for (const TypeIdOffsetVtableInfo &P : TypeID.second)
947 VisibleToRegularObjSymbols.insert(V: P.VTableVI.getGUID());
948 }
949}
950
951/// If whole program visibility asserted, then upgrade all public vcall
952/// visibility metadata on vtable definition summaries to linkage unit
953/// visibility in Module summary index (for ThinLTO).
954void llvm::updateVCallVisibilityInIndex(
955 ModuleSummaryIndex &Index, bool WholeProgramVisibilityEnabledInLTO,
956 const DenseSet<GlobalValue::GUID> &DynamicExportSymbols,
957 const DenseSet<GlobalValue::GUID> &VisibleToRegularObjSymbols) {
958 if (!hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO))
959 return;
960 for (auto &P : Index) {
961 // Don't upgrade the visibility for symbols exported to the dynamic
962 // linker, as we have no information on their eventual use.
963 if (DynamicExportSymbols.count(V: P.first))
964 continue;
965 // With validation enabled, we want to exclude symbols visible to regular
966 // objects. Local symbols will be in this group due to the current
967 // implementation but those with VCallVisibilityTranslationUnit will have
968 // already been marked in clang so are unaffected.
969 if (VisibleToRegularObjSymbols.count(V: P.first))
970 continue;
971 for (auto &S : P.second.getSummaryList()) {
972 auto *GVar = dyn_cast<GlobalVarSummary>(Val: S.get());
973 if (!GVar ||
974 GVar->getVCallVisibility() != GlobalObject::VCallVisibilityPublic)
975 continue;
976 GVar->setVCallVisibility(GlobalObject::VCallVisibilityLinkageUnit);
977 }
978 }
979}
980
981void llvm::runWholeProgramDevirtOnIndex(
982 ModuleSummaryIndex &Summary, std::set<GlobalValue::GUID> &ExportedGUIDs,
983 std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap,
984 DenseSet<StringRef> *ExternallyVisibleSymbolNamesPtr) {
985 DevirtIndex(Summary, ExportedGUIDs, LocalWPDTargetsMap,
986 ExternallyVisibleSymbolNamesPtr)
987 .run();
988}
989
990void llvm::updateIndexWPDForExports(
991 ModuleSummaryIndex &Summary,
992 function_ref<bool(StringRef, ValueInfo)> IsExported,
993 std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap,
994 DenseSet<StringRef> *ExternallyVisibleSymbolNamesPtr) {
995 for (auto &T : LocalWPDTargetsMap) {
996 auto &VI = T.first;
997 // This was enforced earlier during trySingleImplDevirt.
998 assert(VI.getSummaryList().size() == 1 &&
999 "Devirt of local target has more than one copy");
1000 auto &S = VI.getSummaryList()[0];
1001 if (!IsExported(S->modulePath(), VI))
1002 continue;
1003
1004 // It's been exported by a cross module import.
1005 for (auto &SlotSummary : T.second) {
1006 auto *TIdSum = Summary.getTypeIdSummary(TypeId: SlotSummary.TypeID);
1007 assert(TIdSum);
1008 auto WPDRes = TIdSum->WPDRes.find(x: SlotSummary.ByteOffset);
1009 assert(WPDRes != TIdSum->WPDRes.end());
1010 if (ExternallyVisibleSymbolNamesPtr)
1011 ExternallyVisibleSymbolNamesPtr->insert(V: WPDRes->second.SingleImplName);
1012 WPDRes->second.SingleImplName = ModuleSummaryIndex::getGlobalNameForLocal(
1013 Name: WPDRes->second.SingleImplName,
1014 ModHash: Summary.getModuleHash(ModPath: S->modulePath()));
1015 }
1016 }
1017}
1018
1019static Error checkCombinedSummaryForTesting(ModuleSummaryIndex *Summary) {
1020 // Check that summary index contains regular LTO module when performing
1021 // export to prevent occasional use of index from pure ThinLTO compilation
1022 // (-fno-split-lto-module). This kind of summary index is passed to
1023 // DevirtIndex::run, not to DevirtModule::run used by opt/runForTesting.
1024 const auto &ModPaths = Summary->modulePaths();
1025 if (ClSummaryAction != PassSummaryAction::Import &&
1026 !ModPaths.contains(Key: ModuleSummaryIndex::getRegularLTOModuleName()))
1027 return createStringError(
1028 EC: errc::invalid_argument,
1029 S: "combined summary should contain Regular LTO module");
1030 return ErrorSuccess();
1031}
1032
1033bool DevirtModule::runForTesting(Module &M, ModuleAnalysisManager &MAM,
1034 bool DevirtSpeculatively) {
1035 std::unique_ptr<ModuleSummaryIndex> Summary =
1036 std::make_unique<ModuleSummaryIndex>(/*HaveGVs=*/args: false);
1037
1038 // Handle the command-line summary arguments. This code is for testing
1039 // purposes only, so we handle errors directly.
1040 if (!ClReadSummary.empty()) {
1041 ExitOnError ExitOnErr("-wholeprogramdevirt-read-summary: " + ClReadSummary +
1042 ": ");
1043 auto ReadSummaryFile =
1044 ExitOnErr(errorOrToExpected(EO: MemoryBuffer::getFile(Filename: ClReadSummary)));
1045 if (Expected<std::unique_ptr<ModuleSummaryIndex>> SummaryOrErr =
1046 getModuleSummaryIndex(Buffer: *ReadSummaryFile)) {
1047 Summary = std::move(*SummaryOrErr);
1048 ExitOnErr(checkCombinedSummaryForTesting(Summary: Summary.get()));
1049 } else {
1050 // Try YAML if we've failed with bitcode.
1051 consumeError(Err: SummaryOrErr.takeError());
1052 yaml::Input In(ReadSummaryFile->getBuffer());
1053 In >> *Summary;
1054 ExitOnErr(errorCodeToError(EC: In.error()));
1055 }
1056 }
1057
1058 bool Changed =
1059 DevirtModule(M, MAM,
1060 ClSummaryAction == PassSummaryAction::Export ? Summary.get()
1061 : nullptr,
1062 ClSummaryAction == PassSummaryAction::Import ? Summary.get()
1063 : nullptr,
1064 DevirtSpeculatively)
1065 .run();
1066
1067 if (!ClWriteSummary.empty()) {
1068 ExitOnError ExitOnErr(
1069 "-wholeprogramdevirt-write-summary: " + ClWriteSummary + ": ");
1070 std::error_code EC;
1071 if (StringRef(ClWriteSummary).ends_with(Suffix: ".bc")) {
1072 raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::OF_None);
1073 ExitOnErr(errorCodeToError(EC));
1074 writeIndexToFile(Index: *Summary, Out&: OS);
1075 } else {
1076 raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::OF_TextWithCRLF);
1077 ExitOnErr(errorCodeToError(EC));
1078 yaml::Output Out(OS);
1079 Out << *Summary;
1080 }
1081 }
1082
1083 return Changed;
1084}
1085
1086void DevirtModule::buildTypeIdentifierMap(
1087 std::vector<VTableBits> &Bits,
1088 DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) {
1089 DenseMap<GlobalVariable *, VTableBits *> GVToBits;
1090 Bits.reserve(n: M.global_size());
1091 SmallVector<MDNode *, 2> Types;
1092 for (GlobalVariable &GV : M.globals()) {
1093 Types.clear();
1094 GV.getMetadata(KindID: LLVMContext::MD_type, MDs&: Types);
1095 if (GV.isDeclaration() || Types.empty())
1096 continue;
1097
1098 VTableBits *&BitsPtr = GVToBits[&GV];
1099 if (!BitsPtr) {
1100 Bits.emplace_back();
1101 Bits.back().GV = &GV;
1102 Bits.back().ObjectSize =
1103 M.getDataLayout().getTypeAllocSize(Ty: GV.getInitializer()->getType());
1104 BitsPtr = &Bits.back();
1105 }
1106
1107 for (MDNode *Type : Types) {
1108 auto *TypeID = Type->getOperand(I: 1).get();
1109
1110 uint64_t Offset =
1111 cast<ConstantInt>(
1112 Val: cast<ConstantAsMetadata>(Val: Type->getOperand(I: 0))->getValue())
1113 ->getZExtValue();
1114
1115 TypeIdMap[TypeID].insert(x: {.Bits: BitsPtr, .Offset: Offset});
1116 }
1117 }
1118}
1119
1120bool DevirtModule::tryFindVirtualCallTargets(
1121 std::vector<VirtualCallTarget> &TargetsForSlot,
1122 const std::set<TypeMemberInfo> &TypeMemberInfos, uint64_t ByteOffset,
1123 ModuleSummaryIndex *ExportSummary) {
1124 for (const TypeMemberInfo &TM : TypeMemberInfos) {
1125 if (!TM.Bits->GV->isConstant())
1126 return false;
1127
1128 // Without DevirtSpeculatively, we cannot perform whole program
1129 // devirtualization analysis on a vtable with public LTO visibility.
1130 if (!DevirtSpeculatively && TM.Bits->GV->getVCallVisibility() ==
1131 GlobalObject::VCallVisibilityPublic)
1132 return false;
1133
1134 Function *Fn = nullptr;
1135 Constant *C = nullptr;
1136 std::tie(args&: Fn, args&: C) =
1137 getFunctionAtVTableOffset(GV: TM.Bits->GV, Offset: TM.Offset + ByteOffset, M);
1138
1139 if (!Fn)
1140 return false;
1141
1142 if (FunctionsToSkip.match(S: Fn->getName()))
1143 return false;
1144
1145 // We can disregard __cxa_pure_virtual as a possible call target, as
1146 // calls to pure virtuals are UB.
1147 if (Fn->getName() == "__cxa_pure_virtual")
1148 continue;
1149
1150 // In most cases empty functions will be overridden by the
1151 // implementation of the derived class, so we can skip them.
1152 if (DevirtSpeculatively && Fn->getReturnType()->isVoidTy() &&
1153 Fn->getInstructionCount() <= 1)
1154 continue;
1155
1156 // We can disregard unreachable functions as possible call targets, as
1157 // unreachable functions shouldn't be called.
1158 if (mustBeUnreachableFunction(F: Fn, ExportSummary))
1159 continue;
1160
1161 // Save the symbol used in the vtable to use as the devirtualization
1162 // target.
1163 auto *GV = dyn_cast<GlobalValue>(Val: C);
1164 assert(GV);
1165 TargetsForSlot.push_back(x: {GV, &TM});
1166 }
1167
1168 // Give up if we couldn't find any targets.
1169 return !TargetsForSlot.empty();
1170}
1171
1172bool DevirtIndex::tryFindVirtualCallTargets(
1173 std::vector<ValueInfo> &TargetsForSlot,
1174 const TypeIdCompatibleVtableInfo TIdInfo, uint64_t ByteOffset) {
1175 for (const TypeIdOffsetVtableInfo &P : TIdInfo) {
1176 // Find a representative copy of the vtable initializer.
1177 // We can have multiple available_externally, linkonce_odr and weak_odr
1178 // vtable initializers. We can also have multiple external vtable
1179 // initializers in the case of comdats, which we cannot check here.
1180 // The linker should give an error in this case.
1181 //
1182 // Also, handle the case of same-named local Vtables with the same path
1183 // and therefore the same GUID. This can happen if there isn't enough
1184 // distinguishing path when compiling the source file. In that case we
1185 // conservatively return false early.
1186 if (P.VTableVI.hasLocal() && P.VTableVI.getSummaryList().size() > 1)
1187 return false;
1188 const GlobalVarSummary *VS = nullptr;
1189 for (const auto &S : P.VTableVI.getSummaryList()) {
1190 auto *CurVS = cast<GlobalVarSummary>(Val: S->getBaseObject());
1191 if (!CurVS->vTableFuncs().empty() ||
1192 // Previously clang did not attach the necessary type metadata to
1193 // available_externally vtables, in which case there would not
1194 // be any vtable functions listed in the summary and we need
1195 // to treat this case conservatively (in case the bitcode is old).
1196 // However, we will also not have any vtable functions in the
1197 // case of a pure virtual base class. In that case we do want
1198 // to set VS to avoid treating it conservatively.
1199 !GlobalValue::isAvailableExternallyLinkage(Linkage: S->linkage())) {
1200 VS = CurVS;
1201 // We cannot perform whole program devirtualization analysis on a vtable
1202 // with public LTO visibility.
1203 if (VS->getVCallVisibility() == GlobalObject::VCallVisibilityPublic)
1204 return false;
1205 break;
1206 }
1207 }
1208 // There will be no VS if all copies are available_externally having no
1209 // type metadata. In that case we can't safely perform WPD.
1210 if (!VS)
1211 return false;
1212 if (!VS->isLive())
1213 continue;
1214 for (auto VTP : VS->vTableFuncs()) {
1215 if (VTP.VTableOffset != P.AddressPointOffset + ByteOffset)
1216 continue;
1217
1218 if (mustBeUnreachableFunction(TheFnVI: VTP.FuncVI))
1219 continue;
1220
1221 TargetsForSlot.push_back(x: VTP.FuncVI);
1222 }
1223 }
1224
1225 // Give up if we couldn't find any targets.
1226 return !TargetsForSlot.empty();
1227}
1228
1229void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo,
1230 Constant *TheFn, bool &IsExported) {
1231 // Don't devirtualize function if we're told to skip it
1232 // in -wholeprogramdevirt-skip.
1233 if (FunctionsToSkip.match(S: TheFn->stripPointerCasts()->getName()))
1234 return;
1235 auto Apply = [&](CallSiteInfo &CSInfo) {
1236 for (auto &&VCallSite : CSInfo.CallSites) {
1237 if (!OptimizedCalls.insert(Ptr: &VCallSite.CB).second)
1238 continue;
1239
1240 // Stop when the number of devirted calls reaches the cutoff.
1241 if (!DebugCounter::shouldExecute(Counter&: CallsToDevirt))
1242 continue;
1243
1244 if (RemarksEnabled)
1245 VCallSite.emitRemark(OptName: "single-impl",
1246 TargetName: TheFn->stripPointerCasts()->getName(), OREGetter);
1247 NumSingleImpl++;
1248 auto &CB = VCallSite.CB;
1249 assert(!CB.getCalledFunction() && "devirtualizing direct call?");
1250 IRBuilder<> Builder(&CB);
1251 Value *Callee =
1252 Builder.CreateBitCast(V: TheFn, DestTy: CB.getCalledOperand()->getType());
1253
1254 // If trap checking is enabled, add support to compare the virtual
1255 // function pointer to the devirtualized target. In case of a mismatch,
1256 // perform a debug trap.
1257 if (DevirtCheckMode == WPDCheckMode::Trap) {
1258 auto *Cond = Builder.CreateICmpNE(LHS: CB.getCalledOperand(), RHS: Callee);
1259 Instruction *ThenTerm = SplitBlockAndInsertIfThen(
1260 Cond, SplitBefore: &CB, /*Unreachable=*/false,
1261 BranchWeights: MDBuilder(M.getContext()).createUnlikelyBranchWeights());
1262 Builder.SetInsertPoint(ThenTerm);
1263 Function *TrapFn =
1264 Intrinsic::getOrInsertDeclaration(M: &M, id: Intrinsic::debugtrap);
1265 auto *CallTrap = Builder.CreateCall(Callee: TrapFn);
1266 CallTrap->setDebugLoc(CB.getDebugLoc());
1267 }
1268
1269 // If fallback checking or speculative devirtualization are enabled,
1270 // add support to compare the virtual function pointer to the
1271 // devirtualized target. In case of a mismatch, fall back to indirect
1272 // call.
1273 if (DevirtCheckMode == WPDCheckMode::Fallback || DevirtSpeculatively) {
1274 MDNode *Weights = MDBuilder(M.getContext()).createLikelyBranchWeights();
1275 // Version the indirect call site. If the called value is equal to the
1276 // given callee, 'NewInst' will be executed, otherwise the original call
1277 // site will be executed.
1278 CallBase &NewInst = versionCallSite(CB, Callee, BranchWeights: Weights);
1279 NewInst.setCalledOperand(Callee);
1280 // Since the new call site is direct, we must clear metadata that
1281 // is only appropriate for indirect calls. This includes !prof and
1282 // !callees metadata.
1283 NewInst.setMetadata(KindID: LLVMContext::MD_prof, Node: nullptr);
1284 NewInst.setMetadata(KindID: LLVMContext::MD_callees, Node: nullptr);
1285 // Additionally, we should remove them from the fallback indirect call,
1286 // so that we don't attempt to perform indirect call promotion later.
1287 CB.setMetadata(KindID: LLVMContext::MD_prof, Node: nullptr);
1288 CB.setMetadata(KindID: LLVMContext::MD_callees, Node: nullptr);
1289 }
1290
1291 // In either trapping or non-checking mode, devirtualize original call.
1292 else {
1293 // Devirtualize unconditionally.
1294 CB.setCalledOperand(Callee);
1295 // Since the call site is now direct, we must clear metadata that
1296 // is only appropriate for indirect calls. This includes !prof and
1297 // !callees metadata.
1298 CB.setMetadata(KindID: LLVMContext::MD_prof, Node: nullptr);
1299 CB.setMetadata(KindID: LLVMContext::MD_callees, Node: nullptr);
1300 if (CB.getCalledOperand() &&
1301 CB.getOperandBundle(ID: LLVMContext::OB_ptrauth)) {
1302 auto *NewCS = CallBase::removeOperandBundle(
1303 CB: &CB, ID: LLVMContext::OB_ptrauth, InsertPt: CB.getIterator());
1304 CB.replaceAllUsesWith(V: NewCS);
1305 // Schedule for deletion at the end of pass run.
1306 CallsWithPtrAuthBundleRemoved.push_back(Elt: &CB);
1307 }
1308 }
1309
1310 // This use is no longer unsafe.
1311 if (VCallSite.NumUnsafeUses)
1312 --*VCallSite.NumUnsafeUses;
1313 }
1314 if (CSInfo.isExported())
1315 IsExported = true;
1316 CSInfo.markDevirt();
1317 };
1318 Apply(SlotInfo.CSInfo);
1319 for (auto &P : SlotInfo.ConstCSInfo)
1320 Apply(P.second);
1321}
1322
1323static bool addCalls(VTableSlotInfo &SlotInfo, const ValueInfo &Callee) {
1324 // We can't add calls if we haven't seen a definition
1325 if (Callee.getSummaryList().empty())
1326 return false;
1327
1328 // Insert calls into the summary index so that the devirtualized targets
1329 // are eligible for import.
1330 // FIXME: Annotate type tests with hotness. For now, mark these as hot
1331 // to better ensure we have the opportunity to inline them.
1332 bool IsExported = false;
1333 auto &S = Callee.getSummaryList()[0];
1334 CalleeInfo CI(CalleeInfo::HotnessType::Hot, /* HasTailCall = */ false);
1335 auto AddCalls = [&](CallSiteInfo &CSInfo) {
1336 for (auto *FS : CSInfo.SummaryTypeCheckedLoadUsers) {
1337 FS->addCall(E: {Callee, CI});
1338 IsExported |= S->modulePath() != FS->modulePath();
1339 }
1340 for (auto *FS : CSInfo.SummaryTypeTestAssumeUsers) {
1341 FS->addCall(E: {Callee, CI});
1342 IsExported |= S->modulePath() != FS->modulePath();
1343 }
1344 };
1345 AddCalls(SlotInfo.CSInfo);
1346 for (auto &P : SlotInfo.ConstCSInfo)
1347 AddCalls(P.second);
1348 return IsExported;
1349}
1350
1351bool DevirtModule::trySingleImplDevirt(
1352 ModuleSummaryIndex *ExportSummary,
1353 MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo,
1354 WholeProgramDevirtResolution *Res) {
1355 // See if the program contains a single implementation of this virtual
1356 // function.
1357 auto *TheFn = TargetsForSlot[0].Fn;
1358 for (auto &&Target : TargetsForSlot)
1359 if (TheFn != Target.Fn)
1360 return false;
1361
1362 // If so, update each call site to call that implementation directly.
1363 if (RemarksEnabled || AreStatisticsEnabled())
1364 TargetsForSlot[0].WasDevirt = true;
1365
1366 bool IsExported = false;
1367 applySingleImplDevirt(SlotInfo, TheFn, IsExported);
1368 if (!IsExported)
1369 return false;
1370
1371 // If the only implementation has local linkage, we must promote to external
1372 // to make it visible to thin LTO objects. We can only get here during the
1373 // ThinLTO export phase.
1374 if (TheFn->hasLocalLinkage()) {
1375 std::string NewName = (TheFn->getName() + ".llvm.merged").str();
1376
1377 // Since we are renaming the function, any comdats with the same name must
1378 // also be renamed. This is required when targeting COFF, as the comdat name
1379 // must match one of the names of the symbols in the comdat.
1380 if (Comdat *C = TheFn->getComdat()) {
1381 if (C->getName() == TheFn->getName()) {
1382 Comdat *NewC = M.getOrInsertComdat(Name: NewName);
1383 NewC->setSelectionKind(C->getSelectionKind());
1384 for (GlobalObject &GO : M.global_objects())
1385 if (GO.getComdat() == C)
1386 GO.setComdat(NewC);
1387 }
1388 }
1389
1390 TheFn->setLinkage(GlobalValue::ExternalLinkage);
1391 TheFn->setVisibility(GlobalValue::HiddenVisibility);
1392 TheFn->setName(NewName);
1393 }
1394 if (ValueInfo TheFnVI = ExportSummary->getValueInfo(GUID: TheFn->getGUID()))
1395 // Any needed promotion of 'TheFn' has already been done during
1396 // LTO unit split, so we can ignore return value of AddCalls.
1397 addCalls(SlotInfo, Callee: TheFnVI);
1398
1399 Res->TheKind = WholeProgramDevirtResolution::SingleImpl;
1400 Res->SingleImplName = std::string(TheFn->getName());
1401
1402 return true;
1403}
1404
1405bool DevirtIndex::trySingleImplDevirt(MutableArrayRef<ValueInfo> TargetsForSlot,
1406 VTableSlotSummary &SlotSummary,
1407 VTableSlotInfo &SlotInfo,
1408 WholeProgramDevirtResolution *Res,
1409 std::set<ValueInfo> &DevirtTargets) {
1410 // See if the program contains a single implementation of this virtual
1411 // function.
1412 auto TheFn = TargetsForSlot[0];
1413 for (auto &&Target : TargetsForSlot)
1414 if (TheFn != Target)
1415 return false;
1416
1417 // Don't devirtualize if we don't have target definition.
1418 auto Size = TheFn.getSummaryList().size();
1419 if (!Size)
1420 return false;
1421
1422 // Don't devirtualize function if we're told to skip it
1423 // in -wholeprogramdevirt-skip.
1424 if (FunctionsToSkip.match(S: TheFn.name()))
1425 return false;
1426
1427 // If the summary list contains multiple summaries where at least one is
1428 // a local, give up, as we won't know which (possibly promoted) name to use.
1429 if (TheFn.hasLocal() && Size > 1)
1430 return false;
1431
1432 // Collect functions devirtualized at least for one call site for stats.
1433 if (PrintSummaryDevirt || AreStatisticsEnabled())
1434 DevirtTargets.insert(x: TheFn);
1435
1436 auto &S = TheFn.getSummaryList()[0];
1437 bool IsExported = addCalls(SlotInfo, Callee: TheFn);
1438 if (IsExported)
1439 ExportedGUIDs.insert(x: TheFn.getGUID());
1440
1441 // Record in summary for use in devirtualization during the ThinLTO import
1442 // step.
1443 Res->TheKind = WholeProgramDevirtResolution::SingleImpl;
1444 if (GlobalValue::isLocalLinkage(Linkage: S->linkage())) {
1445 if (IsExported) {
1446 // If target is a local function and we are exporting it by
1447 // devirtualizing a call in another module, we need to record the
1448 // promoted name.
1449 if (ExternallyVisibleSymbolNamesPtr)
1450 ExternallyVisibleSymbolNamesPtr->insert(V: TheFn.name());
1451 Res->SingleImplName = ModuleSummaryIndex::getGlobalNameForLocal(
1452 Name: TheFn.name(), ModHash: ExportSummary.getModuleHash(ModPath: S->modulePath()));
1453 } else {
1454 LocalWPDTargetsMap[TheFn].push_back(x: SlotSummary);
1455 Res->SingleImplName = std::string(TheFn.name());
1456 }
1457 } else
1458 Res->SingleImplName = std::string(TheFn.name());
1459
1460 // Name will be empty if this thin link driven off of serialized combined
1461 // index (e.g. llvm-lto). However, WPD is not supported/invoked for the
1462 // legacy LTO API anyway.
1463 assert(!Res->SingleImplName.empty());
1464
1465 return true;
1466}
1467
1468void DevirtModule::tryICallBranchFunnel(
1469 MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo,
1470 WholeProgramDevirtResolution *Res, VTableSlot Slot) {
1471 Triple T(M.getTargetTriple());
1472 if (T.getArch() != Triple::x86_64)
1473 return;
1474
1475 if (TargetsForSlot.size() > ClThreshold)
1476 return;
1477
1478 bool HasNonDevirt = !SlotInfo.CSInfo.AllCallSitesDevirted;
1479 if (!HasNonDevirt)
1480 for (auto &P : SlotInfo.ConstCSInfo)
1481 if (!P.second.AllCallSitesDevirted) {
1482 HasNonDevirt = true;
1483 break;
1484 }
1485
1486 if (!HasNonDevirt)
1487 return;
1488
1489 // If any GV is AvailableExternally, not to generate branch.funnel.
1490 // NOTE: It is to avoid crash in LowerTypeTest.
1491 // If the branch.funnel is generated, because GV.isDeclarationForLinker(),
1492 // in LowerTypeTestsModule::lower(), its GlobalTypeMember would NOT
1493 // be saved in GlobalTypeMembers[&GV]. Then crash happens in
1494 // buildBitSetsFromDisjointSet due to GlobalTypeMembers[&GV] is NULL.
1495 // Even doing experiment to save it in GlobalTypeMembers[&GV] and
1496 // making GlobalTypeMembers[&GV] be not NULL, crash could avoid from
1497 // buildBitSetsFromDisjointSet. But still report_fatal_error in Verifier
1498 // or SelectionDAGBuilder later, because operands linkage type consistency
1499 // check of icall.branch.funnel can not pass.
1500 for (auto &T : TargetsForSlot) {
1501 if (T.TM->Bits->GV->hasAvailableExternallyLinkage())
1502 return;
1503 }
1504
1505 FunctionType *FT =
1506 FunctionType::get(Result: Type::getVoidTy(C&: M.getContext()), Params: {Int8PtrTy}, isVarArg: true);
1507 Function *JT;
1508 if (isa<MDString>(Val: Slot.TypeID)) {
1509 JT = Function::Create(Ty: FT, Linkage: Function::ExternalLinkage,
1510 AddrSpace: M.getDataLayout().getProgramAddressSpace(),
1511 N: getGlobalName(Slot, Args: {}, Name: "branch_funnel"), M: &M);
1512 JT->setVisibility(GlobalValue::HiddenVisibility);
1513 } else {
1514 JT = Function::Create(Ty: FT, Linkage: Function::InternalLinkage,
1515 AddrSpace: M.getDataLayout().getProgramAddressSpace(),
1516 N: "branch_funnel", M: &M);
1517 }
1518 JT->addParamAttr(ArgNo: 0, Kind: Attribute::Nest);
1519
1520 std::vector<Value *> JTArgs;
1521 JTArgs.push_back(x: JT->arg_begin());
1522 for (auto &T : TargetsForSlot) {
1523 JTArgs.push_back(x: getMemberAddr(M: T.TM));
1524 JTArgs.push_back(x: T.Fn);
1525 }
1526
1527 BasicBlock *BB = BasicBlock::Create(Context&: M.getContext(), Name: "", Parent: JT, InsertBefore: nullptr);
1528 Function *Intr = Intrinsic::getOrInsertDeclaration(
1529 M: &M, id: llvm::Intrinsic::icall_branch_funnel, Tys: {});
1530
1531 auto *CI = CallInst::Create(Func: Intr, Args: JTArgs, NameStr: "", InsertBefore: BB);
1532 CI->setTailCallKind(CallInst::TCK_MustTail);
1533 ReturnInst::Create(C&: M.getContext(), retVal: nullptr, InsertBefore: BB);
1534
1535 bool IsExported = false;
1536 applyICallBranchFunnel(SlotInfo, JT&: *JT, IsExported);
1537 if (IsExported)
1538 Res->TheKind = WholeProgramDevirtResolution::BranchFunnel;
1539
1540 if (!JT->getEntryCount().has_value()) {
1541 // FIXME: we could pass through thinlto the necessary information.
1542 setExplicitlyUnknownFunctionEntryCount(F&: *JT, DEBUG_TYPE);
1543 }
1544}
1545
1546void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
1547 Function &JT, bool &IsExported) {
1548 DenseMap<Function *, double> FunctionEntryCounts;
1549 auto Apply = [&](CallSiteInfo &CSInfo) {
1550 if (CSInfo.isExported())
1551 IsExported = true;
1552 if (CSInfo.AllCallSitesDevirted)
1553 return;
1554
1555 std::map<CallBase *, CallBase *> CallBases;
1556 for (auto &&VCallSite : CSInfo.CallSites) {
1557 CallBase &CB = VCallSite.CB;
1558
1559 if (CallBases.find(x: &CB) != CallBases.end()) {
1560 // When finding devirtualizable calls, it's possible to find the same
1561 // vtable passed to multiple llvm.type.test or llvm.type.checked.load
1562 // calls, which can cause duplicate call sites to be recorded in
1563 // [Const]CallSites. If we've already found one of these
1564 // call instances, just ignore it. It will be replaced later.
1565 continue;
1566 }
1567
1568 // Jump tables are only profitable if the retpoline mitigation is enabled.
1569 Attribute FSAttr = CB.getCaller()->getFnAttribute(Kind: "target-features");
1570 if (!FSAttr.isValid() ||
1571 !FSAttr.getValueAsString().contains(Other: "+retpoline"))
1572 continue;
1573
1574 NumBranchFunnel++;
1575 if (RemarksEnabled)
1576 VCallSite.emitRemark(OptName: "branch-funnel", TargetName: JT.getName(), OREGetter);
1577
1578 // Pass the address of the vtable in the nest register, which is r10 on
1579 // x86_64.
1580 std::vector<Type *> NewArgs;
1581 NewArgs.push_back(x: Int8PtrTy);
1582 append_range(C&: NewArgs, R: CB.getFunctionType()->params());
1583 FunctionType *NewFT =
1584 FunctionType::get(Result: CB.getFunctionType()->getReturnType(), Params: NewArgs,
1585 isVarArg: CB.getFunctionType()->isVarArg());
1586 IRBuilder<> IRB(&CB);
1587 std::vector<Value *> Args;
1588 Args.push_back(x: VCallSite.VTable);
1589 llvm::append_range(C&: Args, R: CB.args());
1590
1591 CallBase *NewCS = nullptr;
1592 if (!JT.isDeclaration() && !ProfcheckDisableMetadataFixes) {
1593 // Accumulate the call frequencies of the original call site, and use
1594 // that as total entry count for the funnel function.
1595 auto &F = *CB.getCaller();
1596 auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(IR&: F);
1597 auto EC = BFI.getBlockFreq(BB: &F.getEntryBlock());
1598 auto CC = F.getEntryCount(/*AllowSynthetic=*/true);
1599 double CallCount = 0.0;
1600 if (EC.getFrequency() != 0 && CC && CC->getCount() != 0) {
1601 double CallFreq =
1602 static_cast<double>(
1603 BFI.getBlockFreq(BB: CB.getParent()).getFrequency()) /
1604 EC.getFrequency();
1605 CallCount = CallFreq * CC->getCount();
1606 }
1607 FunctionEntryCounts[&JT] += CallCount;
1608 }
1609 if (isa<CallInst>(Val: CB))
1610 NewCS = IRB.CreateCall(FTy: NewFT, Callee: &JT, Args);
1611 else
1612 NewCS =
1613 IRB.CreateInvoke(Ty: NewFT, Callee: &JT, NormalDest: cast<InvokeInst>(Val&: CB).getNormalDest(),
1614 UnwindDest: cast<InvokeInst>(Val&: CB).getUnwindDest(), Args);
1615 NewCS->setCallingConv(CB.getCallingConv());
1616
1617 AttributeList Attrs = CB.getAttributes();
1618 std::vector<AttributeSet> NewArgAttrs;
1619 NewArgAttrs.push_back(x: AttributeSet::get(
1620 C&: M.getContext(), Attrs: ArrayRef<Attribute>{Attribute::get(
1621 Context&: M.getContext(), Kind: Attribute::Nest)}));
1622 for (unsigned I = 0; I + 2 < Attrs.getNumAttrSets(); ++I)
1623 NewArgAttrs.push_back(x: Attrs.getParamAttrs(ArgNo: I));
1624 NewCS->setAttributes(
1625 AttributeList::get(C&: M.getContext(), FnAttrs: Attrs.getFnAttrs(),
1626 RetAttrs: Attrs.getRetAttrs(), ArgAttrs: NewArgAttrs));
1627
1628 CallBases[&CB] = NewCS;
1629
1630 // This use is no longer unsafe.
1631 if (VCallSite.NumUnsafeUses)
1632 --*VCallSite.NumUnsafeUses;
1633 }
1634 // Don't mark as devirtualized because there may be callers compiled without
1635 // retpoline mitigation, which would mean that they are lowered to
1636 // llvm.type.test and therefore require an llvm.type.test resolution for the
1637 // type identifier.
1638
1639 for (auto &[Old, New] : CallBases) {
1640 Old->replaceAllUsesWith(V: New);
1641 Old->eraseFromParent();
1642 }
1643 };
1644 Apply(SlotInfo.CSInfo);
1645 for (auto &P : SlotInfo.ConstCSInfo)
1646 Apply(P.second);
1647 for (auto &[F, C] : FunctionEntryCounts) {
1648 assert(!F->getEntryCount(/*AllowSynthetic=*/true) &&
1649 "Unexpected entry count for funnel that was freshly synthesized");
1650 F->setEntryCount(Count: static_cast<uint64_t>(std::round(x: C)));
1651 }
1652}
1653
1654bool DevirtModule::tryEvaluateFunctionsWithArgs(
1655 MutableArrayRef<VirtualCallTarget> TargetsForSlot,
1656 ArrayRef<uint64_t> Args) {
1657 // Evaluate each function and store the result in each target's RetVal
1658 // field.
1659 for (VirtualCallTarget &Target : TargetsForSlot) {
1660 // TODO: Skip for now if the vtable symbol was an alias to a function,
1661 // need to evaluate whether it would be correct to analyze the aliasee
1662 // function for this optimization.
1663 auto *Fn = dyn_cast<Function>(Val: Target.Fn);
1664 if (!Fn)
1665 return false;
1666
1667 if (Fn->arg_size() != Args.size() + 1)
1668 return false;
1669
1670 Evaluator Eval(M.getDataLayout(), nullptr);
1671 SmallVector<Constant *, 2> EvalArgs;
1672 EvalArgs.push_back(
1673 Elt: Constant::getNullValue(Ty: Fn->getFunctionType()->getParamType(i: 0)));
1674 for (unsigned I = 0; I != Args.size(); ++I) {
1675 auto *ArgTy =
1676 dyn_cast<IntegerType>(Val: Fn->getFunctionType()->getParamType(i: I + 1));
1677 if (!ArgTy)
1678 return false;
1679 EvalArgs.push_back(Elt: ConstantInt::get(Ty: ArgTy, V: Args[I]));
1680 }
1681
1682 Constant *RetVal;
1683 if (!Eval.EvaluateFunction(F: Fn, RetVal, ActualArgs: EvalArgs) ||
1684 !isa<ConstantInt>(Val: RetVal))
1685 return false;
1686 Target.RetVal = cast<ConstantInt>(Val: RetVal)->getZExtValue();
1687 }
1688 return true;
1689}
1690
1691void DevirtModule::applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
1692 uint64_t TheRetVal) {
1693 for (auto Call : CSInfo.CallSites) {
1694 if (!OptimizedCalls.insert(Ptr: &Call.CB).second)
1695 continue;
1696 NumUniformRetVal++;
1697 Call.replaceAndErase(
1698 OptName: "uniform-ret-val", TargetName: FnName, RemarksEnabled, OREGetter,
1699 New: ConstantInt::get(Ty: cast<IntegerType>(Val: Call.CB.getType()), V: TheRetVal));
1700 }
1701 CSInfo.markDevirt();
1702}
1703
1704bool DevirtModule::tryUniformRetValOpt(
1705 MutableArrayRef<VirtualCallTarget> TargetsForSlot, CallSiteInfo &CSInfo,
1706 WholeProgramDevirtResolution::ByArg *Res) {
1707 // Uniform return value optimization. If all functions return the same
1708 // constant, replace all calls with that constant.
1709 uint64_t TheRetVal = TargetsForSlot[0].RetVal;
1710 for (const VirtualCallTarget &Target : TargetsForSlot)
1711 if (Target.RetVal != TheRetVal)
1712 return false;
1713
1714 if (CSInfo.isExported()) {
1715 Res->TheKind = WholeProgramDevirtResolution::ByArg::UniformRetVal;
1716 Res->Info = TheRetVal;
1717 }
1718
1719 applyUniformRetValOpt(CSInfo, FnName: TargetsForSlot[0].Fn->getName(), TheRetVal);
1720 if (RemarksEnabled || AreStatisticsEnabled())
1721 for (auto &&Target : TargetsForSlot)
1722 Target.WasDevirt = true;
1723 return true;
1724}
1725
1726std::string DevirtModule::getGlobalName(VTableSlot Slot,
1727 ArrayRef<uint64_t> Args,
1728 StringRef Name) {
1729 std::string FullName = "__typeid_";
1730 raw_string_ostream OS(FullName);
1731 OS << cast<MDString>(Val: Slot.TypeID)->getString() << '_' << Slot.ByteOffset;
1732 for (uint64_t Arg : Args)
1733 OS << '_' << Arg;
1734 OS << '_' << Name;
1735 return FullName;
1736}
1737
1738bool DevirtModule::shouldExportConstantsAsAbsoluteSymbols() {
1739 Triple T(M.getTargetTriple());
1740 return T.isX86() && T.getObjectFormat() == Triple::ELF;
1741}
1742
1743void DevirtModule::exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args,
1744 StringRef Name, Constant *C) {
1745 GlobalAlias *GA = GlobalAlias::create(Ty: Int8Ty, AddressSpace: 0, Linkage: GlobalValue::ExternalLinkage,
1746 Name: getGlobalName(Slot, Args, Name), Aliasee: C, Parent: &M);
1747 GA->setVisibility(GlobalValue::HiddenVisibility);
1748}
1749
1750void DevirtModule::exportConstant(VTableSlot Slot, ArrayRef<uint64_t> Args,
1751 StringRef Name, uint32_t Const,
1752 uint32_t &Storage) {
1753 if (shouldExportConstantsAsAbsoluteSymbols()) {
1754 exportGlobal(
1755 Slot, Args, Name,
1756 C: ConstantExpr::getIntToPtr(C: ConstantInt::get(Ty: Int32Ty, V: Const), Ty: Int8PtrTy));
1757 return;
1758 }
1759
1760 Storage = Const;
1761}
1762
1763Constant *DevirtModule::importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args,
1764 StringRef Name) {
1765 GlobalVariable *GV =
1766 M.getOrInsertGlobal(Name: getGlobalName(Slot, Args, Name), Ty: Int8Arr0Ty);
1767 GV->setVisibility(GlobalValue::HiddenVisibility);
1768 return GV;
1769}
1770
1771Constant *DevirtModule::importConstant(VTableSlot Slot, ArrayRef<uint64_t> Args,
1772 StringRef Name, IntegerType *IntTy,
1773 uint32_t Storage) {
1774 if (!shouldExportConstantsAsAbsoluteSymbols())
1775 return ConstantInt::get(Ty: IntTy, V: Storage);
1776
1777 Constant *C = importGlobal(Slot, Args, Name);
1778 auto *GV = cast<GlobalVariable>(Val: C->stripPointerCasts());
1779 C = ConstantExpr::getPtrToInt(C, Ty: IntTy);
1780
1781 // We only need to set metadata if the global is newly created, in which
1782 // case it would not have hidden visibility.
1783 if (GV->hasMetadata(KindID: LLVMContext::MD_absolute_symbol))
1784 return C;
1785
1786 auto SetAbsRange = [&](uint64_t Min, uint64_t Max) {
1787 auto *MinC = ConstantAsMetadata::get(C: ConstantInt::get(Ty: IntPtrTy, V: Min));
1788 auto *MaxC = ConstantAsMetadata::get(C: ConstantInt::get(Ty: IntPtrTy, V: Max));
1789 GV->setMetadata(KindID: LLVMContext::MD_absolute_symbol,
1790 Node: MDNode::get(Context&: M.getContext(), MDs: {MinC, MaxC}));
1791 };
1792 unsigned AbsWidth = IntTy->getBitWidth();
1793 if (AbsWidth == IntPtrTy->getBitWidth()) {
1794 uint64_t AllOnes = IntTy->getBitMask();
1795 SetAbsRange(AllOnes, AllOnes); // Full set.
1796 } else {
1797 SetAbsRange(0, 1ull << AbsWidth);
1798 }
1799 return C;
1800}
1801
1802void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
1803 bool IsOne,
1804 Constant *UniqueMemberAddr) {
1805 for (auto &&Call : CSInfo.CallSites) {
1806 if (!OptimizedCalls.insert(Ptr: &Call.CB).second)
1807 continue;
1808 IRBuilder<> B(&Call.CB);
1809 Value *Cmp =
1810 B.CreateICmp(P: IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, LHS: Call.VTable,
1811 RHS: B.CreateBitCast(V: UniqueMemberAddr, DestTy: Call.VTable->getType()));
1812 Cmp = B.CreateZExt(V: Cmp, DestTy: Call.CB.getType());
1813 NumUniqueRetVal++;
1814 Call.replaceAndErase(OptName: "unique-ret-val", TargetName: FnName, RemarksEnabled, OREGetter,
1815 New: Cmp);
1816 }
1817 CSInfo.markDevirt();
1818}
1819
1820Constant *DevirtModule::getMemberAddr(const TypeMemberInfo *M) {
1821 return ConstantExpr::getPtrAdd(Ptr: M->Bits->GV,
1822 Offset: ConstantInt::get(Ty: Int64Ty, V: M->Offset));
1823}
1824
1825bool DevirtModule::tryUniqueRetValOpt(
1826 unsigned BitWidth, MutableArrayRef<VirtualCallTarget> TargetsForSlot,
1827 CallSiteInfo &CSInfo, WholeProgramDevirtResolution::ByArg *Res,
1828 VTableSlot Slot, ArrayRef<uint64_t> Args) {
1829 // IsOne controls whether we look for a 0 or a 1.
1830 auto tryUniqueRetValOptFor = [&](bool IsOne) {
1831 const TypeMemberInfo *UniqueMember = nullptr;
1832 for (const VirtualCallTarget &Target : TargetsForSlot) {
1833 if (Target.RetVal == (IsOne ? 1 : 0)) {
1834 if (UniqueMember)
1835 return false;
1836 UniqueMember = Target.TM;
1837 }
1838 }
1839
1840 // We should have found a unique member or bailed out by now. We already
1841 // checked for a uniform return value in tryUniformRetValOpt.
1842 assert(UniqueMember);
1843
1844 Constant *UniqueMemberAddr = getMemberAddr(M: UniqueMember);
1845 if (CSInfo.isExported()) {
1846 Res->TheKind = WholeProgramDevirtResolution::ByArg::UniqueRetVal;
1847 Res->Info = IsOne;
1848
1849 exportGlobal(Slot, Args, Name: "unique_member", C: UniqueMemberAddr);
1850 }
1851
1852 // Replace each call with the comparison.
1853 applyUniqueRetValOpt(CSInfo, FnName: TargetsForSlot[0].Fn->getName(), IsOne,
1854 UniqueMemberAddr);
1855
1856 // Update devirtualization statistics for targets.
1857 if (RemarksEnabled || AreStatisticsEnabled())
1858 for (auto &&Target : TargetsForSlot)
1859 Target.WasDevirt = true;
1860
1861 return true;
1862 };
1863
1864 if (BitWidth == 1) {
1865 if (tryUniqueRetValOptFor(true))
1866 return true;
1867 if (tryUniqueRetValOptFor(false))
1868 return true;
1869 }
1870 return false;
1871}
1872
1873void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName,
1874 Constant *Byte, Constant *Bit) {
1875 for (auto Call : CSInfo.CallSites) {
1876 if (!OptimizedCalls.insert(Ptr: &Call.CB).second)
1877 continue;
1878 auto *RetType = cast<IntegerType>(Val: Call.CB.getType());
1879 IRBuilder<> B(&Call.CB);
1880 Value *Addr = B.CreatePtrAdd(Ptr: Call.VTable, Offset: Byte);
1881 if (RetType->getBitWidth() == 1) {
1882 Value *Bits = B.CreateLoad(Ty: Int8Ty, Ptr: Addr);
1883 Value *BitsAndBit = B.CreateAnd(LHS: Bits, RHS: Bit);
1884 auto IsBitSet = B.CreateICmpNE(LHS: BitsAndBit, RHS: ConstantInt::get(Ty: Int8Ty, V: 0));
1885 NumVirtConstProp1Bit++;
1886 Call.replaceAndErase(OptName: "virtual-const-prop-1-bit", TargetName: FnName, RemarksEnabled,
1887 OREGetter, New: IsBitSet);
1888 } else {
1889 Value *Val = B.CreateLoad(Ty: RetType, Ptr: Addr);
1890 NumVirtConstProp++;
1891 Call.replaceAndErase(OptName: "virtual-const-prop", TargetName: FnName, RemarksEnabled,
1892 OREGetter, New: Val);
1893 }
1894 }
1895 CSInfo.markDevirt();
1896}
1897
1898bool DevirtModule::tryVirtualConstProp(
1899 MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo,
1900 WholeProgramDevirtResolution *Res, VTableSlot Slot) {
1901 // TODO: Skip for now if the vtable symbol was an alias to a function,
1902 // need to evaluate whether it would be correct to analyze the aliasee
1903 // function for this optimization.
1904 auto *Fn = dyn_cast<Function>(Val: TargetsForSlot[0].Fn);
1905 if (!Fn)
1906 return false;
1907 // This only works if the function returns an integer.
1908 auto *RetType = dyn_cast<IntegerType>(Val: Fn->getReturnType());
1909 if (!RetType)
1910 return false;
1911 unsigned BitWidth = RetType->getBitWidth();
1912
1913 // TODO: Since we can evaluated these constants at compile-time, we can save
1914 // some space by calculating the smallest range of values that all these
1915 // constants can fit in, then only allocate enough space to fit those values.
1916 // At each callsite, we can get the original type by doing a sign/zero
1917 // extension. For example, if we would store an i64, but we can see that all
1918 // the values fit into an i16, then we can store an i16 before/after the
1919 // vtable and at each callsite do a s/zext.
1920 if (BitWidth > 64)
1921 return false;
1922
1923 Align TypeAlignment = M.getDataLayout().getABIIntegerTypeAlignment(BitWidth);
1924
1925 // Make sure that each function is defined, does not access memory, takes at
1926 // least one argument, does not use its first argument (which we assume is
1927 // 'this'), and has the same return type.
1928 //
1929 // Note that we test whether this copy of the function is readnone, rather
1930 // than testing function attributes, which must hold for any copy of the
1931 // function, even a less optimized version substituted at link time. This is
1932 // sound because the virtual constant propagation optimizations effectively
1933 // inline all implementations of the virtual function into each call site,
1934 // rather than using function attributes to perform local optimization.
1935 for (VirtualCallTarget &Target : TargetsForSlot) {
1936 // TODO: Skip for now if the vtable symbol was an alias to a function,
1937 // need to evaluate whether it would be correct to analyze the aliasee
1938 // function for this optimization.
1939 auto *Fn = dyn_cast<Function>(Val: Target.Fn);
1940 if (!Fn)
1941 return false;
1942
1943 if (Fn->isDeclaration() ||
1944 !computeFunctionBodyMemoryAccess(F&: *Fn, AAR&: FAM.getResult<AAManager>(IR&: *Fn))
1945 .doesNotAccessMemory() ||
1946 Fn->arg_empty() || !Fn->arg_begin()->use_empty() ||
1947 Fn->getReturnType() != RetType)
1948 return false;
1949
1950 // This only works if the integer size is at most the alignment of the
1951 // vtable. If the table is underaligned, then we can't guarantee that the
1952 // constant will always be aligned to the integer type alignment. For
1953 // example, if the table is `align 1`, we can never guarantee that an i32
1954 // stored before/after the vtable is 32-bit aligned without changing the
1955 // alignment of the new global.
1956 GlobalVariable *GV = Target.TM->Bits->GV;
1957 Align TableAlignment = M.getDataLayout().getValueOrABITypeAlignment(
1958 Alignment: GV->getAlign(), Ty: GV->getValueType());
1959 if (TypeAlignment > TableAlignment)
1960 return false;
1961 }
1962
1963 for (auto &&CSByConstantArg : SlotInfo.ConstCSInfo) {
1964 if (!tryEvaluateFunctionsWithArgs(TargetsForSlot, Args: CSByConstantArg.first))
1965 continue;
1966
1967 WholeProgramDevirtResolution::ByArg *ResByArg = nullptr;
1968 if (Res)
1969 ResByArg = &Res->ResByArg[CSByConstantArg.first];
1970
1971 if (tryUniformRetValOpt(TargetsForSlot, CSInfo&: CSByConstantArg.second, Res: ResByArg))
1972 continue;
1973
1974 if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSInfo&: CSByConstantArg.second,
1975 Res: ResByArg, Slot, Args: CSByConstantArg.first))
1976 continue;
1977
1978 // Find an allocation offset in bits in all vtables associated with the
1979 // type.
1980 // TODO: If there would be "holes" in the vtable that were added by
1981 // padding, we could place i1s there to reduce any extra padding that
1982 // would be introduced by the i1s.
1983 uint64_t AllocBefore =
1984 findLowestOffset(Targets: TargetsForSlot, /*IsAfter=*/false, Size: BitWidth);
1985 uint64_t AllocAfter =
1986 findLowestOffset(Targets: TargetsForSlot, /*IsAfter=*/true, Size: BitWidth);
1987
1988 // Calculate the total amount of padding needed to store a value at both
1989 // ends of the object.
1990 uint64_t TotalPaddingBefore = 0, TotalPaddingAfter = 0;
1991 for (auto &&Target : TargetsForSlot) {
1992 TotalPaddingBefore += std::max<int64_t>(
1993 a: (AllocBefore + 7) / 8 - Target.allocatedBeforeBytes() - 1, b: 0);
1994 TotalPaddingAfter += std::max<int64_t>(
1995 a: (AllocAfter + 7) / 8 - Target.allocatedAfterBytes() - 1, b: 0);
1996 }
1997
1998 // If the amount of padding is too large, give up.
1999 // FIXME: do something smarter here.
2000 if (std::min(a: TotalPaddingBefore, b: TotalPaddingAfter) > 128)
2001 continue;
2002
2003 // Calculate the offset to the value as a (possibly negative) byte offset
2004 // and (if applicable) a bit offset, and store the values in the targets.
2005 int64_t OffsetByte;
2006 uint64_t OffsetBit;
2007 if (TotalPaddingBefore <= TotalPaddingAfter)
2008 setBeforeReturnValues(Targets: TargetsForSlot, AllocBefore, BitWidth, OffsetByte,
2009 OffsetBit);
2010 else
2011 setAfterReturnValues(Targets: TargetsForSlot, AllocAfter, BitWidth, OffsetByte,
2012 OffsetBit);
2013
2014 // In an earlier check we forbade constant propagation from operating on
2015 // tables whose alignment is less than the alignment needed for loading
2016 // the constant. Thus, the address we take the offset from will always be
2017 // aligned to at least this integer alignment. Now, we need to ensure that
2018 // the offset is also aligned to this integer alignment to ensure we always
2019 // have an aligned load.
2020 assert(OffsetByte % TypeAlignment.value() == 0);
2021
2022 if (RemarksEnabled || AreStatisticsEnabled())
2023 for (auto &&Target : TargetsForSlot)
2024 Target.WasDevirt = true;
2025
2026
2027 if (CSByConstantArg.second.isExported()) {
2028 ResByArg->TheKind = WholeProgramDevirtResolution::ByArg::VirtualConstProp;
2029 exportConstant(Slot, Args: CSByConstantArg.first, Name: "byte", Const: OffsetByte,
2030 Storage&: ResByArg->Byte);
2031 exportConstant(Slot, Args: CSByConstantArg.first, Name: "bit", Const: 1ULL << OffsetBit,
2032 Storage&: ResByArg->Bit);
2033 }
2034
2035 // Rewrite each call to a load from OffsetByte/OffsetBit.
2036 Constant *ByteConst = ConstantInt::getSigned(Ty: Int32Ty, V: OffsetByte);
2037 Constant *BitConst = ConstantInt::get(Ty: Int8Ty, V: 1ULL << OffsetBit);
2038 applyVirtualConstProp(CSInfo&: CSByConstantArg.second,
2039 FnName: TargetsForSlot[0].Fn->getName(), Byte: ByteConst, Bit: BitConst);
2040 }
2041 return true;
2042}
2043
2044void DevirtModule::rebuildGlobal(VTableBits &B) {
2045 if (B.Before.Bytes.empty() && B.After.Bytes.empty())
2046 return;
2047
2048 // Align the before byte array to the global's minimum alignment so that we
2049 // don't break any alignment requirements on the global.
2050 Align Alignment = M.getDataLayout().getValueOrABITypeAlignment(
2051 Alignment: B.GV->getAlign(), Ty: B.GV->getValueType());
2052 B.Before.Bytes.resize(new_size: alignTo(Size: B.Before.Bytes.size(), A: Alignment));
2053
2054 // Before was stored in reverse order; flip it now.
2055 for (size_t I = 0, Size = B.Before.Bytes.size(); I != Size / 2; ++I)
2056 std::swap(a&: B.Before.Bytes[I], b&: B.Before.Bytes[Size - 1 - I]);
2057
2058 // Build an anonymous global containing the before bytes, followed by the
2059 // original initializer, followed by the after bytes.
2060 auto *NewInit = ConstantStruct::getAnon(
2061 V: {ConstantDataArray::get(Context&: M.getContext(), Elts&: B.Before.Bytes),
2062 B.GV->getInitializer(),
2063 ConstantDataArray::get(Context&: M.getContext(), Elts&: B.After.Bytes)});
2064 auto *NewGV =
2065 new GlobalVariable(M, NewInit->getType(), B.GV->isConstant(),
2066 GlobalVariable::PrivateLinkage, NewInit, "", B.GV);
2067 NewGV->setSection(B.GV->getSection());
2068 NewGV->setComdat(B.GV->getComdat());
2069 NewGV->setAlignment(B.GV->getAlign());
2070
2071 // Copy the original vtable's metadata to the anonymous global, adjusting
2072 // offsets as required.
2073 NewGV->copyMetadata(Src: B.GV, Offset: B.Before.Bytes.size());
2074
2075 // Build an alias named after the original global, pointing at the second
2076 // element (the original initializer).
2077 auto *Alias = GlobalAlias::create(
2078 Ty: B.GV->getInitializer()->getType(), AddressSpace: 0, Linkage: B.GV->getLinkage(), Name: "",
2079 Aliasee: ConstantExpr::getInBoundsGetElementPtr(
2080 Ty: NewInit->getType(), C: NewGV,
2081 IdxList: ArrayRef<Constant *>{ConstantInt::get(Ty: Int32Ty, V: 0),
2082 ConstantInt::get(Ty: Int32Ty, V: 1)}),
2083 Parent: &M);
2084 Alias->setVisibility(B.GV->getVisibility());
2085 Alias->takeName(V: B.GV);
2086
2087 B.GV->replaceAllUsesWith(V: Alias);
2088 B.GV->eraseFromParent();
2089}
2090
2091bool DevirtModule::areRemarksEnabled() {
2092 const auto &FL = M.getFunctionList();
2093 for (const Function &Fn : FL) {
2094 if (Fn.empty())
2095 continue;
2096 auto DI = OptimizationRemark(DEBUG_TYPE, "", DebugLoc(), &Fn.front());
2097 return DI.isEnabled();
2098 }
2099 return false;
2100}
2101
2102void DevirtModule::scanTypeTestUsers(
2103 Function *TypeTestFunc,
2104 DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) {
2105 // Find all virtual calls via a virtual table pointer %p under an assumption
2106 // of the form llvm.assume(llvm.type.test(%p, %md)) or
2107 // llvm.assume(llvm.public.type.test(%p, %md)).
2108 // This indicates that %p points to a member of the type identifier %md.
2109 // Group calls by (type ID, offset) pair (effectively the identity of the
2110 // virtual function) and store to CallSlots.
2111 for (Use &U : llvm::make_early_inc_range(Range: TypeTestFunc->uses())) {
2112 auto *CI = dyn_cast<CallInst>(Val: U.getUser());
2113 if (!CI)
2114 continue;
2115 // Search for virtual calls based on %p and add them to DevirtCalls.
2116 SmallVector<DevirtCallSite, 1> DevirtCalls;
2117 SmallVector<CallInst *, 1> Assumes;
2118 auto &DT = FAM.getResult<DominatorTreeAnalysis>(IR&: *CI->getFunction());
2119 findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI, DT);
2120
2121 Metadata *TypeId =
2122 cast<MetadataAsValue>(Val: CI->getArgOperand(i: 1))->getMetadata();
2123 // If we found any, add them to CallSlots.
2124 if (!Assumes.empty()) {
2125 Value *Ptr = CI->getArgOperand(i: 0)->stripPointerCasts();
2126 for (DevirtCallSite Call : DevirtCalls)
2127 CallSlots[{.TypeID: TypeId, .ByteOffset: Call.Offset}].addCallSite(VTable: Ptr, CB&: Call.CB, NumUnsafeUses: nullptr);
2128 }
2129
2130 auto RemoveTypeTestAssumes = [&]() {
2131 // We no longer need the assumes or the type test.
2132 for (auto *Assume : Assumes)
2133 Assume->eraseFromParent();
2134 // We can't use RecursivelyDeleteTriviallyDeadInstructions here because we
2135 // may use the vtable argument later.
2136 if (CI->use_empty())
2137 CI->eraseFromParent();
2138 };
2139
2140 // At this point we could remove all type test assume sequences, as they
2141 // were originally inserted for WPD. However, we can keep these in the
2142 // code stream for later analysis (e.g. to help drive more efficient ICP
2143 // sequences). They will eventually be removed by a second LowerTypeTests
2144 // invocation that cleans them up. In order to do this correctly, the first
2145 // LowerTypeTests invocation needs to know that they have "Unknown" type
2146 // test resolution, so that they aren't treated as Unsat and lowered to
2147 // False, which will break any uses on assumes. Below we remove any type
2148 // test assumes that will not be treated as Unknown by LTT.
2149
2150 // The type test assumes will be treated by LTT as Unsat if the type id is
2151 // not used on a global (in which case it has no entry in the TypeIdMap).
2152 if (!TypeIdMap.count(Val: TypeId))
2153 RemoveTypeTestAssumes();
2154
2155 // For ThinLTO importing, we need to remove the type test assumes if this is
2156 // an MDString type id without a corresponding TypeIdSummary. Any
2157 // non-MDString type ids are ignored and treated as Unknown by LTT, so their
2158 // type test assumes can be kept. If the MDString type id is missing a
2159 // TypeIdSummary (e.g. because there was no use on a vcall, preventing the
2160 // exporting phase of WPD from analyzing it), then it would be treated as
2161 // Unsat by LTT and we need to remove its type test assumes here. If not
2162 // used on a vcall we don't need them for later optimization use in any
2163 // case.
2164 else if (ImportSummary && isa<MDString>(Val: TypeId)) {
2165 const TypeIdSummary *TidSummary =
2166 ImportSummary->getTypeIdSummary(TypeId: cast<MDString>(Val: TypeId)->getString());
2167 if (!TidSummary)
2168 RemoveTypeTestAssumes();
2169 else
2170 // If one was created it should not be Unsat, because if we reached here
2171 // the type id was used on a global.
2172 assert(TidSummary->TTRes.TheKind != TypeTestResolution::Unsat);
2173 }
2174 }
2175}
2176
2177void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) {
2178 Function *TypeTestFunc =
2179 Intrinsic::getOrInsertDeclaration(M: &M, id: Intrinsic::type_test);
2180
2181 for (Use &U : llvm::make_early_inc_range(Range: TypeCheckedLoadFunc->uses())) {
2182 auto *CI = dyn_cast<CallInst>(Val: U.getUser());
2183 if (!CI)
2184 continue;
2185
2186 Value *Ptr = CI->getArgOperand(i: 0);
2187 Value *Offset = CI->getArgOperand(i: 1);
2188 Value *TypeIdValue = CI->getArgOperand(i: 2);
2189 Metadata *TypeId = cast<MetadataAsValue>(Val: TypeIdValue)->getMetadata();
2190
2191 SmallVector<DevirtCallSite, 1> DevirtCalls;
2192 SmallVector<Instruction *, 1> LoadedPtrs;
2193 SmallVector<Instruction *, 1> Preds;
2194 bool HasNonCallUses = false;
2195 auto &DT = FAM.getResult<DominatorTreeAnalysis>(IR&: *CI->getFunction());
2196 findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds,
2197 HasNonCallUses, CI, DT);
2198
2199 // Start by generating "pessimistic" code that explicitly loads the function
2200 // pointer from the vtable and performs the type check. If possible, we will
2201 // eliminate the load and the type check later.
2202
2203 // If possible, only generate the load at the point where it is used.
2204 // This helps avoid unnecessary spills.
2205 IRBuilder<> LoadB(
2206 (LoadedPtrs.size() == 1 && !HasNonCallUses) ? LoadedPtrs[0] : CI);
2207
2208 Value *LoadedValue = nullptr;
2209 if (TypeCheckedLoadFunc->getIntrinsicID() ==
2210 Intrinsic::type_checked_load_relative) {
2211 Function *LoadRelFunc = Intrinsic::getOrInsertDeclaration(
2212 M: &M, id: Intrinsic::load_relative, Tys: {Int32Ty});
2213 LoadedValue = LoadB.CreateCall(Callee: LoadRelFunc, Args: {Ptr, Offset});
2214 } else {
2215 Value *GEP = LoadB.CreatePtrAdd(Ptr, Offset);
2216 LoadedValue = LoadB.CreateLoad(Ty: Int8PtrTy, Ptr: GEP);
2217 }
2218
2219 for (Instruction *LoadedPtr : LoadedPtrs) {
2220 LoadedPtr->replaceAllUsesWith(V: LoadedValue);
2221 LoadedPtr->eraseFromParent();
2222 }
2223
2224 // Likewise for the type test.
2225 IRBuilder<> CallB((Preds.size() == 1 && !HasNonCallUses) ? Preds[0] : CI);
2226 CallInst *TypeTestCall = CallB.CreateCall(Callee: TypeTestFunc, Args: {Ptr, TypeIdValue});
2227
2228 for (Instruction *Pred : Preds) {
2229 Pred->replaceAllUsesWith(V: TypeTestCall);
2230 Pred->eraseFromParent();
2231 }
2232
2233 // We have already erased any extractvalue instructions that refer to the
2234 // intrinsic call, but the intrinsic may have other non-extractvalue uses
2235 // (although this is unlikely). In that case, explicitly build a pair and
2236 // RAUW it.
2237 if (!CI->use_empty()) {
2238 Value *Pair = PoisonValue::get(T: CI->getType());
2239 IRBuilder<> B(CI);
2240 Pair = B.CreateInsertValue(Agg: Pair, Val: LoadedValue, Idxs: {0});
2241 Pair = B.CreateInsertValue(Agg: Pair, Val: TypeTestCall, Idxs: {1});
2242 CI->replaceAllUsesWith(V: Pair);
2243 }
2244
2245 // The number of unsafe uses is initially the number of uses.
2246 auto &NumUnsafeUses = NumUnsafeUsesForTypeTest[TypeTestCall];
2247 NumUnsafeUses = DevirtCalls.size();
2248
2249 // If the function pointer has a non-call user, we cannot eliminate the type
2250 // check, as one of those users may eventually call the pointer. Increment
2251 // the unsafe use count to make sure it cannot reach zero.
2252 if (HasNonCallUses)
2253 ++NumUnsafeUses;
2254 for (DevirtCallSite Call : DevirtCalls) {
2255 CallSlots[{.TypeID: TypeId, .ByteOffset: Call.Offset}].addCallSite(VTable: Ptr, CB&: Call.CB,
2256 NumUnsafeUses: &NumUnsafeUses);
2257 }
2258
2259 CI->eraseFromParent();
2260 }
2261}
2262
2263void DevirtModule::importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo) {
2264 auto *TypeId = dyn_cast<MDString>(Val: Slot.TypeID);
2265 if (!TypeId)
2266 return;
2267 const TypeIdSummary *TidSummary =
2268 ImportSummary->getTypeIdSummary(TypeId: TypeId->getString());
2269 if (!TidSummary)
2270 return;
2271 auto ResI = TidSummary->WPDRes.find(x: Slot.ByteOffset);
2272 if (ResI == TidSummary->WPDRes.end())
2273 return;
2274 const WholeProgramDevirtResolution &Res = ResI->second;
2275
2276 if (Res.TheKind == WholeProgramDevirtResolution::SingleImpl) {
2277 assert(!Res.SingleImplName.empty());
2278 // The type of the function in the declaration is irrelevant because every
2279 // call site will cast it to the correct type.
2280 Constant *SingleImpl =
2281 cast<Constant>(Val: M.getOrInsertFunction(Name: Res.SingleImplName,
2282 RetTy: Type::getVoidTy(C&: M.getContext()))
2283 .getCallee());
2284
2285 // This is the import phase so we should not be exporting anything.
2286 bool IsExported = false;
2287 applySingleImplDevirt(SlotInfo, TheFn: SingleImpl, IsExported);
2288 assert(!IsExported);
2289 }
2290
2291 for (auto &CSByConstantArg : SlotInfo.ConstCSInfo) {
2292 auto I = Res.ResByArg.find(x: CSByConstantArg.first);
2293 if (I == Res.ResByArg.end())
2294 continue;
2295 auto &ResByArg = I->second;
2296 // FIXME: We should figure out what to do about the "function name" argument
2297 // to the apply* functions, as the function names are unavailable during the
2298 // importing phase. For now we just pass the empty string. This does not
2299 // impact correctness because the function names are just used for remarks.
2300 switch (ResByArg.TheKind) {
2301 case WholeProgramDevirtResolution::ByArg::UniformRetVal:
2302 applyUniformRetValOpt(CSInfo&: CSByConstantArg.second, FnName: "", TheRetVal: ResByArg.Info);
2303 break;
2304 case WholeProgramDevirtResolution::ByArg::UniqueRetVal: {
2305 Constant *UniqueMemberAddr =
2306 importGlobal(Slot, Args: CSByConstantArg.first, Name: "unique_member");
2307 applyUniqueRetValOpt(CSInfo&: CSByConstantArg.second, FnName: "", IsOne: ResByArg.Info,
2308 UniqueMemberAddr);
2309 break;
2310 }
2311 case WholeProgramDevirtResolution::ByArg::VirtualConstProp: {
2312 Constant *Byte = importConstant(Slot, Args: CSByConstantArg.first, Name: "byte",
2313 IntTy: Int32Ty, Storage: ResByArg.Byte);
2314 Constant *Bit = importConstant(Slot, Args: CSByConstantArg.first, Name: "bit", IntTy: Int8Ty,
2315 Storage: ResByArg.Bit);
2316 applyVirtualConstProp(CSInfo&: CSByConstantArg.second, FnName: "", Byte, Bit);
2317 break;
2318 }
2319 default:
2320 break;
2321 }
2322 }
2323
2324 if (Res.TheKind == WholeProgramDevirtResolution::BranchFunnel) {
2325 // The type of the function is irrelevant, because it's bitcast at calls
2326 // anyhow.
2327 auto *JT = cast<Function>(
2328 Val: M.getOrInsertFunction(Name: getGlobalName(Slot, Args: {}, Name: "branch_funnel"),
2329 RetTy: Type::getVoidTy(C&: M.getContext()))
2330 .getCallee());
2331 bool IsExported = false;
2332 applyICallBranchFunnel(SlotInfo, JT&: *JT, IsExported);
2333 assert(!IsExported);
2334 }
2335}
2336
2337void DevirtModule::removeRedundantTypeTests() {
2338 auto *True = ConstantInt::getTrue(Context&: M.getContext());
2339 for (auto &&U : NumUnsafeUsesForTypeTest) {
2340 if (U.second == 0) {
2341 U.first->replaceAllUsesWith(V: True);
2342 U.first->eraseFromParent();
2343 }
2344 }
2345}
2346
2347ValueInfo
2348DevirtModule::lookUpFunctionValueInfo(Function *TheFn,
2349 ModuleSummaryIndex *ExportSummary) {
2350 assert((ExportSummary != nullptr) &&
2351 "Caller guarantees ExportSummary is not nullptr");
2352
2353 const auto TheFnGUID = TheFn->getGUID();
2354 const auto TheFnGUIDWithExportedName =
2355 GlobalValue::getGUIDAssumingExternalLinkage(GlobalName: TheFn->getName());
2356 // Look up ValueInfo with the GUID in the current linkage.
2357 ValueInfo TheFnVI = ExportSummary->getValueInfo(GUID: TheFnGUID);
2358 // If no entry is found and GUID is different from GUID computed using
2359 // exported name, look up ValueInfo with the exported name unconditionally.
2360 // This is a fallback.
2361 //
2362 // The reason to have a fallback:
2363 // 1. LTO could enable global value internalization via
2364 // `enable-lto-internalization`.
2365 // 2. The GUID in ExportedSummary is computed using exported name.
2366 if ((!TheFnVI) && (TheFnGUID != TheFnGUIDWithExportedName)) {
2367 TheFnVI = ExportSummary->getValueInfo(GUID: TheFnGUIDWithExportedName);
2368 }
2369 return TheFnVI;
2370}
2371
2372bool DevirtModule::mustBeUnreachableFunction(
2373 Function *const F, ModuleSummaryIndex *ExportSummary) {
2374 if (WholeProgramDevirtKeepUnreachableFunction)
2375 return false;
2376 // First, learn unreachability by analyzing function IR.
2377 if (!F->isDeclaration()) {
2378 // A function must be unreachable if its entry block ends with an
2379 // 'unreachable'.
2380 return isa<UnreachableInst>(Val: F->getEntryBlock().getTerminator());
2381 }
2382 // Learn unreachability from ExportSummary if ExportSummary is present.
2383 return ExportSummary &&
2384 ::mustBeUnreachableFunction(
2385 TheFnVI: DevirtModule::lookUpFunctionValueInfo(TheFn: F, ExportSummary));
2386}
2387
2388bool DevirtModule::run() {
2389 // If only some of the modules were split, we cannot correctly perform
2390 // this transformation. We already checked for the presense of type tests
2391 // with partially split modules during the thin link, and would have emitted
2392 // an error if any were found, so here we can simply return.
2393 if ((ExportSummary && ExportSummary->partiallySplitLTOUnits()) ||
2394 (ImportSummary && ImportSummary->partiallySplitLTOUnits()))
2395 return false;
2396
2397 Function *PublicTypeTestFunc = nullptr;
2398 // If we are in speculative devirtualization mode, we can work on the public
2399 // type test intrinsics.
2400 if (DevirtSpeculatively)
2401 PublicTypeTestFunc =
2402 Intrinsic::getDeclarationIfExists(M: &M, id: Intrinsic::public_type_test);
2403 Function *TypeTestFunc =
2404 Intrinsic::getDeclarationIfExists(M: &M, id: Intrinsic::type_test);
2405 Function *TypeCheckedLoadFunc =
2406 Intrinsic::getDeclarationIfExists(M: &M, id: Intrinsic::type_checked_load);
2407 Function *TypeCheckedLoadRelativeFunc = Intrinsic::getDeclarationIfExists(
2408 M: &M, id: Intrinsic::type_checked_load_relative);
2409 Function *AssumeFunc =
2410 Intrinsic::getDeclarationIfExists(M: &M, id: Intrinsic::assume);
2411
2412 // Normally if there are no users of the devirtualization intrinsics in the
2413 // module, this pass has nothing to do. But if we are exporting, we also need
2414 // to handle any users that appear only in the function summaries.
2415 if (!ExportSummary &&
2416 (((!PublicTypeTestFunc || PublicTypeTestFunc->use_empty()) &&
2417 (!TypeTestFunc || TypeTestFunc->use_empty())) ||
2418 !AssumeFunc || AssumeFunc->use_empty()) &&
2419 (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty()) &&
2420 (!TypeCheckedLoadRelativeFunc ||
2421 TypeCheckedLoadRelativeFunc->use_empty()))
2422 return false;
2423
2424 // Rebuild type metadata into a map for easy lookup.
2425 std::vector<VTableBits> Bits;
2426 DenseMap<Metadata *, std::set<TypeMemberInfo>> TypeIdMap;
2427 buildTypeIdentifierMap(Bits, TypeIdMap);
2428
2429 if (PublicTypeTestFunc && AssumeFunc)
2430 scanTypeTestUsers(TypeTestFunc: PublicTypeTestFunc, TypeIdMap);
2431
2432 if (TypeTestFunc && AssumeFunc)
2433 scanTypeTestUsers(TypeTestFunc, TypeIdMap);
2434
2435 if (TypeCheckedLoadFunc)
2436 scanTypeCheckedLoadUsers(TypeCheckedLoadFunc);
2437
2438 if (TypeCheckedLoadRelativeFunc)
2439 scanTypeCheckedLoadUsers(TypeCheckedLoadFunc: TypeCheckedLoadRelativeFunc);
2440
2441 if (ImportSummary) {
2442 for (auto &S : CallSlots)
2443 importResolution(Slot: S.first, SlotInfo&: S.second);
2444
2445 removeRedundantTypeTests();
2446
2447 // We have lowered or deleted the type intrinsics, so we will no longer have
2448 // enough information to reason about the liveness of virtual function
2449 // pointers in GlobalDCE.
2450 for (GlobalVariable &GV : M.globals())
2451 GV.eraseMetadata(KindID: LLVMContext::MD_vcall_visibility);
2452
2453 // The rest of the code is only necessary when exporting or during regular
2454 // LTO, so we are done.
2455 return true;
2456 }
2457
2458 if (TypeIdMap.empty())
2459 return true;
2460
2461 // Collect information from summary about which calls to try to devirtualize.
2462 if (ExportSummary) {
2463 DenseMap<GlobalValue::GUID, TinyPtrVector<Metadata *>> MetadataByGUID;
2464 for (auto &P : TypeIdMap) {
2465 if (auto *TypeId = dyn_cast<MDString>(Val: P.first))
2466 MetadataByGUID[GlobalValue::getGUIDAssumingExternalLinkage(
2467 GlobalName: TypeId->getString())]
2468 .push_back(NewVal: TypeId);
2469 }
2470
2471 for (auto &P : *ExportSummary) {
2472 for (auto &S : P.second.getSummaryList()) {
2473 auto *FS = dyn_cast<FunctionSummary>(Val: S.get());
2474 if (!FS)
2475 continue;
2476 // FIXME: Only add live functions.
2477 for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) {
2478 for (Metadata *MD : MetadataByGUID[VF.GUID]) {
2479 CallSlots[{.TypeID: MD, .ByteOffset: VF.Offset}].CSInfo.addSummaryTypeTestAssumeUser(FS);
2480 }
2481 }
2482 for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) {
2483 for (Metadata *MD : MetadataByGUID[VF.GUID]) {
2484 CallSlots[{.TypeID: MD, .ByteOffset: VF.Offset}].CSInfo.addSummaryTypeCheckedLoadUser(FS);
2485 }
2486 }
2487 for (const FunctionSummary::ConstVCall &VC :
2488 FS->type_test_assume_const_vcalls()) {
2489 for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) {
2490 CallSlots[{.TypeID: MD, .ByteOffset: VC.VFunc.Offset}]
2491 .ConstCSInfo[VC.Args]
2492 .addSummaryTypeTestAssumeUser(FS);
2493 }
2494 }
2495 for (const FunctionSummary::ConstVCall &VC :
2496 FS->type_checked_load_const_vcalls()) {
2497 for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) {
2498 CallSlots[{.TypeID: MD, .ByteOffset: VC.VFunc.Offset}]
2499 .ConstCSInfo[VC.Args]
2500 .addSummaryTypeCheckedLoadUser(FS);
2501 }
2502 }
2503 }
2504 }
2505 }
2506
2507 // For each (type, offset) pair:
2508 bool DidVirtualConstProp = false;
2509 std::map<std::string, GlobalValue *> DevirtTargets;
2510 for (auto &S : CallSlots) {
2511 // Search each of the members of the type identifier for the virtual
2512 // function implementation at offset S.first.ByteOffset, and add to
2513 // TargetsForSlot.
2514 std::vector<VirtualCallTarget> TargetsForSlot;
2515 WholeProgramDevirtResolution *Res = nullptr;
2516 const std::set<TypeMemberInfo> &TypeMemberInfos = TypeIdMap[S.first.TypeID];
2517 if (ExportSummary && isa<MDString>(Val: S.first.TypeID) &&
2518 TypeMemberInfos.size())
2519 // For any type id used on a global's type metadata, create the type id
2520 // summary resolution regardless of whether we can devirtualize, so that
2521 // lower type tests knows the type id is not Unsat. If it was not used on
2522 // a global's type metadata, the TypeIdMap entry set will be empty, and
2523 // we don't want to create an entry (with the default Unknown type
2524 // resolution), which can prevent detection of the Unsat.
2525 Res = &ExportSummary
2526 ->getOrInsertTypeIdSummary(
2527 TypeId: cast<MDString>(Val: S.first.TypeID)->getString())
2528 .WPDRes[S.first.ByteOffset];
2529 if (tryFindVirtualCallTargets(TargetsForSlot, TypeMemberInfos,
2530 ByteOffset: S.first.ByteOffset, ExportSummary)) {
2531 bool SingleImplDevirt =
2532 trySingleImplDevirt(ExportSummary, TargetsForSlot, SlotInfo&: S.second, Res);
2533 // Out of speculative devirtualization mode, Try to apply virtual constant
2534 // propagation or branch funneling.
2535 // TODO: This should eventually be enabled for non-public type tests.
2536 if (!SingleImplDevirt && !DevirtSpeculatively) {
2537 DidVirtualConstProp |=
2538 tryVirtualConstProp(TargetsForSlot, SlotInfo&: S.second, Res, Slot: S.first);
2539
2540 tryICallBranchFunnel(TargetsForSlot, SlotInfo&: S.second, Res, Slot: S.first);
2541 }
2542
2543 // Collect functions devirtualized at least for one call site for stats.
2544 if (RemarksEnabled || AreStatisticsEnabled())
2545 for (const auto &T : TargetsForSlot)
2546 if (T.WasDevirt)
2547 DevirtTargets[std::string(T.Fn->getName())] = T.Fn;
2548 }
2549
2550 // CFI-specific: if we are exporting and any llvm.type.checked.load
2551 // intrinsics were *not* devirtualized, we need to add the resulting
2552 // llvm.type.test intrinsics to the function summaries so that the
2553 // LowerTypeTests pass will export them.
2554 if (ExportSummary && isa<MDString>(Val: S.first.TypeID)) {
2555 auto GUID = GlobalValue::getGUIDAssumingExternalLinkage(
2556 GlobalName: cast<MDString>(Val: S.first.TypeID)->getString());
2557 auto AddTypeTestsForTypeCheckedLoads = [&](CallSiteInfo &CSI) {
2558 if (!CSI.AllCallSitesDevirted)
2559 for (auto *FS : CSI.SummaryTypeCheckedLoadUsers)
2560 FS->addTypeTest(Guid: GUID);
2561 };
2562 AddTypeTestsForTypeCheckedLoads(S.second.CSInfo);
2563 for (auto &CCS : S.second.ConstCSInfo)
2564 AddTypeTestsForTypeCheckedLoads(CCS.second);
2565 }
2566 }
2567
2568 if (RemarksEnabled) {
2569 // Generate remarks for each devirtualized function.
2570 for (const auto &DT : DevirtTargets) {
2571 GlobalValue *GV = DT.second;
2572 auto *F = dyn_cast<Function>(Val: GV);
2573 if (!F) {
2574 auto *A = dyn_cast<GlobalAlias>(Val: GV);
2575 assert(A && isa<Function>(A->getAliasee()));
2576 F = dyn_cast<Function>(Val: A->getAliasee());
2577 assert(F);
2578 }
2579
2580 using namespace ore;
2581 OREGetter(*F).emit(OptDiag: OptimizationRemark(DEBUG_TYPE, "Devirtualized", F)
2582 << "devirtualized " << NV("FunctionName", DT.first));
2583 }
2584 }
2585
2586 NumDevirtTargets += DevirtTargets.size();
2587
2588 removeRedundantTypeTests();
2589
2590 // Rebuild each global we touched as part of virtual constant propagation to
2591 // include the before and after bytes.
2592 if (DidVirtualConstProp)
2593 for (VTableBits &B : Bits)
2594 rebuildGlobal(B);
2595
2596 // We have lowered or deleted the type intrinsics, so we will no longer have
2597 // enough information to reason about the liveness of virtual function
2598 // pointers in GlobalDCE.
2599 for (GlobalVariable &GV : M.globals())
2600 GV.eraseMetadata(KindID: LLVMContext::MD_vcall_visibility);
2601
2602 for (auto *CI : CallsWithPtrAuthBundleRemoved)
2603 CI->eraseFromParent();
2604
2605 return true;
2606}
2607
2608void DevirtIndex::run() {
2609 if (ExportSummary.typeIdCompatibleVtableMap().empty())
2610 return;
2611
2612 // Assert that we haven't made any changes that would affect the hasLocal()
2613 // flag on the GUID summary info.
2614 assert(!ExportSummary.withInternalizeAndPromote() &&
2615 "Expect index-based WPD to run before internalization and promotion");
2616
2617 DenseMap<GlobalValue::GUID, std::vector<StringRef>> NameByGUID;
2618 for (const auto &P : ExportSummary.typeIdCompatibleVtableMap()) {
2619 NameByGUID[GlobalValue::getGUIDAssumingExternalLinkage(GlobalName: P.first)].push_back(
2620 x: P.first);
2621 // Create the type id summary resolution regardlness of whether we can
2622 // devirtualize, so that lower type tests knows the type id is used on
2623 // a global and not Unsat. We do this here rather than in the loop over the
2624 // CallSlots, since that handling will only see type tests that directly
2625 // feed assumes, and we would miss any that aren't currently handled by WPD
2626 // (such as type tests that feed assumes via phis).
2627 ExportSummary.getOrInsertTypeIdSummary(TypeId: P.first);
2628 }
2629
2630 // Collect information from summary about which calls to try to devirtualize.
2631 for (auto &P : ExportSummary) {
2632 for (auto &S : P.second.getSummaryList()) {
2633 auto *FS = dyn_cast<FunctionSummary>(Val: S.get());
2634 if (!FS)
2635 continue;
2636 // FIXME: Only add live functions.
2637 for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) {
2638 for (StringRef Name : NameByGUID[VF.GUID]) {
2639 CallSlots[{.TypeID: Name, .ByteOffset: VF.Offset}].CSInfo.addSummaryTypeTestAssumeUser(FS);
2640 }
2641 }
2642 for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) {
2643 for (StringRef Name : NameByGUID[VF.GUID]) {
2644 CallSlots[{.TypeID: Name, .ByteOffset: VF.Offset}].CSInfo.addSummaryTypeCheckedLoadUser(FS);
2645 }
2646 }
2647 for (const FunctionSummary::ConstVCall &VC :
2648 FS->type_test_assume_const_vcalls()) {
2649 for (StringRef Name : NameByGUID[VC.VFunc.GUID]) {
2650 CallSlots[{.TypeID: Name, .ByteOffset: VC.VFunc.Offset}]
2651 .ConstCSInfo[VC.Args]
2652 .addSummaryTypeTestAssumeUser(FS);
2653 }
2654 }
2655 for (const FunctionSummary::ConstVCall &VC :
2656 FS->type_checked_load_const_vcalls()) {
2657 for (StringRef Name : NameByGUID[VC.VFunc.GUID]) {
2658 CallSlots[{.TypeID: Name, .ByteOffset: VC.VFunc.Offset}]
2659 .ConstCSInfo[VC.Args]
2660 .addSummaryTypeCheckedLoadUser(FS);
2661 }
2662 }
2663 }
2664 }
2665
2666 std::set<ValueInfo> DevirtTargets;
2667 // For each (type, offset) pair:
2668 for (auto &S : CallSlots) {
2669 // Search each of the members of the type identifier for the virtual
2670 // function implementation at offset S.first.ByteOffset, and add to
2671 // TargetsForSlot.
2672 std::vector<ValueInfo> TargetsForSlot;
2673 auto TidSummary = ExportSummary.getTypeIdCompatibleVtableSummary(TypeId: S.first.TypeID);
2674 assert(TidSummary);
2675 // The type id summary would have been created while building the NameByGUID
2676 // map earlier.
2677 WholeProgramDevirtResolution *Res =
2678 &ExportSummary.getTypeIdSummary(TypeId: S.first.TypeID)
2679 ->WPDRes[S.first.ByteOffset];
2680 if (tryFindVirtualCallTargets(TargetsForSlot, TIdInfo: *TidSummary,
2681 ByteOffset: S.first.ByteOffset)) {
2682
2683 if (!trySingleImplDevirt(TargetsForSlot, SlotSummary&: S.first, SlotInfo&: S.second, Res,
2684 DevirtTargets))
2685 continue;
2686 }
2687 }
2688
2689 // Optionally have the thin link print message for each devirtualized
2690 // function.
2691 if (PrintSummaryDevirt)
2692 for (const auto &DT : DevirtTargets)
2693 errs() << "Devirtualized call to " << DT << "\n";
2694
2695 NumDevirtTargets += DevirtTargets.size();
2696}
2697