1//===- ExpandReductions.cpp - Expand reduction intrinsics -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements IR expansion for reduction intrinsics, allowing targets
10// to enable the intrinsics until just before codegen.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/CodeGen/ExpandReductions.h"
15#include "llvm/Analysis/LoopInfo.h"
16#include "llvm/Analysis/TargetTransformInfo.h"
17#include "llvm/CodeGen/Passes.h"
18#include "llvm/IR/Dominators.h"
19#include "llvm/IR/IRBuilder.h"
20#include "llvm/IR/InstIterator.h"
21#include "llvm/IR/IntrinsicInst.h"
22#include "llvm/IR/Intrinsics.h"
23#include "llvm/InitializePasses.h"
24#include "llvm/Pass.h"
25#include "llvm/Transforms/Utils/LoopUtils.h"
26
27using namespace llvm;
28
29namespace {
30
31bool expandReductions(Function &F, const TargetTransformInfo *TTI,
32 DominatorTree *DT, LoopInfo *LI) {
33 bool Changed = false;
34 SmallVector<IntrinsicInst *, 4> Worklist;
35 for (auto &I : instructions(F)) {
36 if (auto *II = dyn_cast<IntrinsicInst>(Val: &I)) {
37 switch (II->getIntrinsicID()) {
38 default: break;
39 case Intrinsic::vector_reduce_fadd:
40 case Intrinsic::vector_reduce_fmul:
41 case Intrinsic::vector_reduce_add:
42 case Intrinsic::vector_reduce_mul:
43 case Intrinsic::vector_reduce_and:
44 case Intrinsic::vector_reduce_or:
45 case Intrinsic::vector_reduce_xor:
46 case Intrinsic::vector_reduce_smax:
47 case Intrinsic::vector_reduce_smin:
48 case Intrinsic::vector_reduce_umax:
49 case Intrinsic::vector_reduce_umin:
50 case Intrinsic::vector_reduce_fmax:
51 case Intrinsic::vector_reduce_fmin:
52 if (TTI->shouldExpandReduction(II))
53 Worklist.push_back(Elt: II);
54
55 break;
56 }
57 }
58 }
59
60 for (auto *II : Worklist) {
61 FastMathFlags FMF = II->getFastMathFlagsOrNone();
62 Intrinsic::ID ID = II->getIntrinsicID();
63 RecurKind RK = getMinMaxReductionRecurKind(RdxID: ID);
64 TargetTransformInfo::ReductionShuffle RS =
65 TTI->getPreferredExpandedReductionShuffle(II);
66
67 Value *Rdx = nullptr;
68 IRBuilder<> Builder(II);
69 IRBuilder<>::FastMathFlagGuard FMFGuard(Builder);
70 Builder.setFastMathFlags(FMF);
71 switch (ID) {
72 default: llvm_unreachable("Unexpected intrinsic!");
73 case Intrinsic::vector_reduce_fadd:
74 case Intrinsic::vector_reduce_fmul: {
75 // FMFs must be attached to the call, otherwise it's an ordered reduction
76 // and it can't be handled by generating a shuffle sequence.
77 Value *Acc = II->getArgOperand(i: 0);
78 Value *Vec = II->getArgOperand(i: 1);
79 unsigned RdxOpcode = getArithmeticReductionInstruction(RdxID: ID);
80 if (isa<ScalableVectorType>(Val: Vec->getType())) {
81 Rdx = expandReductionViaLoop(Builder, Vec, RdxOpcode, Acc, DT, LI);
82 break;
83 }
84 if (!FMF.allowReassoc())
85 Rdx = getOrderedReduction(Builder, Acc, Src: Vec, Op: RdxOpcode, MinMaxKind: RK);
86 else {
87 if (!isPowerOf2_32(
88 Value: cast<FixedVectorType>(Val: Vec->getType())->getNumElements()))
89 continue;
90 Rdx = getShuffleReduction(Builder, Src: Vec, Op: RdxOpcode, RS, MinMaxKind: RK);
91 Rdx = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)RdxOpcode, LHS: Acc, RHS: Rdx,
92 Name: "bin.rdx");
93 }
94 break;
95 }
96 case Intrinsic::vector_reduce_and:
97 case Intrinsic::vector_reduce_or: {
98 // Canonicalize logical or/and reductions:
99 // Or reduction for i1 is represented as:
100 // %val = bitcast <ReduxWidth x i1> to iReduxWidth
101 // %res = cmp ne iReduxWidth %val, 0
102 // And reduction for i1 is represented as:
103 // %val = bitcast <ReduxWidth x i1> to iReduxWidth
104 // %res = cmp eq iReduxWidth %val, 11111
105 Value *Vec = II->getArgOperand(i: 0);
106 auto *FTy = cast<FixedVectorType>(Val: Vec->getType());
107 unsigned NumElts = FTy->getNumElements();
108 if (!isPowerOf2_32(Value: NumElts))
109 continue;
110
111 if (FTy->getElementType() == Builder.getInt1Ty()) {
112 Rdx = Builder.CreateBitCast(V: Vec, DestTy: Builder.getIntNTy(N: NumElts));
113 if (ID == Intrinsic::vector_reduce_and) {
114 Rdx = Builder.CreateICmpEQ(
115 LHS: Rdx, RHS: ConstantInt::getAllOnesValue(Ty: Rdx->getType()));
116 } else {
117 assert(ID == Intrinsic::vector_reduce_or && "Expected or reduction.");
118 Rdx = Builder.CreateIsNotNull(Arg: Rdx);
119 }
120 break;
121 }
122 unsigned RdxOpcode = getArithmeticReductionInstruction(RdxID: ID);
123 Rdx = getShuffleReduction(Builder, Src: Vec, Op: RdxOpcode, RS, MinMaxKind: RK);
124 break;
125 }
126 case Intrinsic::vector_reduce_add:
127 case Intrinsic::vector_reduce_mul:
128 case Intrinsic::vector_reduce_xor:
129 case Intrinsic::vector_reduce_smax:
130 case Intrinsic::vector_reduce_smin:
131 case Intrinsic::vector_reduce_umax:
132 case Intrinsic::vector_reduce_umin: {
133 Value *Vec = II->getArgOperand(i: 0);
134 unsigned RdxOpcode = getArithmeticReductionInstruction(RdxID: ID);
135 if (isa<ScalableVectorType>(Val: Vec->getType())) {
136 Type *EltTy = Vec->getType()->getScalarType();
137 Value *Ident = getReductionIdentity(RdxID: ID, Ty: EltTy, FMF);
138 Rdx = expandReductionViaLoop(Builder, Vec, RdxOpcode, Acc: Ident, DT, LI);
139 break;
140 }
141 if (!isPowerOf2_32(
142 Value: cast<FixedVectorType>(Val: Vec->getType())->getNumElements()))
143 continue;
144 Rdx = getShuffleReduction(Builder, Src: Vec, Op: RdxOpcode, RS, MinMaxKind: RK);
145 break;
146 }
147 case Intrinsic::vector_reduce_fmax:
148 case Intrinsic::vector_reduce_fmin: {
149 // We require "nnan" to use a shuffle reduction; "nsz" is implied by the
150 // semantics of the reduction.
151 Value *Vec = II->getArgOperand(i: 0);
152 if (!isPowerOf2_32(
153 Value: cast<FixedVectorType>(Val: Vec->getType())->getNumElements()) ||
154 !FMF.noNaNs())
155 continue;
156 unsigned RdxOpcode = getArithmeticReductionInstruction(RdxID: ID);
157 Rdx = getShuffleReduction(Builder, Src: Vec, Op: RdxOpcode, RS, MinMaxKind: RK);
158 break;
159 }
160 }
161 II->replaceAllUsesWith(V: Rdx);
162 II->eraseFromParent();
163 Changed = true;
164 }
165 return Changed;
166}
167
168class ExpandReductions : public FunctionPass {
169public:
170 static char ID;
171 ExpandReductions() : FunctionPass(ID) {}
172
173 bool runOnFunction(Function &F) override {
174 const auto *TTI =&getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
175 auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
176 auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>();
177 auto *DT = DTWP ? &DTWP->getDomTree() : nullptr;
178 auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr;
179 return expandReductions(F, TTI, DT, LI);
180 }
181
182 void getAnalysisUsage(AnalysisUsage &AU) const override {
183 AU.addRequired<TargetTransformInfoWrapperPass>();
184 AU.addPreserved<DominatorTreeWrapperPass>();
185 AU.addPreserved<LoopInfoWrapperPass>();
186 }
187};
188}
189
190char ExpandReductions::ID;
191INITIALIZE_PASS_BEGIN(ExpandReductions, "expand-reductions",
192 "Expand reduction intrinsics", false, false)
193INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
194INITIALIZE_PASS_END(ExpandReductions, "expand-reductions",
195 "Expand reduction intrinsics", false, false)
196
197FunctionPass *llvm::createExpandReductionsPass() {
198 return new ExpandReductions();
199}
200
201PreservedAnalyses ExpandReductionsPass::run(Function &F,
202 FunctionAnalysisManager &AM) {
203 const auto &TTI = AM.getResult<TargetIRAnalysis>(IR&: F);
204 auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(IR&: F);
205 auto *LI = AM.getCachedResult<LoopAnalysis>(IR&: F);
206 if (!expandReductions(F, TTI: &TTI, DT, LI))
207 return PreservedAnalyses::all();
208 PreservedAnalyses PA;
209 PA.preserve<DominatorTreeAnalysis>();
210 PA.preserve<LoopAnalysis>();
211 return PA;
212}
213