1//===----- RISCVCodeGenPrepare.cpp ----------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This is a RISC-V specific version of CodeGenPrepare.
10// It munges the code in the input function to better prepare it for
11// SelectionDAG-based code generation. This works around limitations in it's
12// basic-block-at-a-time approach.
13//
14//===----------------------------------------------------------------------===//
15
16#include "RISCV.h"
17#include "RISCVTargetMachine.h"
18#include "llvm/ADT/Statistic.h"
19#include "llvm/Analysis/ValueTracking.h"
20#include "llvm/CodeGen/TargetPassConfig.h"
21#include "llvm/IR/Dominators.h"
22#include "llvm/IR/IRBuilder.h"
23#include "llvm/IR/InstVisitor.h"
24#include "llvm/IR/IntrinsicInst.h"
25#include "llvm/IR/Intrinsics.h"
26#include "llvm/IR/PatternMatch.h"
27#include "llvm/InitializePasses.h"
28#include "llvm/Pass.h"
29#include "llvm/Transforms/Utils/Local.h"
30
31using namespace llvm;
32
33#define DEBUG_TYPE "riscv-codegenprepare"
34#define PASS_NAME "RISC-V CodeGenPrepare"
35
36namespace {
37class RISCVCodeGenPrepare : public InstVisitor<RISCVCodeGenPrepare, bool> {
38 Function &F;
39 const DataLayout *DL;
40 const DominatorTree *DT;
41 const RISCVSubtarget *ST;
42
43public:
44 RISCVCodeGenPrepare(Function &F, const DominatorTree *DT,
45 const RISCVSubtarget *ST)
46 : F(F), DL(&F.getDataLayout()), DT(DT), ST(ST) {}
47 bool run();
48 bool visitInstruction(Instruction &I) { return false; }
49 bool visitAnd(BinaryOperator &BO);
50 bool visitIntrinsicInst(IntrinsicInst &I);
51 bool expandVPStrideLoad(IntrinsicInst &I);
52 bool expandMulReduction(IntrinsicInst &I);
53 bool widenVPMerge(Instruction *I);
54 bool visitFreezeInst(FreezeInst &BO);
55};
56} // namespace
57
58namespace {
59class RISCVCodeGenPrepareLegacyPass : public FunctionPass {
60public:
61 static char ID;
62
63 RISCVCodeGenPrepareLegacyPass() : FunctionPass(ID) {}
64
65 bool runOnFunction(Function &F) override;
66 StringRef getPassName() const override { return PASS_NAME; }
67
68 void getAnalysisUsage(AnalysisUsage &AU) const override {
69 AU.setPreservesCFG();
70 AU.addRequired<DominatorTreeWrapperPass>();
71 AU.addRequired<TargetPassConfig>();
72 }
73};
74} // namespace
75
76// Try to optimize (i64 (and (zext/sext (i32 X), C1))) if C1 has bit 31 set,
77// but bits 63:32 are zero. If we know that bit 31 of X is 0, we can fill
78// the upper 32 bits with ones.
79bool RISCVCodeGenPrepare::visitAnd(BinaryOperator &BO) {
80 if (!ST->is64Bit())
81 return false;
82
83 if (!BO.getType()->isIntegerTy(BitWidth: 64))
84 return false;
85
86 using namespace PatternMatch;
87
88 // Left hand side should be a zext nneg.
89 Value *LHSSrc;
90 if (!match(V: BO.getOperand(i_nocapture: 0), P: m_NNegZExt(Op: m_Value(V&: LHSSrc))))
91 return false;
92
93 if (!LHSSrc->getType()->isIntegerTy(BitWidth: 32))
94 return false;
95
96 // Right hand side should be a constant.
97 Value *RHS = BO.getOperand(i_nocapture: 1);
98
99 auto *CI = dyn_cast<ConstantInt>(Val: RHS);
100 if (!CI)
101 return false;
102 uint64_t C = CI->getZExtValue();
103
104 // Look for constants that fit in 32 bits but not simm12, and can be made
105 // into simm12 by sign extending bit 31. This will allow use of ANDI.
106 // TODO: Is worth making simm32?
107 if (!isUInt<32>(x: C) || isInt<12>(x: C) || !isInt<12>(x: SignExtend64<32>(x: C)))
108 return false;
109
110 // Sign extend the constant and replace the And operand.
111 C = SignExtend64<32>(x: C);
112 BO.setOperand(i_nocapture: 1, Val_nocapture: ConstantInt::get(Ty: RHS->getType(), V: C));
113
114 return true;
115}
116
117// With EVL tail folding, an AnyOf reduction will generate an i1 vp.merge like
118// follows:
119//
120// loop:
121// %phi = phi <vscale x 4 x i1> [zeroinitializer, %entry], [%freeze, %loop]
122// %cmp = icmp ...
123// %rec = call <vscale x 4 x i1> @llvm.vp.merge(%cmp, i1 true, %phi, %evl)
124// %freeze = freeze <vscale x 4 x i1> %rec [optional]
125// ...
126// middle:
127// %res = call i1 @llvm.vector.reduce.or(<vscale x 4 x i1> %freeze)
128//
129// However RVV doesn't have any tail undisturbed mask instructions and so we
130// need a convoluted sequence of mask instructions to lower the i1 vp.merge: see
131// llvm/test/CodeGen/RISCV/rvv/vpmerge-sdnode.ll.
132//
133// To avoid that this widens the i1 vp.merge to an i8 vp.merge, which will
134// generate a single vmerge.vim:
135//
136// loop:
137// %phi = phi <vscale x 4 x i8> [zeroinitializer, %entry], [%freeze, %loop]
138// %cmp = icmp ...
139// %rec = call <vscale x 4 x i8> @llvm.vp.merge(%cmp, i8 true, %phi, %evl)
140// %freeze = freeze <vscale x 4 x i8> %rec
141// %trunc = trunc <vscale x 4 x i8> %freeze to <vscale x 4 x i1>
142// ...
143// middle:
144// %res = call i1 @llvm.vector.reduce.or(<vscale x 4 x i1> %trunc)
145//
146// The trunc will normally be sunk outside of the loop, but even if there are
147// users inside the loop it is still profitable.
148bool RISCVCodeGenPrepare::widenVPMerge(Instruction *Root) {
149 if (!Root->getType()->getScalarType()->isIntegerTy(BitWidth: 1))
150 return false;
151
152 Value *Mask, *True, *PhiV, *EVL;
153 using namespace PatternMatch;
154 auto m_VPMerge = m_Intrinsic<Intrinsic::vp_merge>(
155 Op0: m_Value(V&: Mask), Op1: m_Value(V&: True), Op2: m_Value(V&: PhiV), Op3: m_Value(V&: EVL));
156 if (!match(V: Root, P: m_CombineOr(Ps: m_VPMerge, Ps: m_Freeze(Op: m_VPMerge))))
157 return false;
158
159 auto *Phi = dyn_cast<PHINode>(Val: PhiV);
160 if (!Phi || !Phi->hasOneUse() || Phi->getNumIncomingValues() != 2 ||
161 !match(V: Phi->getIncomingValue(i: 0), P: m_Zero()) ||
162 Phi->getIncomingValue(i: 1) != Root)
163 return false;
164
165 Type *WideTy =
166 VectorType::get(ElementType: IntegerType::getInt8Ty(C&: Root->getContext()),
167 EC: cast<VectorType>(Val: Root->getType())->getElementCount());
168
169 IRBuilder<> Builder(Phi);
170 PHINode *WidePhi = Builder.CreatePHI(Ty: WideTy, NumReservedValues: 2);
171 WidePhi->addIncoming(V: ConstantAggregateZero::get(Ty: WideTy),
172 BB: Phi->getIncomingBlock(i: 0));
173 Builder.SetInsertPoint(Root);
174 Value *WideTrue = Builder.CreateZExt(V: True, DestTy: WideTy);
175 Value *WideMerge = Builder.CreateIntrinsic(ID: Intrinsic::vp_merge, OverloadTypes: {WideTy},
176 Args: {Mask, WideTrue, WidePhi, EVL});
177 if (isa<FreezeInst>(Val: Root))
178 WideMerge = Builder.CreateFreeze(V: WideMerge);
179 WidePhi->addIncoming(V: WideMerge, BB: Phi->getIncomingBlock(i: 1));
180 Value *Trunc = Builder.CreateTrunc(V: WideMerge, DestTy: Root->getType());
181
182 Root->replaceAllUsesWith(V: Trunc);
183
184 // Break the cycle and delete the old chain.
185 Phi->setIncomingValue(i: 1, V: Phi->getIncomingValue(i: 0));
186 llvm::RecursivelyDeleteTriviallyDeadInstructions(V: Root);
187
188 return true;
189}
190
191bool RISCVCodeGenPrepare::visitFreezeInst(FreezeInst &I) {
192 if (auto *II = dyn_cast<IntrinsicInst>(Val: I.getOperand(i_nocapture: 0)))
193 if (II->getIntrinsicID() == Intrinsic::vp_merge)
194 return widenVPMerge(Root: &I);
195 return false;
196}
197
198// LLVM vector reduction intrinsics return a scalar result, but on RISC-V vector
199// reduction instructions write the result in the first element of a vector
200// register. So when a reduction in a loop uses a scalar phi, we end up with
201// unnecessary scalar moves:
202//
203// loop:
204// vfmv.s.f v10, fa0
205// vfredosum.vs v8, v8, v10
206// vfmv.f.s fa0, v8
207//
208// This mainly affects ordered fadd reductions and VP reductions that have a
209// scalar start value, since other types of reduction typically use element-wise
210// vectorisation in the loop body. This tries to vectorize any scalar phis that
211// feed into these reductions:
212//
213// loop:
214// %phi = phi <float> [ ..., %entry ], [ %acc, %loop ]
215// %acc = call float @llvm.vector.reduce.fadd.nxv2f32(float %phi,
216// <vscale x 2 x float> %vec)
217//
218// ->
219//
220// loop:
221// %phi = phi <vscale x 2 x float> [ ..., %entry ], [ %acc.vec, %loop ]
222// %phi.scalar = extractelement <vscale x 2 x float> %phi, i64 0
223// %acc = call float @llvm.vector.reduce.fadd.nxv2f32(float %x,
224// <vscale x 2 x float> %vec)
225// %acc.vec = insertelement <vscale x 2 x float> poison, float %acc.next, i64 0
226//
227// Which eliminates the scalar -> vector -> scalar crossing during instruction
228// selection.
229bool RISCVCodeGenPrepare::visitIntrinsicInst(IntrinsicInst &I) {
230 if (expandVPStrideLoad(I))
231 return true;
232
233 if (expandMulReduction(I))
234 return true;
235
236 if (widenVPMerge(Root: &I))
237 return true;
238
239 if (I.getIntrinsicID() != Intrinsic::vector_reduce_fadd &&
240 !isa<VPReductionIntrinsic>(Val: &I))
241 return false;
242
243 auto *PHI = dyn_cast<PHINode>(Val: I.getOperand(i_nocapture: 0));
244 if (!PHI || !PHI->hasOneUse() ||
245 !llvm::is_contained(Range: PHI->incoming_values(), Element: &I))
246 return false;
247
248 Type *VecTy = I.getOperand(i_nocapture: 1)->getType();
249 IRBuilder<> Builder(PHI);
250 auto *VecPHI = Builder.CreatePHI(Ty: VecTy, NumReservedValues: PHI->getNumIncomingValues());
251
252 for (auto *BB : PHI->blocks()) {
253 Builder.SetInsertPoint(BB->getTerminator());
254 Value *InsertElt = Builder.CreateInsertElement(
255 VecTy, NewElt: PHI->getIncomingValueForBlock(BB), Idx: (uint64_t)0);
256 VecPHI->addIncoming(V: InsertElt, BB);
257 }
258
259 Builder.SetInsertPoint(&I);
260 I.setOperand(i_nocapture: 0, Val_nocapture: Builder.CreateExtractElement(Vec: VecPHI, Idx: (uint64_t)0));
261
262 PHI->eraseFromParent();
263
264 return true;
265}
266
267// Extract pieces of size PieceEC from Vec, then build a binary tree of
268// element-wise multiplies reducing to a single piece.
269static Value *buildMulTree(IRBuilder<> &Builder, ElementCount PieceEC,
270 Value *Vec) {
271 auto *VecTy = cast<VectorType>(Val: Vec->getType());
272 auto *PieceTy = VectorType::get(ElementType: VecTy->getElementType(), EC: PieceEC);
273 unsigned PieceElts = PieceEC.getKnownMinValue();
274 unsigned NumPieces = VecTy->getElementCount().getKnownMinValue() / PieceElts;
275 assert(isPowerOf2_32(NumPieces));
276
277 SmallVector<Value *, 8> Pieces(NumPieces);
278 for (unsigned i = 0; i < NumPieces; i++)
279 Pieces[i] = Builder.CreateExtractVector(DstType: PieceTy, SrcVec: Vec, Idx: i * PieceElts);
280
281 while (Pieces.size() > 1) {
282 for (unsigned i = 0; i < Pieces.size() / 2; i++)
283 Pieces[i] =
284 Builder.CreateMul(LHS: Pieces[i * 2], RHS: Pieces[i * 2 + 1], Name: "bin.rdx");
285 Pieces.truncate(N: Pieces.size() / 2);
286 }
287 return Pieces[0];
288}
289
290// Partially expand a vector_reduce_mul wider than M1 to reduce
291// register pressure and the number of vsetvlis required.
292bool RISCVCodeGenPrepare::expandMulReduction(IntrinsicInst &II) {
293 if (II.getIntrinsicID() != Intrinsic::vector_reduce_mul)
294 return false;
295
296 if (!ST->hasVInstructions())
297 return false;
298
299 Value *TmpVec = II.getArgOperand(i: 0);
300 auto *VecTy = cast<VectorType>(Val: TmpVec->getType());
301 unsigned EltSize = VecTy->getScalarSizeInBits();
302
303 if (auto *ScalTy = dyn_cast<ScalableVectorType>(Val: VecTy)) {
304 unsigned MinElts = ScalTy->getMinNumElements();
305
306 if (auto VLen = ST->getRealVLen()) {
307 // If VLEN is exactly known, convert to a fixed vector reduction and
308 // recurse to let the fixed path handle it (shuffle reduction instead
309 // of a scalar loop).
310 unsigned VScale = *VLen / RISCV::RVVBitsPerBlock;
311 auto *FixedTy =
312 FixedVectorType::get(ElementType: VecTy->getElementType(), NumElts: MinElts * VScale);
313 IRBuilder<> Builder(&II);
314 Value *Fixed = Builder.CreateExtractVector(DstType: FixedTy, SrcVec: TmpVec, Idx: (uint64_t)0);
315 auto *FixedRdx = cast<IntrinsicInst>(Val: Builder.CreateIntrinsic(
316 ID: Intrinsic::vector_reduce_mul, OverloadTypes: {FixedTy}, Args: {Fixed}));
317 II.replaceAllUsesWith(V: FixedRdx);
318 II.eraseFromParent();
319 expandMulReduction(II&: *FixedRdx);
320 return true;
321 }
322
323 unsigned M1MinElts = RISCV::RVVBitsPerBlock / EltSize;
324 if (MinElts <= M1MinElts || !isPowerOf2_32(Value: MinElts / M1MinElts))
325 return false;
326
327 IRBuilder<> Builder(&II);
328 auto M1EC = ElementCount::getScalable(MinVal: M1MinElts);
329 Value *Reduced = buildMulTree(Builder, PieceEC: M1EC, Vec: TmpVec);
330 Value *Rdx = Builder.CreateIntrinsic(ID: Intrinsic::vector_reduce_mul,
331 OverloadTypes: {Reduced->getType()}, Args: {Reduced});
332 II.replaceAllUsesWith(V: Rdx);
333 II.eraseFromParent();
334 return true;
335 }
336
337 unsigned VF = cast<FixedVectorType>(Val: VecTy)->getNumElements();
338 unsigned MinVLen = ST->getRealMinVLen();
339 unsigned M1VF = MinVLen / EltSize;
340
341 if (!isPowerOf2_32(Value: VF) || VF <= M1VF)
342 return false;
343
344 IRBuilder<> Builder(&II);
345 auto M1EC = ElementCount::getFixed(MinVal: M1VF);
346 auto *M1Ty = VectorType::get(ElementType: VecTy->getElementType(), EC: M1EC);
347
348 // When VLEN is exactly known, extract m1 pieces and build a mul tree.
349 // This greatly reduces register pressure during the reduction, and
350 // avoids all but one vsetvli (the one from original LMUL to m1).
351 // TODO: Generalize to handle the splitting case.
352 if (MinVLen == ST->getRealMaxVLen() && VF <= 8 * M1VF) {
353 TmpVec = buildMulTree(Builder, PieceEC: M1EC, Vec: TmpVec);
354 } else {
355 // For non-exact VLEN, shuffle-reduce at the original vector width down to
356 // m1, then extract. This prioritizes reducing the number of vsetvli
357 // over maximal reduction of LMUL for the intermediate states.
358 SmallVector<int, 32> ShuffleMask(VF);
359 for (unsigned LiveElts = VF; LiveElts > M1VF; LiveElts /= 2) {
360 unsigned Half = LiveElts / 2;
361 std::iota(first: ShuffleMask.begin(), last: ShuffleMask.begin() + Half, value: Half);
362 std::fill(first: ShuffleMask.begin() + Half, last: ShuffleMask.end(), value: -1);
363 Value *Shuf =
364 Builder.CreateShuffleVector(V: TmpVec, Mask: ShuffleMask, Name: "rdx.shuf");
365 TmpVec = Builder.CreateMul(LHS: TmpVec, RHS: Shuf, Name: "bin.rdx");
366 }
367 // Extract the M1-sized subvector and emit the final reduction intrinsic.
368 // This is the reason we're here - to force a vsetvli toggle once at m1.
369 TmpVec = Builder.CreateExtractVector(DstType: M1Ty, SrcVec: TmpVec, Idx: (uint64_t)0, Name: "rdx.sub");
370 }
371
372 Value *Rdx =
373 Builder.CreateIntrinsic(ID: Intrinsic::vector_reduce_mul, OverloadTypes: {M1Ty}, Args: {TmpVec});
374 II.replaceAllUsesWith(V: Rdx);
375 II.eraseFromParent();
376 return true;
377}
378
379// Always expand zero strided loads so we match more .vx splat patterns, even if
380// we have +optimized-zero-stride-loads. RISCVDAGToDAGISel::Select will convert
381// it back to a strided load if it's optimized.
382bool RISCVCodeGenPrepare::expandVPStrideLoad(IntrinsicInst &II) {
383 Value *BasePtr, *VL;
384
385 using namespace PatternMatch;
386 if (!match(V: &II, P: m_Intrinsic<Intrinsic::experimental_vp_strided_load>(
387 Op0: m_Value(V&: BasePtr), Op1: m_Zero(), Op2: m_AllOnes(), Op3: m_Value(V&: VL))))
388 return false;
389
390 // If SEW>XLEN then a splat will get lowered as a zero strided load anyway, so
391 // avoid expanding here.
392 if (II.getType()->getScalarSizeInBits() > ST->getXLen())
393 return false;
394
395 if (!isKnownNonZero(V: VL, Q: {*DL, DT, nullptr, &II}))
396 return false;
397
398 auto *VTy = cast<VectorType>(Val: II.getType());
399
400 IRBuilder<> Builder(&II);
401 Type *STy = VTy->getElementType();
402 Value *Val = Builder.CreateLoad(Ty: STy, Ptr: BasePtr);
403 Value *Res = Builder.CreateIntrinsic(
404 ID: Intrinsic::vp_merge, OverloadTypes: VTy,
405 Args: {II.getOperand(i_nocapture: 2), Builder.CreateVectorSplat(EC: VTy->getElementCount(), V: Val),
406 PoisonValue::get(T: VTy), VL});
407
408 II.replaceAllUsesWith(V: Res);
409 II.eraseFromParent();
410 return true;
411}
412
413bool RISCVCodeGenPrepare::run() {
414 bool MadeChange = false;
415 for (auto &BB : F)
416 for (Instruction &I : llvm::make_early_inc_range(Range&: BB))
417 MadeChange |= visit(I);
418
419 return MadeChange;
420}
421
422bool RISCVCodeGenPrepareLegacyPass::runOnFunction(Function &F) {
423 if (skipFunction(F))
424 return false;
425
426 auto &TPC = getAnalysis<TargetPassConfig>();
427 auto &TM = TPC.getTM<RISCVTargetMachine>();
428 auto ST = &TM.getSubtarget<RISCVSubtarget>(F);
429 auto DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
430
431 RISCVCodeGenPrepare RVCGP(F, DT, ST);
432 return RVCGP.run();
433}
434
435INITIALIZE_PASS_BEGIN(RISCVCodeGenPrepareLegacyPass, DEBUG_TYPE, PASS_NAME,
436 false, false)
437INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
438INITIALIZE_PASS_END(RISCVCodeGenPrepareLegacyPass, DEBUG_TYPE, PASS_NAME, false,
439 false)
440
441char RISCVCodeGenPrepareLegacyPass::ID = 0;
442
443FunctionPass *llvm::createRISCVCodeGenPrepareLegacyPass() {
444 return new RISCVCodeGenPrepareLegacyPass();
445}
446
447PreservedAnalyses RISCVCodeGenPreparePass::run(Function &F,
448 FunctionAnalysisManager &FAM) {
449 DominatorTree *DT = &FAM.getResult<DominatorTreeAnalysis>(IR&: F);
450 auto ST = &TM->getSubtarget<RISCVSubtarget>(F);
451 bool Changed = RISCVCodeGenPrepare(F, DT, ST).run();
452 if (!Changed)
453 return PreservedAnalyses::all();
454
455 PreservedAnalyses PA = PreservedAnalyses::none();
456 PA.preserveSet<CFGAnalyses>();
457 return PA;
458}
459