1 | //===--- ExpandLargeDivRem.cpp - Expand large div/rem ---------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This pass expands div/rem instructions with a bitwidth above a threshold |
10 | // into a call to auto-generated functions. |
11 | // This is useful for targets like x86_64 that cannot lower divisions |
12 | // with more than 128 bits or targets like x86_32 that cannot lower divisions |
13 | // with more than 64 bits. |
14 | // |
15 | //===----------------------------------------------------------------------===// |
16 | |
17 | #include "llvm/CodeGen/ExpandLargeDivRem.h" |
18 | #include "llvm/ADT/SmallVector.h" |
19 | #include "llvm/ADT/StringExtras.h" |
20 | #include "llvm/Analysis/GlobalsModRef.h" |
21 | #include "llvm/CodeGen/Passes.h" |
22 | #include "llvm/CodeGen/TargetLowering.h" |
23 | #include "llvm/CodeGen/TargetPassConfig.h" |
24 | #include "llvm/CodeGen/TargetSubtargetInfo.h" |
25 | #include "llvm/IR/IRBuilder.h" |
26 | #include "llvm/IR/InstIterator.h" |
27 | #include "llvm/IR/PassManager.h" |
28 | #include "llvm/InitializePasses.h" |
29 | #include "llvm/Pass.h" |
30 | #include "llvm/Support/CommandLine.h" |
31 | #include "llvm/Target/TargetMachine.h" |
32 | #include "llvm/Transforms/Utils/IntegerDivision.h" |
33 | |
34 | using namespace llvm; |
35 | |
36 | static cl::opt<unsigned> |
37 | ExpandDivRemBits("expand-div-rem-bits" , cl::Hidden, |
38 | cl::init(Val: llvm::IntegerType::MAX_INT_BITS), |
39 | cl::desc("div and rem instructions on integers with " |
40 | "more than <N> bits are expanded." )); |
41 | |
42 | static bool isConstantPowerOfTwo(llvm::Value *V, bool SignedOp) { |
43 | auto *C = dyn_cast<ConstantInt>(Val: V); |
44 | if (!C) |
45 | return false; |
46 | |
47 | APInt Val = C->getValue(); |
48 | if (SignedOp && Val.isNegative()) |
49 | Val = -Val; |
50 | return Val.isPowerOf2(); |
51 | } |
52 | |
53 | static bool isSigned(unsigned int Opcode) { |
54 | return Opcode == Instruction::SDiv || Opcode == Instruction::SRem; |
55 | } |
56 | |
57 | static void scalarize(BinaryOperator *BO, |
58 | SmallVectorImpl<BinaryOperator *> &Replace) { |
59 | VectorType *VTy = cast<FixedVectorType>(Val: BO->getType()); |
60 | |
61 | IRBuilder<> Builder(BO); |
62 | |
63 | unsigned NumElements = VTy->getElementCount().getFixedValue(); |
64 | Value *Result = PoisonValue::get(T: VTy); |
65 | for (unsigned Idx = 0; Idx < NumElements; ++Idx) { |
66 | Value *LHS = Builder.CreateExtractElement(Vec: BO->getOperand(i_nocapture: 0), Idx); |
67 | Value *RHS = Builder.CreateExtractElement(Vec: BO->getOperand(i_nocapture: 1), Idx); |
68 | Value *Op = Builder.CreateBinOp(Opc: BO->getOpcode(), LHS, RHS); |
69 | Result = Builder.CreateInsertElement(Vec: Result, NewElt: Op, Idx); |
70 | if (auto *NewBO = dyn_cast<BinaryOperator>(Val: Op)) { |
71 | NewBO->copyIRFlags(V: Op, IncludeWrapFlags: true); |
72 | Replace.push_back(Elt: NewBO); |
73 | } |
74 | } |
75 | BO->replaceAllUsesWith(V: Result); |
76 | BO->dropAllReferences(); |
77 | BO->eraseFromParent(); |
78 | } |
79 | |
80 | static bool runImpl(Function &F, const TargetLowering &TLI) { |
81 | SmallVector<BinaryOperator *, 4> Replace; |
82 | SmallVector<BinaryOperator *, 4> ReplaceVector; |
83 | bool Modified = false; |
84 | |
85 | unsigned MaxLegalDivRemBitWidth = TLI.getMaxDivRemBitWidthSupported(); |
86 | if (ExpandDivRemBits != llvm::IntegerType::MAX_INT_BITS) |
87 | MaxLegalDivRemBitWidth = ExpandDivRemBits; |
88 | |
89 | if (MaxLegalDivRemBitWidth >= llvm::IntegerType::MAX_INT_BITS) |
90 | return false; |
91 | |
92 | for (auto &I : instructions(F)) { |
93 | switch (I.getOpcode()) { |
94 | case Instruction::UDiv: |
95 | case Instruction::SDiv: |
96 | case Instruction::URem: |
97 | case Instruction::SRem: { |
98 | // TODO: This pass doesn't handle scalable vectors. |
99 | if (I.getOperand(i: 0)->getType()->isScalableTy()) |
100 | continue; |
101 | |
102 | auto *IntTy = dyn_cast<IntegerType>(Val: I.getType()->getScalarType()); |
103 | if (!IntTy || IntTy->getIntegerBitWidth() <= MaxLegalDivRemBitWidth) |
104 | continue; |
105 | |
106 | // The backend has peephole optimizations for powers of two. |
107 | // TODO: We don't consider vectors here. |
108 | if (isConstantPowerOfTwo(V: I.getOperand(i: 1), SignedOp: isSigned(Opcode: I.getOpcode()))) |
109 | continue; |
110 | |
111 | if (I.getOperand(i: 0)->getType()->isVectorTy()) |
112 | ReplaceVector.push_back(Elt: &cast<BinaryOperator>(Val&: I)); |
113 | else |
114 | Replace.push_back(Elt: &cast<BinaryOperator>(Val&: I)); |
115 | Modified = true; |
116 | break; |
117 | } |
118 | default: |
119 | break; |
120 | } |
121 | } |
122 | |
123 | while (!ReplaceVector.empty()) { |
124 | BinaryOperator *BO = ReplaceVector.pop_back_val(); |
125 | scalarize(BO, Replace); |
126 | } |
127 | |
128 | if (Replace.empty()) |
129 | return false; |
130 | |
131 | while (!Replace.empty()) { |
132 | BinaryOperator *I = Replace.pop_back_val(); |
133 | |
134 | if (I->getOpcode() == Instruction::UDiv || |
135 | I->getOpcode() == Instruction::SDiv) { |
136 | expandDivision(Div: I); |
137 | } else { |
138 | expandRemainder(Rem: I); |
139 | } |
140 | } |
141 | |
142 | return Modified; |
143 | } |
144 | |
145 | namespace { |
146 | class ExpandLargeDivRemLegacyPass : public FunctionPass { |
147 | public: |
148 | static char ID; |
149 | |
150 | ExpandLargeDivRemLegacyPass() : FunctionPass(ID) { |
151 | initializeExpandLargeDivRemLegacyPassPass(*PassRegistry::getPassRegistry()); |
152 | } |
153 | |
154 | bool runOnFunction(Function &F) override { |
155 | auto *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>(); |
156 | auto *TLI = TM->getSubtargetImpl(F)->getTargetLowering(); |
157 | return runImpl(F, TLI: *TLI); |
158 | } |
159 | |
160 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
161 | AU.addRequired<TargetPassConfig>(); |
162 | AU.addPreserved<AAResultsWrapperPass>(); |
163 | AU.addPreserved<GlobalsAAWrapperPass>(); |
164 | } |
165 | }; |
166 | } // namespace |
167 | |
168 | PreservedAnalyses ExpandLargeDivRemPass::run(Function &F, |
169 | FunctionAnalysisManager &FAM) { |
170 | const TargetSubtargetInfo *STI = TM->getSubtargetImpl(F); |
171 | return runImpl(F, TLI: *STI->getTargetLowering()) ? PreservedAnalyses::none() |
172 | : PreservedAnalyses::all(); |
173 | } |
174 | |
175 | char ExpandLargeDivRemLegacyPass::ID = 0; |
176 | INITIALIZE_PASS_BEGIN(ExpandLargeDivRemLegacyPass, "expand-large-div-rem" , |
177 | "Expand large div/rem" , false, false) |
178 | INITIALIZE_PASS_END(ExpandLargeDivRemLegacyPass, "expand-large-div-rem" , |
179 | "Expand large div/rem" , false, false) |
180 | |
181 | FunctionPass *llvm::createExpandLargeDivRemPass() { |
182 | return new ExpandLargeDivRemLegacyPass(); |
183 | } |
184 | |