1 | //===- MLInlineAdvisor.h - ML - based InlineAdvisor factories ---*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef LLVM_ANALYSIS_MLINLINEADVISOR_H |
10 | #define LLVM_ANALYSIS_MLINLINEADVISOR_H |
11 | |
12 | #include "llvm/Analysis/FunctionPropertiesAnalysis.h" |
13 | #include "llvm/Analysis/InlineAdvisor.h" |
14 | #include "llvm/Analysis/LazyCallGraph.h" |
15 | #include "llvm/Analysis/MLModelRunner.h" |
16 | #include "llvm/IR/PassManager.h" |
17 | |
18 | #include <map> |
19 | #include <memory> |
20 | #include <optional> |
21 | |
22 | namespace llvm { |
23 | class DiagnosticInfoOptimizationBase; |
24 | class Module; |
25 | class MLInlineAdvice; |
26 | class ProfileSummaryInfo; |
27 | |
28 | class MLInlineAdvisor : public InlineAdvisor { |
29 | public: |
30 | MLInlineAdvisor(Module &M, ModuleAnalysisManager &MAM, |
31 | std::unique_ptr<MLModelRunner> ModelRunner, |
32 | std::function<bool(CallBase &)> GetDefaultAdvice); |
33 | |
34 | virtual ~MLInlineAdvisor() = default; |
35 | |
36 | void onPassEntry(LazyCallGraph::SCC *SCC) override; |
37 | void onPassExit(LazyCallGraph::SCC *SCC) override; |
38 | |
39 | int64_t getIRSize(Function &F) const { |
40 | return getCachedFPI(F).TotalInstructionCount; |
41 | } |
42 | void onSuccessfulInlining(const MLInlineAdvice &Advice, |
43 | bool CalleeWasDeleted); |
44 | |
45 | bool isForcedToStop() const { return ForceStop; } |
46 | int64_t getLocalCalls(Function &F); |
47 | const MLModelRunner &getModelRunner() const { return *ModelRunner; } |
48 | FunctionPropertiesInfo &getCachedFPI(Function &) const; |
49 | |
50 | protected: |
51 | std::unique_ptr<InlineAdvice> getAdviceImpl(CallBase &CB) override; |
52 | |
53 | std::unique_ptr<InlineAdvice> getMandatoryAdvice(CallBase &CB, |
54 | bool Advice) override; |
55 | |
56 | virtual std::unique_ptr<MLInlineAdvice> getMandatoryAdviceImpl(CallBase &CB); |
57 | |
58 | virtual std::unique_ptr<MLInlineAdvice> |
59 | (CallBase &CB, OptimizationRemarkEmitter &ORE); |
60 | |
61 | // Get the initial 'level' of the function, or 0 if the function has been |
62 | // introduced afterwards. |
63 | // TODO: should we keep this updated? |
64 | unsigned getInitialFunctionLevel(const Function &F) const; |
65 | |
66 | std::unique_ptr<MLModelRunner> ModelRunner; |
67 | std::function<bool(CallBase &)> GetDefaultAdvice; |
68 | |
69 | private: |
70 | int64_t getModuleIRSize() const; |
71 | std::unique_ptr<InlineAdvice> |
72 | getSkipAdviceIfUnreachableCallsite(CallBase &CB); |
73 | void print(raw_ostream &OS) const override; |
74 | |
75 | // Using std::map to benefit from its iterator / reference non-invalidating |
76 | // semantics, which make it easy to use `getCachedFPI` results from multiple |
77 | // calls without needing to copy to avoid invalidation effects. |
78 | mutable std::map<const Function *, FunctionPropertiesInfo> FPICache; |
79 | |
80 | LazyCallGraph &CG; |
81 | |
82 | int64_t NodeCount = 0; |
83 | int64_t EdgeCount = 0; |
84 | int64_t EdgesOfLastSeenNodes = 0; |
85 | const bool UseIR2Vec; |
86 | |
87 | std::map<const LazyCallGraph::Node *, unsigned> FunctionLevels; |
88 | const int32_t InitialIRSize = 0; |
89 | int32_t CurrentIRSize = 0; |
90 | llvm::SmallPtrSet<const LazyCallGraph::Node *, 1> NodesInLastSCC; |
91 | DenseSet<const LazyCallGraph::Node *> AllNodes; |
92 | DenseSet<Function *> DeadFunctions; |
93 | bool ForceStop = false; |
94 | ProfileSummaryInfo &PSI; |
95 | }; |
96 | |
97 | /// InlineAdvice that tracks changes post inlining. For that reason, it only |
98 | /// overrides the "successful inlining" extension points. |
99 | class MLInlineAdvice : public InlineAdvice { |
100 | public: |
101 | (MLInlineAdvisor *Advisor, CallBase &CB, |
102 | OptimizationRemarkEmitter &ORE, bool Recommendation); |
103 | virtual ~MLInlineAdvice() = default; |
104 | |
105 | void recordInliningImpl() override; |
106 | void recordInliningWithCalleeDeletedImpl() override; |
107 | void recordUnsuccessfulInliningImpl(const InlineResult &Result) override; |
108 | void recordUnattemptedInliningImpl() override; |
109 | |
110 | Function *getCaller() const { return Caller; } |
111 | Function *getCallee() const { return Callee; } |
112 | |
113 | const int64_t CallerIRSize; |
114 | const int64_t CalleeIRSize; |
115 | const int64_t CallerAndCalleeEdges; |
116 | void updateCachedCallerFPI(FunctionAnalysisManager &FAM) const; |
117 | |
118 | private: |
119 | void (DiagnosticInfoOptimizationBase &OR); |
120 | MLInlineAdvisor *getAdvisor() const { |
121 | return static_cast<MLInlineAdvisor *>(Advisor); |
122 | }; |
123 | // Make a copy of the FPI of the caller right before inlining. If inlining |
124 | // fails, we can just update the cache with that value. |
125 | const FunctionPropertiesInfo PreInlineCallerFPI; |
126 | std::optional<FunctionPropertiesUpdater> FPU; |
127 | }; |
128 | |
129 | } // namespace llvm |
130 | |
131 | #endif // LLVM_ANALYSIS_MLINLINEADVISOR_H |
132 | |