1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "ReduceOperands.h"
10#include "llvm/IR/Constants.h"
11#include "llvm/IR/InstIterator.h"
12#include "llvm/IR/InstrTypes.h"
13#include "llvm/IR/Operator.h"
14#include "llvm/IR/PatternMatch.h"
15#include "llvm/IR/Type.h"
16
17using namespace llvm;
18using namespace PatternMatch;
19
20static void
21extractOperandsFromModule(Oracle &O, ReducerWorkItem &WorkItem,
22 function_ref<Value *(Use &)> ReduceValue) {
23 Module &Program = WorkItem.getModule();
24
25 for (auto &F : Program.functions()) {
26 for (auto &I : instructions(F: &F)) {
27 if (PHINode *Phi = dyn_cast<PHINode>(Val: &I)) {
28 for (auto &Op : Phi->incoming_values()) {
29 if (Value *Reduced = ReduceValue(Op)) {
30 if (!O.shouldKeep())
31 Phi->setIncomingValueForBlock(BB: Phi->getIncomingBlock(U: Op), V: Reduced);
32 }
33 }
34
35 continue;
36 }
37
38 for (auto &Op : I.operands()) {
39 if (Value *Reduced = ReduceValue(Op)) {
40 if (!O.shouldKeep())
41 Op.set(Reduced);
42 }
43 }
44 }
45 }
46}
47
48static bool isOne(Use &Op) {
49 auto *C = dyn_cast<Constant>(Val&: Op);
50 return C && C->isOneValue();
51}
52
53static bool isZero(Use &Op) {
54 auto *C = dyn_cast<Constant>(Val&: Op);
55 return C && C->isNullValue();
56}
57
58static bool isZeroOrOneFP(Value *Op) {
59 const APFloat *C;
60 return match(V: Op, P: m_APFloat(Res&: C)) &&
61 ((C->isZero() && !C->isNegative()) || C->isExactlyValue(V: 1.0));
62}
63
64static bool shouldReduceOperand(Use &Op) {
65 Type *Ty = Op->getType();
66 if (Ty->isLabelTy() || Ty->isMetadataTy())
67 return false;
68 // TODO: be more precise about which GEP operands we can reduce (e.g. array
69 // indexes)
70 if (isa<GEPOperator>(Val: Op.getUser()))
71 return false;
72 if (auto *CB = dyn_cast<CallBase>(Val: Op.getUser())) {
73 if (&CB->getCalledOperandUse() == &Op)
74 return false;
75 }
76 return true;
77}
78
79static bool switchCaseExists(Use &Op, ConstantInt *CI) {
80 SwitchInst *SI = dyn_cast<SwitchInst>(Val: Op.getUser());
81 if (!SI)
82 return false;
83 return SI->findCaseValue(C: CI) != SI->case_default();
84}
85
86void llvm::reduceOperandsOneDeltaPass(Oracle &O, ReducerWorkItem &WorkItem) {
87 auto ReduceValue = [](Use &Op) -> Value * {
88 if (!shouldReduceOperand(Op))
89 return nullptr;
90
91 Type *Ty = Op->getType();
92 if (auto *IntTy = dyn_cast<IntegerType>(Val: Ty)) {
93 // Don't duplicate an existing switch case.
94 if (switchCaseExists(Op, CI: ConstantInt::get(Ty: IntTy, V: 1)))
95 return nullptr;
96 // Don't replace existing ones and zeroes.
97 return (isOne(Op) || isZero(Op)) ? nullptr : ConstantInt::get(Ty: IntTy, V: 1);
98 }
99
100 if (Ty->isFloatingPointTy())
101 return isZeroOrOneFP(Op) ? nullptr : ConstantFP::get(Ty, V: 1.0);
102
103 if (VectorType *VT = dyn_cast<VectorType>(Val: Ty)) {
104 if (isOne(Op) || isZero(Op) || isZeroOrOneFP(Op))
105 return nullptr;
106
107 Type *ElementType = VT->getElementType();
108 Constant *C;
109 if (ElementType->isFloatingPointTy()) {
110 C = ConstantFP::get(Ty: ElementType, V: 1.0);
111 } else if (IntegerType *IntTy = dyn_cast<IntegerType>(Val: ElementType)) {
112 C = ConstantInt::get(Ty: IntTy, V: 1);
113 } else {
114 return nullptr;
115 }
116 return ConstantVector::getSplat(EC: VT->getElementCount(), Elt: C);
117 }
118
119 return nullptr;
120 };
121 extractOperandsFromModule(O, WorkItem, ReduceValue);
122}
123
124void llvm::reduceOperandsZeroDeltaPass(Oracle &O, ReducerWorkItem &WorkItem) {
125 auto ReduceValue = [](Use &Op) -> Value * {
126 if (!shouldReduceOperand(Op))
127 return nullptr;
128
129 // Avoid introducing 0-sized allocations.
130 if (isa<AllocaInst>(Val: Op.getUser()))
131 return nullptr;
132
133 // Don't duplicate an existing switch case.
134 if (auto *IntTy = dyn_cast<IntegerType>(Val: Op->getType()))
135 if (switchCaseExists(Op, CI: ConstantInt::get(Ty: IntTy, V: 0)))
136 return nullptr;
137
138 if (auto *TET = dyn_cast<TargetExtType>(Val: Op->getType())) {
139 if (isa<ConstantTargetNone, PoisonValue>(Val: Op))
140 return nullptr;
141 if (TET->hasProperty(Prop: TargetExtType::HasZeroInit))
142 return ConstantTargetNone::get(T: TET);
143 return nullptr;
144 }
145
146 // Don't replace existing zeroes.
147 return isZero(Op) ? nullptr : Constant::getNullValue(Ty: Op->getType());
148 };
149 extractOperandsFromModule(O, WorkItem, ReduceValue);
150}
151
152void llvm::reduceOperandsNaNDeltaPass(Oracle &O, ReducerWorkItem &WorkItem) {
153 auto ReduceValue = [](Use &Op) -> Value * {
154 Type *Ty = Op->getType();
155 if (!Ty->isFPOrFPVectorTy())
156 return nullptr;
157
158 // Prefer 0.0 or 1.0 over NaN.
159 //
160 // TODO: Preferring NaN may make more sense because FP operations are more
161 // universally foldable.
162 if (match(V: Op.get(), P: m_NaN()) || isZeroOrOneFP(Op: Op.get()))
163 return nullptr;
164
165 if (VectorType *VT = dyn_cast<VectorType>(Val: Ty)) {
166 return ConstantVector::getSplat(EC: VT->getElementCount(),
167 Elt: ConstantFP::getQNaN(Ty: VT->getElementType()));
168 }
169
170 return ConstantFP::getQNaN(Ty);
171 };
172 extractOperandsFromModule(O, WorkItem, ReduceValue);
173}
174
175void llvm::reduceOperandsPoisonDeltaPass(Oracle &O, ReducerWorkItem &WorkItem) {
176 auto ReduceValue = [](Use &Op) -> Value * {
177 Type *Ty = Op->getType();
178 if (auto *TET = dyn_cast<TargetExtType>(Val: Ty)) {
179 if (isa<ConstantTargetNone, PoisonValue>(Val: Op))
180 return nullptr;
181 return PoisonValue::get(T: TET);
182 }
183
184 return nullptr;
185 };
186
187 extractOperandsFromModule(O, WorkItem, ReduceValue);
188}
189