1//===- MLInlineAdvisor.cpp - machine learned InlineAdvisor ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the interface between the inliner and a learned model.
10// It delegates model evaluation to either the AOT compiled model (the
11// 'release' mode) or a runtime-loaded model (the 'development' case).
12//
13//===----------------------------------------------------------------------===//
14#include "llvm/Analysis/MLInlineAdvisor.h"
15#include "llvm/ADT/SCCIterator.h"
16#include "llvm/Analysis/AssumptionCache.h"
17#include "llvm/Analysis/BlockFrequencyInfo.h"
18#include "llvm/Analysis/CallGraph.h"
19#include "llvm/Analysis/FunctionPropertiesAnalysis.h"
20#include "llvm/Analysis/InlineCost.h"
21#include "llvm/Analysis/InlineModelFeatureMaps.h"
22#include "llvm/Analysis/InteractiveModelRunner.h"
23#include "llvm/Analysis/LazyCallGraph.h"
24#include "llvm/Analysis/LoopInfo.h"
25#include "llvm/Analysis/MLModelRunner.h"
26#include "llvm/Analysis/OptimizationRemarkEmitter.h"
27#include "llvm/Analysis/ProfileSummaryInfo.h"
28#include "llvm/Analysis/ReleaseModeModelRunner.h"
29#include "llvm/Analysis/TargetTransformInfo.h"
30#include "llvm/IR/Dominators.h"
31#include "llvm/IR/InstIterator.h"
32#include "llvm/IR/Module.h"
33#include "llvm/IR/PassManager.h"
34#include "llvm/Support/CommandLine.h"
35
36using namespace llvm;
37
38static cl::opt<std::string> InteractiveChannelBaseName(
39 "inliner-interactive-channel-base", cl::Hidden,
40 cl::desc(
41 "Base file path for the interactive mode. The incoming filename should "
42 "have the name <inliner-interactive-channel-base>.in, while the "
43 "outgoing name should be <inliner-interactive-channel-base>.out"));
44static const std::string InclDefaultMsg =
45 (Twine("In interactive mode, also send the default policy decision: ") +
46 DefaultDecisionName + ".")
47 .str();
48static cl::opt<bool>
49 InteractiveIncludeDefault("inliner-interactive-include-default", cl::Hidden,
50 cl::desc(InclDefaultMsg));
51
52enum class SkipMLPolicyCriteria { Never, IfCallerIsNotCold };
53
54static cl::opt<SkipMLPolicyCriteria> SkipPolicy(
55 "ml-inliner-skip-policy", cl::Hidden, cl::init(Val: SkipMLPolicyCriteria::Never),
56 cl::values(clEnumValN(SkipMLPolicyCriteria::Never, "never", "never"),
57 clEnumValN(SkipMLPolicyCriteria::IfCallerIsNotCold,
58 "if-caller-not-cold", "if the caller is not cold")));
59
60static cl::opt<std::string> ModelSelector("ml-inliner-model-selector",
61 cl::Hidden, cl::init(Val: ""));
62
63#if defined(LLVM_HAVE_TF_AOT_INLINERSIZEMODEL)
64// codegen-ed file
65#include "InlinerSizeModel.h" // NOLINT
66using CompiledModelType = llvm::InlinerSizeModel;
67#else
68using CompiledModelType = NoopSavedModelImpl;
69#endif
70
71std::unique_ptr<InlineAdvisor>
72llvm::getReleaseModeAdvisor(Module &M, ModuleAnalysisManager &MAM,
73 std::function<bool(CallBase &)> GetDefaultAdvice) {
74 if (!llvm::isEmbeddedModelEvaluatorValid<CompiledModelType>() &&
75 InteractiveChannelBaseName.empty())
76 return nullptr;
77 std::unique_ptr<MLModelRunner> AOTRunner;
78 if (InteractiveChannelBaseName.empty())
79 AOTRunner = std::make_unique<ReleaseModeModelRunner<CompiledModelType>>(
80 args&: M.getContext(), args&: FeatureMap, args: DecisionName,
81 args&: EmbeddedModelRunnerOptions().setModelSelector(ModelSelector));
82 else {
83 auto Features = FeatureMap;
84 if (InteractiveIncludeDefault)
85 Features.push_back(x: DefaultDecisionSpec);
86 AOTRunner = std::make_unique<InteractiveModelRunner>(
87 args&: M.getContext(), args&: Features, args: InlineDecisionSpec,
88 args: InteractiveChannelBaseName + ".out",
89 args: InteractiveChannelBaseName + ".in");
90 }
91 return std::make_unique<MLInlineAdvisor>(args&: M, args&: MAM, args: std::move(AOTRunner),
92 args&: GetDefaultAdvice);
93}
94
95#define DEBUG_TYPE "inline-ml"
96
97static cl::opt<float> SizeIncreaseThreshold(
98 "ml-advisor-size-increase-threshold", cl::Hidden,
99 cl::desc("Maximum factor by which expected native size may increase before "
100 "blocking any further inlining."),
101 cl::init(Val: 2.0));
102
103static cl::opt<bool> KeepFPICache(
104 "ml-advisor-keep-fpi-cache", cl::Hidden,
105 cl::desc(
106 "For test - keep the ML Inline advisor's FunctionPropertiesInfo cache"),
107 cl::init(Val: false));
108
109// clang-format off
110std::vector<TensorSpec> llvm::FeatureMap{
111#define POPULATE_NAMES(DTYPE, SHAPE, NAME, __) TensorSpec::createSpec<DTYPE>(#NAME, SHAPE),
112// InlineCost features - these must come first
113 INLINE_COST_FEATURE_ITERATOR(POPULATE_NAMES)
114
115// Non-cost features
116 INLINE_FEATURE_ITERATOR(POPULATE_NAMES)
117#undef POPULATE_NAMES
118};
119// clang-format on
120
121const char *const llvm::DecisionName = "inlining_decision";
122const TensorSpec llvm::InlineDecisionSpec =
123 TensorSpec::createSpec<int64_t>(Name: DecisionName, Shape: {1});
124const char *const llvm::DefaultDecisionName = "inlining_default";
125const TensorSpec llvm::DefaultDecisionSpec =
126 TensorSpec::createSpec<int64_t>(Name: DefaultDecisionName, Shape: {1});
127const char *const llvm::RewardName = "delta_size";
128
129CallBase *getInlinableCS(Instruction &I) {
130 if (auto *CS = dyn_cast<CallBase>(Val: &I))
131 if (Function *Callee = CS->getCalledFunction()) {
132 if (!Callee->isDeclaration()) {
133 return CS;
134 }
135 }
136 return nullptr;
137}
138
139MLInlineAdvisor::MLInlineAdvisor(
140 Module &M, ModuleAnalysisManager &MAM,
141 std::unique_ptr<MLModelRunner> Runner,
142 std::function<bool(CallBase &)> GetDefaultAdvice)
143 : InlineAdvisor(
144 M, MAM.getResult<FunctionAnalysisManagerModuleProxy>(IR&: M).getManager()),
145 ModelRunner(std::move(Runner)), GetDefaultAdvice(GetDefaultAdvice),
146 CG(MAM.getResult<LazyCallGraphAnalysis>(IR&: M)),
147 UseIR2Vec(MAM.getCachedResult<IR2VecVocabAnalysis>(IR&: M) != nullptr),
148 InitialIRSize(getModuleIRSize()), CurrentIRSize(InitialIRSize),
149 PSI(MAM.getResult<ProfileSummaryAnalysis>(IR&: M)) {
150 assert(ModelRunner);
151 ModelRunner->switchContext(Name: "");
152 // Extract the 'call site height' feature - the position of a call site
153 // relative to the farthest statically reachable SCC node. We don't mutate
154 // this value while inlining happens. Empirically, this feature proved
155 // critical in behavioral cloning - i.e. training a model to mimic the manual
156 // heuristic's decisions - and, thus, equally important for training for
157 // improvement.
158 CallGraph CGraph(M);
159 for (auto I = scc_begin(G: &CGraph); !I.isAtEnd(); ++I) {
160 const std::vector<CallGraphNode *> &CGNodes = *I;
161 unsigned Level = 0;
162 for (auto *CGNode : CGNodes) {
163 Function *F = CGNode->getFunction();
164 if (!F || F->isDeclaration())
165 continue;
166 for (auto &I : instructions(F)) {
167 if (auto *CS = getInlinableCS(I)) {
168 auto *Called = CS->getCalledFunction();
169 auto Pos = FunctionLevels.find(x: &CG.get(F&: *Called));
170 // In bottom up traversal, an inlinable callee is either in the
171 // same SCC, or to a function in a visited SCC. So not finding its
172 // level means we haven't visited it yet, meaning it's in this SCC.
173 if (Pos == FunctionLevels.end())
174 continue;
175 Level = std::max(a: Level, b: Pos->second + 1);
176 }
177 }
178 }
179 for (auto *CGNode : CGNodes) {
180 Function *F = CGNode->getFunction();
181 if (F && !F->isDeclaration())
182 FunctionLevels[&CG.get(F&: *F)] = Level;
183 }
184 }
185 for (auto KVP : FunctionLevels) {
186 AllNodes.insert(V: KVP.first);
187 EdgeCount += getLocalCalls(F&: KVP.first->getFunction());
188 }
189 NodeCount = AllNodes.size();
190
191 if (auto IR2VecVocabResult = MAM.getCachedResult<IR2VecVocabAnalysis>(IR&: M)) {
192 if (!IR2VecVocabResult->isValid()) {
193 M.getContext().emitError(ErrorStr: "IR2VecVocabAnalysis is not valid");
194 return;
195 }
196 // Add the IR2Vec features to the feature map
197 auto IR2VecDim = IR2VecVocabResult->getDimension();
198 FeatureMap.push_back(
199 x: TensorSpec::createSpec<float>(Name: "callee_embedding", Shape: {IR2VecDim}));
200 FeatureMap.push_back(
201 x: TensorSpec::createSpec<float>(Name: "caller_embedding", Shape: {IR2VecDim}));
202 }
203}
204
205unsigned MLInlineAdvisor::getInitialFunctionLevel(const Function &F) const {
206 return CG.lookup(F) ? FunctionLevels.at(k: CG.lookup(F)) : 0;
207}
208
209void MLInlineAdvisor::onPassEntry(LazyCallGraph::SCC *CurSCC) {
210 if (!CurSCC || ForceStop)
211 return;
212 FPICache.clear();
213 // Function passes executed between InlinerPass runs may have changed the
214 // module-wide features.
215 // The cgscc pass manager rules are such that:
216 // - if a pass leads to merging SCCs, then the pipeline is restarted on the
217 // merged SCC
218 // - if a pass leads to splitting the SCC, then we continue with one of the
219 // splits
220 // This means that the NodesInLastSCC is a superset (not strict) of the nodes
221 // that subsequent passes would have processed
222 // - in addition, if new Nodes were created by a pass (e.g. CoroSplit),
223 // they'd be adjacent to Nodes in the last SCC. So we just need to check the
224 // boundary of Nodes in NodesInLastSCC for Nodes we haven't seen. We don't
225 // care about the nature of the Edge (call or ref). `FunctionLevels`-wise, we
226 // record them at the same level as the original node (this is a choice, may
227 // need revisiting).
228 // - nodes are only deleted at the end of a call graph walk where they are
229 // batch deleted, so we shouldn't see any dead nodes here.
230 while (!NodesInLastSCC.empty()) {
231 const auto *N = *NodesInLastSCC.begin();
232 assert(!N->isDead());
233 NodesInLastSCC.erase(Ptr: N);
234 EdgeCount += getLocalCalls(F&: N->getFunction());
235 const auto NLevel = FunctionLevels.at(k: N);
236 for (const auto &E : *(*N)) {
237 const auto *AdjNode = &E.getNode();
238 assert(!AdjNode->isDead() && !AdjNode->getFunction().isDeclaration());
239 auto I = AllNodes.insert(V: AdjNode);
240 // We've discovered a new function.
241 if (I.second) {
242 ++NodeCount;
243 NodesInLastSCC.insert(Ptr: AdjNode);
244 FunctionLevels[AdjNode] = NLevel;
245 }
246 }
247 }
248
249 EdgeCount -= EdgesOfLastSeenNodes;
250 EdgesOfLastSeenNodes = 0;
251
252 // (Re)use NodesInLastSCC to remember the nodes in the SCC right now,
253 // in case the SCC is split before onPassExit and some nodes are split out
254 assert(NodesInLastSCC.empty());
255 for (const auto &N : *CurSCC)
256 NodesInLastSCC.insert(Ptr: &N);
257}
258
259void MLInlineAdvisor::onPassExit(LazyCallGraph::SCC *CurSCC) {
260 // No need to keep this around - function passes will invalidate it.
261 if (!KeepFPICache)
262 FPICache.clear();
263 if (!CurSCC || ForceStop)
264 return;
265 // Keep track of the nodes and edges we last saw. Then, in onPassEntry,
266 // we update the node count and edge count from the subset of these nodes that
267 // survived.
268 EdgesOfLastSeenNodes = 0;
269
270 // Check on nodes that were in SCC onPassEntry
271 for (const LazyCallGraph::Node *N : NodesInLastSCC) {
272 assert(!N->isDead());
273 EdgesOfLastSeenNodes += getLocalCalls(F&: N->getFunction());
274 }
275
276 // Check on nodes that may have got added to SCC
277 for (const auto &N : *CurSCC) {
278 assert(!N.isDead());
279 auto I = NodesInLastSCC.insert(Ptr: &N);
280 if (I.second)
281 EdgesOfLastSeenNodes += getLocalCalls(F&: N.getFunction());
282 }
283 assert(NodeCount >= NodesInLastSCC.size());
284 assert(EdgeCount >= EdgesOfLastSeenNodes);
285}
286
287int64_t MLInlineAdvisor::getLocalCalls(Function &F) {
288 return getCachedFPI(F).DirectCallsToDefinedFunctions;
289}
290
291// Update the internal state of the advisor, and force invalidate feature
292// analysis. Currently, we maintain minimal (and very simple) global state - the
293// number of functions and the number of static calls. We also keep track of the
294// total IR size in this module, to stop misbehaving policies at a certain bloat
295// factor (SizeIncreaseThreshold)
296void MLInlineAdvisor::onSuccessfulInlining(const MLInlineAdvice &Advice,
297 bool CalleeWasDeleted) {
298 assert(!ForceStop);
299 Function *Caller = Advice.getCaller();
300 Function *Callee = Advice.getCallee();
301 // The caller features aren't valid anymore.
302 {
303 PreservedAnalyses PA = PreservedAnalyses::all();
304 PA.abandon<FunctionPropertiesAnalysis>();
305 PA.abandon<LoopAnalysis>();
306 FAM.invalidate(IR&: *Caller, PA);
307 }
308 Advice.updateCachedCallerFPI(FAM);
309 int64_t IRSizeAfter =
310 getIRSize(F&: *Caller) + (CalleeWasDeleted ? 0 : Advice.CalleeIRSize);
311 CurrentIRSize += IRSizeAfter - (Advice.CallerIRSize + Advice.CalleeIRSize);
312 if (CurrentIRSize > SizeIncreaseThreshold * InitialIRSize)
313 ForceStop = true;
314
315 // We can delta-update module-wide features. We know the inlining only changed
316 // the caller, and maybe the callee (by deleting the latter).
317 // Nodes are simple to update.
318 // For edges, we 'forget' the edges that the caller and callee used to have
319 // before inlining, and add back what they currently have together.
320 int64_t NewCallerAndCalleeEdges =
321 getCachedFPI(*Caller).DirectCallsToDefinedFunctions;
322
323 // A dead function's node is not actually removed from the call graph until
324 // the end of the call graph walk, but the node no longer belongs to any valid
325 // SCC.
326 if (CalleeWasDeleted) {
327 --NodeCount;
328 NodesInLastSCC.erase(Ptr: CG.lookup(F: *Callee));
329 DeadFunctions.insert(V: Callee);
330 } else {
331 NewCallerAndCalleeEdges +=
332 getCachedFPI(*Callee).DirectCallsToDefinedFunctions;
333 }
334 EdgeCount += (NewCallerAndCalleeEdges - Advice.CallerAndCalleeEdges);
335 assert(CurrentIRSize >= 0 && EdgeCount >= 0 && NodeCount >= 0);
336}
337
338int64_t MLInlineAdvisor::getModuleIRSize() const {
339 int64_t Ret = 0;
340 for (auto &F : M)
341 if (!F.isDeclaration())
342 Ret += getIRSize(F);
343 return Ret;
344}
345
346FunctionPropertiesInfo &MLInlineAdvisor::getCachedFPI(Function &F) const {
347 auto InsertPair = FPICache.try_emplace(k: &F);
348 if (!InsertPair.second)
349 return InsertPair.first->second;
350 InsertPair.first->second = FAM.getResult<FunctionPropertiesAnalysis>(IR&: F);
351 return InsertPair.first->second;
352}
353
354std::unique_ptr<InlineAdvice> MLInlineAdvisor::getAdviceImpl(CallBase &CB) {
355 if (auto Skip = getSkipAdviceIfUnreachableCallsite(CB))
356 return Skip;
357
358 auto &Caller = *CB.getCaller();
359 auto &Callee = *CB.getCalledFunction();
360
361 auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & {
362 return FAM.getResult<AssumptionAnalysis>(IR&: F);
363 };
364 auto &TIR = FAM.getResult<TargetIRAnalysis>(IR&: Callee);
365 auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(IR&: Caller);
366
367 if (SkipPolicy == SkipMLPolicyCriteria::IfCallerIsNotCold) {
368 if (!PSI.isFunctionEntryCold(F: &Caller))
369 return std::make_unique<InlineAdvice>(args: this, args&: CB, args&: ORE,
370 args: GetDefaultAdvice(CB));
371 }
372 auto MandatoryKind = InlineAdvisor::getMandatoryKind(CB, FAM, ORE);
373 // If this is a "never inline" case, there won't be any changes to internal
374 // state we need to track, so we can just return the base InlineAdvice, which
375 // will do nothing interesting.
376 // Same thing if this is a recursive case.
377 if (MandatoryKind == InlineAdvisor::MandatoryInliningKind::Never ||
378 &Caller == &Callee)
379 return getMandatoryAdvice(CB, Advice: false);
380
381 bool Mandatory =
382 MandatoryKind == InlineAdvisor::MandatoryInliningKind::Always;
383
384 // If we need to stop, we won't want to track anymore any state changes, so
385 // we just return the base InlineAdvice, which acts as a noop.
386 if (ForceStop) {
387 ORE.emit(RemarkBuilder: [&] {
388 return OptimizationRemarkMissed(DEBUG_TYPE, "ForceStop", &CB)
389 << "Won't attempt inlining because module size grew too much.";
390 });
391 return std::make_unique<InlineAdvice>(args: this, args&: CB, args&: ORE, args&: Mandatory);
392 }
393
394 int CostEstimate = 0;
395 if (!Mandatory) {
396 auto IsCallSiteInlinable =
397 llvm::getInliningCostEstimate(Call&: CB, CalleeTTI&: TIR, GetAssumptionCache);
398 if (!IsCallSiteInlinable) {
399 // We can't inline this for correctness reasons, so return the base
400 // InlineAdvice, as we don't care about tracking any state changes (which
401 // won't happen).
402 return std::make_unique<InlineAdvice>(args: this, args&: CB, args&: ORE, args: false);
403 }
404 CostEstimate = *IsCallSiteInlinable;
405 }
406
407 const auto CostFeatures =
408 llvm::getInliningCostFeatures(Call&: CB, CalleeTTI&: TIR, GetAssumptionCache);
409 if (!CostFeatures) {
410 return std::make_unique<InlineAdvice>(args: this, args&: CB, args&: ORE, args: false);
411 }
412
413 if (Mandatory)
414 return getMandatoryAdvice(CB, Advice: true);
415
416 auto NumCtantParams = 0;
417 for (auto I = CB.arg_begin(), E = CB.arg_end(); I != E; ++I) {
418 NumCtantParams += (isa<Constant>(Val: *I));
419 }
420
421 auto &CallerBefore = getCachedFPI(F&: Caller);
422 auto &CalleeBefore = getCachedFPI(F&: Callee);
423
424 *ModelRunner->getTensor<int64_t>(FeatureID: FeatureIndex::callee_basic_block_count) =
425 CalleeBefore.BasicBlockCount;
426 *ModelRunner->getTensor<int64_t>(FeatureID: FeatureIndex::callsite_height) =
427 getInitialFunctionLevel(F: Caller);
428 *ModelRunner->getTensor<int64_t>(FeatureID: FeatureIndex::node_count) = NodeCount;
429 *ModelRunner->getTensor<int64_t>(FeatureID: FeatureIndex::nr_ctant_params) =
430 NumCtantParams;
431 *ModelRunner->getTensor<int64_t>(FeatureID: FeatureIndex::edge_count) = EdgeCount;
432 *ModelRunner->getTensor<int64_t>(FeatureID: FeatureIndex::caller_users) =
433 CallerBefore.Uses;
434 *ModelRunner->getTensor<int64_t>(
435 FeatureID: FeatureIndex::caller_conditionally_executed_blocks) =
436 CallerBefore.BlocksReachedFromConditionalInstruction;
437 *ModelRunner->getTensor<int64_t>(FeatureID: FeatureIndex::caller_basic_block_count) =
438 CallerBefore.BasicBlockCount;
439 *ModelRunner->getTensor<int64_t>(
440 FeatureID: FeatureIndex::callee_conditionally_executed_blocks) =
441 CalleeBefore.BlocksReachedFromConditionalInstruction;
442 *ModelRunner->getTensor<int64_t>(FeatureID: FeatureIndex::callee_users) =
443 CalleeBefore.Uses;
444 *ModelRunner->getTensor<int64_t>(FeatureID: FeatureIndex::cost_estimate) = CostEstimate;
445 *ModelRunner->getTensor<int64_t>(FeatureID: FeatureIndex::is_callee_avail_external) =
446 Callee.hasAvailableExternallyLinkage();
447 *ModelRunner->getTensor<int64_t>(FeatureID: FeatureIndex::is_caller_avail_external) =
448 Caller.hasAvailableExternallyLinkage();
449
450 if (UseIR2Vec) {
451 // Python side expects float embeddings. The IR2Vec embeddings are doubles
452 // as of now due to the restriction of fromJSON method used by the
453 // readVocabulary method in ir2vec::Embeddings.
454 auto setEmbedding = [&](const ir2vec::Embedding &Embedding,
455 FeatureIndex Index) {
456 llvm::transform(Range: Embedding, d_first: ModelRunner->getTensor<float>(FeatureID: Index),
457 F: [](double Val) { return static_cast<float>(Val); });
458 };
459
460 setEmbedding(CalleeBefore.getFunctionEmbedding(),
461 FeatureIndex::callee_embedding);
462 setEmbedding(CallerBefore.getFunctionEmbedding(),
463 FeatureIndex::caller_embedding);
464 }
465
466 // Add the cost features
467 for (size_t I = 0;
468 I < static_cast<size_t>(InlineCostFeatureIndex::NumberOfFeatures); ++I) {
469 *ModelRunner->getTensor<int64_t>(FeatureID: inlineCostFeatureToMlFeature(
470 Feature: static_cast<InlineCostFeatureIndex>(I))) = CostFeatures->at(n: I);
471 }
472 // This one would have been set up to be right at the end.
473 if (!InteractiveChannelBaseName.empty() && InteractiveIncludeDefault)
474 *ModelRunner->getTensor<int64_t>(FeatureID: FeatureMap.size()) = GetDefaultAdvice(CB);
475 return getAdviceFromModel(CB, ORE);
476}
477
478std::unique_ptr<MLInlineAdvice>
479MLInlineAdvisor::getAdviceFromModel(CallBase &CB,
480 OptimizationRemarkEmitter &ORE) {
481 return std::make_unique<MLInlineAdvice>(
482 args: this, args&: CB, args&: ORE, args: static_cast<bool>(ModelRunner->evaluate<int64_t>()));
483}
484
485std::unique_ptr<InlineAdvice>
486MLInlineAdvisor::getSkipAdviceIfUnreachableCallsite(CallBase &CB) {
487 if (!FAM.getResult<DominatorTreeAnalysis>(IR&: *CB.getCaller())
488 .isReachableFromEntry(A: CB.getParent()))
489 return std::make_unique<InlineAdvice>(args: this, args&: CB, args&: getCallerORE(CB), args: false);
490 return nullptr;
491}
492
493std::unique_ptr<InlineAdvice> MLInlineAdvisor::getMandatoryAdvice(CallBase &CB,
494 bool Advice) {
495 // Make sure we track inlinings in all cases - mandatory or not.
496 if (auto Skip = getSkipAdviceIfUnreachableCallsite(CB))
497 return Skip;
498 if (Advice && !ForceStop)
499 return getMandatoryAdviceImpl(CB);
500
501 // If this is a "never inline" case, there won't be any changes to internal
502 // state we need to track, so we can just return the base InlineAdvice, which
503 // will do nothing interesting.
504 // Same if we are forced to stop - we don't track anymore.
505 return std::make_unique<InlineAdvice>(args: this, args&: CB, args&: getCallerORE(CB), args&: Advice);
506}
507
508std::unique_ptr<MLInlineAdvice>
509MLInlineAdvisor::getMandatoryAdviceImpl(CallBase &CB) {
510 return std::make_unique<MLInlineAdvice>(args: this, args&: CB, args&: getCallerORE(CB), args: true);
511}
512
513void MLInlineAdvisor::print(raw_ostream &OS) const {
514 OS << "[MLInlineAdvisor] Nodes: " << NodeCount << " Edges: " << EdgeCount
515 << " EdgesOfLastSeenNodes: " << EdgesOfLastSeenNodes << "\n";
516 OS << "[MLInlineAdvisor] FPI:\n";
517 for (auto I : FPICache) {
518 OS << I.first->getName() << ":\n";
519 I.second.print(OS);
520 OS << "\n";
521 }
522 OS << "\n";
523 OS << "[MLInlineAdvisor] FuncLevels:\n";
524 for (auto I : FunctionLevels)
525 OS << (DeadFunctions.contains(V: &I.first->getFunction())
526 ? "<deleted>"
527 : I.first->getFunction().getName())
528 << " : " << I.second << "\n";
529
530 OS << "\n";
531}
532
533MLInlineAdvice::MLInlineAdvice(MLInlineAdvisor *Advisor, CallBase &CB,
534 OptimizationRemarkEmitter &ORE,
535 bool Recommendation)
536 : InlineAdvice(Advisor, CB, ORE, Recommendation),
537 CallerIRSize(Advisor->isForcedToStop() ? 0 : Advisor->getIRSize(F&: *Caller)),
538 CalleeIRSize(Advisor->isForcedToStop() ? 0 : Advisor->getIRSize(F&: *Callee)),
539 CallerAndCalleeEdges(Advisor->isForcedToStop()
540 ? 0
541 : (Advisor->getLocalCalls(F&: *Caller) +
542 Advisor->getLocalCalls(F&: *Callee))),
543 PreInlineCallerFPI(Advisor->getCachedFPI(F&: *Caller)) {
544 if (Recommendation)
545 FPU.emplace(args&: Advisor->getCachedFPI(F&: *getCaller()), args&: CB);
546}
547
548void MLInlineAdvice::reportContextForRemark(
549 DiagnosticInfoOptimizationBase &OR) {
550 using namespace ore;
551 OR << NV("Callee", Callee->getName());
552 for (size_t I = 0; I < FeatureMap.size(); ++I)
553 OR << NV(FeatureMap[I].name(),
554 *getAdvisor()->getModelRunner().getTensor<int64_t>(FeatureID: I));
555 OR << NV("ShouldInline", isInliningRecommended());
556}
557
558void MLInlineAdvice::updateCachedCallerFPI(FunctionAnalysisManager &FAM) const {
559 FPU->finish(FAM);
560}
561
562void MLInlineAdvice::recordInliningImpl() {
563 ORE.emit(RemarkBuilder: [&]() {
564 OptimizationRemark R(DEBUG_TYPE, "InliningSuccess", DLoc, Block);
565 reportContextForRemark(OR&: R);
566 return R;
567 });
568 getAdvisor()->onSuccessfulInlining(Advice: *this, /*CalleeWasDeleted*/ false);
569}
570
571void MLInlineAdvice::recordInliningWithCalleeDeletedImpl() {
572 ORE.emit(RemarkBuilder: [&]() {
573 OptimizationRemark R(DEBUG_TYPE, "InliningSuccessWithCalleeDeleted", DLoc,
574 Block);
575 reportContextForRemark(OR&: R);
576 return R;
577 });
578 getAdvisor()->onSuccessfulInlining(Advice: *this, /*CalleeWasDeleted*/ true);
579}
580
581void MLInlineAdvice::recordUnsuccessfulInliningImpl(
582 const InlineResult &Result) {
583 getAdvisor()->getCachedFPI(F&: *Caller) = PreInlineCallerFPI;
584 ORE.emit(RemarkBuilder: [&]() {
585 OptimizationRemarkMissed R(DEBUG_TYPE, "InliningAttemptedAndUnsuccessful",
586 DLoc, Block);
587 reportContextForRemark(OR&: R);
588 return R;
589 });
590}
591void MLInlineAdvice::recordUnattemptedInliningImpl() {
592 assert(!FPU);
593 ORE.emit(RemarkBuilder: [&]() {
594 OptimizationRemarkMissed R(DEBUG_TYPE, "IniningNotAttempted", DLoc, Block);
595 reportContextForRemark(OR&: R);
596 return R;
597 });
598}
599