1 | //===-- X86PartialReduction.cpp -------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This pass looks for add instructions used by a horizontal reduction to see |
10 | // if we might be able to use pmaddwd or psadbw. Some cases of this require |
11 | // cross basic block knowledge and can't be done in SelectionDAG. |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "X86.h" |
16 | #include "X86TargetMachine.h" |
17 | #include "llvm/Analysis/ValueTracking.h" |
18 | #include "llvm/CodeGen/TargetPassConfig.h" |
19 | #include "llvm/IR/Constants.h" |
20 | #include "llvm/IR/IRBuilder.h" |
21 | #include "llvm/IR/Instructions.h" |
22 | #include "llvm/IR/IntrinsicsX86.h" |
23 | #include "llvm/IR/PatternMatch.h" |
24 | #include "llvm/Pass.h" |
25 | #include "llvm/Support/KnownBits.h" |
26 | |
27 | using namespace llvm; |
28 | |
29 | #define DEBUG_TYPE "x86-partial-reduction" |
30 | |
31 | namespace { |
32 | |
33 | class X86PartialReduction : public FunctionPass { |
34 | const DataLayout *DL = nullptr; |
35 | const X86Subtarget *ST = nullptr; |
36 | |
37 | public: |
38 | static char ID; // Pass identification, replacement for typeid. |
39 | |
40 | X86PartialReduction() : FunctionPass(ID) { } |
41 | |
42 | bool runOnFunction(Function &Fn) override; |
43 | |
44 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
45 | AU.setPreservesCFG(); |
46 | } |
47 | |
48 | StringRef getPassName() const override { |
49 | return "X86 Partial Reduction" ; |
50 | } |
51 | |
52 | private: |
53 | bool tryMAddReplacement(Instruction *Op, bool ReduceInOneBB); |
54 | bool trySADReplacement(Instruction *Op); |
55 | }; |
56 | } |
57 | |
58 | FunctionPass *llvm::createX86PartialReductionPass() { |
59 | return new X86PartialReduction(); |
60 | } |
61 | |
62 | char X86PartialReduction::ID = 0; |
63 | |
64 | INITIALIZE_PASS(X86PartialReduction, DEBUG_TYPE, |
65 | "X86 Partial Reduction" , false, false) |
66 | |
67 | // This function should be aligned with detectExtMul() in X86ISelLowering.cpp. |
68 | static bool matchVPDPBUSDPattern(const X86Subtarget *ST, BinaryOperator *Mul, |
69 | const DataLayout *DL) { |
70 | if (!ST->hasVNNI() && !ST->hasAVXVNNI()) |
71 | return false; |
72 | |
73 | Value *LHS = Mul->getOperand(i_nocapture: 0); |
74 | Value *RHS = Mul->getOperand(i_nocapture: 1); |
75 | |
76 | if (isa<SExtInst>(Val: LHS)) |
77 | std::swap(a&: LHS, b&: RHS); |
78 | |
79 | auto IsFreeTruncation = [&](Value *Op) { |
80 | if (auto *Cast = dyn_cast<CastInst>(Val: Op)) { |
81 | if (Cast->getParent() == Mul->getParent() && |
82 | (Cast->getOpcode() == Instruction::SExt || |
83 | Cast->getOpcode() == Instruction::ZExt) && |
84 | Cast->getOperand(i_nocapture: 0)->getType()->getScalarSizeInBits() <= 8) |
85 | return true; |
86 | } |
87 | |
88 | return isa<Constant>(Val: Op); |
89 | }; |
90 | |
91 | // (dpbusd (zext a), (sext, b)). Since the first operand should be unsigned |
92 | // value, we need to check LHS is zero extended value. RHS should be signed |
93 | // value, so we just check the signed bits. |
94 | if ((IsFreeTruncation(LHS) && |
95 | computeKnownBits(V: LHS, DL: *DL).countMaxActiveBits() <= 8) && |
96 | (IsFreeTruncation(RHS) && ComputeMaxSignificantBits(Op: RHS, DL: *DL) <= 8)) |
97 | return true; |
98 | |
99 | return false; |
100 | } |
101 | |
102 | bool X86PartialReduction::tryMAddReplacement(Instruction *Op, |
103 | bool ReduceInOneBB) { |
104 | if (!ST->hasSSE2()) |
105 | return false; |
106 | |
107 | // Need at least 8 elements. |
108 | if (cast<FixedVectorType>(Val: Op->getType())->getNumElements() < 8) |
109 | return false; |
110 | |
111 | // Element type should be i32. |
112 | if (!cast<VectorType>(Val: Op->getType())->getElementType()->isIntegerTy(Bitwidth: 32)) |
113 | return false; |
114 | |
115 | auto *Mul = dyn_cast<BinaryOperator>(Val: Op); |
116 | if (!Mul || Mul->getOpcode() != Instruction::Mul) |
117 | return false; |
118 | |
119 | Value *LHS = Mul->getOperand(i_nocapture: 0); |
120 | Value *RHS = Mul->getOperand(i_nocapture: 1); |
121 | |
122 | // If the target support VNNI, leave it to ISel to combine reduce operation |
123 | // to VNNI instruction. |
124 | // TODO: we can support transforming reduce to VNNI intrinsic for across block |
125 | // in this pass. |
126 | if (ReduceInOneBB && matchVPDPBUSDPattern(ST, Mul, DL)) |
127 | return false; |
128 | |
129 | // LHS and RHS should be only used once or if they are the same then only |
130 | // used twice. Only check this when SSE4.1 is enabled and we have zext/sext |
131 | // instructions, otherwise we use punpck to emulate zero extend in stages. The |
132 | // trunc/ we need to do likely won't introduce new instructions in that case. |
133 | if (ST->hasSSE41()) { |
134 | if (LHS == RHS) { |
135 | if (!isa<Constant>(Val: LHS) && !LHS->hasNUses(N: 2)) |
136 | return false; |
137 | } else { |
138 | if (!isa<Constant>(Val: LHS) && !LHS->hasOneUse()) |
139 | return false; |
140 | if (!isa<Constant>(Val: RHS) && !RHS->hasOneUse()) |
141 | return false; |
142 | } |
143 | } |
144 | |
145 | auto CanShrinkOp = [&](Value *Op) { |
146 | auto IsFreeTruncation = [&](Value *Op) { |
147 | if (auto *Cast = dyn_cast<CastInst>(Val: Op)) { |
148 | if (Cast->getParent() == Mul->getParent() && |
149 | (Cast->getOpcode() == Instruction::SExt || |
150 | Cast->getOpcode() == Instruction::ZExt) && |
151 | Cast->getOperand(i_nocapture: 0)->getType()->getScalarSizeInBits() <= 16) |
152 | return true; |
153 | } |
154 | |
155 | return isa<Constant>(Val: Op); |
156 | }; |
157 | |
158 | // If the operation can be freely truncated and has enough sign bits we |
159 | // can shrink. |
160 | if (IsFreeTruncation(Op) && ComputeNumSignBits(Op, DL: *DL, AC: nullptr, CxtI: Mul) > 16) |
161 | return true; |
162 | |
163 | // SelectionDAG has limited support for truncating through an add or sub if |
164 | // the inputs are freely truncatable. |
165 | if (auto *BO = dyn_cast<BinaryOperator>(Val: Op)) { |
166 | if (BO->getParent() == Mul->getParent() && |
167 | IsFreeTruncation(BO->getOperand(i_nocapture: 0)) && |
168 | IsFreeTruncation(BO->getOperand(i_nocapture: 1)) && |
169 | ComputeNumSignBits(Op, DL: *DL, AC: nullptr, CxtI: Mul) > 16) |
170 | return true; |
171 | } |
172 | |
173 | return false; |
174 | }; |
175 | |
176 | // Both Ops need to be shrinkable. |
177 | if (!CanShrinkOp(LHS) && !CanShrinkOp(RHS)) |
178 | return false; |
179 | |
180 | IRBuilder<> Builder(Mul); |
181 | |
182 | auto *MulTy = cast<FixedVectorType>(Val: Op->getType()); |
183 | unsigned NumElts = MulTy->getNumElements(); |
184 | |
185 | // Extract even elements and odd elements and add them together. This will |
186 | // be pattern matched by SelectionDAG to pmaddwd. This instruction will be |
187 | // half the original width. |
188 | SmallVector<int, 16> EvenMask(NumElts / 2); |
189 | SmallVector<int, 16> OddMask(NumElts / 2); |
190 | for (int i = 0, e = NumElts / 2; i != e; ++i) { |
191 | EvenMask[i] = i * 2; |
192 | OddMask[i] = i * 2 + 1; |
193 | } |
194 | // Creating a new mul so the replaceAllUsesWith below doesn't replace the |
195 | // uses in the shuffles we're creating. |
196 | Value *NewMul = Builder.CreateMul(LHS: Mul->getOperand(i_nocapture: 0), RHS: Mul->getOperand(i_nocapture: 1)); |
197 | Value *EvenElts = Builder.CreateShuffleVector(V1: NewMul, V2: NewMul, Mask: EvenMask); |
198 | Value *OddElts = Builder.CreateShuffleVector(V1: NewMul, V2: NewMul, Mask: OddMask); |
199 | Value *MAdd = Builder.CreateAdd(LHS: EvenElts, RHS: OddElts); |
200 | |
201 | // Concatenate zeroes to extend back to the original type. |
202 | SmallVector<int, 32> ConcatMask(NumElts); |
203 | std::iota(first: ConcatMask.begin(), last: ConcatMask.end(), value: 0); |
204 | Value *Zero = Constant::getNullValue(Ty: MAdd->getType()); |
205 | Value *Concat = Builder.CreateShuffleVector(V1: MAdd, V2: Zero, Mask: ConcatMask); |
206 | |
207 | Mul->replaceAllUsesWith(V: Concat); |
208 | Mul->eraseFromParent(); |
209 | |
210 | return true; |
211 | } |
212 | |
213 | bool X86PartialReduction::trySADReplacement(Instruction *Op) { |
214 | if (!ST->hasSSE2()) |
215 | return false; |
216 | |
217 | // TODO: There's nothing special about i32, any integer type above i16 should |
218 | // work just as well. |
219 | if (!cast<VectorType>(Val: Op->getType())->getElementType()->isIntegerTy(Bitwidth: 32)) |
220 | return false; |
221 | |
222 | Value *LHS; |
223 | if (match(V: Op, P: PatternMatch::m_Intrinsic<Intrinsic::abs>())) { |
224 | LHS = Op->getOperand(i: 0); |
225 | } else { |
226 | // Operand should be a select. |
227 | auto *SI = dyn_cast<SelectInst>(Val: Op); |
228 | if (!SI) |
229 | return false; |
230 | |
231 | Value *RHS; |
232 | // Select needs to implement absolute value. |
233 | auto SPR = matchSelectPattern(V: SI, LHS, RHS); |
234 | if (SPR.Flavor != SPF_ABS) |
235 | return false; |
236 | } |
237 | |
238 | // Need a subtract of two values. |
239 | auto *Sub = dyn_cast<BinaryOperator>(Val: LHS); |
240 | if (!Sub || Sub->getOpcode() != Instruction::Sub) |
241 | return false; |
242 | |
243 | // Look for zero extend from i8. |
244 | auto getZeroExtendedVal = [](Value *Op) -> Value * { |
245 | if (auto *ZExt = dyn_cast<ZExtInst>(Val: Op)) |
246 | if (cast<VectorType>(Val: ZExt->getOperand(i_nocapture: 0)->getType()) |
247 | ->getElementType() |
248 | ->isIntegerTy(Bitwidth: 8)) |
249 | return ZExt->getOperand(i_nocapture: 0); |
250 | |
251 | return nullptr; |
252 | }; |
253 | |
254 | // Both operands of the subtract should be extends from vXi8. |
255 | Value *Op0 = getZeroExtendedVal(Sub->getOperand(i_nocapture: 0)); |
256 | Value *Op1 = getZeroExtendedVal(Sub->getOperand(i_nocapture: 1)); |
257 | if (!Op0 || !Op1) |
258 | return false; |
259 | |
260 | IRBuilder<> Builder(Op); |
261 | |
262 | auto *OpTy = cast<FixedVectorType>(Val: Op->getType()); |
263 | unsigned NumElts = OpTy->getNumElements(); |
264 | |
265 | unsigned IntrinsicNumElts; |
266 | Intrinsic::ID IID; |
267 | if (ST->hasBWI() && NumElts >= 64) { |
268 | IID = Intrinsic::x86_avx512_psad_bw_512; |
269 | IntrinsicNumElts = 64; |
270 | } else if (ST->hasAVX2() && NumElts >= 32) { |
271 | IID = Intrinsic::x86_avx2_psad_bw; |
272 | IntrinsicNumElts = 32; |
273 | } else { |
274 | IID = Intrinsic::x86_sse2_psad_bw; |
275 | IntrinsicNumElts = 16; |
276 | } |
277 | |
278 | Function *PSADBWFn = Intrinsic::getOrInsertDeclaration(M: Op->getModule(), id: IID); |
279 | |
280 | if (NumElts < 16) { |
281 | // Pad input with zeroes. |
282 | SmallVector<int, 32> ConcatMask(16); |
283 | for (unsigned i = 0; i != NumElts; ++i) |
284 | ConcatMask[i] = i; |
285 | for (unsigned i = NumElts; i != 16; ++i) |
286 | ConcatMask[i] = (i % NumElts) + NumElts; |
287 | |
288 | Value *Zero = Constant::getNullValue(Ty: Op0->getType()); |
289 | Op0 = Builder.CreateShuffleVector(V1: Op0, V2: Zero, Mask: ConcatMask); |
290 | Op1 = Builder.CreateShuffleVector(V1: Op1, V2: Zero, Mask: ConcatMask); |
291 | NumElts = 16; |
292 | } |
293 | |
294 | // Intrinsics produce vXi64 and need to be casted to vXi32. |
295 | auto *I32Ty = |
296 | FixedVectorType::get(ElementType: Builder.getInt32Ty(), NumElts: IntrinsicNumElts / 4); |
297 | |
298 | assert(NumElts % IntrinsicNumElts == 0 && "Unexpected number of elements!" ); |
299 | unsigned NumSplits = NumElts / IntrinsicNumElts; |
300 | |
301 | // First collect the pieces we need. |
302 | SmallVector<Value *, 4> Ops(NumSplits); |
303 | for (unsigned i = 0; i != NumSplits; ++i) { |
304 | SmallVector<int, 64> (IntrinsicNumElts); |
305 | std::iota(first: ExtractMask.begin(), last: ExtractMask.end(), value: i * IntrinsicNumElts); |
306 | Value * = Builder.CreateShuffleVector(V1: Op0, V2: Op0, Mask: ExtractMask); |
307 | Value * = Builder.CreateShuffleVector(V1: Op1, V2: Op0, Mask: ExtractMask); |
308 | Ops[i] = Builder.CreateCall(Callee: PSADBWFn, Args: {ExtractOp0, ExtractOp1}); |
309 | Ops[i] = Builder.CreateBitCast(V: Ops[i], DestTy: I32Ty); |
310 | } |
311 | |
312 | assert(isPowerOf2_32(NumSplits) && "Expected power of 2 splits" ); |
313 | unsigned Stages = Log2_32(Value: NumSplits); |
314 | for (unsigned s = Stages; s > 0; --s) { |
315 | unsigned NumConcatElts = |
316 | cast<FixedVectorType>(Val: Ops[0]->getType())->getNumElements() * 2; |
317 | for (unsigned i = 0; i != 1U << (s - 1); ++i) { |
318 | SmallVector<int, 64> ConcatMask(NumConcatElts); |
319 | std::iota(first: ConcatMask.begin(), last: ConcatMask.end(), value: 0); |
320 | Ops[i] = Builder.CreateShuffleVector(V1: Ops[i*2], V2: Ops[i*2+1], Mask: ConcatMask); |
321 | } |
322 | } |
323 | |
324 | // At this point the final value should be in Ops[0]. Now we need to adjust |
325 | // it to the final original type. |
326 | NumElts = cast<FixedVectorType>(Val: OpTy)->getNumElements(); |
327 | if (NumElts == 2) { |
328 | // Extract down to 2 elements. |
329 | Ops[0] = Builder.CreateShuffleVector(V1: Ops[0], V2: Ops[0], Mask: ArrayRef<int>{0, 1}); |
330 | } else if (NumElts >= 8) { |
331 | SmallVector<int, 32> ConcatMask(NumElts); |
332 | unsigned SubElts = |
333 | cast<FixedVectorType>(Val: Ops[0]->getType())->getNumElements(); |
334 | for (unsigned i = 0; i != SubElts; ++i) |
335 | ConcatMask[i] = i; |
336 | for (unsigned i = SubElts; i != NumElts; ++i) |
337 | ConcatMask[i] = (i % SubElts) + SubElts; |
338 | |
339 | Value *Zero = Constant::getNullValue(Ty: Ops[0]->getType()); |
340 | Ops[0] = Builder.CreateShuffleVector(V1: Ops[0], V2: Zero, Mask: ConcatMask); |
341 | } |
342 | |
343 | Op->replaceAllUsesWith(V: Ops[0]); |
344 | Op->eraseFromParent(); |
345 | |
346 | return true; |
347 | } |
348 | |
349 | // Walk backwards from the ExtractElementInst and determine if it is the end of |
350 | // a horizontal reduction. Return the input to the reduction if we find one. |
351 | static Value *(const ExtractElementInst &EE, |
352 | bool &ReduceInOneBB) { |
353 | ReduceInOneBB = true; |
354 | // Make sure we're extracting index 0. |
355 | auto *Index = dyn_cast<ConstantInt>(Val: EE.getIndexOperand()); |
356 | if (!Index || !Index->isNullValue()) |
357 | return nullptr; |
358 | |
359 | const auto *BO = dyn_cast<BinaryOperator>(Val: EE.getVectorOperand()); |
360 | if (!BO || BO->getOpcode() != Instruction::Add || !BO->hasOneUse()) |
361 | return nullptr; |
362 | if (EE.getParent() != BO->getParent()) |
363 | ReduceInOneBB = false; |
364 | |
365 | unsigned NumElems = cast<FixedVectorType>(Val: BO->getType())->getNumElements(); |
366 | // Ensure the reduction size is a power of 2. |
367 | if (!isPowerOf2_32(Value: NumElems)) |
368 | return nullptr; |
369 | |
370 | const Value *Op = BO; |
371 | unsigned Stages = Log2_32(Value: NumElems); |
372 | for (unsigned i = 0; i != Stages; ++i) { |
373 | const auto *BO = dyn_cast<BinaryOperator>(Val: Op); |
374 | if (!BO || BO->getOpcode() != Instruction::Add) |
375 | return nullptr; |
376 | if (EE.getParent() != BO->getParent()) |
377 | ReduceInOneBB = false; |
378 | |
379 | // If this isn't the first add, then it should only have 2 users, the |
380 | // shuffle and another add which we checked in the previous iteration. |
381 | if (i != 0 && !BO->hasNUses(N: 2)) |
382 | return nullptr; |
383 | |
384 | Value *LHS = BO->getOperand(i_nocapture: 0); |
385 | Value *RHS = BO->getOperand(i_nocapture: 1); |
386 | |
387 | auto *Shuffle = dyn_cast<ShuffleVectorInst>(Val: LHS); |
388 | if (Shuffle) { |
389 | Op = RHS; |
390 | } else { |
391 | Shuffle = dyn_cast<ShuffleVectorInst>(Val: RHS); |
392 | Op = LHS; |
393 | } |
394 | |
395 | // The first operand of the shuffle should be the same as the other operand |
396 | // of the bin op. |
397 | if (!Shuffle || Shuffle->getOperand(i_nocapture: 0) != Op) |
398 | return nullptr; |
399 | |
400 | // Verify the shuffle has the expected (at this stage of the pyramid) mask. |
401 | unsigned MaskEnd = 1 << i; |
402 | for (unsigned Index = 0; Index < MaskEnd; ++Index) |
403 | if (Shuffle->getMaskValue(Elt: Index) != (int)(MaskEnd + Index)) |
404 | return nullptr; |
405 | } |
406 | |
407 | return const_cast<Value *>(Op); |
408 | } |
409 | |
410 | // See if this BO is reachable from this Phi by walking forward through single |
411 | // use BinaryOperators with the same opcode. If we get back then we know we've |
412 | // found a loop and it is safe to step through this Add to find more leaves. |
413 | static bool isReachableFromPHI(PHINode *Phi, BinaryOperator *BO) { |
414 | // The PHI itself should only have one use. |
415 | if (!Phi->hasOneUse()) |
416 | return false; |
417 | |
418 | Instruction *U = cast<Instruction>(Val: *Phi->user_begin()); |
419 | if (U == BO) |
420 | return true; |
421 | |
422 | while (U->hasOneUse() && U->getOpcode() == BO->getOpcode()) |
423 | U = cast<Instruction>(Val: *U->user_begin()); |
424 | |
425 | return U == BO; |
426 | } |
427 | |
428 | // Collect all the leaves of the tree of adds that feeds into the horizontal |
429 | // reduction. Root is the Value that is used by the horizontal reduction. |
430 | // We look through single use phis, single use adds, or adds that are used by |
431 | // a phi that forms a loop with the add. |
432 | static void collectLeaves(Value *Root, SmallVectorImpl<Instruction *> &Leaves) { |
433 | SmallPtrSet<Value *, 8> Visited; |
434 | SmallVector<Value *, 8> Worklist; |
435 | Worklist.push_back(Elt: Root); |
436 | |
437 | while (!Worklist.empty()) { |
438 | Value *V = Worklist.pop_back_val(); |
439 | if (!Visited.insert(Ptr: V).second) |
440 | continue; |
441 | |
442 | if (auto *PN = dyn_cast<PHINode>(Val: V)) { |
443 | // PHI node should have single use unless it is the root node, then it |
444 | // has 2 uses. |
445 | if (!PN->hasNUses(N: PN == Root ? 2 : 1)) |
446 | break; |
447 | |
448 | // Push incoming values to the worklist. |
449 | append_range(C&: Worklist, R: PN->incoming_values()); |
450 | |
451 | continue; |
452 | } |
453 | |
454 | if (auto *BO = dyn_cast<BinaryOperator>(Val: V)) { |
455 | if (BO->getOpcode() == Instruction::Add) { |
456 | // Simple case. Single use, just push its operands to the worklist. |
457 | if (BO->hasNUses(N: BO == Root ? 2 : 1)) { |
458 | append_range(C&: Worklist, R: BO->operands()); |
459 | continue; |
460 | } |
461 | |
462 | // If there is additional use, make sure it is an unvisited phi that |
463 | // gets us back to this node. |
464 | if (BO->hasNUses(N: BO == Root ? 3 : 2)) { |
465 | PHINode *PN = nullptr; |
466 | for (auto *U : BO->users()) |
467 | if (auto *P = dyn_cast<PHINode>(Val: U)) |
468 | if (!Visited.count(Ptr: P)) |
469 | PN = P; |
470 | |
471 | // If we didn't find a 2-input PHI then this isn't a case we can |
472 | // handle. |
473 | if (!PN || PN->getNumIncomingValues() != 2) |
474 | continue; |
475 | |
476 | // Walk forward from this phi to see if it reaches back to this add. |
477 | if (!isReachableFromPHI(Phi: PN, BO)) |
478 | continue; |
479 | |
480 | // The phi forms a loop with this Add, push its operands. |
481 | append_range(C&: Worklist, R: BO->operands()); |
482 | } |
483 | } |
484 | } |
485 | |
486 | // Not an add or phi, make it a leaf. |
487 | if (auto *I = dyn_cast<Instruction>(Val: V)) { |
488 | if (!V->hasNUses(N: I == Root ? 2 : 1)) |
489 | continue; |
490 | |
491 | // Add this as a leaf. |
492 | Leaves.push_back(Elt: I); |
493 | } |
494 | } |
495 | } |
496 | |
497 | bool X86PartialReduction::runOnFunction(Function &F) { |
498 | if (skipFunction(F)) |
499 | return false; |
500 | |
501 | auto *TPC = getAnalysisIfAvailable<TargetPassConfig>(); |
502 | if (!TPC) |
503 | return false; |
504 | |
505 | auto &TM = TPC->getTM<X86TargetMachine>(); |
506 | ST = TM.getSubtargetImpl(F); |
507 | |
508 | DL = &F.getDataLayout(); |
509 | |
510 | bool MadeChange = false; |
511 | for (auto &BB : F) { |
512 | for (auto &I : BB) { |
513 | auto *EE = dyn_cast<ExtractElementInst>(Val: &I); |
514 | if (!EE) |
515 | continue; |
516 | |
517 | bool ReduceInOneBB; |
518 | // First find a reduction tree. |
519 | // FIXME: Do we need to handle other opcodes than Add? |
520 | Value *Root = matchAddReduction(EE: *EE, ReduceInOneBB); |
521 | if (!Root) |
522 | continue; |
523 | |
524 | SmallVector<Instruction *, 8> Leaves; |
525 | collectLeaves(Root, Leaves); |
526 | |
527 | for (Instruction *I : Leaves) { |
528 | if (tryMAddReplacement(Op: I, ReduceInOneBB)) { |
529 | MadeChange = true; |
530 | continue; |
531 | } |
532 | |
533 | // Don't do SAD matching on the root node. SelectionDAG already |
534 | // has support for that and currently generates better code. |
535 | if (I != Root && trySADReplacement(Op: I)) |
536 | MadeChange = true; |
537 | } |
538 | } |
539 | } |
540 | |
541 | return MadeChange; |
542 | } |
543 | |