1 | //===-- X86PartialReduction.cpp -------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This pass looks for add instructions used by a horizontal reduction to see |
10 | // if we might be able to use pmaddwd or psadbw. Some cases of this require |
11 | // cross basic block knowledge and can't be done in SelectionDAG. |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "X86.h" |
16 | #include "X86TargetMachine.h" |
17 | #include "llvm/Analysis/ValueTracking.h" |
18 | #include "llvm/CodeGen/TargetPassConfig.h" |
19 | #include "llvm/IR/Constants.h" |
20 | #include "llvm/IR/IRBuilder.h" |
21 | #include "llvm/IR/Instructions.h" |
22 | #include "llvm/IR/IntrinsicInst.h" |
23 | #include "llvm/IR/IntrinsicsX86.h" |
24 | #include "llvm/IR/Operator.h" |
25 | #include "llvm/IR/PatternMatch.h" |
26 | #include "llvm/Pass.h" |
27 | #include "llvm/Support/KnownBits.h" |
28 | |
29 | using namespace llvm; |
30 | |
31 | #define DEBUG_TYPE "x86-partial-reduction" |
32 | |
33 | namespace { |
34 | |
35 | class X86PartialReduction : public FunctionPass { |
36 | const DataLayout *DL = nullptr; |
37 | const X86Subtarget *ST = nullptr; |
38 | |
39 | public: |
40 | static char ID; // Pass identification, replacement for typeid. |
41 | |
42 | X86PartialReduction() : FunctionPass(ID) { } |
43 | |
44 | bool runOnFunction(Function &Fn) override; |
45 | |
46 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
47 | AU.setPreservesCFG(); |
48 | } |
49 | |
50 | StringRef getPassName() const override { |
51 | return "X86 Partial Reduction" ; |
52 | } |
53 | |
54 | private: |
55 | bool tryMAddReplacement(Instruction *Op, bool ReduceInOneBB); |
56 | bool trySADReplacement(Instruction *Op); |
57 | }; |
58 | } |
59 | |
60 | FunctionPass *llvm::createX86PartialReductionPass() { |
61 | return new X86PartialReduction(); |
62 | } |
63 | |
64 | char X86PartialReduction::ID = 0; |
65 | |
66 | INITIALIZE_PASS(X86PartialReduction, DEBUG_TYPE, |
67 | "X86 Partial Reduction" , false, false) |
68 | |
69 | // This function should be aligned with detectExtMul() in X86ISelLowering.cpp. |
70 | static bool matchVPDPBUSDPattern(const X86Subtarget *ST, BinaryOperator *Mul, |
71 | const DataLayout *DL) { |
72 | if (!ST->hasVNNI() && !ST->hasAVXVNNI()) |
73 | return false; |
74 | |
75 | Value *LHS = Mul->getOperand(i_nocapture: 0); |
76 | Value *RHS = Mul->getOperand(i_nocapture: 1); |
77 | |
78 | if (isa<SExtInst>(Val: LHS)) |
79 | std::swap(a&: LHS, b&: RHS); |
80 | |
81 | auto IsFreeTruncation = [&](Value *Op) { |
82 | if (auto *Cast = dyn_cast<CastInst>(Val: Op)) { |
83 | if (Cast->getParent() == Mul->getParent() && |
84 | (Cast->getOpcode() == Instruction::SExt || |
85 | Cast->getOpcode() == Instruction::ZExt) && |
86 | Cast->getOperand(i_nocapture: 0)->getType()->getScalarSizeInBits() <= 8) |
87 | return true; |
88 | } |
89 | |
90 | return isa<Constant>(Val: Op); |
91 | }; |
92 | |
93 | // (dpbusd (zext a), (sext, b)). Since the first operand should be unsigned |
94 | // value, we need to check LHS is zero extended value. RHS should be signed |
95 | // value, so we just check the signed bits. |
96 | if ((IsFreeTruncation(LHS) && |
97 | computeKnownBits(V: LHS, DL: *DL).countMaxActiveBits() <= 8) && |
98 | (IsFreeTruncation(RHS) && ComputeMaxSignificantBits(Op: RHS, DL: *DL) <= 8)) |
99 | return true; |
100 | |
101 | return false; |
102 | } |
103 | |
104 | bool X86PartialReduction::tryMAddReplacement(Instruction *Op, |
105 | bool ReduceInOneBB) { |
106 | if (!ST->hasSSE2()) |
107 | return false; |
108 | |
109 | // Need at least 8 elements. |
110 | if (cast<FixedVectorType>(Val: Op->getType())->getNumElements() < 8) |
111 | return false; |
112 | |
113 | // Element type should be i32. |
114 | if (!cast<VectorType>(Val: Op->getType())->getElementType()->isIntegerTy(Bitwidth: 32)) |
115 | return false; |
116 | |
117 | auto *Mul = dyn_cast<BinaryOperator>(Val: Op); |
118 | if (!Mul || Mul->getOpcode() != Instruction::Mul) |
119 | return false; |
120 | |
121 | Value *LHS = Mul->getOperand(i_nocapture: 0); |
122 | Value *RHS = Mul->getOperand(i_nocapture: 1); |
123 | |
124 | // If the target support VNNI, leave it to ISel to combine reduce operation |
125 | // to VNNI instruction. |
126 | // TODO: we can support transforming reduce to VNNI intrinsic for across block |
127 | // in this pass. |
128 | if (ReduceInOneBB && matchVPDPBUSDPattern(ST, Mul, DL)) |
129 | return false; |
130 | |
131 | // LHS and RHS should be only used once or if they are the same then only |
132 | // used twice. Only check this when SSE4.1 is enabled and we have zext/sext |
133 | // instructions, otherwise we use punpck to emulate zero extend in stages. The |
134 | // trunc/ we need to do likely won't introduce new instructions in that case. |
135 | if (ST->hasSSE41()) { |
136 | if (LHS == RHS) { |
137 | if (!isa<Constant>(Val: LHS) && !LHS->hasNUses(N: 2)) |
138 | return false; |
139 | } else { |
140 | if (!isa<Constant>(Val: LHS) && !LHS->hasOneUse()) |
141 | return false; |
142 | if (!isa<Constant>(Val: RHS) && !RHS->hasOneUse()) |
143 | return false; |
144 | } |
145 | } |
146 | |
147 | auto CanShrinkOp = [&](Value *Op) { |
148 | auto IsFreeTruncation = [&](Value *Op) { |
149 | if (auto *Cast = dyn_cast<CastInst>(Val: Op)) { |
150 | if (Cast->getParent() == Mul->getParent() && |
151 | (Cast->getOpcode() == Instruction::SExt || |
152 | Cast->getOpcode() == Instruction::ZExt) && |
153 | Cast->getOperand(i_nocapture: 0)->getType()->getScalarSizeInBits() <= 16) |
154 | return true; |
155 | } |
156 | |
157 | return isa<Constant>(Val: Op); |
158 | }; |
159 | |
160 | // If the operation can be freely truncated and has enough sign bits we |
161 | // can shrink. |
162 | if (IsFreeTruncation(Op) && |
163 | ComputeNumSignBits(Op, DL: *DL, Depth: 0, AC: nullptr, CxtI: Mul) > 16) |
164 | return true; |
165 | |
166 | // SelectionDAG has limited support for truncating through an add or sub if |
167 | // the inputs are freely truncatable. |
168 | if (auto *BO = dyn_cast<BinaryOperator>(Val: Op)) { |
169 | if (BO->getParent() == Mul->getParent() && |
170 | IsFreeTruncation(BO->getOperand(i_nocapture: 0)) && |
171 | IsFreeTruncation(BO->getOperand(i_nocapture: 1)) && |
172 | ComputeNumSignBits(Op, DL: *DL, Depth: 0, AC: nullptr, CxtI: Mul) > 16) |
173 | return true; |
174 | } |
175 | |
176 | return false; |
177 | }; |
178 | |
179 | // Both Ops need to be shrinkable. |
180 | if (!CanShrinkOp(LHS) && !CanShrinkOp(RHS)) |
181 | return false; |
182 | |
183 | IRBuilder<> Builder(Mul); |
184 | |
185 | auto *MulTy = cast<FixedVectorType>(Val: Op->getType()); |
186 | unsigned NumElts = MulTy->getNumElements(); |
187 | |
188 | // Extract even elements and odd elements and add them together. This will |
189 | // be pattern matched by SelectionDAG to pmaddwd. This instruction will be |
190 | // half the original width. |
191 | SmallVector<int, 16> EvenMask(NumElts / 2); |
192 | SmallVector<int, 16> OddMask(NumElts / 2); |
193 | for (int i = 0, e = NumElts / 2; i != e; ++i) { |
194 | EvenMask[i] = i * 2; |
195 | OddMask[i] = i * 2 + 1; |
196 | } |
197 | // Creating a new mul so the replaceAllUsesWith below doesn't replace the |
198 | // uses in the shuffles we're creating. |
199 | Value *NewMul = Builder.CreateMul(LHS: Mul->getOperand(i_nocapture: 0), RHS: Mul->getOperand(i_nocapture: 1)); |
200 | Value *EvenElts = Builder.CreateShuffleVector(V1: NewMul, V2: NewMul, Mask: EvenMask); |
201 | Value *OddElts = Builder.CreateShuffleVector(V1: NewMul, V2: NewMul, Mask: OddMask); |
202 | Value *MAdd = Builder.CreateAdd(LHS: EvenElts, RHS: OddElts); |
203 | |
204 | // Concatenate zeroes to extend back to the original type. |
205 | SmallVector<int, 32> ConcatMask(NumElts); |
206 | std::iota(first: ConcatMask.begin(), last: ConcatMask.end(), value: 0); |
207 | Value *Zero = Constant::getNullValue(Ty: MAdd->getType()); |
208 | Value *Concat = Builder.CreateShuffleVector(V1: MAdd, V2: Zero, Mask: ConcatMask); |
209 | |
210 | Mul->replaceAllUsesWith(V: Concat); |
211 | Mul->eraseFromParent(); |
212 | |
213 | return true; |
214 | } |
215 | |
216 | bool X86PartialReduction::trySADReplacement(Instruction *Op) { |
217 | if (!ST->hasSSE2()) |
218 | return false; |
219 | |
220 | // TODO: There's nothing special about i32, any integer type above i16 should |
221 | // work just as well. |
222 | if (!cast<VectorType>(Val: Op->getType())->getElementType()->isIntegerTy(Bitwidth: 32)) |
223 | return false; |
224 | |
225 | Value *LHS; |
226 | if (match(V: Op, P: PatternMatch::m_Intrinsic<Intrinsic::abs>())) { |
227 | LHS = Op->getOperand(i: 0); |
228 | } else { |
229 | // Operand should be a select. |
230 | auto *SI = dyn_cast<SelectInst>(Val: Op); |
231 | if (!SI) |
232 | return false; |
233 | |
234 | Value *RHS; |
235 | // Select needs to implement absolute value. |
236 | auto SPR = matchSelectPattern(V: SI, LHS, RHS); |
237 | if (SPR.Flavor != SPF_ABS) |
238 | return false; |
239 | } |
240 | |
241 | // Need a subtract of two values. |
242 | auto *Sub = dyn_cast<BinaryOperator>(Val: LHS); |
243 | if (!Sub || Sub->getOpcode() != Instruction::Sub) |
244 | return false; |
245 | |
246 | // Look for zero extend from i8. |
247 | auto getZeroExtendedVal = [](Value *Op) -> Value * { |
248 | if (auto *ZExt = dyn_cast<ZExtInst>(Val: Op)) |
249 | if (cast<VectorType>(Val: ZExt->getOperand(i_nocapture: 0)->getType()) |
250 | ->getElementType() |
251 | ->isIntegerTy(Bitwidth: 8)) |
252 | return ZExt->getOperand(i_nocapture: 0); |
253 | |
254 | return nullptr; |
255 | }; |
256 | |
257 | // Both operands of the subtract should be extends from vXi8. |
258 | Value *Op0 = getZeroExtendedVal(Sub->getOperand(i_nocapture: 0)); |
259 | Value *Op1 = getZeroExtendedVal(Sub->getOperand(i_nocapture: 1)); |
260 | if (!Op0 || !Op1) |
261 | return false; |
262 | |
263 | IRBuilder<> Builder(Op); |
264 | |
265 | auto *OpTy = cast<FixedVectorType>(Val: Op->getType()); |
266 | unsigned NumElts = OpTy->getNumElements(); |
267 | |
268 | unsigned IntrinsicNumElts; |
269 | Intrinsic::ID IID; |
270 | if (ST->hasBWI() && NumElts >= 64) { |
271 | IID = Intrinsic::x86_avx512_psad_bw_512; |
272 | IntrinsicNumElts = 64; |
273 | } else if (ST->hasAVX2() && NumElts >= 32) { |
274 | IID = Intrinsic::x86_avx2_psad_bw; |
275 | IntrinsicNumElts = 32; |
276 | } else { |
277 | IID = Intrinsic::x86_sse2_psad_bw; |
278 | IntrinsicNumElts = 16; |
279 | } |
280 | |
281 | Function *PSADBWFn = Intrinsic::getDeclaration(M: Op->getModule(), id: IID); |
282 | |
283 | if (NumElts < 16) { |
284 | // Pad input with zeroes. |
285 | SmallVector<int, 32> ConcatMask(16); |
286 | for (unsigned i = 0; i != NumElts; ++i) |
287 | ConcatMask[i] = i; |
288 | for (unsigned i = NumElts; i != 16; ++i) |
289 | ConcatMask[i] = (i % NumElts) + NumElts; |
290 | |
291 | Value *Zero = Constant::getNullValue(Ty: Op0->getType()); |
292 | Op0 = Builder.CreateShuffleVector(V1: Op0, V2: Zero, Mask: ConcatMask); |
293 | Op1 = Builder.CreateShuffleVector(V1: Op1, V2: Zero, Mask: ConcatMask); |
294 | NumElts = 16; |
295 | } |
296 | |
297 | // Intrinsics produce vXi64 and need to be casted to vXi32. |
298 | auto *I32Ty = |
299 | FixedVectorType::get(ElementType: Builder.getInt32Ty(), NumElts: IntrinsicNumElts / 4); |
300 | |
301 | assert(NumElts % IntrinsicNumElts == 0 && "Unexpected number of elements!" ); |
302 | unsigned NumSplits = NumElts / IntrinsicNumElts; |
303 | |
304 | // First collect the pieces we need. |
305 | SmallVector<Value *, 4> Ops(NumSplits); |
306 | for (unsigned i = 0; i != NumSplits; ++i) { |
307 | SmallVector<int, 64> (IntrinsicNumElts); |
308 | std::iota(first: ExtractMask.begin(), last: ExtractMask.end(), value: i * IntrinsicNumElts); |
309 | Value * = Builder.CreateShuffleVector(V1: Op0, V2: Op0, Mask: ExtractMask); |
310 | Value * = Builder.CreateShuffleVector(V1: Op1, V2: Op0, Mask: ExtractMask); |
311 | Ops[i] = Builder.CreateCall(Callee: PSADBWFn, Args: {ExtractOp0, ExtractOp1}); |
312 | Ops[i] = Builder.CreateBitCast(V: Ops[i], DestTy: I32Ty); |
313 | } |
314 | |
315 | assert(isPowerOf2_32(NumSplits) && "Expected power of 2 splits" ); |
316 | unsigned Stages = Log2_32(Value: NumSplits); |
317 | for (unsigned s = Stages; s > 0; --s) { |
318 | unsigned NumConcatElts = |
319 | cast<FixedVectorType>(Val: Ops[0]->getType())->getNumElements() * 2; |
320 | for (unsigned i = 0; i != 1U << (s - 1); ++i) { |
321 | SmallVector<int, 64> ConcatMask(NumConcatElts); |
322 | std::iota(first: ConcatMask.begin(), last: ConcatMask.end(), value: 0); |
323 | Ops[i] = Builder.CreateShuffleVector(V1: Ops[i*2], V2: Ops[i*2+1], Mask: ConcatMask); |
324 | } |
325 | } |
326 | |
327 | // At this point the final value should be in Ops[0]. Now we need to adjust |
328 | // it to the final original type. |
329 | NumElts = cast<FixedVectorType>(Val: OpTy)->getNumElements(); |
330 | if (NumElts == 2) { |
331 | // Extract down to 2 elements. |
332 | Ops[0] = Builder.CreateShuffleVector(V1: Ops[0], V2: Ops[0], Mask: ArrayRef<int>{0, 1}); |
333 | } else if (NumElts >= 8) { |
334 | SmallVector<int, 32> ConcatMask(NumElts); |
335 | unsigned SubElts = |
336 | cast<FixedVectorType>(Val: Ops[0]->getType())->getNumElements(); |
337 | for (unsigned i = 0; i != SubElts; ++i) |
338 | ConcatMask[i] = i; |
339 | for (unsigned i = SubElts; i != NumElts; ++i) |
340 | ConcatMask[i] = (i % SubElts) + SubElts; |
341 | |
342 | Value *Zero = Constant::getNullValue(Ty: Ops[0]->getType()); |
343 | Ops[0] = Builder.CreateShuffleVector(V1: Ops[0], V2: Zero, Mask: ConcatMask); |
344 | } |
345 | |
346 | Op->replaceAllUsesWith(V: Ops[0]); |
347 | Op->eraseFromParent(); |
348 | |
349 | return true; |
350 | } |
351 | |
352 | // Walk backwards from the ExtractElementInst and determine if it is the end of |
353 | // a horizontal reduction. Return the input to the reduction if we find one. |
354 | static Value *(const ExtractElementInst &EE, |
355 | bool &ReduceInOneBB) { |
356 | ReduceInOneBB = true; |
357 | // Make sure we're extracting index 0. |
358 | auto *Index = dyn_cast<ConstantInt>(Val: EE.getIndexOperand()); |
359 | if (!Index || !Index->isNullValue()) |
360 | return nullptr; |
361 | |
362 | const auto *BO = dyn_cast<BinaryOperator>(Val: EE.getVectorOperand()); |
363 | if (!BO || BO->getOpcode() != Instruction::Add || !BO->hasOneUse()) |
364 | return nullptr; |
365 | if (EE.getParent() != BO->getParent()) |
366 | ReduceInOneBB = false; |
367 | |
368 | unsigned NumElems = cast<FixedVectorType>(Val: BO->getType())->getNumElements(); |
369 | // Ensure the reduction size is a power of 2. |
370 | if (!isPowerOf2_32(Value: NumElems)) |
371 | return nullptr; |
372 | |
373 | const Value *Op = BO; |
374 | unsigned Stages = Log2_32(Value: NumElems); |
375 | for (unsigned i = 0; i != Stages; ++i) { |
376 | const auto *BO = dyn_cast<BinaryOperator>(Val: Op); |
377 | if (!BO || BO->getOpcode() != Instruction::Add) |
378 | return nullptr; |
379 | if (EE.getParent() != BO->getParent()) |
380 | ReduceInOneBB = false; |
381 | |
382 | // If this isn't the first add, then it should only have 2 users, the |
383 | // shuffle and another add which we checked in the previous iteration. |
384 | if (i != 0 && !BO->hasNUses(N: 2)) |
385 | return nullptr; |
386 | |
387 | Value *LHS = BO->getOperand(i_nocapture: 0); |
388 | Value *RHS = BO->getOperand(i_nocapture: 1); |
389 | |
390 | auto *Shuffle = dyn_cast<ShuffleVectorInst>(Val: LHS); |
391 | if (Shuffle) { |
392 | Op = RHS; |
393 | } else { |
394 | Shuffle = dyn_cast<ShuffleVectorInst>(Val: RHS); |
395 | Op = LHS; |
396 | } |
397 | |
398 | // The first operand of the shuffle should be the same as the other operand |
399 | // of the bin op. |
400 | if (!Shuffle || Shuffle->getOperand(i_nocapture: 0) != Op) |
401 | return nullptr; |
402 | |
403 | // Verify the shuffle has the expected (at this stage of the pyramid) mask. |
404 | unsigned MaskEnd = 1 << i; |
405 | for (unsigned Index = 0; Index < MaskEnd; ++Index) |
406 | if (Shuffle->getMaskValue(Elt: Index) != (int)(MaskEnd + Index)) |
407 | return nullptr; |
408 | } |
409 | |
410 | return const_cast<Value *>(Op); |
411 | } |
412 | |
413 | // See if this BO is reachable from this Phi by walking forward through single |
414 | // use BinaryOperators with the same opcode. If we get back then we know we've |
415 | // found a loop and it is safe to step through this Add to find more leaves. |
416 | static bool isReachableFromPHI(PHINode *Phi, BinaryOperator *BO) { |
417 | // The PHI itself should only have one use. |
418 | if (!Phi->hasOneUse()) |
419 | return false; |
420 | |
421 | Instruction *U = cast<Instruction>(Val: *Phi->user_begin()); |
422 | if (U == BO) |
423 | return true; |
424 | |
425 | while (U->hasOneUse() && U->getOpcode() == BO->getOpcode()) |
426 | U = cast<Instruction>(Val: *U->user_begin()); |
427 | |
428 | return U == BO; |
429 | } |
430 | |
431 | // Collect all the leaves of the tree of adds that feeds into the horizontal |
432 | // reduction. Root is the Value that is used by the horizontal reduction. |
433 | // We look through single use phis, single use adds, or adds that are used by |
434 | // a phi that forms a loop with the add. |
435 | static void collectLeaves(Value *Root, SmallVectorImpl<Instruction *> &Leaves) { |
436 | SmallPtrSet<Value *, 8> Visited; |
437 | SmallVector<Value *, 8> Worklist; |
438 | Worklist.push_back(Elt: Root); |
439 | |
440 | while (!Worklist.empty()) { |
441 | Value *V = Worklist.pop_back_val(); |
442 | if (!Visited.insert(Ptr: V).second) |
443 | continue; |
444 | |
445 | if (auto *PN = dyn_cast<PHINode>(Val: V)) { |
446 | // PHI node should have single use unless it is the root node, then it |
447 | // has 2 uses. |
448 | if (!PN->hasNUses(N: PN == Root ? 2 : 1)) |
449 | break; |
450 | |
451 | // Push incoming values to the worklist. |
452 | append_range(C&: Worklist, R: PN->incoming_values()); |
453 | |
454 | continue; |
455 | } |
456 | |
457 | if (auto *BO = dyn_cast<BinaryOperator>(Val: V)) { |
458 | if (BO->getOpcode() == Instruction::Add) { |
459 | // Simple case. Single use, just push its operands to the worklist. |
460 | if (BO->hasNUses(N: BO == Root ? 2 : 1)) { |
461 | append_range(C&: Worklist, R: BO->operands()); |
462 | continue; |
463 | } |
464 | |
465 | // If there is additional use, make sure it is an unvisited phi that |
466 | // gets us back to this node. |
467 | if (BO->hasNUses(N: BO == Root ? 3 : 2)) { |
468 | PHINode *PN = nullptr; |
469 | for (auto *U : BO->users()) |
470 | if (auto *P = dyn_cast<PHINode>(Val: U)) |
471 | if (!Visited.count(Ptr: P)) |
472 | PN = P; |
473 | |
474 | // If we didn't find a 2-input PHI then this isn't a case we can |
475 | // handle. |
476 | if (!PN || PN->getNumIncomingValues() != 2) |
477 | continue; |
478 | |
479 | // Walk forward from this phi to see if it reaches back to this add. |
480 | if (!isReachableFromPHI(Phi: PN, BO)) |
481 | continue; |
482 | |
483 | // The phi forms a loop with this Add, push its operands. |
484 | append_range(C&: Worklist, R: BO->operands()); |
485 | } |
486 | } |
487 | } |
488 | |
489 | // Not an add or phi, make it a leaf. |
490 | if (auto *I = dyn_cast<Instruction>(Val: V)) { |
491 | if (!V->hasNUses(N: I == Root ? 2 : 1)) |
492 | continue; |
493 | |
494 | // Add this as a leaf. |
495 | Leaves.push_back(Elt: I); |
496 | } |
497 | } |
498 | } |
499 | |
500 | bool X86PartialReduction::runOnFunction(Function &F) { |
501 | if (skipFunction(F)) |
502 | return false; |
503 | |
504 | auto *TPC = getAnalysisIfAvailable<TargetPassConfig>(); |
505 | if (!TPC) |
506 | return false; |
507 | |
508 | auto &TM = TPC->getTM<X86TargetMachine>(); |
509 | ST = TM.getSubtargetImpl(F); |
510 | |
511 | DL = &F.getDataLayout(); |
512 | |
513 | bool MadeChange = false; |
514 | for (auto &BB : F) { |
515 | for (auto &I : BB) { |
516 | auto *EE = dyn_cast<ExtractElementInst>(Val: &I); |
517 | if (!EE) |
518 | continue; |
519 | |
520 | bool ReduceInOneBB; |
521 | // First find a reduction tree. |
522 | // FIXME: Do we need to handle other opcodes than Add? |
523 | Value *Root = matchAddReduction(EE: *EE, ReduceInOneBB); |
524 | if (!Root) |
525 | continue; |
526 | |
527 | SmallVector<Instruction *, 8> Leaves; |
528 | collectLeaves(Root, Leaves); |
529 | |
530 | for (Instruction *I : Leaves) { |
531 | if (tryMAddReplacement(Op: I, ReduceInOneBB)) { |
532 | MadeChange = true; |
533 | continue; |
534 | } |
535 | |
536 | // Don't do SAD matching on the root node. SelectionDAG already |
537 | // has support for that and currently generates better code. |
538 | if (I != Root && trySADReplacement(Op: I)) |
539 | MadeChange = true; |
540 | } |
541 | } |
542 | } |
543 | |
544 | return MadeChange; |
545 | } |
546 | |