1//===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Identification:
10// This step is responsible for finding the patterns that can be lowered to
11// complex instructions, and building a graph to represent the complex
12// structures. Starting from the "Converging Shuffle" (a shuffle that
13// reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14// operands are evaluated and identified as "Composite Nodes" (collections of
15// instructions that can potentially be lowered to a single complex
16// instruction). This is performed by checking the real and imaginary components
17// and tracking the data flow for each component while following the operand
18// pairs. Validity of each node is expected to be done upon creation, and any
19// validation errors should halt traversal and prevent further graph
20// construction.
21// Instead of relying on Shuffle operations, vector interleaving and
22// deinterleaving can be represented by vector.interleave2 and
23// vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24// these intrinsics, whereas, fixed-width vectors are recognized for both
25// shufflevector instruction and intrinsics.
26//
27// Replacement:
28// This step traverses the graph built up by identification, delegating to the
29// target to validate and generate the correct intrinsics, and plumbs them
30// together connecting each end of the new intrinsics graph to the existing
31// use-def chain. This step is assumed to finish successfully, as all
32// information is expected to be correct by this point.
33//
34//
35// Internal data structure:
36// ComplexDeinterleavingGraph:
37// Keeps references to all the valid CompositeNodes formed as part of the
38// transformation, and every Instruction contained within said nodes. It also
39// holds onto a reference to the root Instruction, and the root node that should
40// replace it.
41//
42// ComplexDeinterleavingCompositeNode:
43// A CompositeNode represents a single transformation point; each node should
44// transform into a single complex instruction (ignoring vector splitting, which
45// would generate more instructions per node). They are identified in a
46// depth-first manner, traversing and identifying the operands of each
47// instruction in the order they appear in the IR.
48// Each node maintains a reference to its Real and Imaginary instructions,
49// as well as any additional instructions that make up the identified operation
50// (Internal instructions should only have uses within their containing node).
51// A Node also contains the rotation and operation type that it represents.
52// Operands contains pointers to other CompositeNodes, acting as the edges in
53// the graph. ReplacementValue is the transformed Value* that has been emitted
54// to the IR.
55//
56// Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57// ReplacementValue fields of that Node are relevant, where the ReplacementValue
58// should be pre-populated.
59//
60//===----------------------------------------------------------------------===//
61
62#include "llvm/CodeGen/ComplexDeinterleavingPass.h"
63#include "llvm/ADT/MapVector.h"
64#include "llvm/ADT/Statistic.h"
65#include "llvm/Analysis/TargetLibraryInfo.h"
66#include "llvm/Analysis/TargetTransformInfo.h"
67#include "llvm/CodeGen/TargetLowering.h"
68#include "llvm/CodeGen/TargetSubtargetInfo.h"
69#include "llvm/IR/IRBuilder.h"
70#include "llvm/IR/PatternMatch.h"
71#include "llvm/InitializePasses.h"
72#include "llvm/Target/TargetMachine.h"
73#include "llvm/Transforms/Utils/Local.h"
74#include <algorithm>
75
76using namespace llvm;
77using namespace PatternMatch;
78
79#define DEBUG_TYPE "complex-deinterleaving"
80
81STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
82
83static cl::opt<bool> ComplexDeinterleavingEnabled(
84 "enable-complex-deinterleaving",
85 cl::desc("Enable generation of complex instructions"), cl::init(Val: true),
86 cl::Hidden);
87
88/// Checks the given mask, and determines whether said mask is interleaving.
89///
90/// To be interleaving, a mask must alternate between `i` and `i + (Length /
91/// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
92/// 4x vector interleaving mask would be <0, 2, 1, 3>).
93static bool isInterleavingMask(ArrayRef<int> Mask);
94
95/// Checks the given mask, and determines whether said mask is deinterleaving.
96///
97/// To be deinterleaving, a mask must increment in steps of 2, and either start
98/// with 0 or 1.
99/// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
100/// <1, 3, 5, 7>).
101static bool isDeinterleavingMask(ArrayRef<int> Mask);
102
103/// Returns true if the operation is a negation of V, and it works for both
104/// integers and floats.
105static bool isNeg(Value *V);
106
107/// Returns the operand for negation operation.
108static Value *getNegOperand(Value *V);
109
110namespace {
111template <typename T, typename IterT>
112std::optional<T> findCommonBetweenCollections(IterT A, IterT B) {
113 auto Common = llvm::find_if(A, [B](T I) { return llvm::is_contained(B, I); });
114 if (Common != A.end())
115 return std::make_optional(*Common);
116 return std::nullopt;
117}
118
119class ComplexDeinterleavingLegacyPass : public FunctionPass {
120public:
121 static char ID;
122
123 ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
124 : FunctionPass(ID), TM(TM) {
125 initializeComplexDeinterleavingLegacyPassPass(
126 *PassRegistry::getPassRegistry());
127 }
128
129 StringRef getPassName() const override {
130 return "Complex Deinterleaving Pass";
131 }
132
133 bool runOnFunction(Function &F) override;
134 void getAnalysisUsage(AnalysisUsage &AU) const override {
135 AU.addRequired<TargetLibraryInfoWrapperPass>();
136 AU.setPreservesCFG();
137 }
138
139private:
140 const TargetMachine *TM;
141};
142
143class ComplexDeinterleavingGraph;
144struct ComplexDeinterleavingCompositeNode {
145
146 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
147 Value *R, Value *I)
148 : Operation(Op), Real(R), Imag(I) {}
149
150private:
151 friend class ComplexDeinterleavingGraph;
152 using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
153 using RawNodePtr = ComplexDeinterleavingCompositeNode *;
154 bool OperandsValid = true;
155
156public:
157 ComplexDeinterleavingOperation Operation;
158 Value *Real;
159 Value *Imag;
160
161 // This two members are required exclusively for generating
162 // ComplexDeinterleavingOperation::Symmetric operations.
163 unsigned Opcode;
164 std::optional<FastMathFlags> Flags;
165
166 ComplexDeinterleavingRotation Rotation =
167 ComplexDeinterleavingRotation::Rotation_0;
168 SmallVector<RawNodePtr> Operands;
169 Value *ReplacementNode = nullptr;
170
171 void addOperand(NodePtr Node) {
172 if (!Node || !Node.get())
173 OperandsValid = false;
174 Operands.push_back(Elt: Node.get());
175 }
176
177 void dump() { dump(OS&: dbgs()); }
178 void dump(raw_ostream &OS) {
179 auto PrintValue = [&](Value *V) {
180 if (V) {
181 OS << "\"";
182 V->print(O&: OS, IsForDebug: true);
183 OS << "\"\n";
184 } else
185 OS << "nullptr\n";
186 };
187 auto PrintNodeRef = [&](RawNodePtr Ptr) {
188 if (Ptr)
189 OS << Ptr << "\n";
190 else
191 OS << "nullptr\n";
192 };
193
194 OS << "- CompositeNode: " << this << "\n";
195 OS << " Real: ";
196 PrintValue(Real);
197 OS << " Imag: ";
198 PrintValue(Imag);
199 OS << " ReplacementNode: ";
200 PrintValue(ReplacementNode);
201 OS << " Operation: " << (int)Operation << "\n";
202 OS << " Rotation: " << ((int)Rotation * 90) << "\n";
203 OS << " Operands: \n";
204 for (const auto &Op : Operands) {
205 OS << " - ";
206 PrintNodeRef(Op);
207 }
208 }
209
210 bool areOperandsValid() { return OperandsValid; }
211};
212
213class ComplexDeinterleavingGraph {
214public:
215 struct Product {
216 Value *Multiplier;
217 Value *Multiplicand;
218 bool IsPositive;
219 };
220
221 using Addend = std::pair<Value *, bool>;
222 using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
223 using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
224
225 // Helper struct for holding info about potential partial multiplication
226 // candidates
227 struct PartialMulCandidate {
228 Value *Common;
229 NodePtr Node;
230 unsigned RealIdx;
231 unsigned ImagIdx;
232 bool IsNodeInverted;
233 };
234
235 explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
236 const TargetLibraryInfo *TLI)
237 : TL(TL), TLI(TLI) {}
238
239private:
240 const TargetLowering *TL = nullptr;
241 const TargetLibraryInfo *TLI = nullptr;
242 SmallVector<NodePtr> CompositeNodes;
243 DenseMap<std::pair<Value *, Value *>, NodePtr> CachedResult;
244
245 SmallPtrSet<Instruction *, 16> FinalInstructions;
246
247 /// Root instructions are instructions from which complex computation starts
248 std::map<Instruction *, NodePtr> RootToNode;
249
250 /// Topologically sorted root instructions
251 SmallVector<Instruction *, 1> OrderedRoots;
252
253 /// When examining a basic block for complex deinterleaving, if it is a simple
254 /// one-block loop, then the only incoming block is 'Incoming' and the
255 /// 'BackEdge' block is the block itself."
256 BasicBlock *BackEdge = nullptr;
257 BasicBlock *Incoming = nullptr;
258
259 /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
260 /// %OutsideUser as it is shown in the IR:
261 ///
262 /// vector.body:
263 /// %PHInode = phi <vector type> [ zeroinitializer, %entry ],
264 /// [ %ReductionOp, %vector.body ]
265 /// ...
266 /// %ReductionOp = fadd i64 ...
267 /// ...
268 /// br i1 %condition, label %vector.body, %middle.block
269 ///
270 /// middle.block:
271 /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
272 ///
273 /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
274 /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
275 MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
276
277 /// In the process of detecting a reduction, we consider a pair of
278 /// %ReductionOP, which we refer to as real and imag (or vice versa), and
279 /// traverse the use-tree to detect complex operations. As this is a reduction
280 /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
281 /// to the %ReductionOPs that we suspect to be complex.
282 /// RealPHI and ImagPHI are used by the identifyPHINode method.
283 PHINode *RealPHI = nullptr;
284 PHINode *ImagPHI = nullptr;
285
286 /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
287 /// detection.
288 bool PHIsFound = false;
289
290 /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
291 /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
292 /// This mapping is populated during
293 /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
294 /// used in the ComplexDeinterleavingOperation::ReductionOperation node
295 /// replacement process.
296 std::map<PHINode *, PHINode *> OldToNewPHI;
297
298 NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
299 Value *R, Value *I) {
300 assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
301 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
302 (R && I)) &&
303 "Reduction related nodes must have Real and Imaginary parts");
304 return std::make_shared<ComplexDeinterleavingCompositeNode>(args&: Operation, args&: R,
305 args&: I);
306 }
307
308 NodePtr submitCompositeNode(NodePtr Node) {
309 CompositeNodes.push_back(Elt: Node);
310 if (Node->Real)
311 CachedResult[{Node->Real, Node->Imag}] = Node;
312 return Node;
313 }
314
315 /// Identifies a complex partial multiply pattern and its rotation, based on
316 /// the following patterns
317 ///
318 /// 0: r: cr + ar * br
319 /// i: ci + ar * bi
320 /// 90: r: cr - ai * bi
321 /// i: ci + ai * br
322 /// 180: r: cr - ar * br
323 /// i: ci - ar * bi
324 /// 270: r: cr + ai * bi
325 /// i: ci - ai * br
326 NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);
327
328 /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
329 /// is partially known from identifyPartialMul, filling in the other half of
330 /// the complex pair.
331 NodePtr
332 identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
333 std::pair<Value *, Value *> &CommonOperandI);
334
335 /// Identifies a complex add pattern and its rotation, based on the following
336 /// patterns.
337 ///
338 /// 90: r: ar - bi
339 /// i: ai + br
340 /// 270: r: ar + bi
341 /// i: ai - br
342 NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
343 NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
344 NodePtr identifyPartialReduction(Value *R, Value *I);
345 NodePtr identifyDotProduct(Value *Inst);
346
347 NodePtr identifyNode(Value *R, Value *I);
348
349 /// Determine if a sum of complex numbers can be formed from \p RealAddends
350 /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
351 /// Return nullptr if it is not possible to construct a complex number.
352 /// \p Flags are needed to generate symmetric Add and Sub operations.
353 NodePtr identifyAdditions(std::list<Addend> &RealAddends,
354 std::list<Addend> &ImagAddends,
355 std::optional<FastMathFlags> Flags,
356 NodePtr Accumulator);
357
358 /// Extract one addend that have both real and imaginary parts positive.
359 NodePtr extractPositiveAddend(std::list<Addend> &RealAddends,
360 std::list<Addend> &ImagAddends);
361
362 /// Determine if sum of multiplications of complex numbers can be formed from
363 /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
364 /// to it. Return nullptr if it is not possible to construct a complex number.
365 NodePtr identifyMultiplications(std::vector<Product> &RealMuls,
366 std::vector<Product> &ImagMuls,
367 NodePtr Accumulator);
368
369 /// Go through pairs of multiplication (one Real and one Imag) and find all
370 /// possible candidates for partial multiplication and put them into \p
371 /// Candidates. Returns true if all Product has pair with common operand
372 bool collectPartialMuls(const std::vector<Product> &RealMuls,
373 const std::vector<Product> &ImagMuls,
374 std::vector<PartialMulCandidate> &Candidates);
375
376 /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
377 /// the order of complex computation operations may be significantly altered,
378 /// and the real and imaginary parts may not be executed in parallel. This
379 /// function takes this into consideration and employs a more general approach
380 /// to identify complex computations. Initially, it gathers all the addends
381 /// and multiplicands and then constructs a complex expression from them.
382 NodePtr identifyReassocNodes(Instruction *I, Instruction *J);
383
384 NodePtr identifyRoot(Instruction *I);
385
386 /// Identifies the Deinterleave operation applied to a vector containing
387 /// complex numbers. There are two ways to represent the Deinterleave
388 /// operation:
389 /// * Using two shufflevectors with even indices for /pReal instruction and
390 /// odd indices for /pImag instructions (only for fixed-width vectors)
391 /// * Using two extractvalue instructions applied to `vector.deinterleave2`
392 /// intrinsic (for both fixed and scalable vectors)
393 NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag);
394
395 /// identifying the operation that represents a complex number repeated in a
396 /// Splat vector. There are two possible types of splats: ConstantExpr with
397 /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
398 /// initialization mask with all values set to zero.
399 NodePtr identifySplat(Value *Real, Value *Imag);
400
401 NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);
402
403 /// Identifies SelectInsts in a loop that has reduction with predication masks
404 /// and/or predicated tail folding
405 NodePtr identifySelectNode(Instruction *Real, Instruction *Imag);
406
407 Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);
408
409 /// Complete IR modifications after producing new reduction operation:
410 /// * Populate the PHINode generated for
411 /// ComplexDeinterleavingOperation::ReductionPHI
412 /// * Deinterleave the final value outside of the loop and repurpose original
413 /// reduction users
414 void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);
415 void processReductionSingle(Value *OperationReplacement, RawNodePtr Node);
416
417public:
418 void dump() { dump(OS&: dbgs()); }
419 void dump(raw_ostream &OS) {
420 for (const auto &Node : CompositeNodes)
421 Node->dump(OS);
422 }
423
424 /// Returns false if the deinterleaving operation should be cancelled for the
425 /// current graph.
426 bool identifyNodes(Instruction *RootI);
427
428 /// In case \pB is one-block loop, this function seeks potential reductions
429 /// and populates ReductionInfo. Returns true if any reductions were
430 /// identified.
431 bool collectPotentialReductions(BasicBlock *B);
432
433 void identifyReductionNodes();
434
435 /// Check that every instruction, from the roots to the leaves, has internal
436 /// uses.
437 bool checkNodes();
438
439 /// Perform the actual replacement of the underlying instruction graph.
440 void replaceNodes();
441};
442
443class ComplexDeinterleaving {
444public:
445 ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
446 : TL(tl), TLI(tli) {}
447 bool runOnFunction(Function &F);
448
449private:
450 bool evaluateBasicBlock(BasicBlock *B);
451
452 const TargetLowering *TL = nullptr;
453 const TargetLibraryInfo *TLI = nullptr;
454};
455
456} // namespace
457
458char ComplexDeinterleavingLegacyPass::ID = 0;
459
460INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
461 "Complex Deinterleaving", false, false)
462INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
463 "Complex Deinterleaving", false, false)
464
465PreservedAnalyses ComplexDeinterleavingPass::run(Function &F,
466 FunctionAnalysisManager &AM) {
467 const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
468 auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(IR&: F);
469 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
470 return PreservedAnalyses::all();
471
472 PreservedAnalyses PA;
473 PA.preserve<FunctionAnalysisManagerModuleProxy>();
474 return PA;
475}
476
477FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) {
478 return new ComplexDeinterleavingLegacyPass(TM);
479}
480
481bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
482 const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
483 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
484 return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
485}
486
487bool ComplexDeinterleaving::runOnFunction(Function &F) {
488 if (!ComplexDeinterleavingEnabled) {
489 LLVM_DEBUG(
490 dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
491 return false;
492 }
493
494 if (!TL->isComplexDeinterleavingSupported()) {
495 LLVM_DEBUG(
496 dbgs() << "Complex deinterleaving has been disabled, target does "
497 "not support lowering of complex number operations.\n");
498 return false;
499 }
500
501 bool Changed = false;
502 for (auto &B : F)
503 Changed |= evaluateBasicBlock(B: &B);
504
505 return Changed;
506}
507
508static bool isInterleavingMask(ArrayRef<int> Mask) {
509 // If the size is not even, it's not an interleaving mask
510 if ((Mask.size() & 1))
511 return false;
512
513 int HalfNumElements = Mask.size() / 2;
514 for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
515 int MaskIdx = Idx * 2;
516 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
517 return false;
518 }
519
520 return true;
521}
522
523static bool isDeinterleavingMask(ArrayRef<int> Mask) {
524 int Offset = Mask[0];
525 int HalfNumElements = Mask.size() / 2;
526
527 for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
528 if (Mask[Idx] != (Idx * 2) + Offset)
529 return false;
530 }
531
532 return true;
533}
534
535bool isNeg(Value *V) {
536 return match(V, P: m_FNeg(X: m_Value())) || match(V, P: m_Neg(V: m_Value()));
537}
538
539Value *getNegOperand(Value *V) {
540 assert(isNeg(V));
541 auto *I = cast<Instruction>(Val: V);
542 if (I->getOpcode() == Instruction::FNeg)
543 return I->getOperand(i: 0);
544
545 return I->getOperand(i: 1);
546}
547
548bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
549 ComplexDeinterleavingGraph Graph(TL, TLI);
550 if (Graph.collectPotentialReductions(B))
551 Graph.identifyReductionNodes();
552
553 for (auto &I : *B)
554 Graph.identifyNodes(RootI: &I);
555
556 if (Graph.checkNodes()) {
557 Graph.replaceNodes();
558 return true;
559 }
560
561 return false;
562}
563
564ComplexDeinterleavingGraph::NodePtr
565ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
566 Instruction *Real, Instruction *Imag,
567 std::pair<Value *, Value *> &PartialMatch) {
568 LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
569 << "\n");
570
571 if (!Real->hasOneUse() || !Imag->hasOneUse()) {
572 LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n");
573 return nullptr;
574 }
575
576 if ((Real->getOpcode() != Instruction::FMul &&
577 Real->getOpcode() != Instruction::Mul) ||
578 (Imag->getOpcode() != Instruction::FMul &&
579 Imag->getOpcode() != Instruction::Mul)) {
580 LLVM_DEBUG(
581 dbgs() << " - Real or imaginary instruction is not fmul or mul\n");
582 return nullptr;
583 }
584
585 Value *R0 = Real->getOperand(i: 0);
586 Value *R1 = Real->getOperand(i: 1);
587 Value *I0 = Imag->getOperand(i: 0);
588 Value *I1 = Imag->getOperand(i: 1);
589
590 // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
591 // rotations and use the operand.
592 unsigned Negs = 0;
593 Value *Op;
594 if (match(V: R0, P: m_Neg(V: m_Value(V&: Op)))) {
595 Negs |= 1;
596 R0 = Op;
597 } else if (match(V: R1, P: m_Neg(V: m_Value(V&: Op)))) {
598 Negs |= 1;
599 R1 = Op;
600 }
601
602 if (isNeg(V: I0)) {
603 Negs |= 2;
604 Negs ^= 1;
605 I0 = Op;
606 } else if (match(V: I1, P: m_Neg(V: m_Value(V&: Op)))) {
607 Negs |= 2;
608 Negs ^= 1;
609 I1 = Op;
610 }
611
612 ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;
613
614 Value *CommonOperand;
615 Value *UncommonRealOp;
616 Value *UncommonImagOp;
617
618 if (R0 == I0 || R0 == I1) {
619 CommonOperand = R0;
620 UncommonRealOp = R1;
621 } else if (R1 == I0 || R1 == I1) {
622 CommonOperand = R1;
623 UncommonRealOp = R0;
624 } else {
625 LLVM_DEBUG(dbgs() << " - No equal operand\n");
626 return nullptr;
627 }
628
629 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
630 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
631 Rotation == ComplexDeinterleavingRotation::Rotation_270)
632 std::swap(a&: UncommonRealOp, b&: UncommonImagOp);
633
634 // Between identifyPartialMul and here we need to have found a complete valid
635 // pair from the CommonOperand of each part.
636 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
637 Rotation == ComplexDeinterleavingRotation::Rotation_180)
638 PartialMatch.first = CommonOperand;
639 else
640 PartialMatch.second = CommonOperand;
641
642 if (!PartialMatch.first || !PartialMatch.second) {
643 LLVM_DEBUG(dbgs() << " - Incomplete partial match\n");
644 return nullptr;
645 }
646
647 NodePtr CommonNode = identifyNode(R: PartialMatch.first, I: PartialMatch.second);
648 if (!CommonNode) {
649 LLVM_DEBUG(dbgs() << " - No CommonNode identified\n");
650 return nullptr;
651 }
652
653 NodePtr UncommonNode = identifyNode(R: UncommonRealOp, I: UncommonImagOp);
654 if (!UncommonNode) {
655 LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n");
656 return nullptr;
657 }
658
659 NodePtr Node = prepareCompositeNode(
660 Operation: ComplexDeinterleavingOperation::CMulPartial, R: Real, I: Imag);
661 Node->Rotation = Rotation;
662 Node->addOperand(Node: CommonNode);
663 Node->addOperand(Node: UncommonNode);
664 return submitCompositeNode(Node);
665}
666
667ComplexDeinterleavingGraph::NodePtr
668ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
669 Instruction *Imag) {
670 LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
671 << "\n");
672 // Determine rotation
673 auto IsAdd = [](unsigned Op) {
674 return Op == Instruction::FAdd || Op == Instruction::Add;
675 };
676 auto IsSub = [](unsigned Op) {
677 return Op == Instruction::FSub || Op == Instruction::Sub;
678 };
679 ComplexDeinterleavingRotation Rotation;
680 if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
681 Rotation = ComplexDeinterleavingRotation::Rotation_0;
682 else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
683 Rotation = ComplexDeinterleavingRotation::Rotation_90;
684 else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
685 Rotation = ComplexDeinterleavingRotation::Rotation_180;
686 else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
687 Rotation = ComplexDeinterleavingRotation::Rotation_270;
688 else {
689 LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n");
690 return nullptr;
691 }
692
693 if (isa<FPMathOperator>(Val: Real) &&
694 (!Real->getFastMathFlags().allowContract() ||
695 !Imag->getFastMathFlags().allowContract())) {
696 LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n");
697 return nullptr;
698 }
699
700 Value *CR = Real->getOperand(i: 0);
701 Instruction *RealMulI = dyn_cast<Instruction>(Val: Real->getOperand(i: 1));
702 if (!RealMulI)
703 return nullptr;
704 Value *CI = Imag->getOperand(i: 0);
705 Instruction *ImagMulI = dyn_cast<Instruction>(Val: Imag->getOperand(i: 1));
706 if (!ImagMulI)
707 return nullptr;
708
709 if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
710 LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n");
711 return nullptr;
712 }
713
714 Value *R0 = RealMulI->getOperand(i: 0);
715 Value *R1 = RealMulI->getOperand(i: 1);
716 Value *I0 = ImagMulI->getOperand(i: 0);
717 Value *I1 = ImagMulI->getOperand(i: 1);
718
719 Value *CommonOperand;
720 Value *UncommonRealOp;
721 Value *UncommonImagOp;
722
723 if (R0 == I0 || R0 == I1) {
724 CommonOperand = R0;
725 UncommonRealOp = R1;
726 } else if (R1 == I0 || R1 == I1) {
727 CommonOperand = R1;
728 UncommonRealOp = R0;
729 } else {
730 LLVM_DEBUG(dbgs() << " - No equal operand\n");
731 return nullptr;
732 }
733
734 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
735 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
736 Rotation == ComplexDeinterleavingRotation::Rotation_270)
737 std::swap(a&: UncommonRealOp, b&: UncommonImagOp);
738
739 std::pair<Value *, Value *> PartialMatch(
740 (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
741 Rotation == ComplexDeinterleavingRotation::Rotation_180)
742 ? CommonOperand
743 : nullptr,
744 (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
745 Rotation == ComplexDeinterleavingRotation::Rotation_270)
746 ? CommonOperand
747 : nullptr);
748
749 auto *CRInst = dyn_cast<Instruction>(Val: CR);
750 auto *CIInst = dyn_cast<Instruction>(Val: CI);
751
752 if (!CRInst || !CIInst) {
753 LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n");
754 return nullptr;
755 }
756
757 NodePtr CNode = identifyNodeWithImplicitAdd(Real: CRInst, Imag: CIInst, PartialMatch);
758 if (!CNode) {
759 LLVM_DEBUG(dbgs() << " - No cnode identified\n");
760 return nullptr;
761 }
762
763 NodePtr UncommonRes = identifyNode(R: UncommonRealOp, I: UncommonImagOp);
764 if (!UncommonRes) {
765 LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n");
766 return nullptr;
767 }
768
769 assert(PartialMatch.first && PartialMatch.second);
770 NodePtr CommonRes = identifyNode(R: PartialMatch.first, I: PartialMatch.second);
771 if (!CommonRes) {
772 LLVM_DEBUG(dbgs() << " - No CommonRes identified\n");
773 return nullptr;
774 }
775
776 NodePtr Node = prepareCompositeNode(
777 Operation: ComplexDeinterleavingOperation::CMulPartial, R: Real, I: Imag);
778 Node->Rotation = Rotation;
779 Node->addOperand(Node: CommonRes);
780 Node->addOperand(Node: UncommonRes);
781 Node->addOperand(Node: CNode);
782 return submitCompositeNode(Node);
783}
784
785ComplexDeinterleavingGraph::NodePtr
786ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
787 LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
788
789 // Determine rotation
790 ComplexDeinterleavingRotation Rotation;
791 if ((Real->getOpcode() == Instruction::FSub &&
792 Imag->getOpcode() == Instruction::FAdd) ||
793 (Real->getOpcode() == Instruction::Sub &&
794 Imag->getOpcode() == Instruction::Add))
795 Rotation = ComplexDeinterleavingRotation::Rotation_90;
796 else if ((Real->getOpcode() == Instruction::FAdd &&
797 Imag->getOpcode() == Instruction::FSub) ||
798 (Real->getOpcode() == Instruction::Add &&
799 Imag->getOpcode() == Instruction::Sub))
800 Rotation = ComplexDeinterleavingRotation::Rotation_270;
801 else {
802 LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
803 return nullptr;
804 }
805
806 auto *AR = dyn_cast<Instruction>(Val: Real->getOperand(i: 0));
807 auto *BI = dyn_cast<Instruction>(Val: Real->getOperand(i: 1));
808 auto *AI = dyn_cast<Instruction>(Val: Imag->getOperand(i: 0));
809 auto *BR = dyn_cast<Instruction>(Val: Imag->getOperand(i: 1));
810
811 if (!AR || !AI || !BR || !BI) {
812 LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
813 return nullptr;
814 }
815
816 NodePtr ResA = identifyNode(R: AR, I: AI);
817 if (!ResA) {
818 LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
819 return nullptr;
820 }
821 NodePtr ResB = identifyNode(R: BR, I: BI);
822 if (!ResB) {
823 LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
824 return nullptr;
825 }
826
827 NodePtr Node =
828 prepareCompositeNode(Operation: ComplexDeinterleavingOperation::CAdd, R: Real, I: Imag);
829 Node->Rotation = Rotation;
830 Node->addOperand(Node: ResA);
831 Node->addOperand(Node: ResB);
832 return submitCompositeNode(Node);
833}
834
835static bool isInstructionPairAdd(Instruction *A, Instruction *B) {
836 unsigned OpcA = A->getOpcode();
837 unsigned OpcB = B->getOpcode();
838
839 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
840 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
841 (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
842 (OpcA == Instruction::Add && OpcB == Instruction::Sub);
843}
844
845static bool isInstructionPairMul(Instruction *A, Instruction *B) {
846 auto Pattern =
847 m_BinOp(L: m_FMul(L: m_Value(), R: m_Value()), R: m_FMul(L: m_Value(), R: m_Value()));
848
849 return match(V: A, P: Pattern) && match(V: B, P: Pattern);
850}
851
852static bool isInstructionPotentiallySymmetric(Instruction *I) {
853 switch (I->getOpcode()) {
854 case Instruction::FAdd:
855 case Instruction::FSub:
856 case Instruction::FMul:
857 case Instruction::FNeg:
858 case Instruction::Add:
859 case Instruction::Sub:
860 case Instruction::Mul:
861 return true;
862 default:
863 return false;
864 }
865}
866
867ComplexDeinterleavingGraph::NodePtr
868ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
869 Instruction *Imag) {
870 if (Real->getOpcode() != Imag->getOpcode())
871 return nullptr;
872
873 if (!isInstructionPotentiallySymmetric(I: Real) ||
874 !isInstructionPotentiallySymmetric(I: Imag))
875 return nullptr;
876
877 auto *R0 = Real->getOperand(i: 0);
878 auto *I0 = Imag->getOperand(i: 0);
879
880 NodePtr Op0 = identifyNode(R: R0, I: I0);
881 NodePtr Op1 = nullptr;
882 if (Op0 == nullptr)
883 return nullptr;
884
885 if (Real->isBinaryOp()) {
886 auto *R1 = Real->getOperand(i: 1);
887 auto *I1 = Imag->getOperand(i: 1);
888 Op1 = identifyNode(R: R1, I: I1);
889 if (Op1 == nullptr)
890 return nullptr;
891 }
892
893 if (isa<FPMathOperator>(Val: Real) &&
894 Real->getFastMathFlags() != Imag->getFastMathFlags())
895 return nullptr;
896
897 auto Node = prepareCompositeNode(Operation: ComplexDeinterleavingOperation::Symmetric,
898 R: Real, I: Imag);
899 Node->Opcode = Real->getOpcode();
900 if (isa<FPMathOperator>(Val: Real))
901 Node->Flags = Real->getFastMathFlags();
902
903 Node->addOperand(Node: Op0);
904 if (Real->isBinaryOp())
905 Node->addOperand(Node: Op1);
906
907 return submitCompositeNode(Node);
908}
909
910ComplexDeinterleavingGraph::NodePtr
911ComplexDeinterleavingGraph::identifyDotProduct(Value *V) {
912
913 if (!TL->isComplexDeinterleavingOperationSupported(
914 Operation: ComplexDeinterleavingOperation::CDot, Ty: V->getType())) {
915 LLVM_DEBUG(dbgs() << "Target doesn't support complex deinterleaving "
916 "operation CDot with the type "
917 << *V->getType() << "\n");
918 return nullptr;
919 }
920
921 auto *Inst = cast<Instruction>(Val: V);
922 auto *RealUser = cast<Instruction>(Val: *Inst->user_begin());
923
924 NodePtr CN =
925 prepareCompositeNode(Operation: ComplexDeinterleavingOperation::CDot, R: Inst, I: nullptr);
926
927 NodePtr ANode;
928
929 const Intrinsic::ID PartialReduceInt =
930 Intrinsic::experimental_vector_partial_reduce_add;
931
932 Value *AReal = nullptr;
933 Value *AImag = nullptr;
934 Value *BReal = nullptr;
935 Value *BImag = nullptr;
936 Value *Phi = nullptr;
937
938 auto UnwrapCast = [](Value *V) -> Value * {
939 if (auto *CI = dyn_cast<CastInst>(Val: V))
940 return CI->getOperand(i_nocapture: 0);
941 return V;
942 };
943
944 auto PatternRot0 = m_Intrinsic<PartialReduceInt>(
945 Op0: m_Intrinsic<PartialReduceInt>(Op0: m_Value(V&: Phi),
946 Op1: m_Mul(L: m_Value(V&: BReal), R: m_Value(V&: AReal))),
947 Op1: m_Neg(V: m_Mul(L: m_Value(V&: BImag), R: m_Value(V&: AImag))));
948
949 auto PatternRot270 = m_Intrinsic<PartialReduceInt>(
950 Op0: m_Intrinsic<PartialReduceInt>(
951 Op0: m_Value(V&: Phi), Op1: m_Neg(V: m_Mul(L: m_Value(V&: BReal), R: m_Value(V&: AImag)))),
952 Op1: m_Mul(L: m_Value(V&: BImag), R: m_Value(V&: AReal)));
953
954 if (match(V: Inst, P: PatternRot0)) {
955 CN->Rotation = ComplexDeinterleavingRotation::Rotation_0;
956 } else if (match(V: Inst, P: PatternRot270)) {
957 CN->Rotation = ComplexDeinterleavingRotation::Rotation_270;
958 } else {
959 Value *A0, *A1;
960 // The rotations 90 and 180 share the same operation pattern, so inspect the
961 // order of the operands, identifying where the real and imaginary
962 // components of A go, to discern between the aforementioned rotations.
963 auto PatternRot90Rot180 = m_Intrinsic<PartialReduceInt>(
964 Op0: m_Intrinsic<PartialReduceInt>(Op0: m_Value(V&: Phi),
965 Op1: m_Mul(L: m_Value(V&: BReal), R: m_Value(V&: A0))),
966 Op1: m_Mul(L: m_Value(V&: BImag), R: m_Value(V&: A1)));
967
968 if (!match(V: Inst, P: PatternRot90Rot180))
969 return nullptr;
970
971 A0 = UnwrapCast(A0);
972 A1 = UnwrapCast(A1);
973
974 // Test if A0 is real/A1 is imag
975 ANode = identifyNode(R: A0, I: A1);
976 if (!ANode) {
977 // Test if A0 is imag/A1 is real
978 ANode = identifyNode(R: A1, I: A0);
979 // Unable to identify operand components, thus unable to identify rotation
980 if (!ANode)
981 return nullptr;
982 CN->Rotation = ComplexDeinterleavingRotation::Rotation_90;
983 AReal = A1;
984 AImag = A0;
985 } else {
986 AReal = A0;
987 AImag = A1;
988 CN->Rotation = ComplexDeinterleavingRotation::Rotation_180;
989 }
990 }
991
992 AReal = UnwrapCast(AReal);
993 AImag = UnwrapCast(AImag);
994 BReal = UnwrapCast(BReal);
995 BImag = UnwrapCast(BImag);
996
997 VectorType *VTy = cast<VectorType>(Val: V->getType());
998 Type *ExpectedOperandTy = VectorType::getSubdividedVectorType(VTy, NumSubdivs: 2);
999 if (AReal->getType() != ExpectedOperandTy)
1000 return nullptr;
1001 if (AImag->getType() != ExpectedOperandTy)
1002 return nullptr;
1003 if (BReal->getType() != ExpectedOperandTy)
1004 return nullptr;
1005 if (BImag->getType() != ExpectedOperandTy)
1006 return nullptr;
1007
1008 if (Phi->getType() != VTy && RealUser->getType() != VTy)
1009 return nullptr;
1010
1011 NodePtr Node = identifyNode(R: AReal, I: AImag);
1012
1013 // In the case that a node was identified to figure out the rotation, ensure
1014 // that trying to identify a node with AReal and AImag post-unwrap results in
1015 // the same node
1016 if (ANode && Node != ANode) {
1017 LLVM_DEBUG(
1018 dbgs()
1019 << "Identified node is different from previously identified node. "
1020 "Unable to confidently generate a complex operation node\n");
1021 return nullptr;
1022 }
1023
1024 CN->addOperand(Node);
1025 CN->addOperand(Node: identifyNode(R: BReal, I: BImag));
1026 CN->addOperand(Node: identifyNode(R: Phi, I: RealUser));
1027
1028 return submitCompositeNode(Node: CN);
1029}
1030
1031ComplexDeinterleavingGraph::NodePtr
1032ComplexDeinterleavingGraph::identifyPartialReduction(Value *R, Value *I) {
1033 // Partial reductions don't support non-vector types, so check these first
1034 if (!isa<VectorType>(Val: R->getType()) || !isa<VectorType>(Val: I->getType()))
1035 return nullptr;
1036
1037 if (!R->hasUseList() || !I->hasUseList())
1038 return nullptr;
1039
1040 auto CommonUser =
1041 findCommonBetweenCollections<Value *>(A: R->users(), B: I->users());
1042 if (!CommonUser)
1043 return nullptr;
1044
1045 auto *IInst = dyn_cast<IntrinsicInst>(Val: *CommonUser);
1046 if (!IInst || IInst->getIntrinsicID() !=
1047 Intrinsic::experimental_vector_partial_reduce_add)
1048 return nullptr;
1049
1050 if (NodePtr CN = identifyDotProduct(V: IInst))
1051 return CN;
1052
1053 return nullptr;
1054}
1055
1056ComplexDeinterleavingGraph::NodePtr
1057ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
1058 auto It = CachedResult.find(Val: {R, I});
1059 if (It != CachedResult.end()) {
1060 LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
1061 return It->second;
1062 }
1063
1064 if (NodePtr CN = identifyPartialReduction(R, I))
1065 return CN;
1066
1067 bool IsReduction = RealPHI == R && (!ImagPHI || ImagPHI == I);
1068 if (!IsReduction && R->getType() != I->getType())
1069 return nullptr;
1070
1071 if (NodePtr CN = identifySplat(Real: R, Imag: I))
1072 return CN;
1073
1074 auto *Real = dyn_cast<Instruction>(Val: R);
1075 auto *Imag = dyn_cast<Instruction>(Val: I);
1076 if (!Real || !Imag)
1077 return nullptr;
1078
1079 if (NodePtr CN = identifyDeinterleave(Real, Imag))
1080 return CN;
1081
1082 if (NodePtr CN = identifyPHINode(Real, Imag))
1083 return CN;
1084
1085 if (NodePtr CN = identifySelectNode(Real, Imag))
1086 return CN;
1087
1088 auto *VTy = cast<VectorType>(Val: Real->getType());
1089 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1090
1091 bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
1092 Operation: ComplexDeinterleavingOperation::CMulPartial, Ty: NewVTy);
1093 bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
1094 Operation: ComplexDeinterleavingOperation::CAdd, Ty: NewVTy);
1095
1096 if (HasCMulSupport && isInstructionPairMul(A: Real, B: Imag)) {
1097 if (NodePtr CN = identifyPartialMul(Real, Imag))
1098 return CN;
1099 }
1100
1101 if (HasCAddSupport && isInstructionPairAdd(A: Real, B: Imag)) {
1102 if (NodePtr CN = identifyAdd(Real, Imag))
1103 return CN;
1104 }
1105
1106 if (HasCMulSupport && HasCAddSupport) {
1107 if (NodePtr CN = identifyReassocNodes(I: Real, J: Imag))
1108 return CN;
1109 }
1110
1111 if (NodePtr CN = identifySymmetricOperation(Real, Imag))
1112 return CN;
1113
1114 LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n");
1115 CachedResult[{R, I}] = nullptr;
1116 return nullptr;
1117}
1118
1119ComplexDeinterleavingGraph::NodePtr
1120ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
1121 Instruction *Imag) {
1122 auto IsOperationSupported = [](unsigned Opcode) -> bool {
1123 return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
1124 Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
1125 Opcode == Instruction::Sub;
1126 };
1127
1128 if (!IsOperationSupported(Real->getOpcode()) ||
1129 !IsOperationSupported(Imag->getOpcode()))
1130 return nullptr;
1131
1132 std::optional<FastMathFlags> Flags;
1133 if (isa<FPMathOperator>(Val: Real)) {
1134 if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
1135 LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
1136 "not identical\n");
1137 return nullptr;
1138 }
1139
1140 Flags = Real->getFastMathFlags();
1141 if (!Flags->allowReassoc()) {
1142 LLVM_DEBUG(
1143 dbgs()
1144 << "the 'Reassoc' attribute is missing in the FastMath flags\n");
1145 return nullptr;
1146 }
1147 }
1148
1149 // Collect multiplications and addend instructions from the given instruction
1150 // while traversing it operands. Additionally, verify that all instructions
1151 // have the same fast math flags.
1152 auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls,
1153 std::list<Addend> &Addends) -> bool {
1154 SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
1155 SmallPtrSet<Value *, 8> Visited;
1156 while (!Worklist.empty()) {
1157 auto [V, IsPositive] = Worklist.pop_back_val();
1158 if (!Visited.insert(Ptr: V).second)
1159 continue;
1160
1161 Instruction *I = dyn_cast<Instruction>(Val: V);
1162 if (!I) {
1163 Addends.emplace_back(args&: V, args&: IsPositive);
1164 continue;
1165 }
1166
1167 // If an instruction has more than one user, it indicates that it either
1168 // has an external user, which will be later checked by the checkNodes
1169 // function, or it is a subexpression utilized by multiple expressions. In
1170 // the latter case, we will attempt to separately identify the complex
1171 // operation from here in order to create a shared
1172 // ComplexDeinterleavingCompositeNode.
1173 if (I != Insn && I->hasNUsesOrMore(N: 2)) {
1174 LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
1175 Addends.emplace_back(args&: I, args&: IsPositive);
1176 continue;
1177 }
1178 switch (I->getOpcode()) {
1179 case Instruction::FAdd:
1180 case Instruction::Add:
1181 Worklist.emplace_back(Args: I->getOperand(i: 1), Args&: IsPositive);
1182 Worklist.emplace_back(Args: I->getOperand(i: 0), Args&: IsPositive);
1183 break;
1184 case Instruction::FSub:
1185 Worklist.emplace_back(Args: I->getOperand(i: 1), Args: !IsPositive);
1186 Worklist.emplace_back(Args: I->getOperand(i: 0), Args&: IsPositive);
1187 break;
1188 case Instruction::Sub:
1189 if (isNeg(V: I)) {
1190 Worklist.emplace_back(Args: getNegOperand(V: I), Args: !IsPositive);
1191 } else {
1192 Worklist.emplace_back(Args: I->getOperand(i: 1), Args: !IsPositive);
1193 Worklist.emplace_back(Args: I->getOperand(i: 0), Args&: IsPositive);
1194 }
1195 break;
1196 case Instruction::FMul:
1197 case Instruction::Mul: {
1198 Value *A, *B;
1199 if (isNeg(V: I->getOperand(i: 0))) {
1200 A = getNegOperand(V: I->getOperand(i: 0));
1201 IsPositive = !IsPositive;
1202 } else {
1203 A = I->getOperand(i: 0);
1204 }
1205
1206 if (isNeg(V: I->getOperand(i: 1))) {
1207 B = getNegOperand(V: I->getOperand(i: 1));
1208 IsPositive = !IsPositive;
1209 } else {
1210 B = I->getOperand(i: 1);
1211 }
1212 Muls.push_back(x: Product{.Multiplier: A, .Multiplicand: B, .IsPositive: IsPositive});
1213 break;
1214 }
1215 case Instruction::FNeg:
1216 Worklist.emplace_back(Args: I->getOperand(i: 0), Args: !IsPositive);
1217 break;
1218 default:
1219 Addends.emplace_back(args&: I, args&: IsPositive);
1220 continue;
1221 }
1222
1223 if (Flags && I->getFastMathFlags() != *Flags) {
1224 LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
1225 "inconsistent with the root instructions' flags: "
1226 << *I << "\n");
1227 return false;
1228 }
1229 }
1230 return true;
1231 };
1232
1233 std::vector<Product> RealMuls, ImagMuls;
1234 std::list<Addend> RealAddends, ImagAddends;
1235 if (!Collect(Real, RealMuls, RealAddends) ||
1236 !Collect(Imag, ImagMuls, ImagAddends))
1237 return nullptr;
1238
1239 if (RealAddends.size() != ImagAddends.size())
1240 return nullptr;
1241
1242 NodePtr FinalNode;
1243 if (!RealMuls.empty() || !ImagMuls.empty()) {
1244 // If there are multiplicands, extract positive addend and use it as an
1245 // accumulator
1246 FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1247 FinalNode = identifyMultiplications(RealMuls, ImagMuls, Accumulator: FinalNode);
1248 if (!FinalNode)
1249 return nullptr;
1250 }
1251
1252 // Identify and process remaining additions
1253 if (!RealAddends.empty() || !ImagAddends.empty()) {
1254 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, Accumulator: FinalNode);
1255 if (!FinalNode)
1256 return nullptr;
1257 }
1258 assert(FinalNode && "FinalNode can not be nullptr here");
1259 // Set the Real and Imag fields of the final node and submit it
1260 FinalNode->Real = Real;
1261 FinalNode->Imag = Imag;
1262 submitCompositeNode(Node: FinalNode);
1263 return FinalNode;
1264}
1265
1266bool ComplexDeinterleavingGraph::collectPartialMuls(
1267 const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls,
1268 std::vector<PartialMulCandidate> &PartialMulCandidates) {
1269 // Helper function to extract a common operand from two products
1270 auto FindCommonInstruction = [](const Product &Real,
1271 const Product &Imag) -> Value * {
1272 if (Real.Multiplicand == Imag.Multiplicand ||
1273 Real.Multiplicand == Imag.Multiplier)
1274 return Real.Multiplicand;
1275
1276 if (Real.Multiplier == Imag.Multiplicand ||
1277 Real.Multiplier == Imag.Multiplier)
1278 return Real.Multiplier;
1279
1280 return nullptr;
1281 };
1282
1283 // Iterating over real and imaginary multiplications to find common operands
1284 // If a common operand is found, a partial multiplication candidate is created
1285 // and added to the candidates vector The function returns false if no common
1286 // operands are found for any product
1287 for (unsigned i = 0; i < RealMuls.size(); ++i) {
1288 bool FoundCommon = false;
1289 for (unsigned j = 0; j < ImagMuls.size(); ++j) {
1290 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1291 if (!Common)
1292 continue;
1293
1294 auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1295 : RealMuls[i].Multiplicand;
1296 auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
1297 : ImagMuls[j].Multiplicand;
1298
1299 auto Node = identifyNode(R: A, I: B);
1300 if (Node) {
1301 FoundCommon = true;
1302 PartialMulCandidates.push_back(x: {.Common: Common, .Node: Node, .RealIdx: i, .ImagIdx: j, .IsNodeInverted: false});
1303 }
1304
1305 Node = identifyNode(R: B, I: A);
1306 if (Node) {
1307 FoundCommon = true;
1308 PartialMulCandidates.push_back(x: {.Common: Common, .Node: Node, .RealIdx: i, .ImagIdx: j, .IsNodeInverted: true});
1309 }
1310 }
1311 if (!FoundCommon)
1312 return false;
1313 }
1314 return true;
1315}
1316
1317ComplexDeinterleavingGraph::NodePtr
1318ComplexDeinterleavingGraph::identifyMultiplications(
1319 std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls,
1320 NodePtr Accumulator = nullptr) {
1321 if (RealMuls.size() != ImagMuls.size())
1322 return nullptr;
1323
1324 std::vector<PartialMulCandidate> Info;
1325 if (!collectPartialMuls(RealMuls, ImagMuls, PartialMulCandidates&: Info))
1326 return nullptr;
1327
1328 // Map to store common instruction to node pointers
1329 std::map<Value *, NodePtr> CommonToNode;
1330 std::vector<bool> Processed(Info.size(), false);
1331 for (unsigned I = 0; I < Info.size(); ++I) {
1332 if (Processed[I])
1333 continue;
1334
1335 PartialMulCandidate &InfoA = Info[I];
1336 for (unsigned J = I + 1; J < Info.size(); ++J) {
1337 if (Processed[J])
1338 continue;
1339
1340 PartialMulCandidate &InfoB = Info[J];
1341 auto *InfoReal = &InfoA;
1342 auto *InfoImag = &InfoB;
1343
1344 auto NodeFromCommon = identifyNode(R: InfoReal->Common, I: InfoImag->Common);
1345 if (!NodeFromCommon) {
1346 std::swap(a&: InfoReal, b&: InfoImag);
1347 NodeFromCommon = identifyNode(R: InfoReal->Common, I: InfoImag->Common);
1348 }
1349 if (!NodeFromCommon)
1350 continue;
1351
1352 CommonToNode[InfoReal->Common] = NodeFromCommon;
1353 CommonToNode[InfoImag->Common] = NodeFromCommon;
1354 Processed[I] = true;
1355 Processed[J] = true;
1356 }
1357 }
1358
1359 std::vector<bool> ProcessedReal(RealMuls.size(), false);
1360 std::vector<bool> ProcessedImag(ImagMuls.size(), false);
1361 NodePtr Result = Accumulator;
1362 for (auto &PMI : Info) {
1363 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1364 continue;
1365
1366 auto It = CommonToNode.find(x: PMI.Common);
1367 // TODO: Process independent complex multiplications. Cases like this:
1368 // A.real() * B where both A and B are complex numbers.
1369 if (It == CommonToNode.end()) {
1370 LLVM_DEBUG({
1371 dbgs() << "Unprocessed independent partial multiplication:\n";
1372 for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1373 dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
1374 << " multiplied by " << *Mul->Multiplicand << "\n";
1375 });
1376 return nullptr;
1377 }
1378
1379 auto &RealMul = RealMuls[PMI.RealIdx];
1380 auto &ImagMul = ImagMuls[PMI.ImagIdx];
1381
1382 auto NodeA = It->second;
1383 auto NodeB = PMI.Node;
1384 auto IsMultiplicandReal = PMI.Common == NodeA->Real;
1385 // The following table illustrates the relationship between multiplications
1386 // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1387 // can see:
1388 //
1389 // Rotation | Real | Imag |
1390 // ---------+--------+--------+
1391 // 0 | x * u | x * v |
1392 // 90 | -y * v | y * u |
1393 // 180 | -x * u | -x * v |
1394 // 270 | y * v | -y * u |
1395 //
1396 // Check if the candidate can indeed be represented by partial
1397 // multiplication
1398 // TODO: Add support for multiplication by complex one
1399 if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1400 (!IsMultiplicandReal && !PMI.IsNodeInverted))
1401 continue;
1402
1403 // Determine the rotation based on the multiplications
1404 ComplexDeinterleavingRotation Rotation;
1405 if (IsMultiplicandReal) {
1406 // Detect 0 and 180 degrees rotation
1407 if (RealMul.IsPositive && ImagMul.IsPositive)
1408 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0;
1409 else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1410 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180;
1411 else
1412 continue;
1413
1414 } else {
1415 // Detect 90 and 270 degrees rotation
1416 if (!RealMul.IsPositive && ImagMul.IsPositive)
1417 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90;
1418 else if (RealMul.IsPositive && !ImagMul.IsPositive)
1419 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270;
1420 else
1421 continue;
1422 }
1423
1424 LLVM_DEBUG({
1425 dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1426 dbgs().indent(4) << "X: " << *NodeA->Real << "\n";
1427 dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n";
1428 dbgs().indent(4) << "U: " << *NodeB->Real << "\n";
1429 dbgs().indent(4) << "V: " << *NodeB->Imag << "\n";
1430 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1431 });
1432
1433 NodePtr NodeMul = prepareCompositeNode(
1434 Operation: ComplexDeinterleavingOperation::CMulPartial, R: nullptr, I: nullptr);
1435 NodeMul->Rotation = Rotation;
1436 NodeMul->addOperand(Node: NodeA);
1437 NodeMul->addOperand(Node: NodeB);
1438 if (Result)
1439 NodeMul->addOperand(Node: Result);
1440 submitCompositeNode(Node: NodeMul);
1441 Result = NodeMul;
1442 ProcessedReal[PMI.RealIdx] = true;
1443 ProcessedImag[PMI.ImagIdx] = true;
1444 }
1445
1446 // Ensure all products have been processed, if not return nullptr.
1447 if (!all_of(Range&: ProcessedReal, P: [](bool V) { return V; }) ||
1448 !all_of(Range&: ProcessedImag, P: [](bool V) { return V; })) {
1449
1450 // Dump debug information about which partial multiplications are not
1451 // processed.
1452 LLVM_DEBUG({
1453 dbgs() << "Unprocessed products (Real):\n";
1454 for (size_t i = 0; i < ProcessedReal.size(); ++i) {
1455 if (!ProcessedReal[i])
1456 dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
1457 << *RealMuls[i].Multiplier << " multiplied by "
1458 << *RealMuls[i].Multiplicand << "\n";
1459 }
1460 dbgs() << "Unprocessed products (Imag):\n";
1461 for (size_t i = 0; i < ProcessedImag.size(); ++i) {
1462 if (!ProcessedImag[i])
1463 dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
1464 << *ImagMuls[i].Multiplier << " multiplied by "
1465 << *ImagMuls[i].Multiplicand << "\n";
1466 }
1467 });
1468 return nullptr;
1469 }
1470
1471 return Result;
1472}
1473
1474ComplexDeinterleavingGraph::NodePtr
1475ComplexDeinterleavingGraph::identifyAdditions(
1476 std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends,
1477 std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) {
1478 if (RealAddends.size() != ImagAddends.size())
1479 return nullptr;
1480
1481 NodePtr Result;
1482 // If we have accumulator use it as first addend
1483 if (Accumulator)
1484 Result = Accumulator;
1485 // Otherwise find an element with both positive real and imaginary parts.
1486 else
1487 Result = extractPositiveAddend(RealAddends, ImagAddends);
1488
1489 if (!Result)
1490 return nullptr;
1491
1492 while (!RealAddends.empty()) {
1493 auto ItR = RealAddends.begin();
1494 auto [R, IsPositiveR] = *ItR;
1495
1496 bool FoundImag = false;
1497 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1498 auto [I, IsPositiveI] = *ItI;
1499 ComplexDeinterleavingRotation Rotation;
1500 if (IsPositiveR && IsPositiveI)
1501 Rotation = ComplexDeinterleavingRotation::Rotation_0;
1502 else if (!IsPositiveR && IsPositiveI)
1503 Rotation = ComplexDeinterleavingRotation::Rotation_90;
1504 else if (!IsPositiveR && !IsPositiveI)
1505 Rotation = ComplexDeinterleavingRotation::Rotation_180;
1506 else
1507 Rotation = ComplexDeinterleavingRotation::Rotation_270;
1508
1509 NodePtr AddNode;
1510 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1511 Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1512 AddNode = identifyNode(R, I);
1513 } else {
1514 AddNode = identifyNode(R: I, I: R);
1515 }
1516 if (AddNode) {
1517 LLVM_DEBUG({
1518 dbgs() << "Identified addition:\n";
1519 dbgs().indent(4) << "X: " << *R << "\n";
1520 dbgs().indent(4) << "Y: " << *I << "\n";
1521 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1522 });
1523
1524 NodePtr TmpNode;
1525 if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) {
1526 TmpNode = prepareCompositeNode(
1527 Operation: ComplexDeinterleavingOperation::Symmetric, R: nullptr, I: nullptr);
1528 if (Flags) {
1529 TmpNode->Opcode = Instruction::FAdd;
1530 TmpNode->Flags = *Flags;
1531 } else {
1532 TmpNode->Opcode = Instruction::Add;
1533 }
1534 } else if (Rotation ==
1535 llvm::ComplexDeinterleavingRotation::Rotation_180) {
1536 TmpNode = prepareCompositeNode(
1537 Operation: ComplexDeinterleavingOperation::Symmetric, R: nullptr, I: nullptr);
1538 if (Flags) {
1539 TmpNode->Opcode = Instruction::FSub;
1540 TmpNode->Flags = *Flags;
1541 } else {
1542 TmpNode->Opcode = Instruction::Sub;
1543 }
1544 } else {
1545 TmpNode = prepareCompositeNode(Operation: ComplexDeinterleavingOperation::CAdd,
1546 R: nullptr, I: nullptr);
1547 TmpNode->Rotation = Rotation;
1548 }
1549
1550 TmpNode->addOperand(Node: Result);
1551 TmpNode->addOperand(Node: AddNode);
1552 submitCompositeNode(Node: TmpNode);
1553 Result = TmpNode;
1554 RealAddends.erase(position: ItR);
1555 ImagAddends.erase(position: ItI);
1556 FoundImag = true;
1557 break;
1558 }
1559 }
1560 if (!FoundImag)
1561 return nullptr;
1562 }
1563 return Result;
1564}
1565
1566ComplexDeinterleavingGraph::NodePtr
1567ComplexDeinterleavingGraph::extractPositiveAddend(
1568 std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) {
1569 for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1570 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1571 auto [R, IsPositiveR] = *ItR;
1572 auto [I, IsPositiveI] = *ItI;
1573 if (IsPositiveR && IsPositiveI) {
1574 auto Result = identifyNode(R, I);
1575 if (Result) {
1576 RealAddends.erase(position: ItR);
1577 ImagAddends.erase(position: ItI);
1578 return Result;
1579 }
1580 }
1581 }
1582 }
1583 return nullptr;
1584}
1585
1586bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1587 // This potential root instruction might already have been recognized as
1588 // reduction. Because RootToNode maps both Real and Imaginary parts to
1589 // CompositeNode we should choose only one either Real or Imag instruction to
1590 // use as an anchor for generating complex instruction.
1591 auto It = RootToNode.find(x: RootI);
1592 if (It != RootToNode.end()) {
1593 auto RootNode = It->second;
1594 assert(RootNode->Operation ==
1595 ComplexDeinterleavingOperation::ReductionOperation ||
1596 RootNode->Operation ==
1597 ComplexDeinterleavingOperation::ReductionSingle);
1598 // Find out which part, Real or Imag, comes later, and only if we come to
1599 // the latest part, add it to OrderedRoots.
1600 auto *R = cast<Instruction>(Val: RootNode->Real);
1601 auto *I = RootNode->Imag ? cast<Instruction>(Val: RootNode->Imag) : nullptr;
1602
1603 Instruction *ReplacementAnchor;
1604 if (I)
1605 ReplacementAnchor = R->comesBefore(Other: I) ? I : R;
1606 else
1607 ReplacementAnchor = R;
1608
1609 if (ReplacementAnchor != RootI)
1610 return false;
1611 OrderedRoots.push_back(Elt: RootI);
1612 return true;
1613 }
1614
1615 auto RootNode = identifyRoot(I: RootI);
1616 if (!RootNode)
1617 return false;
1618
1619 LLVM_DEBUG({
1620 Function *F = RootI->getFunction();
1621 BasicBlock *B = RootI->getParent();
1622 dbgs() << "Complex deinterleaving graph for " << F->getName()
1623 << "::" << B->getName() << ".\n";
1624 dump(dbgs());
1625 dbgs() << "\n";
1626 });
1627 RootToNode[RootI] = RootNode;
1628 OrderedRoots.push_back(Elt: RootI);
1629 return true;
1630}
1631
1632bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
1633 bool FoundPotentialReduction = false;
1634
1635 auto *Br = dyn_cast<BranchInst>(Val: B->getTerminator());
1636 if (!Br || Br->getNumSuccessors() != 2)
1637 return false;
1638
1639 // Identify simple one-block loop
1640 if (Br->getSuccessor(i: 0) != B && Br->getSuccessor(i: 1) != B)
1641 return false;
1642
1643 for (auto &PHI : B->phis()) {
1644 if (PHI.getNumIncomingValues() != 2)
1645 continue;
1646
1647 if (!PHI.getType()->isVectorTy())
1648 continue;
1649
1650 auto *ReductionOp = dyn_cast<Instruction>(Val: PHI.getIncomingValueForBlock(BB: B));
1651 if (!ReductionOp)
1652 continue;
1653
1654 // Check if final instruction is reduced outside of current block
1655 Instruction *FinalReduction = nullptr;
1656 auto NumUsers = 0u;
1657 for (auto *U : ReductionOp->users()) {
1658 ++NumUsers;
1659 if (U == &PHI)
1660 continue;
1661 FinalReduction = dyn_cast<Instruction>(Val: U);
1662 }
1663
1664 if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B ||
1665 isa<PHINode>(Val: FinalReduction))
1666 continue;
1667
1668 ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
1669 BackEdge = B;
1670 auto BackEdgeIdx = PHI.getBasicBlockIndex(BB: B);
1671 auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1672 Incoming = PHI.getIncomingBlock(i: IncomingIdx);
1673 FoundPotentialReduction = true;
1674
1675 // If the initial value of PHINode is an Instruction, consider it a leaf
1676 // value of a complex deinterleaving graph.
1677 if (auto *InitPHI =
1678 dyn_cast<Instruction>(Val: PHI.getIncomingValueForBlock(BB: Incoming)))
1679 FinalInstructions.insert(Ptr: InitPHI);
1680 }
1681 return FoundPotentialReduction;
1682}
1683
1684void ComplexDeinterleavingGraph::identifyReductionNodes() {
1685 SmallVector<bool> Processed(ReductionInfo.size(), false);
1686 SmallVector<Instruction *> OperationInstruction;
1687 for (auto &P : ReductionInfo)
1688 OperationInstruction.push_back(Elt: P.first);
1689
1690 // Identify a complex computation by evaluating two reduction operations that
1691 // potentially could be involved
1692 for (size_t i = 0; i < OperationInstruction.size(); ++i) {
1693 if (Processed[i])
1694 continue;
1695 for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
1696 if (Processed[j])
1697 continue;
1698 auto *Real = OperationInstruction[i];
1699 auto *Imag = OperationInstruction[j];
1700 if (Real->getType() != Imag->getType())
1701 continue;
1702
1703 RealPHI = ReductionInfo[Real].first;
1704 ImagPHI = ReductionInfo[Imag].first;
1705 PHIsFound = false;
1706 auto Node = identifyNode(R: Real, I: Imag);
1707 if (!Node) {
1708 std::swap(a&: Real, b&: Imag);
1709 std::swap(a&: RealPHI, b&: ImagPHI);
1710 Node = identifyNode(R: Real, I: Imag);
1711 }
1712
1713 // If a node is identified and reduction PHINode is used in the chain of
1714 // operations, mark its operation instructions as used to prevent
1715 // re-identification and attach the node to the real part
1716 if (Node && PHIsFound) {
1717 LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
1718 << *Real << " / " << *Imag << "\n");
1719 Processed[i] = true;
1720 Processed[j] = true;
1721 auto RootNode = prepareCompositeNode(
1722 Operation: ComplexDeinterleavingOperation::ReductionOperation, R: Real, I: Imag);
1723 RootNode->addOperand(Node);
1724 RootToNode[Real] = RootNode;
1725 RootToNode[Imag] = RootNode;
1726 submitCompositeNode(Node: RootNode);
1727 break;
1728 }
1729 }
1730
1731 auto *Real = OperationInstruction[i];
1732 // We want to check that we have 2 operands, but the function attributes
1733 // being counted as operands bloats this value.
1734 if (Processed[i] || Real->getNumOperands() < 2)
1735 continue;
1736
1737 // Can only combined integer reductions at the moment.
1738 if (!ReductionInfo[Real].second->getType()->isIntegerTy())
1739 continue;
1740
1741 RealPHI = ReductionInfo[Real].first;
1742 ImagPHI = nullptr;
1743 PHIsFound = false;
1744 auto Node = identifyNode(R: Real->getOperand(i: 0), I: Real->getOperand(i: 1));
1745 if (Node && PHIsFound) {
1746 LLVM_DEBUG(
1747 dbgs() << "Identified single reduction starting from instruction: "
1748 << *Real << "/" << *ReductionInfo[Real].second << "\n");
1749
1750 // Reducing to a single vector is not supported, only permit reducing down
1751 // to scalar values.
1752 // Doing this here will leave the prior node in the graph,
1753 // however with no uses the node will be unreachable by the replacement
1754 // process. That along with the usage outside the graph should prevent the
1755 // replacement process from kicking off at all for this graph.
1756 // TODO Add support for reducing to a single vector value
1757 if (ReductionInfo[Real].second->getType()->isVectorTy())
1758 continue;
1759
1760 Processed[i] = true;
1761 auto RootNode = prepareCompositeNode(
1762 Operation: ComplexDeinterleavingOperation::ReductionSingle, R: Real, I: nullptr);
1763 RootNode->addOperand(Node);
1764 RootToNode[Real] = RootNode;
1765 submitCompositeNode(Node: RootNode);
1766 }
1767 }
1768
1769 RealPHI = nullptr;
1770 ImagPHI = nullptr;
1771}
1772
1773bool ComplexDeinterleavingGraph::checkNodes() {
1774
1775 bool FoundDeinterleaveNode = false;
1776 for (NodePtr N : CompositeNodes) {
1777 if (!N->areOperandsValid())
1778 return false;
1779 if (N->Operation == ComplexDeinterleavingOperation::Deinterleave)
1780 FoundDeinterleaveNode = true;
1781 }
1782
1783 // We need a deinterleave node in order to guarantee that we're working with
1784 // complex numbers.
1785 if (!FoundDeinterleaveNode) {
1786 LLVM_DEBUG(
1787 dbgs() << "Couldn't find a deinterleave node within the graph, cannot "
1788 "guarantee safety during graph transformation.\n");
1789 return false;
1790 }
1791
1792 // Collect all instructions from roots to leaves
1793 SmallPtrSet<Instruction *, 16> AllInstructions;
1794 SmallVector<Instruction *, 8> Worklist;
1795 for (auto &Pair : RootToNode)
1796 Worklist.push_back(Elt: Pair.first);
1797
1798 // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1799 // chains
1800 while (!Worklist.empty()) {
1801 auto *I = Worklist.pop_back_val();
1802
1803 if (!AllInstructions.insert(Ptr: I).second)
1804 continue;
1805
1806 for (Value *Op : I->operands()) {
1807 if (auto *OpI = dyn_cast<Instruction>(Val: Op)) {
1808 if (!FinalInstructions.count(Ptr: I))
1809 Worklist.emplace_back(Args&: OpI);
1810 }
1811 }
1812 }
1813
1814 // Find instructions that have users outside of chain
1815 for (auto *I : AllInstructions) {
1816 // Skip root nodes
1817 if (RootToNode.count(x: I))
1818 continue;
1819
1820 for (User *U : I->users()) {
1821 if (AllInstructions.count(Ptr: cast<Instruction>(Val: U)))
1822 continue;
1823
1824 // Found an instruction that is not used by XCMLA/XCADD chain
1825 Worklist.emplace_back(Args&: I);
1826 break;
1827 }
1828 }
1829
1830 // If any instructions are found to be used outside, find and remove roots
1831 // that somehow connect to those instructions.
1832 SmallPtrSet<Instruction *, 16> Visited;
1833 while (!Worklist.empty()) {
1834 auto *I = Worklist.pop_back_val();
1835 if (!Visited.insert(Ptr: I).second)
1836 continue;
1837
1838 // Found an impacted root node. Removing it from the nodes to be
1839 // deinterleaved
1840 if (RootToNode.count(x: I)) {
1841 LLVM_DEBUG(dbgs() << "Instruction " << *I
1842 << " could be deinterleaved but its chain of complex "
1843 "operations have an outside user\n");
1844 RootToNode.erase(x: I);
1845 }
1846
1847 if (!AllInstructions.count(Ptr: I) || FinalInstructions.count(Ptr: I))
1848 continue;
1849
1850 for (User *U : I->users())
1851 Worklist.emplace_back(Args: cast<Instruction>(Val: U));
1852
1853 for (Value *Op : I->operands()) {
1854 if (auto *OpI = dyn_cast<Instruction>(Val: Op))
1855 Worklist.emplace_back(Args&: OpI);
1856 }
1857 }
1858 return !RootToNode.empty();
1859}
1860
1861ComplexDeinterleavingGraph::NodePtr
1862ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1863 if (auto *Intrinsic = dyn_cast<IntrinsicInst>(Val: RootI)) {
1864 if (Intrinsic->getIntrinsicID() != Intrinsic::vector_interleave2)
1865 return nullptr;
1866
1867 auto *Real = dyn_cast<Instruction>(Val: Intrinsic->getOperand(i_nocapture: 0));
1868 auto *Imag = dyn_cast<Instruction>(Val: Intrinsic->getOperand(i_nocapture: 1));
1869 if (!Real || !Imag)
1870 return nullptr;
1871
1872 return identifyNode(R: Real, I: Imag);
1873 }
1874
1875 auto *SVI = dyn_cast<ShuffleVectorInst>(Val: RootI);
1876 if (!SVI)
1877 return nullptr;
1878
1879 // Look for a shufflevector that takes separate vectors of the real and
1880 // imaginary components and recombines them into a single vector.
1881 if (!isInterleavingMask(Mask: SVI->getShuffleMask()))
1882 return nullptr;
1883
1884 Instruction *Real;
1885 Instruction *Imag;
1886 if (!match(V: RootI, P: m_Shuffle(v1: m_Instruction(I&: Real), v2: m_Instruction(I&: Imag))))
1887 return nullptr;
1888
1889 return identifyNode(R: Real, I: Imag);
1890}
1891
1892ComplexDeinterleavingGraph::NodePtr
1893ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real,
1894 Instruction *Imag) {
1895 Instruction *I = nullptr;
1896 Value *FinalValue = nullptr;
1897 if (match(V: Real, P: m_ExtractValue<0>(V: m_Instruction(I))) &&
1898 match(V: Imag, P: m_ExtractValue<1>(V: m_Specific(V: I))) &&
1899 match(V: I, P: m_Intrinsic<Intrinsic::vector_deinterleave2>(
1900 Op0: m_Value(V&: FinalValue)))) {
1901 NodePtr PlaceholderNode = prepareCompositeNode(
1902 Operation: llvm::ComplexDeinterleavingOperation::Deinterleave, R: Real, I: Imag);
1903 PlaceholderNode->ReplacementNode = FinalValue;
1904 FinalInstructions.insert(Ptr: Real);
1905 FinalInstructions.insert(Ptr: Imag);
1906 return submitCompositeNode(Node: PlaceholderNode);
1907 }
1908
1909 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Val: Real);
1910 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Val: Imag);
1911 if (!RealShuffle || !ImagShuffle) {
1912 if (RealShuffle || ImagShuffle)
1913 LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
1914 return nullptr;
1915 }
1916
1917 Value *RealOp1 = RealShuffle->getOperand(i_nocapture: 1);
1918 if (!isa<UndefValue>(Val: RealOp1) && !isa<ConstantAggregateZero>(Val: RealOp1)) {
1919 LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
1920 return nullptr;
1921 }
1922 Value *ImagOp1 = ImagShuffle->getOperand(i_nocapture: 1);
1923 if (!isa<UndefValue>(Val: ImagOp1) && !isa<ConstantAggregateZero>(Val: ImagOp1)) {
1924 LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
1925 return nullptr;
1926 }
1927
1928 Value *RealOp0 = RealShuffle->getOperand(i_nocapture: 0);
1929 Value *ImagOp0 = ImagShuffle->getOperand(i_nocapture: 0);
1930
1931 if (RealOp0 != ImagOp0) {
1932 LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
1933 return nullptr;
1934 }
1935
1936 ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
1937 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
1938 if (!isDeinterleavingMask(Mask: RealMask) || !isDeinterleavingMask(Mask: ImagMask)) {
1939 LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
1940 return nullptr;
1941 }
1942
1943 if (RealMask[0] != 0 || ImagMask[0] != 1) {
1944 LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
1945 return nullptr;
1946 }
1947
1948 // Type checking, the shuffle type should be a vector type of the same
1949 // scalar type, but half the size
1950 auto CheckType = [&](ShuffleVectorInst *Shuffle) {
1951 Value *Op = Shuffle->getOperand(i_nocapture: 0);
1952 auto *ShuffleTy = cast<FixedVectorType>(Val: Shuffle->getType());
1953 auto *OpTy = cast<FixedVectorType>(Val: Op->getType());
1954
1955 if (OpTy->getScalarType() != ShuffleTy->getScalarType())
1956 return false;
1957 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
1958 return false;
1959
1960 return true;
1961 };
1962
1963 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
1964 if (!CheckType(Shuffle))
1965 return false;
1966
1967 ArrayRef<int> Mask = Shuffle->getShuffleMask();
1968 int Last = *Mask.rbegin();
1969
1970 Value *Op = Shuffle->getOperand(i_nocapture: 0);
1971 auto *OpTy = cast<FixedVectorType>(Val: Op->getType());
1972 int NumElements = OpTy->getNumElements();
1973
1974 // Ensure that the deinterleaving shuffle only pulls from the first
1975 // shuffle operand.
1976 return Last < NumElements;
1977 };
1978
1979 if (RealShuffle->getType() != ImagShuffle->getType()) {
1980 LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
1981 return nullptr;
1982 }
1983 if (!CheckDeinterleavingShuffle(RealShuffle)) {
1984 LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
1985 return nullptr;
1986 }
1987 if (!CheckDeinterleavingShuffle(ImagShuffle)) {
1988 LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
1989 return nullptr;
1990 }
1991
1992 NodePtr PlaceholderNode =
1993 prepareCompositeNode(Operation: llvm::ComplexDeinterleavingOperation::Deinterleave,
1994 R: RealShuffle, I: ImagShuffle);
1995 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(i_nocapture: 0);
1996 FinalInstructions.insert(Ptr: RealShuffle);
1997 FinalInstructions.insert(Ptr: ImagShuffle);
1998 return submitCompositeNode(Node: PlaceholderNode);
1999}
2000
2001ComplexDeinterleavingGraph::NodePtr
2002ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) {
2003 auto IsSplat = [](Value *V) -> bool {
2004 // Fixed-width vector with constants
2005 if (isa<ConstantDataVector>(Val: V))
2006 return true;
2007
2008 if (isa<ConstantInt>(Val: V) || isa<ConstantFP>(Val: V))
2009 return isa<VectorType>(Val: V->getType());
2010
2011 VectorType *VTy;
2012 ArrayRef<int> Mask;
2013 // Splats are represented differently depending on whether the repeated
2014 // value is a constant or an Instruction
2015 if (auto *Const = dyn_cast<ConstantExpr>(Val: V)) {
2016 if (Const->getOpcode() != Instruction::ShuffleVector)
2017 return false;
2018 VTy = cast<VectorType>(Val: Const->getType());
2019 Mask = Const->getShuffleMask();
2020 } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(Val: V)) {
2021 VTy = Shuf->getType();
2022 Mask = Shuf->getShuffleMask();
2023 } else {
2024 return false;
2025 }
2026
2027 // When the data type is <1 x Type>, it's not possible to differentiate
2028 // between the ComplexDeinterleaving::Deinterleave and
2029 // ComplexDeinterleaving::Splat operations.
2030 if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
2031 return false;
2032
2033 return all_equal(Range&: Mask) && Mask[0] == 0;
2034 };
2035
2036 if (!IsSplat(R) || !IsSplat(I))
2037 return nullptr;
2038
2039 auto *Real = dyn_cast<Instruction>(Val: R);
2040 auto *Imag = dyn_cast<Instruction>(Val: I);
2041 if ((!Real && Imag) || (Real && !Imag))
2042 return nullptr;
2043
2044 if (Real && Imag) {
2045 // Non-constant splats should be in the same basic block
2046 if (Real->getParent() != Imag->getParent())
2047 return nullptr;
2048
2049 FinalInstructions.insert(Ptr: Real);
2050 FinalInstructions.insert(Ptr: Imag);
2051 }
2052 NodePtr PlaceholderNode =
2053 prepareCompositeNode(Operation: ComplexDeinterleavingOperation::Splat, R, I);
2054 return submitCompositeNode(Node: PlaceholderNode);
2055}
2056
2057ComplexDeinterleavingGraph::NodePtr
2058ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
2059 Instruction *Imag) {
2060 if (Real != RealPHI || (ImagPHI && Imag != ImagPHI))
2061 return nullptr;
2062
2063 PHIsFound = true;
2064 NodePtr PlaceholderNode = prepareCompositeNode(
2065 Operation: ComplexDeinterleavingOperation::ReductionPHI, R: Real, I: Imag);
2066 return submitCompositeNode(Node: PlaceholderNode);
2067}
2068
2069ComplexDeinterleavingGraph::NodePtr
2070ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
2071 Instruction *Imag) {
2072 auto *SelectReal = dyn_cast<SelectInst>(Val: Real);
2073 auto *SelectImag = dyn_cast<SelectInst>(Val: Imag);
2074 if (!SelectReal || !SelectImag)
2075 return nullptr;
2076
2077 Instruction *MaskA, *MaskB;
2078 Instruction *AR, *AI, *RA, *BI;
2079 if (!match(V: Real, P: m_Select(C: m_Instruction(I&: MaskA), L: m_Instruction(I&: AR),
2080 R: m_Instruction(I&: RA))) ||
2081 !match(V: Imag, P: m_Select(C: m_Instruction(I&: MaskB), L: m_Instruction(I&: AI),
2082 R: m_Instruction(I&: BI))))
2083 return nullptr;
2084
2085 if (MaskA != MaskB && !MaskA->isIdenticalTo(I: MaskB))
2086 return nullptr;
2087
2088 if (!MaskA->getType()->isVectorTy())
2089 return nullptr;
2090
2091 auto NodeA = identifyNode(R: AR, I: AI);
2092 if (!NodeA)
2093 return nullptr;
2094
2095 auto NodeB = identifyNode(R: RA, I: BI);
2096 if (!NodeB)
2097 return nullptr;
2098
2099 NodePtr PlaceholderNode = prepareCompositeNode(
2100 Operation: ComplexDeinterleavingOperation::ReductionSelect, R: Real, I: Imag);
2101 PlaceholderNode->addOperand(Node: NodeA);
2102 PlaceholderNode->addOperand(Node: NodeB);
2103 FinalInstructions.insert(Ptr: MaskA);
2104 FinalInstructions.insert(Ptr: MaskB);
2105 return submitCompositeNode(Node: PlaceholderNode);
2106}
2107
2108static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
2109 std::optional<FastMathFlags> Flags,
2110 Value *InputA, Value *InputB) {
2111 Value *I;
2112 switch (Opcode) {
2113 case Instruction::FNeg:
2114 I = B.CreateFNeg(V: InputA);
2115 break;
2116 case Instruction::FAdd:
2117 I = B.CreateFAdd(L: InputA, R: InputB);
2118 break;
2119 case Instruction::Add:
2120 I = B.CreateAdd(LHS: InputA, RHS: InputB);
2121 break;
2122 case Instruction::FSub:
2123 I = B.CreateFSub(L: InputA, R: InputB);
2124 break;
2125 case Instruction::Sub:
2126 I = B.CreateSub(LHS: InputA, RHS: InputB);
2127 break;
2128 case Instruction::FMul:
2129 I = B.CreateFMul(L: InputA, R: InputB);
2130 break;
2131 case Instruction::Mul:
2132 I = B.CreateMul(LHS: InputA, RHS: InputB);
2133 break;
2134 default:
2135 llvm_unreachable("Incorrect symmetric opcode");
2136 }
2137 if (Flags)
2138 cast<Instruction>(Val: I)->setFastMathFlags(*Flags);
2139 return I;
2140}
2141
2142Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
2143 RawNodePtr Node) {
2144 if (Node->ReplacementNode)
2145 return Node->ReplacementNode;
2146
2147 auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * {
2148 return Node->Operands.size() > Idx
2149 ? replaceNode(Builder, Node: Node->Operands[Idx])
2150 : nullptr;
2151 };
2152
2153 Value *ReplacementNode;
2154 switch (Node->Operation) {
2155 case ComplexDeinterleavingOperation::CDot: {
2156 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2157 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2158 Value *Accumulator = ReplaceOperandIfExist(Node, 2);
2159 assert(!Input1 || (Input0->getType() == Input1->getType() &&
2160 "Node inputs need to be of the same type"));
2161 ReplacementNode = TL->createComplexDeinterleavingIR(
2162 B&: Builder, OperationType: Node->Operation, Rotation: Node->Rotation, InputA: Input0, InputB: Input1, Accumulator);
2163 break;
2164 }
2165 case ComplexDeinterleavingOperation::CAdd:
2166 case ComplexDeinterleavingOperation::CMulPartial:
2167 case ComplexDeinterleavingOperation::Symmetric: {
2168 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2169 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2170 Value *Accumulator = ReplaceOperandIfExist(Node, 2);
2171 assert(!Input1 || (Input0->getType() == Input1->getType() &&
2172 "Node inputs need to be of the same type"));
2173 assert(!Accumulator ||
2174 (Input0->getType() == Accumulator->getType() &&
2175 "Accumulator and input need to be of the same type"));
2176 if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
2177 ReplacementNode = replaceSymmetricNode(B&: Builder, Opcode: Node->Opcode, Flags: Node->Flags,
2178 InputA: Input0, InputB: Input1);
2179 else
2180 ReplacementNode = TL->createComplexDeinterleavingIR(
2181 B&: Builder, OperationType: Node->Operation, Rotation: Node->Rotation, InputA: Input0, InputB: Input1,
2182 Accumulator);
2183 break;
2184 }
2185 case ComplexDeinterleavingOperation::Deinterleave:
2186 llvm_unreachable("Deinterleave node should already have ReplacementNode");
2187 break;
2188 case ComplexDeinterleavingOperation::Splat: {
2189 auto *NewTy = VectorType::getDoubleElementsVectorType(
2190 VTy: cast<VectorType>(Val: Node->Real->getType()));
2191 auto *R = dyn_cast<Instruction>(Val: Node->Real);
2192 auto *I = dyn_cast<Instruction>(Val: Node->Imag);
2193 if (R && I) {
2194 // Splats that are not constant are interleaved where they are located
2195 Instruction *InsertPoint = (I->comesBefore(Other: R) ? R : I)->getNextNode();
2196 IRBuilder<> IRB(InsertPoint);
2197 ReplacementNode = IRB.CreateIntrinsic(ID: Intrinsic::vector_interleave2,
2198 Types: NewTy, Args: {Node->Real, Node->Imag});
2199 } else {
2200 ReplacementNode = Builder.CreateIntrinsic(
2201 ID: Intrinsic::vector_interleave2, Types: NewTy, Args: {Node->Real, Node->Imag});
2202 }
2203 break;
2204 }
2205 case ComplexDeinterleavingOperation::ReductionPHI: {
2206 // If Operation is ReductionPHI, a new empty PHINode is created.
2207 // It is filled later when the ReductionOperation is processed.
2208 auto *OldPHI = cast<PHINode>(Val: Node->Real);
2209 auto *VTy = cast<VectorType>(Val: Node->Real->getType());
2210 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2211 auto *NewPHI = PHINode::Create(Ty: NewVTy, NumReservedValues: 0, NameStr: "", InsertBefore: BackEdge->getFirstNonPHIIt());
2212 OldToNewPHI[OldPHI] = NewPHI;
2213 ReplacementNode = NewPHI;
2214 break;
2215 }
2216 case ComplexDeinterleavingOperation::ReductionSingle:
2217 ReplacementNode = replaceNode(Builder, Node: Node->Operands[0]);
2218 processReductionSingle(OperationReplacement: ReplacementNode, Node);
2219 break;
2220 case ComplexDeinterleavingOperation::ReductionOperation:
2221 ReplacementNode = replaceNode(Builder, Node: Node->Operands[0]);
2222 processReductionOperation(OperationReplacement: ReplacementNode, Node);
2223 break;
2224 case ComplexDeinterleavingOperation::ReductionSelect: {
2225 auto *MaskReal = cast<Instruction>(Val: Node->Real)->getOperand(i: 0);
2226 auto *MaskImag = cast<Instruction>(Val: Node->Imag)->getOperand(i: 0);
2227 auto *A = replaceNode(Builder, Node: Node->Operands[0]);
2228 auto *B = replaceNode(Builder, Node: Node->Operands[1]);
2229 auto *NewMaskTy = VectorType::getDoubleElementsVectorType(
2230 VTy: cast<VectorType>(Val: MaskReal->getType()));
2231 auto *NewMask = Builder.CreateIntrinsic(ID: Intrinsic::vector_interleave2,
2232 Types: NewMaskTy, Args: {MaskReal, MaskImag});
2233 ReplacementNode = Builder.CreateSelect(C: NewMask, True: A, False: B);
2234 break;
2235 }
2236 }
2237
2238 assert(ReplacementNode && "Target failed to create Intrinsic call.");
2239 NumComplexTransformations += 1;
2240 Node->ReplacementNode = ReplacementNode;
2241 return ReplacementNode;
2242}
2243
2244void ComplexDeinterleavingGraph::processReductionSingle(
2245 Value *OperationReplacement, RawNodePtr Node) {
2246 auto *Real = cast<Instruction>(Val: Node->Real);
2247 auto *OldPHI = ReductionInfo[Real].first;
2248 auto *NewPHI = OldToNewPHI[OldPHI];
2249 auto *VTy = cast<VectorType>(Val: Real->getType());
2250 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2251
2252 Value *Init = OldPHI->getIncomingValueForBlock(BB: Incoming);
2253
2254 IRBuilder<> Builder(Incoming->getTerminator());
2255
2256 Value *NewInit = nullptr;
2257 if (auto *C = dyn_cast<Constant>(Val: Init)) {
2258 if (C->isZeroValue())
2259 NewInit = Constant::getNullValue(Ty: NewVTy);
2260 }
2261
2262 if (!NewInit)
2263 NewInit = Builder.CreateIntrinsic(ID: Intrinsic::vector_interleave2, Types: NewVTy,
2264 Args: {Init, Constant::getNullValue(Ty: VTy)});
2265
2266 NewPHI->addIncoming(V: NewInit, BB: Incoming);
2267 NewPHI->addIncoming(V: OperationReplacement, BB: BackEdge);
2268
2269 auto *FinalReduction = ReductionInfo[Real].second;
2270 Builder.SetInsertPoint(&*FinalReduction->getParent()->getFirstInsertionPt());
2271
2272 auto *AddReduce = Builder.CreateAddReduce(Src: OperationReplacement);
2273 FinalReduction->replaceAllUsesWith(V: AddReduce);
2274}
2275
2276void ComplexDeinterleavingGraph::processReductionOperation(
2277 Value *OperationReplacement, RawNodePtr Node) {
2278 auto *Real = cast<Instruction>(Val: Node->Real);
2279 auto *Imag = cast<Instruction>(Val: Node->Imag);
2280 auto *OldPHIReal = ReductionInfo[Real].first;
2281 auto *OldPHIImag = ReductionInfo[Imag].first;
2282 auto *NewPHI = OldToNewPHI[OldPHIReal];
2283
2284 auto *VTy = cast<VectorType>(Val: Real->getType());
2285 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2286
2287 // We have to interleave initial origin values coming from IncomingBlock
2288 Value *InitReal = OldPHIReal->getIncomingValueForBlock(BB: Incoming);
2289 Value *InitImag = OldPHIImag->getIncomingValueForBlock(BB: Incoming);
2290
2291 IRBuilder<> Builder(Incoming->getTerminator());
2292 auto *NewInit = Builder.CreateIntrinsic(ID: Intrinsic::vector_interleave2, Types: NewVTy,
2293 Args: {InitReal, InitImag});
2294
2295 NewPHI->addIncoming(V: NewInit, BB: Incoming);
2296 NewPHI->addIncoming(V: OperationReplacement, BB: BackEdge);
2297
2298 // Deinterleave complex vector outside of loop so that it can be finally
2299 // reduced
2300 auto *FinalReductionReal = ReductionInfo[Real].second;
2301 auto *FinalReductionImag = ReductionInfo[Imag].second;
2302
2303 Builder.SetInsertPoint(
2304 &*FinalReductionReal->getParent()->getFirstInsertionPt());
2305 auto *Deinterleave = Builder.CreateIntrinsic(ID: Intrinsic::vector_deinterleave2,
2306 Types: OperationReplacement->getType(),
2307 Args: OperationReplacement);
2308
2309 auto *NewReal = Builder.CreateExtractValue(Agg: Deinterleave, Idxs: (uint64_t)0);
2310 FinalReductionReal->replaceUsesOfWith(From: Real, To: NewReal);
2311
2312 Builder.SetInsertPoint(FinalReductionImag);
2313 auto *NewImag = Builder.CreateExtractValue(Agg: Deinterleave, Idxs: 1);
2314 FinalReductionImag->replaceUsesOfWith(From: Imag, To: NewImag);
2315}
2316
2317void ComplexDeinterleavingGraph::replaceNodes() {
2318 SmallVector<Instruction *, 16> DeadInstrRoots;
2319 for (auto *RootInstruction : OrderedRoots) {
2320 // Check if this potential root went through check process and we can
2321 // deinterleave it
2322 if (!RootToNode.count(x: RootInstruction))
2323 continue;
2324
2325 IRBuilder<> Builder(RootInstruction);
2326 auto RootNode = RootToNode[RootInstruction];
2327 Value *R = replaceNode(Builder, Node: RootNode.get());
2328
2329 if (RootNode->Operation ==
2330 ComplexDeinterleavingOperation::ReductionOperation) {
2331 auto *RootReal = cast<Instruction>(Val: RootNode->Real);
2332 auto *RootImag = cast<Instruction>(Val: RootNode->Imag);
2333 ReductionInfo[RootReal].first->removeIncomingValue(BB: BackEdge);
2334 ReductionInfo[RootImag].first->removeIncomingValue(BB: BackEdge);
2335 DeadInstrRoots.push_back(Elt: RootReal);
2336 DeadInstrRoots.push_back(Elt: RootImag);
2337 } else if (RootNode->Operation ==
2338 ComplexDeinterleavingOperation::ReductionSingle) {
2339 auto *RootInst = cast<Instruction>(Val: RootNode->Real);
2340 auto &Info = ReductionInfo[RootInst];
2341 Info.first->removeIncomingValue(BB: BackEdge);
2342 DeadInstrRoots.push_back(Elt: Info.second);
2343 } else {
2344 assert(R && "Unable to find replacement for RootInstruction");
2345 DeadInstrRoots.push_back(Elt: RootInstruction);
2346 RootInstruction->replaceAllUsesWith(V: R);
2347 }
2348 }
2349
2350 for (auto *I : DeadInstrRoots)
2351 RecursivelyDeleteTriviallyDeadInstructions(V: I, TLI);
2352}
2353