1//===- MLInlineAdvisor.cpp - machine learned InlineAdvisor ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the interface between the inliner and a learned model.
10// It delegates model evaluation to either the AOT compiled model (the
11// 'release' mode) or a runtime-loaded model (the 'development' case).
12//
13//===----------------------------------------------------------------------===//
14#include "llvm/Analysis/MLInlineAdvisor.h"
15#include "llvm/ADT/SCCIterator.h"
16#include "llvm/Analysis/AssumptionCache.h"
17#include "llvm/Analysis/BlockFrequencyInfo.h"
18#include "llvm/Analysis/CallGraph.h"
19#include "llvm/Analysis/FunctionPropertiesAnalysis.h"
20#include "llvm/Analysis/InlineCost.h"
21#include "llvm/Analysis/InlineModelFeatureMaps.h"
22#include "llvm/Analysis/InteractiveModelRunner.h"
23#include "llvm/Analysis/LazyCallGraph.h"
24#include "llvm/Analysis/LoopInfo.h"
25#include "llvm/Analysis/MLModelRunner.h"
26#include "llvm/Analysis/OptimizationRemarkEmitter.h"
27#include "llvm/Analysis/ProfileSummaryInfo.h"
28#include "llvm/Analysis/ReleaseModeModelRunner.h"
29#include "llvm/Analysis/TargetTransformInfo.h"
30#include "llvm/Analysis/TensorSpec.h"
31#include "llvm/IR/Dominators.h"
32#include "llvm/IR/InstIterator.h"
33#include "llvm/IR/Module.h"
34#include "llvm/IR/PassManager.h"
35#include "llvm/Support/CommandLine.h"
36
37using namespace llvm;
38
39static cl::opt<std::string> InteractiveChannelBaseName(
40 "inliner-interactive-channel-base", cl::Hidden,
41 cl::desc(
42 "Base file path for the interactive mode. The incoming filename should "
43 "have the name <inliner-interactive-channel-base>.in, while the "
44 "outgoing name should be <inliner-interactive-channel-base>.out"));
45static const std::string InclDefaultMsg =
46 (Twine("In interactive mode, also send the default policy decision: ") +
47 DefaultDecisionName + ".")
48 .str();
49static cl::opt<bool>
50 InteractiveIncludeDefault("inliner-interactive-include-default", cl::Hidden,
51 cl::desc(InclDefaultMsg));
52
53enum class SkipMLPolicyCriteria { Never, IfCallerIsNotCold };
54
55static cl::opt<SkipMLPolicyCriteria> SkipPolicy(
56 "ml-inliner-skip-policy", cl::Hidden, cl::init(Val: SkipMLPolicyCriteria::Never),
57 cl::values(clEnumValN(SkipMLPolicyCriteria::Never, "never", "never"),
58 clEnumValN(SkipMLPolicyCriteria::IfCallerIsNotCold,
59 "if-caller-not-cold", "if the caller is not cold")));
60
61static cl::opt<std::string> ModelSelector("ml-inliner-model-selector",
62 cl::Hidden, cl::init(Val: ""));
63
64static cl::opt<bool> StopImmediatelyForTest("ml-inliner-stop-immediately",
65 cl::Hidden);
66
67#if defined(LLVM_HAVE_TF_AOT_INLINERSIZEMODEL)
68// codegen-ed file
69#include "InlinerSizeModel.h" // NOLINT
70using CompiledModelType = llvm::InlinerSizeModel;
71#else
72using CompiledModelType = NoopSavedModelImpl;
73#endif
74
75std::unique_ptr<InlineAdvisor>
76llvm::getReleaseModeAdvisor(Module &M, ModuleAnalysisManager &MAM,
77 std::function<bool(CallBase &)> GetDefaultAdvice) {
78 if (!llvm::isEmbeddedModelEvaluatorValid<CompiledModelType>() &&
79 InteractiveChannelBaseName.empty())
80 return nullptr;
81 auto RunnerFactory = [&](const std::vector<TensorSpec> &InputFeatures)
82 -> std::unique_ptr<MLModelRunner> {
83 std::unique_ptr<MLModelRunner> AOTRunner;
84 if (InteractiveChannelBaseName.empty())
85 AOTRunner = std::make_unique<ReleaseModeModelRunner<CompiledModelType>>(
86 args&: M.getContext(), args: InputFeatures, args: DecisionName,
87 args&: EmbeddedModelRunnerOptions().setModelSelector(ModelSelector));
88 else {
89 AOTRunner = std::make_unique<InteractiveModelRunner>(
90 args&: M.getContext(), args: InputFeatures, args: InlineDecisionSpec,
91 args: InteractiveChannelBaseName + ".out",
92 args: InteractiveChannelBaseName + ".in");
93 }
94 return AOTRunner;
95 };
96 return std::make_unique<MLInlineAdvisor>(args&: M, args&: MAM, args&: RunnerFactory,
97 args&: GetDefaultAdvice);
98}
99
100#define DEBUG_TYPE "inline-ml"
101
102static cl::opt<float> SizeIncreaseThreshold(
103 "ml-advisor-size-increase-threshold", cl::Hidden,
104 cl::desc("Maximum factor by which expected native size may increase before "
105 "blocking any further inlining."),
106 cl::init(Val: 2.0));
107
108static cl::opt<bool> KeepFPICache(
109 "ml-advisor-keep-fpi-cache", cl::Hidden,
110 cl::desc(
111 "For test - keep the ML Inline advisor's FunctionPropertiesInfo cache"),
112 cl::init(Val: false));
113
114const std::vector<TensorSpec> &MLInlineAdvisor::getInitialFeatureMap() {
115 // clang-format off
116static std::vector<TensorSpec> FeatureMap{
117#define POPULATE_NAMES(DTYPE, SHAPE, NAME, __) TensorSpec::createSpec<DTYPE>(#NAME, SHAPE),
118// InlineCost features - these must come first
119 INLINE_COST_FEATURE_ITERATOR(POPULATE_NAMES)
120
121// Non-cost features
122 INLINE_FEATURE_ITERATOR(POPULATE_NAMES)
123#undef POPULATE_NAMES
124};
125 // clang-format on
126 return FeatureMap;
127}
128
129const char *const llvm::DecisionName = "inlining_decision";
130const TensorSpec llvm::InlineDecisionSpec =
131 TensorSpec::createSpec<int64_t>(Name: DecisionName, Shape: {1});
132const char *const llvm::DefaultDecisionName = "inlining_default";
133const TensorSpec llvm::DefaultDecisionSpec =
134 TensorSpec::createSpec<int64_t>(Name: DefaultDecisionName, Shape: {1});
135const char *const llvm::RewardName = "delta_size";
136
137CallBase *getInlinableCS(Instruction &I) {
138 if (auto *CS = dyn_cast<CallBase>(Val: &I))
139 if (Function *Callee = CS->getCalledFunction()) {
140 if (!Callee->isDeclaration()) {
141 return CS;
142 }
143 }
144 return nullptr;
145}
146
147MLInlineAdvisor::MLInlineAdvisor(
148 Module &M, ModuleAnalysisManager &MAM,
149 std::function<
150 std::unique_ptr<MLModelRunner>(const std::vector<TensorSpec> &)>
151 GetModelRunner,
152 std::function<bool(CallBase &)> GetDefaultAdvice)
153 : InlineAdvisor(
154 M, MAM.getResult<FunctionAnalysisManagerModuleProxy>(IR&: M).getManager()),
155 GetDefaultAdvice(GetDefaultAdvice), FeatureMap(getInitialFeatureMap()),
156 CG(MAM.getResult<LazyCallGraphAnalysis>(IR&: M)),
157 UseIR2Vec(MAM.getCachedResult<IR2VecVocabAnalysis>(IR&: M) != nullptr),
158 InitialIRSize(getModuleIRSize()), CurrentIRSize(InitialIRSize),
159 PSI(MAM.getResult<ProfileSummaryAnalysis>(IR&: M)) {
160 // Extract the 'call site height' feature - the position of a call site
161 // relative to the farthest statically reachable SCC node. We don't mutate
162 // this value while inlining happens. Empirically, this feature proved
163 // critical in behavioral cloning - i.e. training a model to mimic the manual
164 // heuristic's decisions - and, thus, equally important for training for
165 // improvement.
166 CallGraph CGraph(M);
167 for (auto I = scc_begin(G: &CGraph); !I.isAtEnd(); ++I) {
168 const std::vector<CallGraphNode *> &CGNodes = *I;
169 unsigned Level = 0;
170 for (auto *CGNode : CGNodes) {
171 Function *F = CGNode->getFunction();
172 if (!F || F->isDeclaration())
173 continue;
174 for (auto &I : instructions(F)) {
175 if (auto *CS = getInlinableCS(I)) {
176 auto *Called = CS->getCalledFunction();
177 auto Pos = FunctionLevels.find(x: &CG.get(F&: *Called));
178 // In bottom up traversal, an inlinable callee is either in the
179 // same SCC, or to a function in a visited SCC. So not finding its
180 // level means we haven't visited it yet, meaning it's in this SCC.
181 if (Pos == FunctionLevels.end())
182 continue;
183 Level = std::max(a: Level, b: Pos->second + 1);
184 }
185 }
186 }
187 for (auto *CGNode : CGNodes) {
188 Function *F = CGNode->getFunction();
189 if (F && !F->isDeclaration())
190 FunctionLevels[&CG.get(F&: *F)] = Level;
191 }
192 }
193 for (auto KVP : FunctionLevels) {
194 AllNodes.insert(V: KVP.first);
195 EdgeCount += getLocalCalls(F&: KVP.first->getFunction());
196 }
197 NodeCount = AllNodes.size();
198
199 if (auto *IR2VecVocabResult = MAM.getCachedResult<IR2VecVocabAnalysis>(IR&: M)) {
200 if (!IR2VecVocabResult->isValid()) {
201 M.getContext().emitError(ErrorStr: "IR2VecVocabAnalysis is not valid");
202 return;
203 }
204 // Add the IR2Vec features to the feature map
205 auto IR2VecDim = IR2VecVocabResult->getDimension();
206 FeatureMap.push_back(
207 x: TensorSpec::createSpec<float>(Name: "callee_embedding", Shape: {IR2VecDim}));
208 FeatureMap.push_back(
209 x: TensorSpec::createSpec<float>(Name: "caller_embedding", Shape: {IR2VecDim}));
210 }
211 if (InteractiveIncludeDefault)
212 FeatureMap.push_back(x: DefaultDecisionSpec);
213
214 ModelRunner = GetModelRunner(getFeatureMap());
215 if (!ModelRunner) {
216 M.getContext().emitError(ErrorStr: "Could not create model runner");
217 return;
218 }
219 ModelRunner->switchContext(Name: "");
220 ForceStop = StopImmediatelyForTest;
221}
222
223unsigned MLInlineAdvisor::getInitialFunctionLevel(const Function &F) const {
224 return CG.lookup(F) ? FunctionLevels.at(k: CG.lookup(F)) : 0;
225}
226
227void MLInlineAdvisor::onPassEntry(LazyCallGraph::SCC *CurSCC) {
228 if (!CurSCC || ForceStop)
229 return;
230 FPICache.clear();
231 // Function passes executed between InlinerPass runs may have changed the
232 // module-wide features.
233 // The cgscc pass manager rules are such that:
234 // - if a pass leads to merging SCCs, then the pipeline is restarted on the
235 // merged SCC
236 // - if a pass leads to splitting the SCC, then we continue with one of the
237 // splits
238 // This means that the NodesInLastSCC is a superset (not strict) of the nodes
239 // that subsequent passes would have processed
240 // - in addition, if new Nodes were created by a pass (e.g. CoroSplit),
241 // they'd be adjacent to Nodes in the last SCC. So we just need to check the
242 // boundary of Nodes in NodesInLastSCC for Nodes we haven't seen. We don't
243 // care about the nature of the Edge (call or ref). `FunctionLevels`-wise, we
244 // record them at the same level as the original node (this is a choice, may
245 // need revisiting).
246 // - nodes are only deleted at the end of a call graph walk where they are
247 // batch deleted, so we shouldn't see any dead nodes here.
248 while (!NodesInLastSCC.empty()) {
249 const auto *N = *NodesInLastSCC.begin();
250 assert(!N->isDead());
251 NodesInLastSCC.erase(Ptr: N);
252 EdgeCount += getLocalCalls(F&: N->getFunction());
253 const auto NLevel = FunctionLevels.at(k: N);
254 for (const auto &E : *(*N)) {
255 const auto *AdjNode = &E.getNode();
256 assert(!AdjNode->isDead() && !AdjNode->getFunction().isDeclaration());
257 auto I = AllNodes.insert(V: AdjNode);
258 // We've discovered a new function.
259 if (I.second) {
260 ++NodeCount;
261 NodesInLastSCC.insert(Ptr: AdjNode);
262 FunctionLevels[AdjNode] = NLevel;
263 }
264 }
265 }
266
267 EdgeCount -= EdgesOfLastSeenNodes;
268 EdgesOfLastSeenNodes = 0;
269
270 // (Re)use NodesInLastSCC to remember the nodes in the SCC right now,
271 // in case the SCC is split before onPassExit and some nodes are split out
272 assert(NodesInLastSCC.empty());
273 for (const auto &N : *CurSCC)
274 NodesInLastSCC.insert(Ptr: &N);
275}
276
277void MLInlineAdvisor::onPassExit(LazyCallGraph::SCC *CurSCC) {
278 // No need to keep this around - function passes will invalidate it.
279 if (!KeepFPICache)
280 FPICache.clear();
281 if (!CurSCC || ForceStop)
282 return;
283 // Keep track of the nodes and edges we last saw. Then, in onPassEntry,
284 // we update the node count and edge count from the subset of these nodes that
285 // survived.
286 EdgesOfLastSeenNodes = 0;
287
288 // Check on nodes that were in SCC onPassEntry
289 for (const LazyCallGraph::Node *N : NodesInLastSCC) {
290 assert(!N->isDead());
291 EdgesOfLastSeenNodes += getLocalCalls(F&: N->getFunction());
292 }
293
294 // Check on nodes that may have got added to SCC
295 for (const auto &N : *CurSCC) {
296 assert(!N.isDead());
297 auto I = NodesInLastSCC.insert(Ptr: &N);
298 if (I.second)
299 EdgesOfLastSeenNodes += getLocalCalls(F&: N.getFunction());
300 }
301 assert(NodeCount >= NodesInLastSCC.size());
302 assert(EdgeCount >= EdgesOfLastSeenNodes);
303}
304
305int64_t MLInlineAdvisor::getLocalCalls(Function &F) {
306 return getCachedFPI(F).DirectCallsToDefinedFunctions;
307}
308
309// Update the internal state of the advisor, and force invalidate feature
310// analysis. Currently, we maintain minimal (and very simple) global state - the
311// number of functions and the number of static calls. We also keep track of the
312// total IR size in this module, to stop misbehaving policies at a certain bloat
313// factor (SizeIncreaseThreshold)
314void MLInlineAdvisor::onSuccessfulInlining(const MLInlineAdvice &Advice,
315 bool CalleeWasDeleted) {
316 assert(!ForceStop);
317 Function *Caller = Advice.getCaller();
318 Function *Callee = Advice.getCallee();
319 // The caller features aren't valid anymore.
320 {
321 PreservedAnalyses PA = PreservedAnalyses::all();
322 PA.abandon<FunctionPropertiesAnalysis>();
323 PA.abandon<LoopAnalysis>();
324 FAM.invalidate(IR&: *Caller, PA);
325 }
326 Advice.updateCachedCallerFPI(FAM);
327 if (Caller == Callee) {
328 assert(!CalleeWasDeleted);
329 // We double-counted CallerAndCalleeEdges - since the caller and callee
330 // would be the same
331 assert(Advice.CallerAndCalleeEdges % 2 == 0);
332 CurrentIRSize += getIRSize(F&: *Caller) - Advice.CallerIRSize;
333 EdgeCount += getCachedFPI(*Caller).DirectCallsToDefinedFunctions -
334 Advice.CallerAndCalleeEdges / 2;
335 // The NodeCount would stay the same.
336 } else {
337 int64_t IRSizeAfter =
338 getIRSize(F&: *Caller) + (CalleeWasDeleted ? 0 : Advice.CalleeIRSize);
339 CurrentIRSize += IRSizeAfter - (Advice.CallerIRSize + Advice.CalleeIRSize);
340
341 // We can delta-update module-wide features. We know the inlining only
342 // changed the caller, and maybe the callee (by deleting the latter). Nodes
343 // are simple to update. For edges, we 'forget' the edges that the caller
344 // and callee used to have before inlining, and add back what they currently
345 // have together.
346 int64_t NewCallerAndCalleeEdges =
347 getCachedFPI(*Caller).DirectCallsToDefinedFunctions;
348
349 // A dead function's node is not actually removed from the call graph until
350 // the end of the call graph walk, but the node no longer belongs to any
351 // valid SCC.
352 if (CalleeWasDeleted) {
353 --NodeCount;
354 NodesInLastSCC.erase(Ptr: CG.lookup(F: *Callee));
355 DeadFunctions.insert(V: Callee);
356 } else {
357 NewCallerAndCalleeEdges +=
358 getCachedFPI(*Callee).DirectCallsToDefinedFunctions;
359 }
360 EdgeCount += (NewCallerAndCalleeEdges - Advice.CallerAndCalleeEdges);
361 }
362 if (CurrentIRSize > SizeIncreaseThreshold * InitialIRSize)
363 ForceStop = true;
364
365 assert(CurrentIRSize >= 0 && EdgeCount >= 0 && NodeCount >= 0);
366}
367
368int64_t MLInlineAdvisor::getModuleIRSize() const {
369 int64_t Ret = 0;
370 for (auto &F : M)
371 if (!F.isDeclaration())
372 Ret += getIRSize(F);
373 return Ret;
374}
375
376FunctionPropertiesInfo &MLInlineAdvisor::getCachedFPI(Function &F) const {
377 auto InsertPair = FPICache.try_emplace(k: &F);
378 if (!InsertPair.second)
379 return InsertPair.first->second;
380 InsertPair.first->second = FAM.getResult<FunctionPropertiesAnalysis>(IR&: F);
381 return InsertPair.first->second;
382}
383
384std::unique_ptr<InlineAdvice> MLInlineAdvisor::getAdviceImpl(CallBase &CB) {
385 if (auto Skip = getSkipAdviceIfUnreachableCallsite(CB))
386 return Skip;
387
388 auto &Caller = *CB.getCaller();
389 auto &Callee = *CB.getCalledFunction();
390
391 auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & {
392 return FAM.getResult<AssumptionAnalysis>(IR&: F);
393 };
394 auto &TIR = FAM.getResult<TargetIRAnalysis>(IR&: Callee);
395 auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(IR&: Caller);
396
397 if (SkipPolicy == SkipMLPolicyCriteria::IfCallerIsNotCold) {
398 if (!PSI.isFunctionEntryCold(F: &Caller)) {
399 // Return a MLInlineAdvice, despite delegating to the default advice,
400 // because we need to keep track of the internal state. This is different
401 // from the other instances where we return a "default" InlineAdvice,
402 // which happen at points we won't come back to the MLAdvisor for
403 // decisions requiring that state.
404 return ForceStop ? std::make_unique<InlineAdvice>(args: this, args&: CB, args&: ORE,
405 args: GetDefaultAdvice(CB))
406 : std::make_unique<MLInlineAdvice>(args: this, args&: CB, args&: ORE,
407 args: GetDefaultAdvice(CB));
408 }
409 }
410 auto MandatoryKind = InlineAdvisor::getMandatoryKind(CB, FAM, ORE);
411 // If this is a "never inline" case, there won't be any changes to internal
412 // state we need to track, so we can just return the base InlineAdvice, which
413 // will do nothing interesting.
414 // Same thing if this is a recursive case.
415 if (MandatoryKind == InlineAdvisor::MandatoryInliningKind::Never ||
416 &Caller == &Callee)
417 return getMandatoryAdvice(CB, Advice: false);
418
419 bool Mandatory =
420 MandatoryKind == InlineAdvisor::MandatoryInliningKind::Always;
421
422 // If we need to stop, we won't want to track anymore any state changes, so
423 // we just return the base InlineAdvice, which acts as a noop.
424 if (ForceStop) {
425 ORE.emit(RemarkBuilder: [&] {
426 return OptimizationRemarkMissed(DEBUG_TYPE, "ForceStop", &CB)
427 << "Won't attempt inlining because module size grew too much.";
428 });
429 return std::make_unique<InlineAdvice>(args: this, args&: CB, args&: ORE, args&: Mandatory);
430 }
431
432 int CostEstimate = 0;
433 if (!Mandatory) {
434 auto IsCallSiteInlinable =
435 llvm::getInliningCostEstimate(Call&: CB, CalleeTTI&: TIR, GetAssumptionCache);
436 if (!IsCallSiteInlinable) {
437 // We can't inline this for correctness reasons, so return the base
438 // InlineAdvice, as we don't care about tracking any state changes (which
439 // won't happen).
440 return std::make_unique<InlineAdvice>(args: this, args&: CB, args&: ORE, args: false);
441 }
442 CostEstimate = *IsCallSiteInlinable;
443 }
444
445 const auto CostFeatures =
446 llvm::getInliningCostFeatures(Call&: CB, CalleeTTI&: TIR, GetAssumptionCache);
447 if (!CostFeatures) {
448 return std::make_unique<InlineAdvice>(args: this, args&: CB, args&: ORE, args: false);
449 }
450
451 if (Mandatory)
452 return getMandatoryAdvice(CB, Advice: true);
453
454 auto NumCtantParams = 0;
455 for (auto I = CB.arg_begin(), E = CB.arg_end(); I != E; ++I) {
456 NumCtantParams += (isa<Constant>(Val: *I));
457 }
458
459 auto &CallerBefore = getCachedFPI(F&: Caller);
460 auto &CalleeBefore = getCachedFPI(F&: Callee);
461
462 *ModelRunner->getTensor<int64_t>(FeatureID: FeatureIndex::callee_basic_block_count) =
463 CalleeBefore.BasicBlockCount;
464 *ModelRunner->getTensor<int64_t>(FeatureID: FeatureIndex::callsite_height) =
465 getInitialFunctionLevel(F: Caller);
466 *ModelRunner->getTensor<int64_t>(FeatureID: FeatureIndex::node_count) = NodeCount;
467 *ModelRunner->getTensor<int64_t>(FeatureID: FeatureIndex::nr_ctant_params) =
468 NumCtantParams;
469 *ModelRunner->getTensor<int64_t>(FeatureID: FeatureIndex::edge_count) = EdgeCount;
470 *ModelRunner->getTensor<int64_t>(FeatureID: FeatureIndex::caller_users) =
471 CallerBefore.Uses;
472 *ModelRunner->getTensor<int64_t>(
473 FeatureID: FeatureIndex::caller_conditionally_executed_blocks) =
474 CallerBefore.BlocksReachedFromConditionalInstruction;
475 *ModelRunner->getTensor<int64_t>(FeatureID: FeatureIndex::caller_basic_block_count) =
476 CallerBefore.BasicBlockCount;
477 *ModelRunner->getTensor<int64_t>(
478 FeatureID: FeatureIndex::callee_conditionally_executed_blocks) =
479 CalleeBefore.BlocksReachedFromConditionalInstruction;
480 *ModelRunner->getTensor<int64_t>(FeatureID: FeatureIndex::callee_users) =
481 CalleeBefore.Uses;
482 *ModelRunner->getTensor<int64_t>(FeatureID: FeatureIndex::cost_estimate) = CostEstimate;
483 *ModelRunner->getTensor<int64_t>(FeatureID: FeatureIndex::is_callee_avail_external) =
484 Callee.hasAvailableExternallyLinkage();
485 *ModelRunner->getTensor<int64_t>(FeatureID: FeatureIndex::is_caller_avail_external) =
486 Caller.hasAvailableExternallyLinkage();
487
488 if (UseIR2Vec) {
489 // Python side expects float embeddings. The IR2Vec embeddings are doubles
490 // as of now due to the restriction of fromJSON method used by the
491 // readVocabulary method in ir2vec::Embeddings.
492 auto setEmbedding = [&](const ir2vec::Embedding &Embedding,
493 FeatureIndex Index) {
494 llvm::transform(Range: Embedding, d_first: ModelRunner->getTensor<float>(FeatureID: Index),
495 F: [](double Val) { return static_cast<float>(Val); });
496 };
497
498 setEmbedding(CalleeBefore.getFunctionEmbedding(),
499 FeatureIndex::callee_embedding);
500 setEmbedding(CallerBefore.getFunctionEmbedding(),
501 FeatureIndex::caller_embedding);
502 }
503
504 // Add the cost features
505 for (size_t I = 0;
506 I < static_cast<size_t>(InlineCostFeatureIndex::NumberOfFeatures); ++I) {
507 *ModelRunner->getTensor<int64_t>(FeatureID: inlineCostFeatureToMlFeature(
508 Feature: static_cast<InlineCostFeatureIndex>(I))) = CostFeatures->at(n: I);
509 }
510 // This one would have been set up to be right at the end.
511 if (!InteractiveChannelBaseName.empty() && InteractiveIncludeDefault)
512 *ModelRunner->getTensor<int64_t>(FeatureID: getFeatureMap().size() - 1) =
513 GetDefaultAdvice(CB);
514 return getAdviceFromModel(CB, ORE);
515}
516
517std::unique_ptr<MLInlineAdvice>
518MLInlineAdvisor::getAdviceFromModel(CallBase &CB,
519 OptimizationRemarkEmitter &ORE) {
520 return std::make_unique<MLInlineAdvice>(
521 args: this, args&: CB, args&: ORE, args: static_cast<bool>(ModelRunner->evaluate<int64_t>()));
522}
523
524std::unique_ptr<InlineAdvice>
525MLInlineAdvisor::getSkipAdviceIfUnreachableCallsite(CallBase &CB) {
526 if (!FAM.getResult<DominatorTreeAnalysis>(IR&: *CB.getCaller())
527 .isReachableFromEntry(A: CB.getParent()))
528 return std::make_unique<InlineAdvice>(args: this, args&: CB, args&: getCallerORE(CB), args: false);
529 return nullptr;
530}
531
532std::unique_ptr<InlineAdvice> MLInlineAdvisor::getMandatoryAdvice(CallBase &CB,
533 bool Advice) {
534 // Make sure we track inlinings in all cases - mandatory or not.
535 if (auto Skip = getSkipAdviceIfUnreachableCallsite(CB))
536 return Skip;
537 if (Advice && !ForceStop)
538 return getMandatoryAdviceImpl(CB);
539
540 // If this is a "never inline" case, there won't be any changes to internal
541 // state we need to track, so we can just return the base InlineAdvice, which
542 // will do nothing interesting.
543 // Same if we are forced to stop - we don't track anymore.
544 return std::make_unique<InlineAdvice>(args: this, args&: CB, args&: getCallerORE(CB), args&: Advice);
545}
546
547std::unique_ptr<MLInlineAdvice>
548MLInlineAdvisor::getMandatoryAdviceImpl(CallBase &CB) {
549 return std::make_unique<MLInlineAdvice>(args: this, args&: CB, args&: getCallerORE(CB), args: true);
550}
551
552void MLInlineAdvisor::print(raw_ostream &OS) const {
553 OS << "[MLInlineAdvisor] Nodes: " << NodeCount << " Edges: " << EdgeCount
554 << " EdgesOfLastSeenNodes: " << EdgesOfLastSeenNodes << "\n";
555 OS << "[MLInlineAdvisor] FPI:\n";
556 for (auto I : FPICache) {
557 OS << I.first->getName() << ":\n";
558 I.second.print(OS);
559 OS << "\n";
560 }
561 OS << "\n";
562 OS << "[MLInlineAdvisor] FuncLevels:\n";
563 for (auto I : FunctionLevels)
564 OS << (DeadFunctions.contains(V: &I.first->getFunction())
565 ? "<deleted>"
566 : I.first->getFunction().getName())
567 << " : " << I.second << "\n";
568
569 OS << "\n";
570}
571
572MLInlineAdvice::MLInlineAdvice(MLInlineAdvisor *Advisor, CallBase &CB,
573 OptimizationRemarkEmitter &ORE,
574 bool Recommendation)
575 : InlineAdvice(Advisor, CB, ORE, Recommendation),
576 CallerIRSize(Advisor->isForcedToStop() ? 0 : Advisor->getIRSize(F&: *Caller)),
577 CalleeIRSize(Advisor->isForcedToStop() ? 0 : Advisor->getIRSize(F&: *Callee)),
578 CallerAndCalleeEdges(Advisor->isForcedToStop()
579 ? 0
580 : (Advisor->getLocalCalls(F&: *Caller) +
581 Advisor->getLocalCalls(F&: *Callee))),
582 PreInlineCallerFPI(Advisor->getCachedFPI(F&: *Caller)) {
583 if (Recommendation)
584 FPU.emplace(args&: Advisor->getCachedFPI(F&: *getCaller()), args&: CB);
585}
586
587void MLInlineAdvice::reportContextForRemark(
588 DiagnosticInfoOptimizationBase &OR) {
589 using namespace ore;
590 OR << NV("Callee", Callee->getName());
591 for (size_t I = 0; I < getAdvisor()->getFeatureMap().size(); ++I)
592 OR << NV(getAdvisor()->getFeatureMap()[I].name(),
593 *getAdvisor()->getModelRunner().getTensor<int64_t>(FeatureID: I));
594 OR << NV("ShouldInline", isInliningRecommended());
595}
596
597void MLInlineAdvice::updateCachedCallerFPI(FunctionAnalysisManager &FAM) const {
598 FPU->finish(FAM);
599}
600
601void MLInlineAdvice::recordInliningImpl() {
602 ORE.emit(RemarkBuilder: [&]() {
603 OptimizationRemark R(DEBUG_TYPE, "InliningSuccess", DLoc, Block);
604 reportContextForRemark(OR&: R);
605 return R;
606 });
607 getAdvisor()->onSuccessfulInlining(Advice: *this, /*CalleeWasDeleted*/ false);
608}
609
610void MLInlineAdvice::recordInliningWithCalleeDeletedImpl() {
611 ORE.emit(RemarkBuilder: [&]() {
612 OptimizationRemark R(DEBUG_TYPE, "InliningSuccessWithCalleeDeleted", DLoc,
613 Block);
614 reportContextForRemark(OR&: R);
615 return R;
616 });
617 getAdvisor()->onSuccessfulInlining(Advice: *this, /*CalleeWasDeleted*/ true);
618}
619
620void MLInlineAdvice::recordUnsuccessfulInliningImpl(
621 const InlineResult &Result) {
622 getAdvisor()->getCachedFPI(F&: *Caller) = PreInlineCallerFPI;
623 ORE.emit(RemarkBuilder: [&]() {
624 OptimizationRemarkMissed R(DEBUG_TYPE, "InliningAttemptedAndUnsuccessful",
625 DLoc, Block);
626 reportContextForRemark(OR&: R);
627 return R;
628 });
629}
630void MLInlineAdvice::recordUnattemptedInliningImpl() {
631 assert(!FPU);
632 ORE.emit(RemarkBuilder: [&]() {
633 OptimizationRemarkMissed R(DEBUG_TYPE, "IniningNotAttempted", DLoc, Block);
634 reportContextForRemark(OR&: R);
635 return R;
636 });
637}
638