1//== WebAssemblyMemIntrinsicResults.cpp - Optimize memory intrinsic results ==//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file implements an optimization pass using memory intrinsic results.
11///
12/// Calls to memory intrinsics (memcpy, memmove, memset) return the destination
13/// address. They are in the form of
14/// %dst_new = call @memcpy %dst, %src, %len
15/// where %dst and %dst_new registers contain the same value.
16///
17/// This is to enable an optimization wherein uses of the %dst register used in
18/// the parameter can be replaced by uses of the %dst_new register used in the
19/// result, making the %dst register more likely to be single-use, thus more
20/// likely to be useful to register stackifying, and potentially also exposing
21/// the call instruction itself to register stackifying. These both can reduce
22/// local.get/local.set traffic.
23///
24/// The LLVM intrinsics for these return void so they can't use the returned
25/// attribute and consequently aren't handled by the OptimizeReturned pass.
26///
27//===----------------------------------------------------------------------===//
28
29#include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
30#include "WebAssembly.h"
31#include "WebAssemblyMachineFunctionInfo.h"
32#include "WebAssemblySubtarget.h"
33#include "llvm/Analysis/TargetLibraryInfo.h"
34#include "llvm/CodeGen/LiveIntervals.h"
35#include "llvm/CodeGen/MachineBlockFrequencyInfo.h"
36#include "llvm/CodeGen/MachineDominators.h"
37#include "llvm/CodeGen/MachineRegisterInfo.h"
38#include "llvm/CodeGen/Passes.h"
39#include "llvm/Support/Debug.h"
40#include "llvm/Support/raw_ostream.h"
41using namespace llvm;
42
43#define DEBUG_TYPE "wasm-mem-intrinsic-results"
44
45namespace {
46class WebAssemblyMemIntrinsicResults final : public MachineFunctionPass {
47public:
48 static char ID; // Pass identification, replacement for typeid
49 WebAssemblyMemIntrinsicResults() : MachineFunctionPass(ID) {}
50
51 StringRef getPassName() const override {
52 return "WebAssembly Memory Intrinsic Results";
53 }
54
55 void getAnalysisUsage(AnalysisUsage &AU) const override {
56 AU.setPreservesCFG();
57 AU.addRequired<MachineBlockFrequencyInfoWrapperPass>();
58 AU.addPreserved<MachineBlockFrequencyInfoWrapperPass>();
59 AU.addRequired<MachineDominatorTreeWrapperPass>();
60 AU.addPreserved<MachineDominatorTreeWrapperPass>();
61 AU.addRequired<LiveIntervalsWrapperPass>();
62 AU.addPreserved<SlotIndexesWrapperPass>();
63 AU.addPreserved<LiveIntervalsWrapperPass>();
64 AU.addRequired<TargetLibraryInfoWrapperPass>();
65 AU.addRequired<LibcallLoweringInfoWrapper>();
66 MachineFunctionPass::getAnalysisUsage(AU);
67 }
68
69 bool runOnMachineFunction(MachineFunction &MF) override;
70
71private:
72 MachineDominatorTree *MDT;
73 LiveIntervals *LIS;
74 const TargetLibraryInfo *LibInfo;
75
76 StringRef MemcpyName, MemmoveName, MemsetName;
77
78 bool optimizeCall(MachineBasicBlock &MBB, MachineInstr &MI,
79 const MachineRegisterInfo &MRI) const;
80};
81} // end anonymous namespace
82
83char WebAssemblyMemIntrinsicResults::ID = 0;
84INITIALIZE_PASS(WebAssemblyMemIntrinsicResults, DEBUG_TYPE,
85 "Optimize memory intrinsic result values for WebAssembly",
86 false, false)
87
88FunctionPass *llvm::createWebAssemblyMemIntrinsicResults() {
89 return new WebAssemblyMemIntrinsicResults();
90}
91
92// Replace uses of FromReg with ToReg if they are dominated by MI.
93static bool replaceDominatedUses(MachineBasicBlock &MBB, MachineInstr &MI,
94 unsigned FromReg, unsigned ToReg,
95 const MachineRegisterInfo &MRI,
96 MachineDominatorTree &MDT,
97 LiveIntervals &LIS) {
98 bool Changed = false;
99
100 LiveInterval *FromLI = &LIS.getInterval(Reg: FromReg);
101 LiveInterval *ToLI = &LIS.getInterval(Reg: ToReg);
102
103 SlotIndex FromIdx = LIS.getInstructionIndex(Instr: MI).getRegSlot();
104 VNInfo *FromVNI = FromLI->getVNInfoAt(Idx: FromIdx);
105
106 SmallVector<SlotIndex, 4> Indices;
107
108 for (MachineOperand &O :
109 llvm::make_early_inc_range(Range: MRI.use_nodbg_operands(Reg: FromReg))) {
110 MachineInstr *Where = O.getParent();
111
112 // Check that MI dominates the instruction in the normal way.
113 if (&MI == Where || !MDT.dominates(A: &MI, B: Where))
114 continue;
115
116 // If this use gets a different value, skip it.
117 SlotIndex WhereIdx = LIS.getInstructionIndex(Instr: *Where);
118 VNInfo *WhereVNI = FromLI->getVNInfoAt(Idx: WhereIdx);
119 if (WhereVNI && WhereVNI != FromVNI)
120 continue;
121
122 // Make sure ToReg isn't clobbered before it gets there.
123 VNInfo *ToVNI = ToLI->getVNInfoAt(Idx: WhereIdx);
124 if (ToVNI && ToVNI != FromVNI)
125 continue;
126
127 Changed = true;
128 LLVM_DEBUG(dbgs() << "Setting operand " << O << " in " << *Where << " from "
129 << MI << "\n");
130 O.setReg(ToReg);
131
132 // If the store's def was previously dead, it is no longer.
133 if (!O.isUndef()) {
134 MI.getOperand(i: 0).setIsDead(false);
135
136 Indices.push_back(Elt: WhereIdx.getRegSlot());
137 }
138 }
139
140 if (Changed) {
141 // Extend ToReg's liveness.
142 LIS.extendToIndices(LR&: *ToLI, Indices);
143
144 // Shrink FromReg's liveness.
145 LIS.shrinkToUses(li: FromLI);
146
147 // If we replaced all dominated uses, FromReg is now killed at MI.
148 if (!FromLI->liveAt(index: FromIdx.getDeadSlot()))
149 MI.addRegisterKilled(IncomingReg: FromReg, RegInfo: MBB.getParent()
150 ->getSubtarget<WebAssemblySubtarget>()
151 .getRegisterInfo());
152 }
153
154 return Changed;
155}
156
157bool WebAssemblyMemIntrinsicResults::optimizeCall(
158 MachineBasicBlock &MBB, MachineInstr &MI,
159 const MachineRegisterInfo &MRI) const {
160 MachineOperand &Op1 = MI.getOperand(i: 1);
161 if (!Op1.isSymbol())
162 return false;
163
164 StringRef Name(Op1.getSymbolName());
165
166 // TODO: Could generalize by parsing to LibcallImpl and checking signature
167 // attributes
168 bool CallReturnsInput =
169 Name == MemcpyName || Name == MemmoveName || Name == MemsetName;
170 if (!CallReturnsInput)
171 return false;
172
173 LibFunc Func;
174 if (!LibInfo->getLibFunc(funcName: Name, F&: Func))
175 return false;
176
177 Register FromReg = MI.getOperand(i: 2).getReg();
178 Register ToReg = MI.getOperand(i: 0).getReg();
179 if (MRI.getRegClass(Reg: FromReg) != MRI.getRegClass(Reg: ToReg))
180 report_fatal_error(reason: "Memory Intrinsic results: call to builtin function "
181 "with wrong signature, from/to mismatch");
182 return replaceDominatedUses(MBB, MI, FromReg, ToReg, MRI, MDT&: *MDT, LIS&: *LIS);
183}
184
185bool WebAssemblyMemIntrinsicResults::runOnMachineFunction(MachineFunction &MF) {
186 LLVM_DEBUG({
187 dbgs() << "********** Memory Intrinsic Results **********\n"
188 << "********** Function: " << MF.getName() << '\n';
189 });
190
191 MachineRegisterInfo &MRI = MF.getRegInfo();
192 LIS = &getAnalysis<LiveIntervalsWrapperPass>().getLIS();
193 MDT = &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
194 const WebAssemblySubtarget &Subtarget =
195 MF.getSubtarget<WebAssemblySubtarget>();
196 LibInfo =
197 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F: MF.getFunction());
198 const LibcallLoweringInfo &Libcalls =
199 getAnalysis<LibcallLoweringInfoWrapper>().getLibcallLowering(
200 M: *MF.getFunction().getParent(), Subtarget);
201
202 MemcpyName = RTLIB::RuntimeLibcallsInfo::getLibcallImplName(
203 CallImpl: Libcalls.getLibcallImpl(Call: RTLIB::MEMCPY));
204 MemmoveName = RTLIB::RuntimeLibcallsInfo::getLibcallImplName(
205 CallImpl: Libcalls.getLibcallImpl(Call: RTLIB::MEMMOVE));
206 MemsetName = RTLIB::RuntimeLibcallsInfo::getLibcallImplName(
207 CallImpl: Libcalls.getLibcallImpl(Call: RTLIB::MEMSET));
208
209 bool Changed = false;
210
211 // We don't preserve SSA form.
212 MRI.leaveSSA();
213
214 assert(MRI.tracksLiveness() &&
215 "MemIntrinsicResults expects liveness tracking");
216
217 for (auto &MBB : MF) {
218 LLVM_DEBUG(dbgs() << "Basic Block: " << MBB.getName() << '\n');
219 for (auto &MI : MBB)
220 switch (MI.getOpcode()) {
221 default:
222 break;
223 case WebAssembly::CALL:
224 Changed |= optimizeCall(MBB, MI, MRI);
225 break;
226 }
227 }
228
229 return Changed;
230}
231