1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "ReduceOperands.h"
10#include "llvm/IR/Constants.h"
11#include "llvm/IR/InstIterator.h"
12#include "llvm/IR/InstrTypes.h"
13#include "llvm/IR/Operator.h"
14#include "llvm/IR/PatternMatch.h"
15#include "llvm/IR/Type.h"
16
17using namespace llvm;
18using namespace PatternMatch;
19
20static void
21extractOperandsFromModule(Oracle &O, ReducerWorkItem &WorkItem,
22 function_ref<Value *(Use &)> ReduceValue) {
23 Module &Program = WorkItem.getModule();
24
25 for (auto &F : Program.functions()) {
26 for (auto &I : instructions(F: &F)) {
27 if (PHINode *Phi = dyn_cast<PHINode>(Val: &I)) {
28 for (auto &Op : Phi->incoming_values()) {
29 if (!O.shouldKeep()) {
30 if (Value *Reduced = ReduceValue(Op))
31 Phi->setIncomingValueForBlock(BB: Phi->getIncomingBlock(U: Op), V: Reduced);
32 }
33 }
34
35 continue;
36 }
37
38 for (auto &Op : I.operands()) {
39 if (Value *Reduced = ReduceValue(Op)) {
40 if (!O.shouldKeep())
41 Op.set(Reduced);
42 }
43 }
44 }
45 }
46}
47
48static bool isOne(Use &Op) {
49 auto *C = dyn_cast<Constant>(Val&: Op);
50 return C && C->isOneValue();
51}
52
53static bool isZero(Use &Op) {
54 auto *C = dyn_cast<Constant>(Val&: Op);
55 return C && C->isNullValue();
56}
57
58static bool isZeroOrOneFP(Value *Op) {
59 const APFloat *C;
60 return match(V: Op, P: m_APFloat(Res&: C)) &&
61 ((C->isZero() && !C->isNegative()) || C->isExactlyValue(V: 1.0));
62}
63
64static bool shouldReduceOperand(Use &Op) {
65 Type *Ty = Op->getType();
66 if (Ty->isLabelTy() || Ty->isMetadataTy())
67 return false;
68 // TODO: be more precise about which GEP operands we can reduce (e.g. array
69 // indexes)
70 if (isa<GEPOperator>(Val: Op.getUser()))
71 return false;
72 if (auto *CB = dyn_cast<CallBase>(Val: Op.getUser())) {
73 if (&CB->getCalledOperandUse() == &Op)
74 return false;
75 }
76 return true;
77}
78
79static bool switchCaseExists(Use &Op, ConstantInt *CI) {
80 SwitchInst *SI = dyn_cast<SwitchInst>(Val: Op.getUser());
81 if (!SI)
82 return false;
83 return SI->findCaseValue(C: CI) != SI->case_default();
84}
85
86void llvm::reduceOperandsOneDeltaPass(TestRunner &Test) {
87 auto ReduceValue = [](Use &Op) -> Value * {
88 if (!shouldReduceOperand(Op))
89 return nullptr;
90
91 Type *Ty = Op->getType();
92 if (auto *IntTy = dyn_cast<IntegerType>(Val: Ty)) {
93 // Don't duplicate an existing switch case.
94 if (switchCaseExists(Op, CI: ConstantInt::get(Ty: IntTy, V: 1)))
95 return nullptr;
96 // Don't replace existing ones and zeroes.
97 return (isOne(Op) || isZero(Op)) ? nullptr : ConstantInt::get(Ty: IntTy, V: 1);
98 }
99
100 if (Ty->isFloatingPointTy())
101 return isZeroOrOneFP(Op) ? nullptr : ConstantFP::get(Ty, V: 1.0);
102
103 if (VectorType *VT = dyn_cast<VectorType>(Val: Ty)) {
104 if (isOne(Op) || isZero(Op) || isZeroOrOneFP(Op))
105 return nullptr;
106
107 Type *ElementType = VT->getElementType();
108 Constant *C;
109 if (ElementType->isFloatingPointTy()) {
110 C = ConstantFP::get(Ty: ElementType, V: 1.0);
111 } else if (IntegerType *IntTy = dyn_cast<IntegerType>(Val: ElementType)) {
112 C = ConstantInt::get(Ty: IntTy, V: 1);
113 } else {
114 return nullptr;
115 }
116 return ConstantVector::getSplat(EC: VT->getElementCount(), Elt: C);
117 }
118
119 return nullptr;
120 };
121 runDeltaPass(
122 Test,
123 ExtractChunksFromModule: [ReduceValue](Oracle &O, ReducerWorkItem &WorkItem) {
124 extractOperandsFromModule(O, WorkItem, ReduceValue);
125 },
126 Message: "Reducing Operands to one");
127}
128
129void llvm::reduceOperandsZeroDeltaPass(TestRunner &Test) {
130 auto ReduceValue = [](Use &Op) -> Value * {
131 if (!shouldReduceOperand(Op))
132 return nullptr;
133 // Don't duplicate an existing switch case.
134 if (auto *IntTy = dyn_cast<IntegerType>(Val: Op->getType()))
135 if (switchCaseExists(Op, CI: ConstantInt::get(Ty: IntTy, V: 0)))
136 return nullptr;
137 // Don't replace existing zeroes.
138 return isZero(Op) ? nullptr : Constant::getNullValue(Ty: Op->getType());
139 };
140 runDeltaPass(
141 Test,
142 ExtractChunksFromModule: [ReduceValue](Oracle &O, ReducerWorkItem &Program) {
143 extractOperandsFromModule(O, WorkItem&: Program, ReduceValue);
144 },
145 Message: "Reducing Operands to zero");
146}
147
148void llvm::reduceOperandsNaNDeltaPass(TestRunner &Test) {
149 auto ReduceValue = [](Use &Op) -> Value * {
150 Type *Ty = Op->getType();
151 if (!Ty->isFPOrFPVectorTy())
152 return nullptr;
153
154 // Prefer 0.0 or 1.0 over NaN.
155 //
156 // TODO: Preferring NaN may make more sense because FP operations are more
157 // universally foldable.
158 if (match(V: Op.get(), P: m_NaN()) || isZeroOrOneFP(Op: Op.get()))
159 return nullptr;
160
161 if (VectorType *VT = dyn_cast<VectorType>(Val: Ty)) {
162 return ConstantVector::getSplat(EC: VT->getElementCount(),
163 Elt: ConstantFP::getQNaN(Ty: VT->getElementType()));
164 }
165
166 return ConstantFP::getQNaN(Ty);
167 };
168 runDeltaPass(
169 Test,
170 ExtractChunksFromModule: [ReduceValue](Oracle &O, ReducerWorkItem &Program) {
171 extractOperandsFromModule(O, WorkItem&: Program, ReduceValue);
172 },
173 Message: "Reducing Operands to NaN");
174}
175