1//- NVPTXForwardParams.cpp - NVPTX Forward Device Params Removing Local Copy -//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// PTX supports 2 methods of accessing device function parameters:
10//
11// - "simple" case: If a parameters is only loaded, and all loads can address
12// the parameter via a constant offset, then the parameter may be loaded via
13// the ".param" address space. This case is not possible if the parameters
14// is stored to or has it's address taken. This method is preferable when
15// possible. Ex:
16//
17// ld.param.u32 %r1, [foo_param_1];
18// ld.param.u32 %r2, [foo_param_1+4];
19//
20// - "move param" case: For more complex cases the address of the param may be
21// placed in a register via a "mov" instruction. This "mov" also implicitly
22// moves the param to the ".local" address space and allows for it to be
23// written to. This essentially defers the responsibilty of the byval copy
24// to the PTX calling convention.
25//
26// mov.b64 %rd1, foo_param_0;
27// st.local.u32 [%rd1], 42;
28// add.u64 %rd3, %rd1, %rd2;
29// ld.local.u32 %r2, [%rd3];
30//
31// In NVPTXLowerArgs and SelectionDAG, we pessimistically assume that all
32// parameters will use the "move param" case and the local address space. This
33// pass is responsible for switching to the "simple" case when possible, as it
34// is more efficient.
35//
36// We do this by simply traversing uses of the param "mov" instructions an
37// trivially checking if they are all loads.
38//
39//===----------------------------------------------------------------------===//
40
41#include "NVPTX.h"
42#include "llvm/ADT/SmallVector.h"
43#include "llvm/CodeGen/MachineFunctionPass.h"
44#include "llvm/CodeGen/MachineInstr.h"
45#include "llvm/CodeGen/MachineOperand.h"
46#include "llvm/CodeGen/MachineRegisterInfo.h"
47#include "llvm/CodeGen/TargetRegisterInfo.h"
48#include "llvm/Support/ErrorHandling.h"
49
50using namespace llvm;
51
52static bool traverseMoveUse(MachineInstr &U, const MachineRegisterInfo &MRI,
53 SmallVectorImpl<MachineInstr *> &RemoveList,
54 SmallVectorImpl<MachineInstr *> &LoadInsts) {
55 switch (U.getOpcode()) {
56 case NVPTX::LD_i16:
57 case NVPTX::LD_i32:
58 case NVPTX::LD_i64:
59 case NVPTX::LD_i8:
60 case NVPTX::LDV_i16_v2:
61 case NVPTX::LDV_i16_v4:
62 case NVPTX::LDV_i32_v2:
63 case NVPTX::LDV_i32_v4:
64 case NVPTX::LDV_i64_v2:
65 case NVPTX::LDV_i64_v4:
66 case NVPTX::LDV_i8_v2:
67 case NVPTX::LDV_i8_v4: {
68 LoadInsts.push_back(Elt: &U);
69 return true;
70 }
71 case NVPTX::cvta_local:
72 case NVPTX::cvta_local_64:
73 case NVPTX::cvta_to_local:
74 case NVPTX::cvta_to_local_64: {
75 for (auto &U2 : MRI.use_instructions(Reg: U.operands_begin()->getReg()))
76 if (!traverseMoveUse(U&: U2, MRI, RemoveList, LoadInsts))
77 return false;
78
79 RemoveList.push_back(Elt: &U);
80 return true;
81 }
82 default:
83 return false;
84 }
85}
86
87static bool eliminateMove(MachineInstr &Mov, const MachineRegisterInfo &MRI,
88 SmallVectorImpl<MachineInstr *> &RemoveList) {
89 SmallVector<MachineInstr *, 16> MaybeRemoveList;
90 SmallVector<MachineInstr *, 16> LoadInsts;
91
92 for (auto &U : MRI.use_instructions(Reg: Mov.operands_begin()->getReg()))
93 if (!traverseMoveUse(U, MRI, RemoveList&: MaybeRemoveList, LoadInsts))
94 return false;
95
96 RemoveList.append(RHS: MaybeRemoveList);
97 RemoveList.push_back(Elt: &Mov);
98
99 const MachineOperand *ParamSymbol = Mov.uses().begin();
100 assert(ParamSymbol->isSymbol());
101
102 constexpr unsigned LDInstBasePtrOpIdx = 5;
103 constexpr unsigned LDInstAddrSpaceOpIdx = 2;
104 for (auto *LI : LoadInsts) {
105 (LI->uses().begin() + LDInstBasePtrOpIdx)
106 ->ChangeToES(SymName: ParamSymbol->getSymbolName());
107 (LI->uses().begin() + LDInstAddrSpaceOpIdx)
108 ->ChangeToImmediate(ImmVal: NVPTX::AddressSpace::Param);
109 }
110 return true;
111}
112
113static bool forwardDeviceParams(MachineFunction &MF) {
114 const auto &MRI = MF.getRegInfo();
115
116 bool Changed = false;
117 SmallVector<MachineInstr *, 16> RemoveList;
118 for (auto &MI : make_early_inc_range(Range&: *MF.begin()))
119 if (MI.getOpcode() == NVPTX::MOV32_PARAM ||
120 MI.getOpcode() == NVPTX::MOV64_PARAM)
121 Changed |= eliminateMove(Mov&: MI, MRI, RemoveList);
122
123 for (auto *MI : RemoveList)
124 MI->eraseFromParent();
125
126 return Changed;
127}
128
129/// ----------------------------------------------------------------------------
130/// Pass (Manager) Boilerplate
131/// ----------------------------------------------------------------------------
132
133namespace {
134struct NVPTXForwardParamsPass : public MachineFunctionPass {
135 static char ID;
136 NVPTXForwardParamsPass() : MachineFunctionPass(ID) {}
137
138 bool runOnMachineFunction(MachineFunction &MF) override;
139
140 void getAnalysisUsage(AnalysisUsage &AU) const override {
141 MachineFunctionPass::getAnalysisUsage(AU);
142 }
143};
144} // namespace
145
146char NVPTXForwardParamsPass::ID = 0;
147
148INITIALIZE_PASS(NVPTXForwardParamsPass, "nvptx-forward-params",
149 "NVPTX Forward Params", false, false)
150
151bool NVPTXForwardParamsPass::runOnMachineFunction(MachineFunction &MF) {
152 return forwardDeviceParams(MF);
153}
154
155MachineFunctionPass *llvm::createNVPTXForwardParamsPass() {
156 return new NVPTXForwardParamsPass();
157}
158