1//===-- SPIRVRegularizer.cpp - regularize IR for SPIR-V ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements regularization of LLVM IR for SPIR-V. The prototype of
10// the pass was taken from SPIRV-LLVM translator.
11//
12//===----------------------------------------------------------------------===//
13
14#include "SPIRV.h"
15#include "llvm/IR/Constants.h"
16#include "llvm/IR/IRBuilder.h"
17#include "llvm/IR/InstIterator.h"
18#include "llvm/IR/Instructions.h"
19#include "llvm/IR/PassManager.h"
20
21#include <list>
22
23#define DEBUG_TYPE "spirv-regularizer"
24
25using namespace llvm;
26
27namespace {
28struct SPIRVRegularizer : public FunctionPass {
29public:
30 static char ID;
31 SPIRVRegularizer() : FunctionPass(ID) {}
32 bool runOnFunction(Function &F) override;
33 StringRef getPassName() const override { return "SPIR-V Regularizer"; }
34
35 void getAnalysisUsage(AnalysisUsage &AU) const override {
36 FunctionPass::getAnalysisUsage(AU);
37 }
38
39private:
40 void runLowerConstExpr(Function &F);
41 void runLowerI1Comparisons(Function &F);
42};
43} // namespace
44
45char SPIRVRegularizer::ID = 0;
46
47INITIALIZE_PASS(SPIRVRegularizer, DEBUG_TYPE, "SPIR-V Regularizer", false,
48 false)
49
50// Since SPIR-V cannot represent constant expression, constant expressions
51// in LLVM IR need to be lowered to instructions. For each function,
52// the constant expressions used by instructions of the function are replaced
53// by instructions placed in the entry block since it dominates all other BBs.
54// Each constant expression only needs to be lowered once in each function
55// and all uses of it by instructions in that function are replaced by
56// one instruction.
57// TODO: remove redundant instructions for common subexpression.
58void SPIRVRegularizer::runLowerConstExpr(Function &F) {
59 LLVMContext &Ctx = F.getContext();
60 std::list<Instruction *> WorkList;
61 for (auto &II : instructions(F))
62 WorkList.push_back(x: &II);
63
64 auto FBegin = F.begin();
65 while (!WorkList.empty()) {
66 Instruction *II = WorkList.front();
67
68 auto LowerOp = [&II, &FBegin, &F](Value *V) -> Value * {
69 if (isa<Function>(Val: V))
70 return V;
71 auto *CE = cast<ConstantExpr>(Val: V);
72 LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] " << *CE);
73 auto ReplInst = CE->getAsInstruction();
74 auto InsPoint = II->getParent() == &*FBegin ? II : &FBegin->back();
75 ReplInst->insertBefore(InsertPos: InsPoint->getIterator());
76 LLVM_DEBUG(dbgs() << " -> " << *ReplInst << '\n');
77 std::vector<Instruction *> Users;
78 // Do not replace use during iteration of use. Do it in another loop.
79 for (auto U : CE->users()) {
80 LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] Use: " << *U << '\n');
81 auto InstUser = dyn_cast<Instruction>(Val: U);
82 // Only replace users in scope of current function.
83 if (InstUser && InstUser->getParent()->getParent() == &F)
84 Users.push_back(x: InstUser);
85 }
86 for (auto &User : Users) {
87 if (ReplInst->getParent() == User->getParent() &&
88 User->comesBefore(Other: ReplInst))
89 ReplInst->moveBefore(InsertPos: User->getIterator());
90 User->replaceUsesOfWith(From: CE, To: ReplInst);
91 }
92 return ReplInst;
93 };
94
95 WorkList.pop_front();
96 auto LowerConstantVec = [&II, &LowerOp, &WorkList,
97 &Ctx](ConstantVector *Vec,
98 unsigned NumOfOp) -> Value * {
99 if (std::all_of(first: Vec->op_begin(), last: Vec->op_end(), pred: [](Value *V) {
100 return isa<ConstantExpr>(Val: V) || isa<Function>(Val: V);
101 })) {
102 // Expand a vector of constexprs and construct it back with
103 // series of insertelement instructions.
104 std::list<Value *> OpList;
105 std::transform(first: Vec->op_begin(), last: Vec->op_end(),
106 result: std::back_inserter(x&: OpList),
107 unary_op: [LowerOp](Value *V) { return LowerOp(V); });
108 Value *Repl = nullptr;
109 unsigned Idx = 0;
110 auto *PhiII = dyn_cast<PHINode>(Val: II);
111 Instruction *InsPoint =
112 PhiII ? &PhiII->getIncomingBlock(i: NumOfOp)->back() : II;
113 std::list<Instruction *> ReplList;
114 for (auto V : OpList) {
115 if (auto *Inst = dyn_cast<Instruction>(Val: V))
116 ReplList.push_back(x: Inst);
117 Repl = InsertElementInst::Create(
118 Vec: (Repl ? Repl : PoisonValue::get(T: Vec->getType())), NewElt: V,
119 Idx: ConstantInt::get(Ty: Type::getInt32Ty(C&: Ctx), V: Idx++), NameStr: "",
120 InsertBefore: InsPoint->getIterator());
121 }
122 WorkList.splice(position: WorkList.begin(), x&: ReplList);
123 return Repl;
124 }
125 return nullptr;
126 };
127 for (unsigned OI = 0, OE = II->getNumOperands(); OI != OE; ++OI) {
128 auto *Op = II->getOperand(i: OI);
129 if (auto *Vec = dyn_cast<ConstantVector>(Val: Op)) {
130 Value *ReplInst = LowerConstantVec(Vec, OI);
131 if (ReplInst)
132 II->replaceUsesOfWith(From: Op, To: ReplInst);
133 } else if (auto CE = dyn_cast<ConstantExpr>(Val: Op)) {
134 WorkList.push_front(x: cast<Instruction>(Val: LowerOp(CE)));
135 } else if (auto MDAsVal = dyn_cast<MetadataAsValue>(Val: Op)) {
136 auto ConstMD = dyn_cast<ConstantAsMetadata>(Val: MDAsVal->getMetadata());
137 if (!ConstMD)
138 continue;
139 Constant *C = ConstMD->getValue();
140 Value *ReplInst = nullptr;
141 if (auto *Vec = dyn_cast<ConstantVector>(Val: C))
142 ReplInst = LowerConstantVec(Vec, OI);
143 if (auto *CE = dyn_cast<ConstantExpr>(Val: C))
144 ReplInst = LowerOp(CE);
145 if (!ReplInst)
146 continue;
147 Metadata *RepMD = ValueAsMetadata::get(V: ReplInst);
148 Value *RepMDVal = MetadataAsValue::get(Context&: Ctx, MD: RepMD);
149 II->setOperand(i: OI, Val: RepMDVal);
150 WorkList.push_front(x: cast<Instruction>(Val: ReplInst));
151 }
152 }
153 }
154}
155
156// Lower i1 comparisons with certain predicates to logical operations.
157// The backend treats i1 as boolean values, and SPIR-V only allows logical
158// operations for boolean values. This function lowers i1 comparisons with
159// certain predicates to logical operations to generate valid SPIR-V.
160void SPIRVRegularizer::runLowerI1Comparisons(Function &F) {
161 for (auto &I : make_early_inc_range(Range: instructions(F))) {
162 auto *Cmp = dyn_cast<ICmpInst>(Val: &I);
163 if (!Cmp)
164 continue;
165
166 bool IsI1 = Cmp->getOperand(i_nocapture: 0)->getType()->isIntegerTy(Bitwidth: 1);
167 if (!IsI1)
168 continue;
169
170 auto Pred = Cmp->getPredicate();
171 bool IsTargetPred =
172 Pred >= ICmpInst::ICMP_UGT && Pred <= ICmpInst::ICMP_SLE;
173 if (!IsTargetPred)
174 continue;
175
176 Value *P = Cmp->getOperand(i_nocapture: 0);
177 Value *Q = Cmp->getOperand(i_nocapture: 1);
178
179 IRBuilder<> Builder(Cmp);
180 Value *Result = nullptr;
181 switch (Pred) {
182 case ICmpInst::ICMP_UGT:
183 case ICmpInst::ICMP_SLT:
184 // Result = p & !q
185 Result = Builder.CreateAnd(LHS: P, RHS: Builder.CreateNot(V: Q));
186 break;
187 case ICmpInst::ICMP_ULT:
188 case ICmpInst::ICMP_SGT:
189 // Result = q & !p
190 Result = Builder.CreateAnd(LHS: Q, RHS: Builder.CreateNot(V: P));
191 break;
192 case ICmpInst::ICMP_ULE:
193 case ICmpInst::ICMP_SGE:
194 // Result = q | !p
195 Result = Builder.CreateOr(LHS: Q, RHS: Builder.CreateNot(V: P));
196 break;
197 case ICmpInst::ICMP_UGE:
198 case ICmpInst::ICMP_SLE:
199 // Result = p | !q
200 Result = Builder.CreateOr(LHS: P, RHS: Builder.CreateNot(V: Q));
201 break;
202 default:
203 llvm_unreachable("Unexpected predicate");
204 }
205
206 Result->takeName(V: Cmp);
207 Cmp->replaceAllUsesWith(V: Result);
208 Cmp->eraseFromParent();
209 }
210}
211
212bool SPIRVRegularizer::runOnFunction(Function &F) {
213 runLowerI1Comparisons(F);
214 runLowerConstExpr(F);
215 return true;
216}
217
218FunctionPass *llvm::createSPIRVRegularizerPass() {
219 return new SPIRVRegularizer();
220}
221