1//==- RISCVPromoteConstant.cpp - Promote constant fp to global for RISC-V --==//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "RISCV.h"
10#include "RISCVSubtarget.h"
11#include "llvm/ADT/DenseMap.h"
12#include "llvm/ADT/SmallVector.h"
13#include "llvm/ADT/Statistic.h"
14#include "llvm/CodeGen/TargetLowering.h"
15#include "llvm/CodeGen/TargetPassConfig.h"
16#include "llvm/IR/BasicBlock.h"
17#include "llvm/IR/Constant.h"
18#include "llvm/IR/Constants.h"
19#include "llvm/IR/Function.h"
20#include "llvm/IR/GlobalValue.h"
21#include "llvm/IR/GlobalVariable.h"
22#include "llvm/IR/IRBuilder.h"
23#include "llvm/IR/InstIterator.h"
24#include "llvm/IR/Instruction.h"
25#include "llvm/IR/Instructions.h"
26#include "llvm/IR/IntrinsicInst.h"
27#include "llvm/IR/Module.h"
28#include "llvm/IR/Type.h"
29#include "llvm/InitializePasses.h"
30#include "llvm/Pass.h"
31#include "llvm/Support/Casting.h"
32#include "llvm/Support/Debug.h"
33
34using namespace llvm;
35
36#define DEBUG_TYPE "riscv-promote-const"
37#define RISCV_PROMOTE_CONSTANT_NAME "RISC-V Promote Constants"
38
39STATISTIC(NumPromoted, "Number of constant literals promoted to globals");
40STATISTIC(NumPromotedUses, "Number of uses of promoted literal constants");
41
42namespace {
43
44class RISCVPromoteConstant : public ModulePass {
45public:
46 static char ID;
47 RISCVPromoteConstant() : ModulePass(ID) {}
48
49 StringRef getPassName() const override { return RISCV_PROMOTE_CONSTANT_NAME; }
50
51 void getAnalysisUsage(AnalysisUsage &AU) const override {
52 AU.addRequired<TargetPassConfig>();
53 AU.setPreservesCFG();
54 }
55
56 /// Iterate over the functions and promote the double fp constants that
57 /// would otherwise go into the constant pool to a constant array.
58 bool runOnModule(Module &M) override {
59 if (skipModule(M))
60 return false;
61 // TargetMachine and Subtarget are needed to query isFPImmlegal.
62 const TargetPassConfig &TPC = getAnalysis<TargetPassConfig>();
63 const TargetMachine &TM = TPC.getTM<TargetMachine>();
64 bool Changed = false;
65 for (Function &F : M) {
66 const RISCVSubtarget &ST = TM.getSubtarget<RISCVSubtarget>(F);
67 const RISCVTargetLowering *TLI = ST.getTargetLowering();
68 Changed |= runOnFunction(F, TLI);
69 }
70 return Changed;
71 }
72
73private:
74 bool runOnFunction(Function &F, const RISCVTargetLowering *TLI);
75};
76} // end anonymous namespace
77
78char RISCVPromoteConstant::ID = 0;
79
80INITIALIZE_PASS(RISCVPromoteConstant, DEBUG_TYPE, RISCV_PROMOTE_CONSTANT_NAME,
81 false, false)
82
83ModulePass *llvm::createRISCVPromoteConstantPass() {
84 return new RISCVPromoteConstant();
85}
86
87bool RISCVPromoteConstant::runOnFunction(Function &F,
88 const RISCVTargetLowering *TLI) {
89 if (F.hasOptNone() || F.hasOptSize())
90 return false;
91
92 // Bail out and make no transformation if the target doesn't support
93 // doubles, or if we're not targeting RV64 as we currently see some
94 // regressions for those targets.
95 if (!TLI->isTypeLegal(VT: MVT::f64) || !TLI->isTypeLegal(VT: MVT::i64))
96 return false;
97
98 // Collect all unique double constants and their uses in the function. Use
99 // MapVector to preserve insertion order.
100 MapVector<ConstantFP *, SmallVector<Use *, 8>> ConstUsesMap;
101
102 for (Instruction &I : instructions(F)) {
103 for (Use &U : I.operands()) {
104 auto *C = dyn_cast<ConstantFP>(Val: U.get());
105 if (!C || !C->getType()->isDoubleTy())
106 continue;
107 // Do not promote if it wouldn't be loaded from the constant pool.
108 if (TLI->isFPImmLegal(Imm: C->getValueAPF(), VT: MVT::f64,
109 /*ForCodeSize=*/false))
110 continue;
111 // Do not promote a constant if it is used as an immediate argument
112 // for an intrinsic.
113 if (auto *II = dyn_cast<IntrinsicInst>(Val: U.getUser())) {
114 Function *IntrinsicFunc = II->getFunction();
115 unsigned OperandIdx = U.getOperandNo();
116 if (IntrinsicFunc && IntrinsicFunc->getAttributes().hasParamAttr(
117 ArgNo: OperandIdx, Kind: Attribute::ImmArg)) {
118 LLVM_DEBUG(dbgs() << "Skipping promotion of constant in: " << *II
119 << " because operand " << OperandIdx
120 << " must be an immediate.\n");
121 continue;
122 }
123 }
124 // Note: FP args to inline asm would be problematic if we had a
125 // constraint that required an immediate floating point operand. At the
126 // time of writing LLVM doesn't recognise such a constraint.
127 ConstUsesMap[C].push_back(Elt: &U);
128 }
129 }
130
131 int PromotableConstants = ConstUsesMap.size();
132 LLVM_DEBUG(dbgs() << "Found " << PromotableConstants
133 << " promotable constants in " << F.getName() << "\n");
134 // Bail out if no promotable constants found, or if only one is found.
135 if (PromotableConstants < 2) {
136 LLVM_DEBUG(dbgs() << "Performing no promotions as insufficient promotable "
137 "constants found\n");
138 return false;
139 }
140
141 NumPromoted += PromotableConstants;
142
143 // Create a global array containing the promoted constants.
144 Module *M = F.getParent();
145 Type *DoubleTy = Type::getDoubleTy(C&: M->getContext());
146
147 SmallVector<Constant *, 16> ConstantVector;
148 for (auto const &Pair : ConstUsesMap)
149 ConstantVector.push_back(Elt: Pair.first);
150
151 ArrayType *ArrayTy = ArrayType::get(ElementType: DoubleTy, NumElements: ConstantVector.size());
152 Constant *GlobalArrayInitializer =
153 ConstantArray::get(T: ArrayTy, V: ConstantVector);
154
155 auto *GlobalArray = new GlobalVariable(
156 *M, ArrayTy,
157 /*isConstant=*/true, GlobalValue::InternalLinkage, GlobalArrayInitializer,
158 ".promoted_doubles." + F.getName());
159
160 // A cache to hold the loaded value for a given constant within a basic block.
161 DenseMap<std::pair<ConstantFP *, BasicBlock *>, Value *> LocalLoads;
162
163 // Replace all uses with the loaded value.
164 unsigned Idx = 0;
165 for (auto const &Pair : ConstUsesMap) {
166 ConstantFP *Const = Pair.first;
167 const SmallVector<Use *, 8> &Uses = Pair.second;
168
169 for (Use *U : Uses) {
170 Instruction *UserInst = cast<Instruction>(Val: U->getUser());
171 BasicBlock *InsertionBB;
172
173 // If the user is a PHI node, we must insert the load in the
174 // corresponding predecessor basic block. Otherwise, it's inserted into
175 // the same block as the use.
176 if (auto *PN = dyn_cast<PHINode>(Val: UserInst))
177 InsertionBB = PN->getIncomingBlock(U: *U);
178 else
179 InsertionBB = UserInst->getParent();
180
181 if (isa<CatchSwitchInst>(Val: InsertionBB->getTerminator())) {
182 LLVM_DEBUG(dbgs() << "Bailing out: catchswitch means thre is no valid "
183 "insertion point.\n");
184 return false;
185 }
186
187 auto CacheKey = std::make_pair(x&: Const, y&: InsertionBB);
188 Value *LoadedVal = nullptr;
189
190 // Re-use a load if it exists in the insertion block.
191 if (LocalLoads.count(Val: CacheKey)) {
192 LoadedVal = LocalLoads.at(Val: CacheKey);
193 } else {
194 // Otherwise, create a new GEP and Load at the correct insertion point.
195 // It is always safe to insert in the first insertion point in the BB,
196 // so do that and let other passes reorder.
197 IRBuilder<> Builder(InsertionBB, InsertionBB->getFirstInsertionPt());
198 Value *ElementPtr = Builder.CreateConstInBoundsGEP2_64(
199 Ty: GlobalArray->getValueType(), Ptr: GlobalArray, Idx0: 0, Idx1: Idx, Name: "double.addr");
200 LoadedVal = Builder.CreateLoad(Ty: DoubleTy, Ptr: ElementPtr, Name: "double.val");
201
202 // Cache the newly created load for this block.
203 LocalLoads[CacheKey] = LoadedVal;
204 }
205
206 U->set(LoadedVal);
207 ++NumPromotedUses;
208 }
209 ++Idx;
210 }
211
212 return true;
213}
214