1 | //- NVPTXForwardParams.cpp - NVPTX Forward Device Params Removing Local Copy -// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // PTX supports 2 methods of accessing device function parameters: |
10 | // |
11 | // - "simple" case: If a parameters is only loaded, and all loads can address |
12 | // the parameter via a constant offset, then the parameter may be loaded via |
13 | // the ".param" address space. This case is not possible if the parameters |
14 | // is stored to or has it's address taken. This method is preferable when |
15 | // possible. Ex: |
16 | // |
17 | // ld.param.u32 %r1, [foo_param_1]; |
18 | // ld.param.u32 %r2, [foo_param_1+4]; |
19 | // |
20 | // - "move param" case: For more complex cases the address of the param may be |
21 | // placed in a register via a "mov" instruction. This "mov" also implicitly |
22 | // moves the param to the ".local" address space and allows for it to be |
23 | // written to. This essentially defers the responsibilty of the byval copy |
24 | // to the PTX calling convention. |
25 | // |
26 | // mov.b64 %rd1, foo_param_0; |
27 | // st.local.u32 [%rd1], 42; |
28 | // add.u64 %rd3, %rd1, %rd2; |
29 | // ld.local.u32 %r2, [%rd3]; |
30 | // |
31 | // In NVPTXLowerArgs and SelectionDAG, we pessimistically assume that all |
32 | // parameters will use the "move param" case and the local address space. This |
33 | // pass is responsible for switching to the "simple" case when possible, as it |
34 | // is more efficient. |
35 | // |
36 | // We do this by simply traversing uses of the param "mov" instructions an |
37 | // trivially checking if they are all loads. |
38 | // |
39 | //===----------------------------------------------------------------------===// |
40 | |
41 | #include "NVPTX.h" |
42 | #include "llvm/ADT/SmallVector.h" |
43 | #include "llvm/CodeGen/MachineFunctionPass.h" |
44 | #include "llvm/CodeGen/MachineInstr.h" |
45 | #include "llvm/CodeGen/MachineOperand.h" |
46 | #include "llvm/CodeGen/MachineRegisterInfo.h" |
47 | #include "llvm/CodeGen/TargetRegisterInfo.h" |
48 | #include "llvm/Support/ErrorHandling.h" |
49 | |
50 | using namespace llvm; |
51 | |
52 | static bool traverseMoveUse(MachineInstr &U, const MachineRegisterInfo &MRI, |
53 | SmallVectorImpl<MachineInstr *> &RemoveList, |
54 | SmallVectorImpl<MachineInstr *> &LoadInsts) { |
55 | switch (U.getOpcode()) { |
56 | case NVPTX::LD_i16: |
57 | case NVPTX::LD_i32: |
58 | case NVPTX::LD_i64: |
59 | case NVPTX::LD_i8: |
60 | case NVPTX::LDV_i16_v2: |
61 | case NVPTX::LDV_i16_v4: |
62 | case NVPTX::LDV_i32_v2: |
63 | case NVPTX::LDV_i32_v4: |
64 | case NVPTX::LDV_i64_v2: |
65 | case NVPTX::LDV_i64_v4: |
66 | case NVPTX::LDV_i8_v2: |
67 | case NVPTX::LDV_i8_v4: { |
68 | LoadInsts.push_back(Elt: &U); |
69 | return true; |
70 | } |
71 | case NVPTX::cvta_local: |
72 | case NVPTX::cvta_local_64: |
73 | case NVPTX::cvta_to_local: |
74 | case NVPTX::cvta_to_local_64: { |
75 | for (auto &U2 : MRI.use_instructions(Reg: U.operands_begin()->getReg())) |
76 | if (!traverseMoveUse(U&: U2, MRI, RemoveList, LoadInsts)) |
77 | return false; |
78 | |
79 | RemoveList.push_back(Elt: &U); |
80 | return true; |
81 | } |
82 | default: |
83 | return false; |
84 | } |
85 | } |
86 | |
87 | static bool eliminateMove(MachineInstr &Mov, const MachineRegisterInfo &MRI, |
88 | SmallVectorImpl<MachineInstr *> &RemoveList) { |
89 | SmallVector<MachineInstr *, 16> MaybeRemoveList; |
90 | SmallVector<MachineInstr *, 16> LoadInsts; |
91 | |
92 | for (auto &U : MRI.use_instructions(Reg: Mov.operands_begin()->getReg())) |
93 | if (!traverseMoveUse(U, MRI, RemoveList&: MaybeRemoveList, LoadInsts)) |
94 | return false; |
95 | |
96 | RemoveList.append(RHS: MaybeRemoveList); |
97 | RemoveList.push_back(Elt: &Mov); |
98 | |
99 | const MachineOperand *ParamSymbol = Mov.uses().begin(); |
100 | assert(ParamSymbol->isSymbol()); |
101 | |
102 | constexpr unsigned LDInstBasePtrOpIdx = 5; |
103 | constexpr unsigned LDInstAddrSpaceOpIdx = 2; |
104 | for (auto *LI : LoadInsts) { |
105 | (LI->uses().begin() + LDInstBasePtrOpIdx) |
106 | ->ChangeToES(SymName: ParamSymbol->getSymbolName()); |
107 | (LI->uses().begin() + LDInstAddrSpaceOpIdx) |
108 | ->ChangeToImmediate(ImmVal: NVPTX::AddressSpace::Param); |
109 | } |
110 | return true; |
111 | } |
112 | |
113 | static bool forwardDeviceParams(MachineFunction &MF) { |
114 | const auto &MRI = MF.getRegInfo(); |
115 | |
116 | bool Changed = false; |
117 | SmallVector<MachineInstr *, 16> RemoveList; |
118 | for (auto &MI : make_early_inc_range(Range&: *MF.begin())) |
119 | if (MI.getOpcode() == NVPTX::MOV32_PARAM || |
120 | MI.getOpcode() == NVPTX::MOV64_PARAM) |
121 | Changed |= eliminateMove(Mov&: MI, MRI, RemoveList); |
122 | |
123 | for (auto *MI : RemoveList) |
124 | MI->eraseFromParent(); |
125 | |
126 | return Changed; |
127 | } |
128 | |
129 | /// ---------------------------------------------------------------------------- |
130 | /// Pass (Manager) Boilerplate |
131 | /// ---------------------------------------------------------------------------- |
132 | |
133 | namespace { |
134 | struct NVPTXForwardParamsPass : public MachineFunctionPass { |
135 | static char ID; |
136 | NVPTXForwardParamsPass() : MachineFunctionPass(ID) {} |
137 | |
138 | bool runOnMachineFunction(MachineFunction &MF) override; |
139 | |
140 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
141 | MachineFunctionPass::getAnalysisUsage(AU); |
142 | } |
143 | }; |
144 | } // namespace |
145 | |
146 | char NVPTXForwardParamsPass::ID = 0; |
147 | |
148 | INITIALIZE_PASS(NVPTXForwardParamsPass, "nvptx-forward-params" , |
149 | "NVPTX Forward Params" , false, false) |
150 | |
151 | bool NVPTXForwardParamsPass::runOnMachineFunction(MachineFunction &MF) { |
152 | return forwardDeviceParams(MF); |
153 | } |
154 | |
155 | MachineFunctionPass *llvm::createNVPTXForwardParamsPass() { |
156 | return new NVPTXForwardParamsPass(); |
157 | } |
158 | |