1//===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Identification:
10// This step is responsible for finding the patterns that can be lowered to
11// complex instructions, and building a graph to represent the complex
12// structures. Starting from the "Converging Shuffle" (a shuffle that
13// reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14// operands are evaluated and identified as "Composite Nodes" (collections of
15// instructions that can potentially be lowered to a single complex
16// instruction). This is performed by checking the real and imaginary components
17// and tracking the data flow for each component while following the operand
18// pairs. Validity of each node is expected to be done upon creation, and any
19// validation errors should halt traversal and prevent further graph
20// construction.
21// Instead of relying on Shuffle operations, vector interleaving and
22// deinterleaving can be represented by vector.interleave2 and
23// vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24// these intrinsics, whereas, fixed-width vectors are recognized for both
25// shufflevector instruction and intrinsics.
26//
27// Replacement:
28// This step traverses the graph built up by identification, delegating to the
29// target to validate and generate the correct intrinsics, and plumbs them
30// together connecting each end of the new intrinsics graph to the existing
31// use-def chain. This step is assumed to finish successfully, as all
32// information is expected to be correct by this point.
33//
34//
35// Internal data structure:
36// ComplexDeinterleavingGraph:
37// Keeps references to all the valid CompositeNodes formed as part of the
38// transformation, and every Instruction contained within said nodes. It also
39// holds onto a reference to the root Instruction, and the root node that should
40// replace it.
41//
42// ComplexDeinterleavingCompositeNode:
43// A CompositeNode represents a single transformation point; each node should
44// transform into a single complex instruction (ignoring vector splitting, which
45// would generate more instructions per node). They are identified in a
46// depth-first manner, traversing and identifying the operands of each
47// instruction in the order they appear in the IR.
48// Each node maintains a reference to its Real and Imaginary instructions,
49// as well as any additional instructions that make up the identified operation
50// (Internal instructions should only have uses within their containing node).
51// A Node also contains the rotation and operation type that it represents.
52// Operands contains pointers to other CompositeNodes, acting as the edges in
53// the graph. ReplacementValue is the transformed Value* that has been emitted
54// to the IR.
55//
56// Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57// ReplacementValue fields of that Node are relevant, where the ReplacementValue
58// should be pre-populated.
59//
60//===----------------------------------------------------------------------===//
61
62#include "llvm/CodeGen/ComplexDeinterleavingPass.h"
63#include "llvm/ADT/AllocatorList.h"
64#include "llvm/ADT/MapVector.h"
65#include "llvm/ADT/Statistic.h"
66#include "llvm/Analysis/TargetLibraryInfo.h"
67#include "llvm/Analysis/TargetTransformInfo.h"
68#include "llvm/CodeGen/TargetLowering.h"
69#include "llvm/CodeGen/TargetSubtargetInfo.h"
70#include "llvm/IR/IRBuilder.h"
71#include "llvm/IR/Intrinsics.h"
72#include "llvm/IR/PatternMatch.h"
73#include "llvm/InitializePasses.h"
74#include "llvm/Support/Allocator.h"
75#include "llvm/Target/TargetMachine.h"
76#include "llvm/Transforms/Utils/Local.h"
77#include <algorithm>
78
79using namespace llvm;
80using namespace PatternMatch;
81
82#define DEBUG_TYPE "complex-deinterleaving"
83
84STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
85
86static cl::opt<bool> ComplexDeinterleavingEnabled(
87 "enable-complex-deinterleaving",
88 cl::desc("Enable generation of complex instructions"), cl::init(Val: true),
89 cl::Hidden);
90
91/// Checks the given mask, and determines whether said mask is interleaving.
92///
93/// To be interleaving, a mask must alternate between `i` and `i + (Length /
94/// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
95/// 4x vector interleaving mask would be <0, 2, 1, 3>).
96static bool isInterleavingMask(ArrayRef<int> Mask);
97
98/// Checks the given mask, and determines whether said mask is deinterleaving.
99///
100/// To be deinterleaving, a mask must increment in steps of 2, and either start
101/// with 0 or 1.
102/// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
103/// <1, 3, 5, 7>).
104static bool isDeinterleavingMask(ArrayRef<int> Mask);
105
106/// Returns true if the operation is a negation of V, and it works for both
107/// integers and floats.
108static bool isNeg(Value *V);
109
110/// Returns the operand for negation operation.
111static Value *getNegOperand(Value *V);
112
113namespace {
114struct ComplexValue {
115 Value *Real = nullptr;
116 Value *Imag = nullptr;
117
118 bool operator==(const ComplexValue &Other) const {
119 return Real == Other.Real && Imag == Other.Imag;
120 }
121};
122hash_code hash_value(const ComplexValue &Arg) {
123 return hash_combine(args: DenseMapInfo<Value *>::getHashValue(PtrVal: Arg.Real),
124 args: DenseMapInfo<Value *>::getHashValue(PtrVal: Arg.Imag));
125}
126} // end namespace
127typedef SmallVector<struct ComplexValue, 2> ComplexValues;
128
129template <> struct llvm::DenseMapInfo<ComplexValue> {
130 static inline ComplexValue getEmptyKey() {
131 return {.Real: DenseMapInfo<Value *>::getEmptyKey(),
132 .Imag: DenseMapInfo<Value *>::getEmptyKey()};
133 }
134 static inline ComplexValue getTombstoneKey() {
135 return {.Real: DenseMapInfo<Value *>::getTombstoneKey(),
136 .Imag: DenseMapInfo<Value *>::getTombstoneKey()};
137 }
138 static unsigned getHashValue(const ComplexValue &Val) {
139 return hash_combine(args: DenseMapInfo<Value *>::getHashValue(PtrVal: Val.Real),
140 args: DenseMapInfo<Value *>::getHashValue(PtrVal: Val.Imag));
141 }
142 static bool isEqual(const ComplexValue &LHS, const ComplexValue &RHS) {
143 return LHS.Real == RHS.Real && LHS.Imag == RHS.Imag;
144 }
145};
146
147namespace {
148template <typename T, typename IterT>
149std::optional<T> findCommonBetweenCollections(IterT A, IterT B) {
150 auto Common = llvm::find_if(A, [B](T I) { return llvm::is_contained(B, I); });
151 if (Common != A.end())
152 return std::make_optional(*Common);
153 return std::nullopt;
154}
155
156class ComplexDeinterleavingLegacyPass : public FunctionPass {
157public:
158 static char ID;
159
160 ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
161 : FunctionPass(ID), TM(TM) {}
162
163 StringRef getPassName() const override {
164 return "Complex Deinterleaving Pass";
165 }
166
167 bool runOnFunction(Function &F) override;
168 void getAnalysisUsage(AnalysisUsage &AU) const override {
169 AU.addRequired<TargetLibraryInfoWrapperPass>();
170 AU.setPreservesCFG();
171 }
172
173private:
174 const TargetMachine *TM;
175};
176
177class ComplexDeinterleavingGraph;
178struct ComplexDeinterleavingCompositeNode {
179
180 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
181 Value *R, Value *I)
182 : Operation(Op) {
183 Vals.push_back(Elt: {.Real: R, .Imag: I});
184 }
185
186 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
187 ComplexValues &Other)
188 : Operation(Op), Vals(Other) {}
189
190private:
191 friend class ComplexDeinterleavingGraph;
192 using CompositeNode = ComplexDeinterleavingCompositeNode;
193 bool OperandsValid = true;
194
195public:
196 ComplexDeinterleavingOperation Operation;
197 ComplexValues Vals;
198
199 // This two members are required exclusively for generating
200 // ComplexDeinterleavingOperation::Symmetric operations.
201 unsigned Opcode;
202 std::optional<FastMathFlags> Flags;
203
204 ComplexDeinterleavingRotation Rotation =
205 ComplexDeinterleavingRotation::Rotation_0;
206 SmallVector<CompositeNode *> Operands;
207 Value *ReplacementNode = nullptr;
208
209 void addOperand(CompositeNode *Node) {
210 if (!Node)
211 OperandsValid = false;
212 Operands.push_back(Elt: Node);
213 }
214
215 void dump() { dump(OS&: dbgs()); }
216 void dump(raw_ostream &OS) {
217 auto PrintValue = [&](Value *V) {
218 if (V) {
219 OS << "\"";
220 V->print(O&: OS, IsForDebug: true);
221 OS << "\"\n";
222 } else
223 OS << "nullptr\n";
224 };
225 auto PrintNodeRef = [&](CompositeNode *Ptr) {
226 if (Ptr)
227 OS << Ptr << "\n";
228 else
229 OS << "nullptr\n";
230 };
231
232 OS << "- CompositeNode: " << this << "\n";
233 for (unsigned I = 0; I < Vals.size(); I++) {
234 OS << " Real(" << I << ") : ";
235 PrintValue(Vals[I].Real);
236 OS << " Imag(" << I << ") : ";
237 PrintValue(Vals[I].Imag);
238 }
239 OS << " ReplacementNode: ";
240 PrintValue(ReplacementNode);
241 OS << " Operation: " << (int)Operation << "\n";
242 OS << " Rotation: " << ((int)Rotation * 90) << "\n";
243 OS << " Operands: \n";
244 for (const auto &Op : Operands) {
245 OS << " - ";
246 PrintNodeRef(Op);
247 }
248 }
249
250 bool areOperandsValid() { return OperandsValid; }
251};
252
253class ComplexDeinterleavingGraph {
254public:
255 struct Product {
256 Value *Multiplier;
257 Value *Multiplicand;
258 bool IsPositive;
259 };
260
261 using Addend = std::pair<Value *, bool>;
262 using AddendList = BumpPtrList<Addend>;
263 using CompositeNode = ComplexDeinterleavingCompositeNode::CompositeNode;
264
265 // Helper struct for holding info about potential partial multiplication
266 // candidates
267 struct PartialMulCandidate {
268 Value *Common;
269 CompositeNode *Node;
270 unsigned RealIdx;
271 unsigned ImagIdx;
272 bool IsNodeInverted;
273 };
274
275 explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
276 const TargetLibraryInfo *TLI,
277 unsigned Factor)
278 : TL(TL), TLI(TLI), Factor(Factor) {}
279
280private:
281 const TargetLowering *TL = nullptr;
282 const TargetLibraryInfo *TLI = nullptr;
283 unsigned Factor;
284 SmallVector<CompositeNode *> CompositeNodes;
285 DenseMap<ComplexValues, CompositeNode *> CachedResult;
286 SpecificBumpPtrAllocator<ComplexDeinterleavingCompositeNode> Allocator;
287
288 SmallPtrSet<Instruction *, 16> FinalInstructions;
289
290 /// Root instructions are instructions from which complex computation starts
291 DenseMap<Instruction *, CompositeNode *> RootToNode;
292
293 /// Topologically sorted root instructions
294 SmallVector<Instruction *, 1> OrderedRoots;
295
296 /// When examining a basic block for complex deinterleaving, if it is a simple
297 /// one-block loop, then the only incoming block is 'Incoming' and the
298 /// 'BackEdge' block is the block itself."
299 BasicBlock *BackEdge = nullptr;
300 BasicBlock *Incoming = nullptr;
301
302 /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
303 /// %OutsideUser as it is shown in the IR:
304 ///
305 /// vector.body:
306 /// %PHInode = phi <vector type> [ zeroinitializer, %entry ],
307 /// [ %ReductionOp, %vector.body ]
308 /// ...
309 /// %ReductionOp = fadd i64 ...
310 /// ...
311 /// br i1 %condition, label %vector.body, %middle.block
312 ///
313 /// middle.block:
314 /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
315 ///
316 /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
317 /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
318 MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
319
320 /// In the process of detecting a reduction, we consider a pair of
321 /// %ReductionOP, which we refer to as real and imag (or vice versa), and
322 /// traverse the use-tree to detect complex operations. As this is a reduction
323 /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
324 /// to the %ReductionOPs that we suspect to be complex.
325 /// RealPHI and ImagPHI are used by the identifyPHINode method.
326 PHINode *RealPHI = nullptr;
327 PHINode *ImagPHI = nullptr;
328
329 /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
330 /// detection.
331 bool PHIsFound = false;
332
333 /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
334 /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
335 /// This mapping is populated during
336 /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
337 /// used in the ComplexDeinterleavingOperation::ReductionOperation node
338 /// replacement process.
339 DenseMap<PHINode *, PHINode *> OldToNewPHI;
340
341 CompositeNode *prepareCompositeNode(ComplexDeinterleavingOperation Operation,
342 Value *R, Value *I) {
343 assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
344 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
345 (R && I)) &&
346 "Reduction related nodes must have Real and Imaginary parts");
347 return new (Allocator.Allocate())
348 ComplexDeinterleavingCompositeNode(Operation, R, I);
349 }
350
351 CompositeNode *prepareCompositeNode(ComplexDeinterleavingOperation Operation,
352 ComplexValues &Vals) {
353#ifndef NDEBUG
354 for (auto &V : Vals) {
355 assert(
356 ((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
357 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
358 (V.Real && V.Imag)) &&
359 "Reduction related nodes must have Real and Imaginary parts");
360 }
361#endif
362 return new (Allocator.Allocate())
363 ComplexDeinterleavingCompositeNode(Operation, Vals);
364 }
365
366 CompositeNode *submitCompositeNode(CompositeNode *Node) {
367 CompositeNodes.push_back(Elt: Node);
368 if (Node->Vals[0].Real)
369 CachedResult[Node->Vals] = Node;
370 return Node;
371 }
372
373 /// Identifies a complex partial multiply pattern and its rotation, based on
374 /// the following patterns
375 ///
376 /// 0: r: cr + ar * br
377 /// i: ci + ar * bi
378 /// 90: r: cr - ai * bi
379 /// i: ci + ai * br
380 /// 180: r: cr - ar * br
381 /// i: ci - ar * bi
382 /// 270: r: cr + ai * bi
383 /// i: ci - ai * br
384 CompositeNode *identifyPartialMul(Instruction *Real, Instruction *Imag);
385
386 /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
387 /// is partially known from identifyPartialMul, filling in the other half of
388 /// the complex pair.
389 CompositeNode *
390 identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
391 std::pair<Value *, Value *> &CommonOperandI);
392
393 /// Identifies a complex add pattern and its rotation, based on the following
394 /// patterns.
395 ///
396 /// 90: r: ar - bi
397 /// i: ai + br
398 /// 270: r: ar + bi
399 /// i: ai - br
400 CompositeNode *identifyAdd(Instruction *Real, Instruction *Imag);
401 CompositeNode *identifySymmetricOperation(ComplexValues &Vals);
402 CompositeNode *identifyPartialReduction(Value *R, Value *I);
403 CompositeNode *identifyDotProduct(Value *Inst);
404
405 CompositeNode *identifyNode(ComplexValues &Vals);
406
407 CompositeNode *identifyNode(Value *R, Value *I) {
408 ComplexValues Vals;
409 Vals.push_back(Elt: {.Real: R, .Imag: I});
410 return identifyNode(Vals);
411 }
412
413 /// Determine if a sum of complex numbers can be formed from \p RealAddends
414 /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
415 /// Return nullptr if it is not possible to construct a complex number.
416 /// \p Flags are needed to generate symmetric Add and Sub operations.
417 CompositeNode *identifyAdditions(AddendList &RealAddends,
418 AddendList &ImagAddends,
419 std::optional<FastMathFlags> Flags,
420 CompositeNode *Accumulator);
421
422 /// Extract one addend that have both real and imaginary parts positive.
423 CompositeNode *extractPositiveAddend(AddendList &RealAddends,
424 AddendList &ImagAddends);
425
426 /// Determine if sum of multiplications of complex numbers can be formed from
427 /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
428 /// to it. Return nullptr if it is not possible to construct a complex number.
429 CompositeNode *identifyMultiplications(SmallVectorImpl<Product> &RealMuls,
430 SmallVectorImpl<Product> &ImagMuls,
431 CompositeNode *Accumulator);
432
433 /// Go through pairs of multiplication (one Real and one Imag) and find all
434 /// possible candidates for partial multiplication and put them into \p
435 /// Candidates. Returns true if all Product has pair with common operand
436 bool collectPartialMuls(ArrayRef<Product> RealMuls,
437 ArrayRef<Product> ImagMuls,
438 SmallVectorImpl<PartialMulCandidate> &Candidates);
439
440 /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
441 /// the order of complex computation operations may be significantly altered,
442 /// and the real and imaginary parts may not be executed in parallel. This
443 /// function takes this into consideration and employs a more general approach
444 /// to identify complex computations. Initially, it gathers all the addends
445 /// and multiplicands and then constructs a complex expression from them.
446 CompositeNode *identifyReassocNodes(Instruction *I, Instruction *J);
447
448 CompositeNode *identifyRoot(Instruction *I);
449
450 /// Identifies the Deinterleave operation applied to a vector containing
451 /// complex numbers. There are two ways to represent the Deinterleave
452 /// operation:
453 /// * Using two shufflevectors with even indices for /pReal instruction and
454 /// odd indices for /pImag instructions (only for fixed-width vectors)
455 /// * Using N extractvalue instructions applied to `vector.deinterleaveN`
456 /// intrinsics (for both fixed and scalable vectors) where N is a multiple of
457 /// 2.
458 CompositeNode *identifyDeinterleave(ComplexValues &Vals);
459
460 /// identifying the operation that represents a complex number repeated in a
461 /// Splat vector. There are two possible types of splats: ConstantExpr with
462 /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
463 /// initialization mask with all values set to zero.
464 CompositeNode *identifySplat(ComplexValues &Vals);
465
466 CompositeNode *identifyPHINode(Instruction *Real, Instruction *Imag);
467
468 /// Identifies SelectInsts in a loop that has reduction with predication masks
469 /// and/or predicated tail folding
470 CompositeNode *identifySelectNode(Instruction *Real, Instruction *Imag);
471
472 Value *replaceNode(IRBuilderBase &Builder, CompositeNode *Node);
473
474 /// Complete IR modifications after producing new reduction operation:
475 /// * Populate the PHINode generated for
476 /// ComplexDeinterleavingOperation::ReductionPHI
477 /// * Deinterleave the final value outside of the loop and repurpose original
478 /// reduction users
479 void processReductionOperation(Value *OperationReplacement,
480 CompositeNode *Node);
481 void processReductionSingle(Value *OperationReplacement, CompositeNode *Node);
482
483public:
484 void dump() { dump(OS&: dbgs()); }
485 void dump(raw_ostream &OS) {
486 for (const auto &Node : CompositeNodes)
487 Node->dump(OS);
488 }
489
490 /// Returns false if the deinterleaving operation should be cancelled for the
491 /// current graph.
492 bool identifyNodes(Instruction *RootI);
493
494 /// In case \pB is one-block loop, this function seeks potential reductions
495 /// and populates ReductionInfo. Returns true if any reductions were
496 /// identified.
497 bool collectPotentialReductions(BasicBlock *B);
498
499 void identifyReductionNodes();
500
501 /// Check that every instruction, from the roots to the leaves, has internal
502 /// uses.
503 bool checkNodes();
504
505 /// Perform the actual replacement of the underlying instruction graph.
506 void replaceNodes();
507};
508
509class ComplexDeinterleaving {
510public:
511 ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
512 : TL(tl), TLI(tli) {}
513 bool runOnFunction(Function &F);
514
515private:
516 bool evaluateBasicBlock(BasicBlock *B, unsigned Factor);
517
518 const TargetLowering *TL = nullptr;
519 const TargetLibraryInfo *TLI = nullptr;
520};
521
522} // namespace
523
524char ComplexDeinterleavingLegacyPass::ID = 0;
525
526INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
527 "Complex Deinterleaving", false, false)
528INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
529 "Complex Deinterleaving", false, false)
530
531PreservedAnalyses ComplexDeinterleavingPass::run(Function &F,
532 FunctionAnalysisManager &AM) {
533 const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
534 auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(IR&: F);
535 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
536 return PreservedAnalyses::all();
537
538 PreservedAnalyses PA;
539 PA.preserve<FunctionAnalysisManagerModuleProxy>();
540 return PA;
541}
542
543FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) {
544 return new ComplexDeinterleavingLegacyPass(TM);
545}
546
547bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
548 const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
549 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
550 return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
551}
552
553bool ComplexDeinterleaving::runOnFunction(Function &F) {
554 if (!ComplexDeinterleavingEnabled) {
555 LLVM_DEBUG(
556 dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
557 return false;
558 }
559
560 if (!TL->isComplexDeinterleavingSupported()) {
561 LLVM_DEBUG(
562 dbgs() << "Complex deinterleaving has been disabled, target does "
563 "not support lowering of complex number operations.\n");
564 return false;
565 }
566
567 bool Changed = false;
568 for (auto &B : F)
569 Changed |= evaluateBasicBlock(B: &B, Factor: 2);
570
571 // TODO: Permit changes for both interleave factors in the same function.
572 if (!Changed) {
573 for (auto &B : F)
574 Changed |= evaluateBasicBlock(B: &B, Factor: 4);
575 }
576
577 // TODO: We can also support interleave factors of 6 and 8 if needed.
578
579 return Changed;
580}
581
582static bool isInterleavingMask(ArrayRef<int> Mask) {
583 // If the size is not even, it's not an interleaving mask
584 if ((Mask.size() & 1))
585 return false;
586
587 int HalfNumElements = Mask.size() / 2;
588 for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
589 int MaskIdx = Idx * 2;
590 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
591 return false;
592 }
593
594 return true;
595}
596
597static bool isDeinterleavingMask(ArrayRef<int> Mask) {
598 int Offset = Mask[0];
599 int HalfNumElements = Mask.size() / 2;
600
601 for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
602 if (Mask[Idx] != (Idx * 2) + Offset)
603 return false;
604 }
605
606 return true;
607}
608
609bool isNeg(Value *V) {
610 return match(V, P: m_FNeg(X: m_Value())) || match(V, P: m_Neg(V: m_Value()));
611}
612
613Value *getNegOperand(Value *V) {
614 assert(isNeg(V));
615 auto *I = cast<Instruction>(Val: V);
616 if (I->getOpcode() == Instruction::FNeg)
617 return I->getOperand(i: 0);
618
619 return I->getOperand(i: 1);
620}
621
622bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B, unsigned Factor) {
623 ComplexDeinterleavingGraph Graph(TL, TLI, Factor);
624 if (Graph.collectPotentialReductions(B))
625 Graph.identifyReductionNodes();
626
627 for (auto &I : *B)
628 Graph.identifyNodes(RootI: &I);
629
630 if (Graph.checkNodes()) {
631 Graph.replaceNodes();
632 return true;
633 }
634
635 return false;
636}
637
638ComplexDeinterleavingGraph::CompositeNode *
639ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
640 Instruction *Real, Instruction *Imag,
641 std::pair<Value *, Value *> &PartialMatch) {
642 LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
643 << "\n");
644
645 if (!Real->hasOneUse() || !Imag->hasOneUse()) {
646 LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n");
647 return nullptr;
648 }
649
650 if ((Real->getOpcode() != Instruction::FMul &&
651 Real->getOpcode() != Instruction::Mul) ||
652 (Imag->getOpcode() != Instruction::FMul &&
653 Imag->getOpcode() != Instruction::Mul)) {
654 LLVM_DEBUG(
655 dbgs() << " - Real or imaginary instruction is not fmul or mul\n");
656 return nullptr;
657 }
658
659 Value *R0 = Real->getOperand(i: 0);
660 Value *R1 = Real->getOperand(i: 1);
661 Value *I0 = Imag->getOperand(i: 0);
662 Value *I1 = Imag->getOperand(i: 1);
663
664 // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
665 // rotations and use the operand.
666 unsigned Negs = 0;
667 Value *Op;
668 if (match(V: R0, P: m_Neg(V: m_Value(V&: Op)))) {
669 Negs |= 1;
670 R0 = Op;
671 } else if (match(V: R1, P: m_Neg(V: m_Value(V&: Op)))) {
672 Negs |= 1;
673 R1 = Op;
674 }
675
676 if (isNeg(V: I0)) {
677 Negs |= 2;
678 Negs ^= 1;
679 I0 = Op;
680 } else if (match(V: I1, P: m_Neg(V: m_Value(V&: Op)))) {
681 Negs |= 2;
682 Negs ^= 1;
683 I1 = Op;
684 }
685
686 ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;
687
688 Value *CommonOperand;
689 Value *UncommonRealOp;
690 Value *UncommonImagOp;
691
692 if (R0 == I0 || R0 == I1) {
693 CommonOperand = R0;
694 UncommonRealOp = R1;
695 } else if (R1 == I0 || R1 == I1) {
696 CommonOperand = R1;
697 UncommonRealOp = R0;
698 } else {
699 LLVM_DEBUG(dbgs() << " - No equal operand\n");
700 return nullptr;
701 }
702
703 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
704 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
705 Rotation == ComplexDeinterleavingRotation::Rotation_270)
706 std::swap(a&: UncommonRealOp, b&: UncommonImagOp);
707
708 // Between identifyPartialMul and here we need to have found a complete valid
709 // pair from the CommonOperand of each part.
710 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
711 Rotation == ComplexDeinterleavingRotation::Rotation_180)
712 PartialMatch.first = CommonOperand;
713 else
714 PartialMatch.second = CommonOperand;
715
716 if (!PartialMatch.first || !PartialMatch.second) {
717 LLVM_DEBUG(dbgs() << " - Incomplete partial match\n");
718 return nullptr;
719 }
720
721 CompositeNode *CommonNode =
722 identifyNode(R: PartialMatch.first, I: PartialMatch.second);
723 if (!CommonNode) {
724 LLVM_DEBUG(dbgs() << " - No CommonNode identified\n");
725 return nullptr;
726 }
727
728 CompositeNode *UncommonNode = identifyNode(R: UncommonRealOp, I: UncommonImagOp);
729 if (!UncommonNode) {
730 LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n");
731 return nullptr;
732 }
733
734 CompositeNode *Node = prepareCompositeNode(
735 Operation: ComplexDeinterleavingOperation::CMulPartial, R: Real, I: Imag);
736 Node->Rotation = Rotation;
737 Node->addOperand(Node: CommonNode);
738 Node->addOperand(Node: UncommonNode);
739 return submitCompositeNode(Node);
740}
741
742ComplexDeinterleavingGraph::CompositeNode *
743ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
744 Instruction *Imag) {
745 LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
746 << "\n");
747
748 // Determine rotation
749 auto IsAdd = [](unsigned Op) {
750 return Op == Instruction::FAdd || Op == Instruction::Add;
751 };
752 auto IsSub = [](unsigned Op) {
753 return Op == Instruction::FSub || Op == Instruction::Sub;
754 };
755 ComplexDeinterleavingRotation Rotation;
756 if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
757 Rotation = ComplexDeinterleavingRotation::Rotation_0;
758 else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
759 Rotation = ComplexDeinterleavingRotation::Rotation_90;
760 else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
761 Rotation = ComplexDeinterleavingRotation::Rotation_180;
762 else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
763 Rotation = ComplexDeinterleavingRotation::Rotation_270;
764 else {
765 LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n");
766 return nullptr;
767 }
768
769 if (isa<FPMathOperator>(Val: Real) &&
770 (!Real->getFastMathFlags().allowContract() ||
771 !Imag->getFastMathFlags().allowContract())) {
772 LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n");
773 return nullptr;
774 }
775
776 Value *CR = Real->getOperand(i: 0);
777 Instruction *RealMulI = dyn_cast<Instruction>(Val: Real->getOperand(i: 1));
778 if (!RealMulI)
779 return nullptr;
780 Value *CI = Imag->getOperand(i: 0);
781 Instruction *ImagMulI = dyn_cast<Instruction>(Val: Imag->getOperand(i: 1));
782 if (!ImagMulI)
783 return nullptr;
784
785 if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
786 LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n");
787 return nullptr;
788 }
789
790 Value *R0 = RealMulI->getOperand(i: 0);
791 Value *R1 = RealMulI->getOperand(i: 1);
792 Value *I0 = ImagMulI->getOperand(i: 0);
793 Value *I1 = ImagMulI->getOperand(i: 1);
794
795 Value *CommonOperand;
796 Value *UncommonRealOp;
797 Value *UncommonImagOp;
798
799 if (R0 == I0 || R0 == I1) {
800 CommonOperand = R0;
801 UncommonRealOp = R1;
802 } else if (R1 == I0 || R1 == I1) {
803 CommonOperand = R1;
804 UncommonRealOp = R0;
805 } else {
806 LLVM_DEBUG(dbgs() << " - No equal operand\n");
807 return nullptr;
808 }
809
810 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
811 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
812 Rotation == ComplexDeinterleavingRotation::Rotation_270)
813 std::swap(a&: UncommonRealOp, b&: UncommonImagOp);
814
815 std::pair<Value *, Value *> PartialMatch(
816 (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
817 Rotation == ComplexDeinterleavingRotation::Rotation_180)
818 ? CommonOperand
819 : nullptr,
820 (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
821 Rotation == ComplexDeinterleavingRotation::Rotation_270)
822 ? CommonOperand
823 : nullptr);
824
825 auto *CRInst = dyn_cast<Instruction>(Val: CR);
826 auto *CIInst = dyn_cast<Instruction>(Val: CI);
827
828 if (!CRInst || !CIInst) {
829 LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n");
830 return nullptr;
831 }
832
833 CompositeNode *CNode =
834 identifyNodeWithImplicitAdd(Real: CRInst, Imag: CIInst, PartialMatch);
835 if (!CNode) {
836 LLVM_DEBUG(dbgs() << " - No cnode identified\n");
837 return nullptr;
838 }
839
840 CompositeNode *UncommonRes = identifyNode(R: UncommonRealOp, I: UncommonImagOp);
841 if (!UncommonRes) {
842 LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n");
843 return nullptr;
844 }
845
846 assert(PartialMatch.first && PartialMatch.second);
847 CompositeNode *CommonRes =
848 identifyNode(R: PartialMatch.first, I: PartialMatch.second);
849 if (!CommonRes) {
850 LLVM_DEBUG(dbgs() << " - No CommonRes identified\n");
851 return nullptr;
852 }
853
854 CompositeNode *Node = prepareCompositeNode(
855 Operation: ComplexDeinterleavingOperation::CMulPartial, R: Real, I: Imag);
856 Node->Rotation = Rotation;
857 Node->addOperand(Node: CommonRes);
858 Node->addOperand(Node: UncommonRes);
859 Node->addOperand(Node: CNode);
860 return submitCompositeNode(Node);
861}
862
863ComplexDeinterleavingGraph::CompositeNode *
864ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
865 LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
866
867 // Determine rotation
868 ComplexDeinterleavingRotation Rotation;
869 if ((Real->getOpcode() == Instruction::FSub &&
870 Imag->getOpcode() == Instruction::FAdd) ||
871 (Real->getOpcode() == Instruction::Sub &&
872 Imag->getOpcode() == Instruction::Add))
873 Rotation = ComplexDeinterleavingRotation::Rotation_90;
874 else if ((Real->getOpcode() == Instruction::FAdd &&
875 Imag->getOpcode() == Instruction::FSub) ||
876 (Real->getOpcode() == Instruction::Add &&
877 Imag->getOpcode() == Instruction::Sub))
878 Rotation = ComplexDeinterleavingRotation::Rotation_270;
879 else {
880 LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
881 return nullptr;
882 }
883
884 auto *AR = dyn_cast<Instruction>(Val: Real->getOperand(i: 0));
885 auto *BI = dyn_cast<Instruction>(Val: Real->getOperand(i: 1));
886 auto *AI = dyn_cast<Instruction>(Val: Imag->getOperand(i: 0));
887 auto *BR = dyn_cast<Instruction>(Val: Imag->getOperand(i: 1));
888
889 if (!AR || !AI || !BR || !BI) {
890 LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
891 return nullptr;
892 }
893
894 CompositeNode *ResA = identifyNode(R: AR, I: AI);
895 if (!ResA) {
896 LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
897 return nullptr;
898 }
899 CompositeNode *ResB = identifyNode(R: BR, I: BI);
900 if (!ResB) {
901 LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
902 return nullptr;
903 }
904
905 CompositeNode *Node =
906 prepareCompositeNode(Operation: ComplexDeinterleavingOperation::CAdd, R: Real, I: Imag);
907 Node->Rotation = Rotation;
908 Node->addOperand(Node: ResA);
909 Node->addOperand(Node: ResB);
910 return submitCompositeNode(Node);
911}
912
913static bool isInstructionPairAdd(Instruction *A, Instruction *B) {
914 unsigned OpcA = A->getOpcode();
915 unsigned OpcB = B->getOpcode();
916
917 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
918 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
919 (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
920 (OpcA == Instruction::Add && OpcB == Instruction::Sub);
921}
922
923static bool isInstructionPairMul(Instruction *A, Instruction *B) {
924 auto Pattern =
925 m_BinOp(L: m_FMul(L: m_Value(), R: m_Value()), R: m_FMul(L: m_Value(), R: m_Value()));
926
927 return match(V: A, P: Pattern) && match(V: B, P: Pattern);
928}
929
930static bool isInstructionPotentiallySymmetric(Instruction *I) {
931 switch (I->getOpcode()) {
932 case Instruction::FAdd:
933 case Instruction::FSub:
934 case Instruction::FMul:
935 case Instruction::FNeg:
936 case Instruction::Add:
937 case Instruction::Sub:
938 case Instruction::Mul:
939 return true;
940 default:
941 return false;
942 }
943}
944
945ComplexDeinterleavingGraph::CompositeNode *
946ComplexDeinterleavingGraph::identifySymmetricOperation(ComplexValues &Vals) {
947 auto *FirstReal = cast<Instruction>(Val: Vals[0].Real);
948 unsigned FirstOpc = FirstReal->getOpcode();
949 for (auto &V : Vals) {
950 auto *Real = cast<Instruction>(Val: V.Real);
951 auto *Imag = cast<Instruction>(Val: V.Imag);
952 if (Real->getOpcode() != FirstOpc || Imag->getOpcode() != FirstOpc)
953 return nullptr;
954
955 if (!isInstructionPotentiallySymmetric(I: Real) ||
956 !isInstructionPotentiallySymmetric(I: Imag))
957 return nullptr;
958
959 if (isa<FPMathOperator>(Val: FirstReal))
960 if (Real->getFastMathFlags() != FirstReal->getFastMathFlags() ||
961 Imag->getFastMathFlags() != FirstReal->getFastMathFlags())
962 return nullptr;
963 }
964
965 ComplexValues OpVals;
966 for (auto &V : Vals) {
967 auto *R0 = cast<Instruction>(Val: V.Real)->getOperand(i: 0);
968 auto *I0 = cast<Instruction>(Val: V.Imag)->getOperand(i: 0);
969 OpVals.push_back(Elt: {.Real: R0, .Imag: I0});
970 }
971
972 CompositeNode *Op0 = identifyNode(Vals&: OpVals);
973 CompositeNode *Op1 = nullptr;
974 if (Op0 == nullptr)
975 return nullptr;
976
977 if (FirstReal->isBinaryOp()) {
978 OpVals.clear();
979 for (auto &V : Vals) {
980 auto *R1 = cast<Instruction>(Val: V.Real)->getOperand(i: 1);
981 auto *I1 = cast<Instruction>(Val: V.Imag)->getOperand(i: 1);
982 OpVals.push_back(Elt: {.Real: R1, .Imag: I1});
983 }
984 Op1 = identifyNode(Vals&: OpVals);
985 if (Op1 == nullptr)
986 return nullptr;
987 }
988
989 auto Node =
990 prepareCompositeNode(Operation: ComplexDeinterleavingOperation::Symmetric, Vals);
991 Node->Opcode = FirstReal->getOpcode();
992 if (isa<FPMathOperator>(Val: FirstReal))
993 Node->Flags = FirstReal->getFastMathFlags();
994
995 Node->addOperand(Node: Op0);
996 if (FirstReal->isBinaryOp())
997 Node->addOperand(Node: Op1);
998
999 return submitCompositeNode(Node);
1000}
1001
1002ComplexDeinterleavingGraph::CompositeNode *
1003ComplexDeinterleavingGraph::identifyDotProduct(Value *V) {
1004 if (!TL->isComplexDeinterleavingOperationSupported(
1005 Operation: ComplexDeinterleavingOperation::CDot, Ty: V->getType())) {
1006 LLVM_DEBUG(dbgs() << "Target doesn't support complex deinterleaving "
1007 "operation CDot with the type "
1008 << *V->getType() << "\n");
1009 return nullptr;
1010 }
1011
1012 auto *Inst = cast<Instruction>(Val: V);
1013 auto *RealUser = cast<Instruction>(Val: *Inst->user_begin());
1014
1015 CompositeNode *CN =
1016 prepareCompositeNode(Operation: ComplexDeinterleavingOperation::CDot, R: Inst, I: nullptr);
1017
1018 CompositeNode *ANode = nullptr;
1019
1020 const Intrinsic::ID PartialReduceInt = Intrinsic::vector_partial_reduce_add;
1021
1022 Value *AReal = nullptr;
1023 Value *AImag = nullptr;
1024 Value *BReal = nullptr;
1025 Value *BImag = nullptr;
1026 Value *Phi = nullptr;
1027
1028 auto UnwrapCast = [](Value *V) -> Value * {
1029 if (auto *CI = dyn_cast<CastInst>(Val: V))
1030 return CI->getOperand(i_nocapture: 0);
1031 return V;
1032 };
1033
1034 auto PatternRot0 = m_Intrinsic<PartialReduceInt>(
1035 Op0: m_Intrinsic<PartialReduceInt>(Op0: m_Value(V&: Phi),
1036 Op1: m_Mul(L: m_Value(V&: BReal), R: m_Value(V&: AReal))),
1037 Op1: m_Neg(V: m_Mul(L: m_Value(V&: BImag), R: m_Value(V&: AImag))));
1038
1039 auto PatternRot270 = m_Intrinsic<PartialReduceInt>(
1040 Op0: m_Intrinsic<PartialReduceInt>(
1041 Op0: m_Value(V&: Phi), Op1: m_Neg(V: m_Mul(L: m_Value(V&: BReal), R: m_Value(V&: AImag)))),
1042 Op1: m_Mul(L: m_Value(V&: BImag), R: m_Value(V&: AReal)));
1043
1044 if (match(V: Inst, P: PatternRot0)) {
1045 CN->Rotation = ComplexDeinterleavingRotation::Rotation_0;
1046 } else if (match(V: Inst, P: PatternRot270)) {
1047 CN->Rotation = ComplexDeinterleavingRotation::Rotation_270;
1048 } else {
1049 Value *A0, *A1;
1050 // The rotations 90 and 180 share the same operation pattern, so inspect the
1051 // order of the operands, identifying where the real and imaginary
1052 // components of A go, to discern between the aforementioned rotations.
1053 auto PatternRot90Rot180 = m_Intrinsic<PartialReduceInt>(
1054 Op0: m_Intrinsic<PartialReduceInt>(Op0: m_Value(V&: Phi),
1055 Op1: m_Mul(L: m_Value(V&: BReal), R: m_Value(V&: A0))),
1056 Op1: m_Mul(L: m_Value(V&: BImag), R: m_Value(V&: A1)));
1057
1058 if (!match(V: Inst, P: PatternRot90Rot180))
1059 return nullptr;
1060
1061 A0 = UnwrapCast(A0);
1062 A1 = UnwrapCast(A1);
1063
1064 // Test if A0 is real/A1 is imag
1065 ANode = identifyNode(R: A0, I: A1);
1066 if (!ANode) {
1067 // Test if A0 is imag/A1 is real
1068 ANode = identifyNode(R: A1, I: A0);
1069 // Unable to identify operand components, thus unable to identify rotation
1070 if (!ANode)
1071 return nullptr;
1072 CN->Rotation = ComplexDeinterleavingRotation::Rotation_90;
1073 AReal = A1;
1074 AImag = A0;
1075 } else {
1076 AReal = A0;
1077 AImag = A1;
1078 CN->Rotation = ComplexDeinterleavingRotation::Rotation_180;
1079 }
1080 }
1081
1082 AReal = UnwrapCast(AReal);
1083 AImag = UnwrapCast(AImag);
1084 BReal = UnwrapCast(BReal);
1085 BImag = UnwrapCast(BImag);
1086
1087 VectorType *VTy = cast<VectorType>(Val: V->getType());
1088 Type *ExpectedOperandTy = VectorType::getSubdividedVectorType(VTy, NumSubdivs: 2);
1089 if (AReal->getType() != ExpectedOperandTy)
1090 return nullptr;
1091 if (AImag->getType() != ExpectedOperandTy)
1092 return nullptr;
1093 if (BReal->getType() != ExpectedOperandTy)
1094 return nullptr;
1095 if (BImag->getType() != ExpectedOperandTy)
1096 return nullptr;
1097
1098 if (Phi->getType() != VTy && RealUser->getType() != VTy)
1099 return nullptr;
1100
1101 CompositeNode *Node = identifyNode(R: AReal, I: AImag);
1102
1103 // In the case that a node was identified to figure out the rotation, ensure
1104 // that trying to identify a node with AReal and AImag post-unwrap results in
1105 // the same node
1106 if (ANode && Node != ANode) {
1107 LLVM_DEBUG(
1108 dbgs()
1109 << "Identified node is different from previously identified node. "
1110 "Unable to confidently generate a complex operation node\n");
1111 return nullptr;
1112 }
1113
1114 CN->addOperand(Node);
1115 CN->addOperand(Node: identifyNode(R: BReal, I: BImag));
1116 CN->addOperand(Node: identifyNode(R: Phi, I: RealUser));
1117
1118 return submitCompositeNode(Node: CN);
1119}
1120
1121ComplexDeinterleavingGraph::CompositeNode *
1122ComplexDeinterleavingGraph::identifyPartialReduction(Value *R, Value *I) {
1123 // Partial reductions don't support non-vector types, so check these first
1124 if (!isa<VectorType>(Val: R->getType()) || !isa<VectorType>(Val: I->getType()))
1125 return nullptr;
1126
1127 if (!R->hasUseList() || !I->hasUseList())
1128 return nullptr;
1129
1130 auto CommonUser =
1131 findCommonBetweenCollections<Value *>(A: R->users(), B: I->users());
1132 if (!CommonUser)
1133 return nullptr;
1134
1135 auto *IInst = dyn_cast<IntrinsicInst>(Val: *CommonUser);
1136 if (!IInst || IInst->getIntrinsicID() != Intrinsic::vector_partial_reduce_add)
1137 return nullptr;
1138
1139 if (CompositeNode *CN = identifyDotProduct(V: IInst))
1140 return CN;
1141
1142 return nullptr;
1143}
1144
1145ComplexDeinterleavingGraph::CompositeNode *
1146ComplexDeinterleavingGraph::identifyNode(ComplexValues &Vals) {
1147 auto It = CachedResult.find(Val: Vals);
1148 if (It != CachedResult.end()) {
1149 LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
1150 return It->second;
1151 }
1152
1153 if (Vals.size() == 1) {
1154 assert(Factor == 2 && "Can only handle interleave factors of 2");
1155 Value *R = Vals[0].Real;
1156 Value *I = Vals[0].Imag;
1157 if (CompositeNode *CN = identifyPartialReduction(R, I))
1158 return CN;
1159 bool IsReduction = RealPHI == R && (!ImagPHI || ImagPHI == I);
1160 if (!IsReduction && R->getType() != I->getType())
1161 return nullptr;
1162 }
1163
1164 if (CompositeNode *CN = identifySplat(Vals))
1165 return CN;
1166
1167 for (auto &V : Vals) {
1168 auto *Real = dyn_cast<Instruction>(Val: V.Real);
1169 auto *Imag = dyn_cast<Instruction>(Val: V.Imag);
1170 if (!Real || !Imag)
1171 return nullptr;
1172 }
1173
1174 if (CompositeNode *CN = identifyDeinterleave(Vals))
1175 return CN;
1176
1177 if (Vals.size() == 1) {
1178 assert(Factor == 2 && "Can only handle interleave factors of 2");
1179 auto *Real = dyn_cast<Instruction>(Val: Vals[0].Real);
1180 auto *Imag = dyn_cast<Instruction>(Val: Vals[0].Imag);
1181 if (CompositeNode *CN = identifyPHINode(Real, Imag))
1182 return CN;
1183
1184 if (CompositeNode *CN = identifySelectNode(Real, Imag))
1185 return CN;
1186
1187 auto *VTy = cast<VectorType>(Val: Real->getType());
1188 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1189
1190 bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
1191 Operation: ComplexDeinterleavingOperation::CMulPartial, Ty: NewVTy);
1192 bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
1193 Operation: ComplexDeinterleavingOperation::CAdd, Ty: NewVTy);
1194
1195 if (HasCMulSupport && isInstructionPairMul(A: Real, B: Imag)) {
1196 if (CompositeNode *CN = identifyPartialMul(Real, Imag))
1197 return CN;
1198 }
1199
1200 if (HasCAddSupport && isInstructionPairAdd(A: Real, B: Imag)) {
1201 if (CompositeNode *CN = identifyAdd(Real, Imag))
1202 return CN;
1203 }
1204
1205 if (HasCMulSupport && HasCAddSupport) {
1206 if (CompositeNode *CN = identifyReassocNodes(I: Real, J: Imag)) {
1207 return CN;
1208 }
1209 }
1210 }
1211
1212 if (CompositeNode *CN = identifySymmetricOperation(Vals))
1213 return CN;
1214
1215 LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n");
1216 CachedResult[Vals] = nullptr;
1217 return nullptr;
1218}
1219
1220ComplexDeinterleavingGraph::CompositeNode *
1221ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
1222 Instruction *Imag) {
1223 auto IsOperationSupported = [](unsigned Opcode) -> bool {
1224 return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
1225 Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
1226 Opcode == Instruction::Sub;
1227 };
1228
1229 if (!IsOperationSupported(Real->getOpcode()) ||
1230 !IsOperationSupported(Imag->getOpcode()))
1231 return nullptr;
1232
1233 std::optional<FastMathFlags> Flags;
1234 if (isa<FPMathOperator>(Val: Real)) {
1235 if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
1236 LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
1237 "not identical\n");
1238 return nullptr;
1239 }
1240
1241 Flags = Real->getFastMathFlags();
1242 if (!Flags->allowReassoc()) {
1243 LLVM_DEBUG(
1244 dbgs()
1245 << "the 'Reassoc' attribute is missing in the FastMath flags\n");
1246 return nullptr;
1247 }
1248 }
1249
1250 // Collect multiplications and addend instructions from the given instruction
1251 // while traversing it operands. Additionally, verify that all instructions
1252 // have the same fast math flags.
1253 auto Collect = [&Flags](Instruction *Insn, SmallVectorImpl<Product> &Muls,
1254 AddendList &Addends) -> bool {
1255 SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
1256 SmallPtrSet<Value *, 8> Visited;
1257 while (!Worklist.empty()) {
1258 auto [V, IsPositive] = Worklist.pop_back_val();
1259 if (!Visited.insert(Ptr: V).second)
1260 continue;
1261
1262 Instruction *I = dyn_cast<Instruction>(Val: V);
1263 if (!I) {
1264 Addends.emplace_back(Vs&: V, Vs&: IsPositive);
1265 continue;
1266 }
1267
1268 // If an instruction has more than one user, it indicates that it either
1269 // has an external user, which will be later checked by the checkNodes
1270 // function, or it is a subexpression utilized by multiple expressions. In
1271 // the latter case, we will attempt to separately identify the complex
1272 // operation from here in order to create a shared
1273 // ComplexDeinterleavingCompositeNode.
1274 if (I != Insn && I->hasNUsesOrMore(N: 2)) {
1275 LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
1276 Addends.emplace_back(Vs&: I, Vs&: IsPositive);
1277 continue;
1278 }
1279 switch (I->getOpcode()) {
1280 case Instruction::FAdd:
1281 case Instruction::Add:
1282 Worklist.emplace_back(Args: I->getOperand(i: 1), Args&: IsPositive);
1283 Worklist.emplace_back(Args: I->getOperand(i: 0), Args&: IsPositive);
1284 break;
1285 case Instruction::FSub:
1286 Worklist.emplace_back(Args: I->getOperand(i: 1), Args: !IsPositive);
1287 Worklist.emplace_back(Args: I->getOperand(i: 0), Args&: IsPositive);
1288 break;
1289 case Instruction::Sub:
1290 if (isNeg(V: I)) {
1291 Worklist.emplace_back(Args: getNegOperand(V: I), Args: !IsPositive);
1292 } else {
1293 Worklist.emplace_back(Args: I->getOperand(i: 1), Args: !IsPositive);
1294 Worklist.emplace_back(Args: I->getOperand(i: 0), Args&: IsPositive);
1295 }
1296 break;
1297 case Instruction::FMul:
1298 case Instruction::Mul: {
1299 Value *A, *B;
1300 if (isNeg(V: I->getOperand(i: 0))) {
1301 A = getNegOperand(V: I->getOperand(i: 0));
1302 IsPositive = !IsPositive;
1303 } else {
1304 A = I->getOperand(i: 0);
1305 }
1306
1307 if (isNeg(V: I->getOperand(i: 1))) {
1308 B = getNegOperand(V: I->getOperand(i: 1));
1309 IsPositive = !IsPositive;
1310 } else {
1311 B = I->getOperand(i: 1);
1312 }
1313 Muls.push_back(Elt: Product{.Multiplier: A, .Multiplicand: B, .IsPositive: IsPositive});
1314 break;
1315 }
1316 case Instruction::FNeg:
1317 Worklist.emplace_back(Args: I->getOperand(i: 0), Args: !IsPositive);
1318 break;
1319 default:
1320 Addends.emplace_back(Vs&: I, Vs&: IsPositive);
1321 continue;
1322 }
1323
1324 if (Flags && I->getFastMathFlags() != *Flags) {
1325 LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
1326 "inconsistent with the root instructions' flags: "
1327 << *I << "\n");
1328 return false;
1329 }
1330 }
1331 return true;
1332 };
1333
1334 SmallVector<Product> RealMuls, ImagMuls;
1335 AddendList RealAddends, ImagAddends;
1336 if (!Collect(Real, RealMuls, RealAddends) ||
1337 !Collect(Imag, ImagMuls, ImagAddends))
1338 return nullptr;
1339
1340 if (RealAddends.size() != ImagAddends.size())
1341 return nullptr;
1342
1343 CompositeNode *FinalNode = nullptr;
1344 if (!RealMuls.empty() || !ImagMuls.empty()) {
1345 // If there are multiplicands, extract positive addend and use it as an
1346 // accumulator
1347 FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1348 FinalNode = identifyMultiplications(RealMuls, ImagMuls, Accumulator: FinalNode);
1349 if (!FinalNode)
1350 return nullptr;
1351 }
1352
1353 // Identify and process remaining additions
1354 if (!RealAddends.empty() || !ImagAddends.empty()) {
1355 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, Accumulator: FinalNode);
1356 if (!FinalNode)
1357 return nullptr;
1358 }
1359 assert(FinalNode && "FinalNode can not be nullptr here");
1360 assert(FinalNode->Vals.size() == 1);
1361 // Set the Real and Imag fields of the final node and submit it
1362 FinalNode->Vals[0].Real = Real;
1363 FinalNode->Vals[0].Imag = Imag;
1364 submitCompositeNode(Node: FinalNode);
1365 return FinalNode;
1366}
1367
1368bool ComplexDeinterleavingGraph::collectPartialMuls(
1369 ArrayRef<Product> RealMuls, ArrayRef<Product> ImagMuls,
1370 SmallVectorImpl<PartialMulCandidate> &PartialMulCandidates) {
1371 // Helper function to extract a common operand from two products
1372 auto FindCommonInstruction = [](const Product &Real,
1373 const Product &Imag) -> Value * {
1374 if (Real.Multiplicand == Imag.Multiplicand ||
1375 Real.Multiplicand == Imag.Multiplier)
1376 return Real.Multiplicand;
1377
1378 if (Real.Multiplier == Imag.Multiplicand ||
1379 Real.Multiplier == Imag.Multiplier)
1380 return Real.Multiplier;
1381
1382 return nullptr;
1383 };
1384
1385 // Iterating over real and imaginary multiplications to find common operands
1386 // If a common operand is found, a partial multiplication candidate is created
1387 // and added to the candidates vector The function returns false if no common
1388 // operands are found for any product
1389 for (unsigned i = 0; i < RealMuls.size(); ++i) {
1390 bool FoundCommon = false;
1391 for (unsigned j = 0; j < ImagMuls.size(); ++j) {
1392 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1393 if (!Common)
1394 continue;
1395
1396 auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1397 : RealMuls[i].Multiplicand;
1398 auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
1399 : ImagMuls[j].Multiplicand;
1400
1401 auto Node = identifyNode(R: A, I: B);
1402 if (Node) {
1403 FoundCommon = true;
1404 PartialMulCandidates.push_back(Elt: {.Common: Common, .Node: Node, .RealIdx: i, .ImagIdx: j, .IsNodeInverted: false});
1405 }
1406
1407 Node = identifyNode(R: B, I: A);
1408 if (Node) {
1409 FoundCommon = true;
1410 PartialMulCandidates.push_back(Elt: {.Common: Common, .Node: Node, .RealIdx: i, .ImagIdx: j, .IsNodeInverted: true});
1411 }
1412 }
1413 if (!FoundCommon)
1414 return false;
1415 }
1416 return true;
1417}
1418
1419ComplexDeinterleavingGraph::CompositeNode *
1420ComplexDeinterleavingGraph::identifyMultiplications(
1421 SmallVectorImpl<Product> &RealMuls, SmallVectorImpl<Product> &ImagMuls,
1422 CompositeNode *Accumulator = nullptr) {
1423 if (RealMuls.size() != ImagMuls.size())
1424 return nullptr;
1425
1426 SmallVector<PartialMulCandidate> Info;
1427 if (!collectPartialMuls(RealMuls, ImagMuls, PartialMulCandidates&: Info))
1428 return nullptr;
1429
1430 // Map to store common instruction to node pointers
1431 DenseMap<Value *, CompositeNode *> CommonToNode;
1432 SmallVector<bool> Processed(Info.size(), false);
1433 for (unsigned I = 0; I < Info.size(); ++I) {
1434 if (Processed[I])
1435 continue;
1436
1437 PartialMulCandidate &InfoA = Info[I];
1438 for (unsigned J = I + 1; J < Info.size(); ++J) {
1439 if (Processed[J])
1440 continue;
1441
1442 PartialMulCandidate &InfoB = Info[J];
1443 auto *InfoReal = &InfoA;
1444 auto *InfoImag = &InfoB;
1445
1446 auto NodeFromCommon = identifyNode(R: InfoReal->Common, I: InfoImag->Common);
1447 if (!NodeFromCommon) {
1448 std::swap(a&: InfoReal, b&: InfoImag);
1449 NodeFromCommon = identifyNode(R: InfoReal->Common, I: InfoImag->Common);
1450 }
1451 if (!NodeFromCommon)
1452 continue;
1453
1454 CommonToNode[InfoReal->Common] = NodeFromCommon;
1455 CommonToNode[InfoImag->Common] = NodeFromCommon;
1456 Processed[I] = true;
1457 Processed[J] = true;
1458 }
1459 }
1460
1461 SmallVector<bool> ProcessedReal(RealMuls.size(), false);
1462 SmallVector<bool> ProcessedImag(ImagMuls.size(), false);
1463 CompositeNode *Result = Accumulator;
1464 for (auto &PMI : Info) {
1465 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1466 continue;
1467
1468 auto It = CommonToNode.find(Val: PMI.Common);
1469 // TODO: Process independent complex multiplications. Cases like this:
1470 // A.real() * B where both A and B are complex numbers.
1471 if (It == CommonToNode.end()) {
1472 LLVM_DEBUG({
1473 dbgs() << "Unprocessed independent partial multiplication:\n";
1474 for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1475 dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
1476 << " multiplied by " << *Mul->Multiplicand << "\n";
1477 });
1478 return nullptr;
1479 }
1480
1481 auto &RealMul = RealMuls[PMI.RealIdx];
1482 auto &ImagMul = ImagMuls[PMI.ImagIdx];
1483
1484 auto NodeA = It->second;
1485 auto NodeB = PMI.Node;
1486 auto IsMultiplicandReal = PMI.Common == NodeA->Vals[0].Real;
1487 // The following table illustrates the relationship between multiplications
1488 // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1489 // can see:
1490 //
1491 // Rotation | Real | Imag |
1492 // ---------+--------+--------+
1493 // 0 | x * u | x * v |
1494 // 90 | -y * v | y * u |
1495 // 180 | -x * u | -x * v |
1496 // 270 | y * v | -y * u |
1497 //
1498 // Check if the candidate can indeed be represented by partial
1499 // multiplication
1500 // TODO: Add support for multiplication by complex one
1501 if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1502 (!IsMultiplicandReal && !PMI.IsNodeInverted))
1503 continue;
1504
1505 // Determine the rotation based on the multiplications
1506 ComplexDeinterleavingRotation Rotation;
1507 if (IsMultiplicandReal) {
1508 // Detect 0 and 180 degrees rotation
1509 if (RealMul.IsPositive && ImagMul.IsPositive)
1510 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0;
1511 else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1512 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180;
1513 else
1514 continue;
1515
1516 } else {
1517 // Detect 90 and 270 degrees rotation
1518 if (!RealMul.IsPositive && ImagMul.IsPositive)
1519 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90;
1520 else if (RealMul.IsPositive && !ImagMul.IsPositive)
1521 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270;
1522 else
1523 continue;
1524 }
1525
1526 LLVM_DEBUG({
1527 dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1528 dbgs().indent(4) << "X: " << *NodeA->Vals[0].Real << "\n";
1529 dbgs().indent(4) << "Y: " << *NodeA->Vals[0].Imag << "\n";
1530 dbgs().indent(4) << "U: " << *NodeB->Vals[0].Real << "\n";
1531 dbgs().indent(4) << "V: " << *NodeB->Vals[0].Imag << "\n";
1532 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1533 });
1534
1535 CompositeNode *NodeMul = prepareCompositeNode(
1536 Operation: ComplexDeinterleavingOperation::CMulPartial, R: nullptr, I: nullptr);
1537 NodeMul->Rotation = Rotation;
1538 NodeMul->addOperand(Node: NodeA);
1539 NodeMul->addOperand(Node: NodeB);
1540 if (Result)
1541 NodeMul->addOperand(Node: Result);
1542 submitCompositeNode(Node: NodeMul);
1543 Result = NodeMul;
1544 ProcessedReal[PMI.RealIdx] = true;
1545 ProcessedImag[PMI.ImagIdx] = true;
1546 }
1547
1548 // Ensure all products have been processed, if not return nullptr.
1549 if (!all_of(Range&: ProcessedReal, P: [](bool V) { return V; }) ||
1550 !all_of(Range&: ProcessedImag, P: [](bool V) { return V; })) {
1551
1552 // Dump debug information about which partial multiplications are not
1553 // processed.
1554 LLVM_DEBUG({
1555 dbgs() << "Unprocessed products (Real):\n";
1556 for (size_t i = 0; i < ProcessedReal.size(); ++i) {
1557 if (!ProcessedReal[i])
1558 dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
1559 << *RealMuls[i].Multiplier << " multiplied by "
1560 << *RealMuls[i].Multiplicand << "\n";
1561 }
1562 dbgs() << "Unprocessed products (Imag):\n";
1563 for (size_t i = 0; i < ProcessedImag.size(); ++i) {
1564 if (!ProcessedImag[i])
1565 dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
1566 << *ImagMuls[i].Multiplier << " multiplied by "
1567 << *ImagMuls[i].Multiplicand << "\n";
1568 }
1569 });
1570 return nullptr;
1571 }
1572
1573 return Result;
1574}
1575
1576ComplexDeinterleavingGraph::CompositeNode *
1577ComplexDeinterleavingGraph::identifyAdditions(
1578 AddendList &RealAddends, AddendList &ImagAddends,
1579 std::optional<FastMathFlags> Flags, CompositeNode *Accumulator = nullptr) {
1580 if (RealAddends.size() != ImagAddends.size())
1581 return nullptr;
1582
1583 CompositeNode *Result = nullptr;
1584 // If we have accumulator use it as first addend
1585 if (Accumulator)
1586 Result = Accumulator;
1587 // Otherwise find an element with both positive real and imaginary parts.
1588 else
1589 Result = extractPositiveAddend(RealAddends, ImagAddends);
1590
1591 if (!Result)
1592 return nullptr;
1593
1594 while (!RealAddends.empty()) {
1595 auto ItR = RealAddends.begin();
1596 auto [R, IsPositiveR] = *ItR;
1597
1598 bool FoundImag = false;
1599 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1600 auto [I, IsPositiveI] = *ItI;
1601 ComplexDeinterleavingRotation Rotation;
1602 if (IsPositiveR && IsPositiveI)
1603 Rotation = ComplexDeinterleavingRotation::Rotation_0;
1604 else if (!IsPositiveR && IsPositiveI)
1605 Rotation = ComplexDeinterleavingRotation::Rotation_90;
1606 else if (!IsPositiveR && !IsPositiveI)
1607 Rotation = ComplexDeinterleavingRotation::Rotation_180;
1608 else
1609 Rotation = ComplexDeinterleavingRotation::Rotation_270;
1610
1611 CompositeNode *AddNode = nullptr;
1612 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1613 Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1614 AddNode = identifyNode(R, I);
1615 } else {
1616 AddNode = identifyNode(R: I, I: R);
1617 }
1618 if (AddNode) {
1619 LLVM_DEBUG({
1620 dbgs() << "Identified addition:\n";
1621 dbgs().indent(4) << "X: " << *R << "\n";
1622 dbgs().indent(4) << "Y: " << *I << "\n";
1623 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1624 });
1625
1626 CompositeNode *TmpNode = nullptr;
1627 if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) {
1628 TmpNode = prepareCompositeNode(
1629 Operation: ComplexDeinterleavingOperation::Symmetric, R: nullptr, I: nullptr);
1630 if (Flags) {
1631 TmpNode->Opcode = Instruction::FAdd;
1632 TmpNode->Flags = *Flags;
1633 } else {
1634 TmpNode->Opcode = Instruction::Add;
1635 }
1636 } else if (Rotation ==
1637 llvm::ComplexDeinterleavingRotation::Rotation_180) {
1638 TmpNode = prepareCompositeNode(
1639 Operation: ComplexDeinterleavingOperation::Symmetric, R: nullptr, I: nullptr);
1640 if (Flags) {
1641 TmpNode->Opcode = Instruction::FSub;
1642 TmpNode->Flags = *Flags;
1643 } else {
1644 TmpNode->Opcode = Instruction::Sub;
1645 }
1646 } else {
1647 TmpNode = prepareCompositeNode(Operation: ComplexDeinterleavingOperation::CAdd,
1648 R: nullptr, I: nullptr);
1649 TmpNode->Rotation = Rotation;
1650 }
1651
1652 TmpNode->addOperand(Node: Result);
1653 TmpNode->addOperand(Node: AddNode);
1654 submitCompositeNode(Node: TmpNode);
1655 Result = TmpNode;
1656 RealAddends.erase(I: ItR);
1657 ImagAddends.erase(I: ItI);
1658 FoundImag = true;
1659 break;
1660 }
1661 }
1662 if (!FoundImag)
1663 return nullptr;
1664 }
1665 return Result;
1666}
1667
1668ComplexDeinterleavingGraph::CompositeNode *
1669ComplexDeinterleavingGraph::extractPositiveAddend(AddendList &RealAddends,
1670 AddendList &ImagAddends) {
1671 for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1672 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1673 auto [R, IsPositiveR] = *ItR;
1674 auto [I, IsPositiveI] = *ItI;
1675 if (IsPositiveR && IsPositiveI) {
1676 auto Result = identifyNode(R, I);
1677 if (Result) {
1678 RealAddends.erase(I: ItR);
1679 ImagAddends.erase(I: ItI);
1680 return Result;
1681 }
1682 }
1683 }
1684 }
1685 return nullptr;
1686}
1687
1688bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1689 // This potential root instruction might already have been recognized as
1690 // reduction. Because RootToNode maps both Real and Imaginary parts to
1691 // CompositeNode we should choose only one either Real or Imag instruction to
1692 // use as an anchor for generating complex instruction.
1693 auto It = RootToNode.find(Val: RootI);
1694 if (It != RootToNode.end()) {
1695 auto RootNode = It->second;
1696 assert(RootNode->Operation ==
1697 ComplexDeinterleavingOperation::ReductionOperation ||
1698 RootNode->Operation ==
1699 ComplexDeinterleavingOperation::ReductionSingle);
1700 assert(RootNode->Vals.size() == 1 &&
1701 "Cannot handle reductions involving multiple complex values");
1702 // Find out which part, Real or Imag, comes later, and only if we come to
1703 // the latest part, add it to OrderedRoots.
1704 auto *R = cast<Instruction>(Val: RootNode->Vals[0].Real);
1705 auto *I = RootNode->Vals[0].Imag ? cast<Instruction>(Val: RootNode->Vals[0].Imag)
1706 : nullptr;
1707
1708 Instruction *ReplacementAnchor;
1709 if (I)
1710 ReplacementAnchor = R->comesBefore(Other: I) ? I : R;
1711 else
1712 ReplacementAnchor = R;
1713
1714 if (ReplacementAnchor != RootI)
1715 return false;
1716 OrderedRoots.push_back(Elt: RootI);
1717 return true;
1718 }
1719
1720 auto RootNode = identifyRoot(I: RootI);
1721 if (!RootNode)
1722 return false;
1723
1724 LLVM_DEBUG({
1725 Function *F = RootI->getFunction();
1726 BasicBlock *B = RootI->getParent();
1727 dbgs() << "Complex deinterleaving graph for " << F->getName()
1728 << "::" << B->getName() << ".\n";
1729 dump(dbgs());
1730 dbgs() << "\n";
1731 });
1732 RootToNode[RootI] = RootNode;
1733 OrderedRoots.push_back(Elt: RootI);
1734 return true;
1735}
1736
1737bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
1738 bool FoundPotentialReduction = false;
1739 if (Factor != 2)
1740 return false;
1741
1742 auto *Br = dyn_cast<BranchInst>(Val: B->getTerminator());
1743 if (!Br || Br->getNumSuccessors() != 2)
1744 return false;
1745
1746 // Identify simple one-block loop
1747 if (Br->getSuccessor(i: 0) != B && Br->getSuccessor(i: 1) != B)
1748 return false;
1749
1750 for (auto &PHI : B->phis()) {
1751 if (PHI.getNumIncomingValues() != 2)
1752 continue;
1753
1754 if (!PHI.getType()->isVectorTy())
1755 continue;
1756
1757 auto *ReductionOp = dyn_cast<Instruction>(Val: PHI.getIncomingValueForBlock(BB: B));
1758 if (!ReductionOp)
1759 continue;
1760
1761 // Check if final instruction is reduced outside of current block
1762 Instruction *FinalReduction = nullptr;
1763 auto NumUsers = 0u;
1764 for (auto *U : ReductionOp->users()) {
1765 ++NumUsers;
1766 if (U == &PHI)
1767 continue;
1768 FinalReduction = dyn_cast<Instruction>(Val: U);
1769 }
1770
1771 if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B ||
1772 isa<PHINode>(Val: FinalReduction))
1773 continue;
1774
1775 ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
1776 BackEdge = B;
1777 auto BackEdgeIdx = PHI.getBasicBlockIndex(BB: B);
1778 auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1779 Incoming = PHI.getIncomingBlock(i: IncomingIdx);
1780 FoundPotentialReduction = true;
1781
1782 // If the initial value of PHINode is an Instruction, consider it a leaf
1783 // value of a complex deinterleaving graph.
1784 if (auto *InitPHI =
1785 dyn_cast<Instruction>(Val: PHI.getIncomingValueForBlock(BB: Incoming)))
1786 FinalInstructions.insert(Ptr: InitPHI);
1787 }
1788 return FoundPotentialReduction;
1789}
1790
1791void ComplexDeinterleavingGraph::identifyReductionNodes() {
1792 assert(Factor == 2 && "Cannot handle multiple complex values");
1793
1794 SmallVector<bool> Processed(ReductionInfo.size(), false);
1795 SmallVector<Instruction *> OperationInstruction;
1796 for (auto &P : ReductionInfo)
1797 OperationInstruction.push_back(Elt: P.first);
1798
1799 // Identify a complex computation by evaluating two reduction operations that
1800 // potentially could be involved
1801 for (size_t i = 0; i < OperationInstruction.size(); ++i) {
1802 if (Processed[i])
1803 continue;
1804 for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
1805 if (Processed[j])
1806 continue;
1807 auto *Real = OperationInstruction[i];
1808 auto *Imag = OperationInstruction[j];
1809 if (Real->getType() != Imag->getType())
1810 continue;
1811
1812 RealPHI = ReductionInfo[Real].first;
1813 ImagPHI = ReductionInfo[Imag].first;
1814 PHIsFound = false;
1815 auto Node = identifyNode(R: Real, I: Imag);
1816 if (!Node) {
1817 std::swap(a&: Real, b&: Imag);
1818 std::swap(a&: RealPHI, b&: ImagPHI);
1819 Node = identifyNode(R: Real, I: Imag);
1820 }
1821
1822 // If a node is identified and reduction PHINode is used in the chain of
1823 // operations, mark its operation instructions as used to prevent
1824 // re-identification and attach the node to the real part
1825 if (Node && PHIsFound) {
1826 LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
1827 << *Real << " / " << *Imag << "\n");
1828 Processed[i] = true;
1829 Processed[j] = true;
1830 auto RootNode = prepareCompositeNode(
1831 Operation: ComplexDeinterleavingOperation::ReductionOperation, R: Real, I: Imag);
1832 RootNode->addOperand(Node);
1833 RootToNode[Real] = RootNode;
1834 RootToNode[Imag] = RootNode;
1835 submitCompositeNode(Node: RootNode);
1836 break;
1837 }
1838 }
1839
1840 auto *Real = OperationInstruction[i];
1841 // We want to check that we have 2 operands, but the function attributes
1842 // being counted as operands bloats this value.
1843 if (Processed[i] || Real->getNumOperands() < 2)
1844 continue;
1845
1846 // Can only combined integer reductions at the moment.
1847 if (!ReductionInfo[Real].second->getType()->isIntegerTy())
1848 continue;
1849
1850 RealPHI = ReductionInfo[Real].first;
1851 ImagPHI = nullptr;
1852 PHIsFound = false;
1853 auto Node = identifyNode(R: Real->getOperand(i: 0), I: Real->getOperand(i: 1));
1854 if (Node && PHIsFound) {
1855 LLVM_DEBUG(
1856 dbgs() << "Identified single reduction starting from instruction: "
1857 << *Real << "/" << *ReductionInfo[Real].second << "\n");
1858
1859 // Reducing to a single vector is not supported, only permit reducing down
1860 // to scalar values.
1861 // Doing this here will leave the prior node in the graph,
1862 // however with no uses the node will be unreachable by the replacement
1863 // process. That along with the usage outside the graph should prevent the
1864 // replacement process from kicking off at all for this graph.
1865 // TODO Add support for reducing to a single vector value
1866 if (ReductionInfo[Real].second->getType()->isVectorTy())
1867 continue;
1868
1869 Processed[i] = true;
1870 auto RootNode = prepareCompositeNode(
1871 Operation: ComplexDeinterleavingOperation::ReductionSingle, R: Real, I: nullptr);
1872 RootNode->addOperand(Node);
1873 RootToNode[Real] = RootNode;
1874 submitCompositeNode(Node: RootNode);
1875 }
1876 }
1877
1878 RealPHI = nullptr;
1879 ImagPHI = nullptr;
1880}
1881
1882bool ComplexDeinterleavingGraph::checkNodes() {
1883 bool FoundDeinterleaveNode = false;
1884 for (CompositeNode *N : CompositeNodes) {
1885 if (!N->areOperandsValid())
1886 return false;
1887
1888 if (N->Operation == ComplexDeinterleavingOperation::Deinterleave)
1889 FoundDeinterleaveNode = true;
1890 }
1891
1892 // We need a deinterleave node in order to guarantee that we're working with
1893 // complex numbers.
1894 if (!FoundDeinterleaveNode) {
1895 LLVM_DEBUG(
1896 dbgs() << "Couldn't find a deinterleave node within the graph, cannot "
1897 "guarantee safety during graph transformation.\n");
1898 return false;
1899 }
1900
1901 // Collect all instructions from roots to leaves
1902 SmallPtrSet<Instruction *, 16> AllInstructions;
1903 SmallVector<Instruction *, 8> Worklist;
1904 for (auto &Pair : RootToNode)
1905 Worklist.push_back(Elt: Pair.first);
1906
1907 // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1908 // chains
1909 while (!Worklist.empty()) {
1910 auto *I = Worklist.pop_back_val();
1911
1912 if (!AllInstructions.insert(Ptr: I).second)
1913 continue;
1914
1915 for (Value *Op : I->operands()) {
1916 if (auto *OpI = dyn_cast<Instruction>(Val: Op)) {
1917 if (!FinalInstructions.count(Ptr: I))
1918 Worklist.emplace_back(Args&: OpI);
1919 }
1920 }
1921 }
1922
1923 // Find instructions that have users outside of chain
1924 for (auto *I : AllInstructions) {
1925 // Skip root nodes
1926 if (RootToNode.count(Val: I))
1927 continue;
1928
1929 for (User *U : I->users()) {
1930 if (AllInstructions.count(Ptr: cast<Instruction>(Val: U)))
1931 continue;
1932
1933 // Found an instruction that is not used by XCMLA/XCADD chain
1934 Worklist.emplace_back(Args&: I);
1935 break;
1936 }
1937 }
1938
1939 // If any instructions are found to be used outside, find and remove roots
1940 // that somehow connect to those instructions.
1941 SmallPtrSet<Instruction *, 16> Visited;
1942 while (!Worklist.empty()) {
1943 auto *I = Worklist.pop_back_val();
1944 if (!Visited.insert(Ptr: I).second)
1945 continue;
1946
1947 // Found an impacted root node. Removing it from the nodes to be
1948 // deinterleaved
1949 if (RootToNode.count(Val: I)) {
1950 LLVM_DEBUG(dbgs() << "Instruction " << *I
1951 << " could be deinterleaved but its chain of complex "
1952 "operations have an outside user\n");
1953 RootToNode.erase(Val: I);
1954 }
1955
1956 if (!AllInstructions.count(Ptr: I) || FinalInstructions.count(Ptr: I))
1957 continue;
1958
1959 for (User *U : I->users())
1960 Worklist.emplace_back(Args: cast<Instruction>(Val: U));
1961
1962 for (Value *Op : I->operands()) {
1963 if (auto *OpI = dyn_cast<Instruction>(Val: Op))
1964 Worklist.emplace_back(Args&: OpI);
1965 }
1966 }
1967 return !RootToNode.empty();
1968}
1969
1970ComplexDeinterleavingGraph::CompositeNode *
1971ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1972 if (auto *Intrinsic = dyn_cast<IntrinsicInst>(Val: RootI)) {
1973 if (Intrinsic::getInterleaveIntrinsicID(Factor) !=
1974 Intrinsic->getIntrinsicID())
1975 return nullptr;
1976
1977 ComplexValues Vals;
1978 for (unsigned I = 0; I < Factor; I += 2) {
1979 auto *Real = dyn_cast<Instruction>(Val: Intrinsic->getOperand(i_nocapture: I));
1980 auto *Imag = dyn_cast<Instruction>(Val: Intrinsic->getOperand(i_nocapture: I + 1));
1981 if (!Real || !Imag)
1982 return nullptr;
1983 Vals.push_back(Elt: {.Real: Real, .Imag: Imag});
1984 }
1985
1986 ComplexDeinterleavingGraph::CompositeNode *Node1 = identifyNode(Vals);
1987 if (!Node1)
1988 return nullptr;
1989 return Node1;
1990 }
1991
1992 // TODO: We could also add support for fixed-width interleave factors of 4
1993 // and above, but currently for symmetric operations the interleaves and
1994 // deinterleaves are already removed by VectorCombine. If we extend this to
1995 // permit complex multiplications, reductions, etc. then we should also add
1996 // support for fixed-width here.
1997 if (Factor != 2)
1998 return nullptr;
1999
2000 auto *SVI = dyn_cast<ShuffleVectorInst>(Val: RootI);
2001 if (!SVI)
2002 return nullptr;
2003
2004 // Look for a shufflevector that takes separate vectors of the real and
2005 // imaginary components and recombines them into a single vector.
2006 if (!isInterleavingMask(Mask: SVI->getShuffleMask()))
2007 return nullptr;
2008
2009 Instruction *Real;
2010 Instruction *Imag;
2011 if (!match(V: RootI, P: m_Shuffle(v1: m_Instruction(I&: Real), v2: m_Instruction(I&: Imag))))
2012 return nullptr;
2013
2014 return identifyNode(R: Real, I: Imag);
2015}
2016
2017ComplexDeinterleavingGraph::CompositeNode *
2018ComplexDeinterleavingGraph::identifyDeinterleave(ComplexValues &Vals) {
2019 Instruction *II = nullptr;
2020
2021 // Must be at least one complex value.
2022 auto CheckExtract = [&](Value *V, unsigned ExpectedIdx,
2023 Instruction *ExpectedInsn) -> ExtractValueInst * {
2024 auto *EVI = dyn_cast<ExtractValueInst>(Val: V);
2025 if (!EVI || EVI->getNumIndices() != 1 ||
2026 EVI->getIndices()[0] != ExpectedIdx ||
2027 !isa<Instruction>(Val: EVI->getAggregateOperand()) ||
2028 (ExpectedInsn && ExpectedInsn != EVI->getAggregateOperand()))
2029 return nullptr;
2030 return EVI;
2031 };
2032
2033 for (unsigned Idx = 0; Idx < Vals.size(); Idx++) {
2034 ExtractValueInst *RealEVI = CheckExtract(Vals[Idx].Real, Idx * 2, II);
2035 if (RealEVI && Idx == 0)
2036 II = cast<Instruction>(Val: RealEVI->getAggregateOperand());
2037 if (!RealEVI || !CheckExtract(Vals[Idx].Imag, (Idx * 2) + 1, II)) {
2038 II = nullptr;
2039 break;
2040 }
2041 }
2042
2043 if (auto *IntrinsicII = dyn_cast_or_null<IntrinsicInst>(Val: II)) {
2044 if (IntrinsicII->getIntrinsicID() !=
2045 Intrinsic::getDeinterleaveIntrinsicID(Factor: 2 * Vals.size()))
2046 return nullptr;
2047
2048 // The remaining should match too.
2049 CompositeNode *PlaceholderNode = prepareCompositeNode(
2050 Operation: llvm::ComplexDeinterleavingOperation::Deinterleave, Vals);
2051 PlaceholderNode->ReplacementNode = II->getOperand(i: 0);
2052 for (auto &V : Vals) {
2053 FinalInstructions.insert(Ptr: cast<Instruction>(Val: V.Real));
2054 FinalInstructions.insert(Ptr: cast<Instruction>(Val: V.Imag));
2055 }
2056 return submitCompositeNode(Node: PlaceholderNode);
2057 }
2058
2059 if (Vals.size() != 1)
2060 return nullptr;
2061
2062 Value *Real = Vals[0].Real;
2063 Value *Imag = Vals[0].Imag;
2064 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Val: Real);
2065 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Val: Imag);
2066 if (!RealShuffle || !ImagShuffle) {
2067 if (RealShuffle || ImagShuffle)
2068 LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
2069 return nullptr;
2070 }
2071
2072 Value *RealOp1 = RealShuffle->getOperand(i_nocapture: 1);
2073 if (!isa<UndefValue>(Val: RealOp1) && !isa<ConstantAggregateZero>(Val: RealOp1)) {
2074 LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
2075 return nullptr;
2076 }
2077 Value *ImagOp1 = ImagShuffle->getOperand(i_nocapture: 1);
2078 if (!isa<UndefValue>(Val: ImagOp1) && !isa<ConstantAggregateZero>(Val: ImagOp1)) {
2079 LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
2080 return nullptr;
2081 }
2082
2083 Value *RealOp0 = RealShuffle->getOperand(i_nocapture: 0);
2084 Value *ImagOp0 = ImagShuffle->getOperand(i_nocapture: 0);
2085
2086 if (RealOp0 != ImagOp0) {
2087 LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
2088 return nullptr;
2089 }
2090
2091 ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
2092 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
2093 if (!isDeinterleavingMask(Mask: RealMask) || !isDeinterleavingMask(Mask: ImagMask)) {
2094 LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
2095 return nullptr;
2096 }
2097
2098 if (RealMask[0] != 0 || ImagMask[0] != 1) {
2099 LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
2100 return nullptr;
2101 }
2102
2103 // Type checking, the shuffle type should be a vector type of the same
2104 // scalar type, but half the size
2105 auto CheckType = [&](ShuffleVectorInst *Shuffle) {
2106 Value *Op = Shuffle->getOperand(i_nocapture: 0);
2107 auto *ShuffleTy = cast<FixedVectorType>(Val: Shuffle->getType());
2108 auto *OpTy = cast<FixedVectorType>(Val: Op->getType());
2109
2110 if (OpTy->getScalarType() != ShuffleTy->getScalarType())
2111 return false;
2112 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
2113 return false;
2114
2115 return true;
2116 };
2117
2118 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
2119 if (!CheckType(Shuffle))
2120 return false;
2121
2122 ArrayRef<int> Mask = Shuffle->getShuffleMask();
2123 int Last = *Mask.rbegin();
2124
2125 Value *Op = Shuffle->getOperand(i_nocapture: 0);
2126 auto *OpTy = cast<FixedVectorType>(Val: Op->getType());
2127 int NumElements = OpTy->getNumElements();
2128
2129 // Ensure that the deinterleaving shuffle only pulls from the first
2130 // shuffle operand.
2131 return Last < NumElements;
2132 };
2133
2134 if (RealShuffle->getType() != ImagShuffle->getType()) {
2135 LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
2136 return nullptr;
2137 }
2138 if (!CheckDeinterleavingShuffle(RealShuffle)) {
2139 LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
2140 return nullptr;
2141 }
2142 if (!CheckDeinterleavingShuffle(ImagShuffle)) {
2143 LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
2144 return nullptr;
2145 }
2146
2147 CompositeNode *PlaceholderNode =
2148 prepareCompositeNode(Operation: llvm::ComplexDeinterleavingOperation::Deinterleave,
2149 R: RealShuffle, I: ImagShuffle);
2150 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(i_nocapture: 0);
2151 FinalInstructions.insert(Ptr: RealShuffle);
2152 FinalInstructions.insert(Ptr: ImagShuffle);
2153 return submitCompositeNode(Node: PlaceholderNode);
2154}
2155
2156ComplexDeinterleavingGraph::CompositeNode *
2157ComplexDeinterleavingGraph::identifySplat(ComplexValues &Vals) {
2158 auto IsSplat = [](Value *V) -> bool {
2159 // Fixed-width vector with constants
2160 if (isa<ConstantDataVector>(Val: V))
2161 return true;
2162
2163 if (isa<ConstantInt>(Val: V) || isa<ConstantFP>(Val: V))
2164 return isa<VectorType>(Val: V->getType());
2165
2166 VectorType *VTy;
2167 ArrayRef<int> Mask;
2168 // Splats are represented differently depending on whether the repeated
2169 // value is a constant or an Instruction
2170 if (auto *Const = dyn_cast<ConstantExpr>(Val: V)) {
2171 if (Const->getOpcode() != Instruction::ShuffleVector)
2172 return false;
2173 VTy = cast<VectorType>(Val: Const->getType());
2174 Mask = Const->getShuffleMask();
2175 } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(Val: V)) {
2176 VTy = Shuf->getType();
2177 Mask = Shuf->getShuffleMask();
2178 } else {
2179 return false;
2180 }
2181
2182 // When the data type is <1 x Type>, it's not possible to differentiate
2183 // between the ComplexDeinterleaving::Deinterleave and
2184 // ComplexDeinterleaving::Splat operations.
2185 if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
2186 return false;
2187
2188 return all_equal(Range&: Mask) && Mask[0] == 0;
2189 };
2190
2191 // The splats must meet the following requirements:
2192 // 1. Must either be all instructions or all values.
2193 // 2. Non-constant splats must live in the same block.
2194 if (auto *FirstValAsInstruction = dyn_cast<Instruction>(Val: Vals[0].Real)) {
2195 BasicBlock *FirstBB = FirstValAsInstruction->getParent();
2196 for (auto &V : Vals) {
2197 if (!IsSplat(V.Real) || !IsSplat(V.Imag))
2198 return nullptr;
2199
2200 auto *Real = dyn_cast<Instruction>(Val: V.Real);
2201 auto *Imag = dyn_cast<Instruction>(Val: V.Imag);
2202 if (!Real || !Imag || Real->getParent() != FirstBB ||
2203 Imag->getParent() != FirstBB)
2204 return nullptr;
2205 }
2206 } else {
2207 for (auto &V : Vals) {
2208 if (!IsSplat(V.Real) || !IsSplat(V.Imag) || isa<Instruction>(Val: V.Real) ||
2209 isa<Instruction>(Val: V.Imag))
2210 return nullptr;
2211 }
2212 }
2213
2214 for (auto &V : Vals) {
2215 auto *Real = dyn_cast<Instruction>(Val: V.Real);
2216 auto *Imag = dyn_cast<Instruction>(Val: V.Imag);
2217 if (Real && Imag) {
2218 FinalInstructions.insert(Ptr: Real);
2219 FinalInstructions.insert(Ptr: Imag);
2220 }
2221 }
2222 CompositeNode *PlaceholderNode =
2223 prepareCompositeNode(Operation: ComplexDeinterleavingOperation::Splat, Vals);
2224 return submitCompositeNode(Node: PlaceholderNode);
2225}
2226
2227ComplexDeinterleavingGraph::CompositeNode *
2228ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
2229 Instruction *Imag) {
2230 if (Real != RealPHI || (ImagPHI && Imag != ImagPHI))
2231 return nullptr;
2232
2233 PHIsFound = true;
2234 CompositeNode *PlaceholderNode = prepareCompositeNode(
2235 Operation: ComplexDeinterleavingOperation::ReductionPHI, R: Real, I: Imag);
2236 return submitCompositeNode(Node: PlaceholderNode);
2237}
2238
2239ComplexDeinterleavingGraph::CompositeNode *
2240ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
2241 Instruction *Imag) {
2242 auto *SelectReal = dyn_cast<SelectInst>(Val: Real);
2243 auto *SelectImag = dyn_cast<SelectInst>(Val: Imag);
2244 if (!SelectReal || !SelectImag)
2245 return nullptr;
2246
2247 Instruction *MaskA, *MaskB;
2248 Instruction *AR, *AI, *RA, *BI;
2249 if (!match(V: Real, P: m_Select(C: m_Instruction(I&: MaskA), L: m_Instruction(I&: AR),
2250 R: m_Instruction(I&: RA))) ||
2251 !match(V: Imag, P: m_Select(C: m_Instruction(I&: MaskB), L: m_Instruction(I&: AI),
2252 R: m_Instruction(I&: BI))))
2253 return nullptr;
2254
2255 if (MaskA != MaskB && !MaskA->isIdenticalTo(I: MaskB))
2256 return nullptr;
2257
2258 if (!MaskA->getType()->isVectorTy())
2259 return nullptr;
2260
2261 auto NodeA = identifyNode(R: AR, I: AI);
2262 if (!NodeA)
2263 return nullptr;
2264
2265 auto NodeB = identifyNode(R: RA, I: BI);
2266 if (!NodeB)
2267 return nullptr;
2268
2269 CompositeNode *PlaceholderNode = prepareCompositeNode(
2270 Operation: ComplexDeinterleavingOperation::ReductionSelect, R: Real, I: Imag);
2271 PlaceholderNode->addOperand(Node: NodeA);
2272 PlaceholderNode->addOperand(Node: NodeB);
2273 FinalInstructions.insert(Ptr: MaskA);
2274 FinalInstructions.insert(Ptr: MaskB);
2275 return submitCompositeNode(Node: PlaceholderNode);
2276}
2277
2278static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
2279 std::optional<FastMathFlags> Flags,
2280 Value *InputA, Value *InputB) {
2281 Value *I;
2282 switch (Opcode) {
2283 case Instruction::FNeg:
2284 I = B.CreateFNeg(V: InputA);
2285 break;
2286 case Instruction::FAdd:
2287 I = B.CreateFAdd(L: InputA, R: InputB);
2288 break;
2289 case Instruction::Add:
2290 I = B.CreateAdd(LHS: InputA, RHS: InputB);
2291 break;
2292 case Instruction::FSub:
2293 I = B.CreateFSub(L: InputA, R: InputB);
2294 break;
2295 case Instruction::Sub:
2296 I = B.CreateSub(LHS: InputA, RHS: InputB);
2297 break;
2298 case Instruction::FMul:
2299 I = B.CreateFMul(L: InputA, R: InputB);
2300 break;
2301 case Instruction::Mul:
2302 I = B.CreateMul(LHS: InputA, RHS: InputB);
2303 break;
2304 default:
2305 llvm_unreachable("Incorrect symmetric opcode");
2306 }
2307 if (Flags)
2308 cast<Instruction>(Val: I)->setFastMathFlags(*Flags);
2309 return I;
2310}
2311
2312Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
2313 CompositeNode *Node) {
2314 if (Node->ReplacementNode)
2315 return Node->ReplacementNode;
2316
2317 auto ReplaceOperandIfExist = [&](CompositeNode *Node,
2318 unsigned Idx) -> Value * {
2319 return Node->Operands.size() > Idx
2320 ? replaceNode(Builder, Node: Node->Operands[Idx])
2321 : nullptr;
2322 };
2323
2324 Value *ReplacementNode = nullptr;
2325 switch (Node->Operation) {
2326 case ComplexDeinterleavingOperation::CDot: {
2327 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2328 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2329 Value *Accumulator = ReplaceOperandIfExist(Node, 2);
2330 assert(!Input1 || (Input0->getType() == Input1->getType() &&
2331 "Node inputs need to be of the same type"));
2332 ReplacementNode = TL->createComplexDeinterleavingIR(
2333 B&: Builder, OperationType: Node->Operation, Rotation: Node->Rotation, InputA: Input0, InputB: Input1, Accumulator);
2334 break;
2335 }
2336 case ComplexDeinterleavingOperation::CAdd:
2337 case ComplexDeinterleavingOperation::CMulPartial:
2338 case ComplexDeinterleavingOperation::Symmetric: {
2339 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2340 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2341 Value *Accumulator = ReplaceOperandIfExist(Node, 2);
2342 assert(!Input1 || (Input0->getType() == Input1->getType() &&
2343 "Node inputs need to be of the same type"));
2344 assert(!Accumulator ||
2345 (Input0->getType() == Accumulator->getType() &&
2346 "Accumulator and input need to be of the same type"));
2347 if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
2348 ReplacementNode = replaceSymmetricNode(B&: Builder, Opcode: Node->Opcode, Flags: Node->Flags,
2349 InputA: Input0, InputB: Input1);
2350 else
2351 ReplacementNode = TL->createComplexDeinterleavingIR(
2352 B&: Builder, OperationType: Node->Operation, Rotation: Node->Rotation, InputA: Input0, InputB: Input1,
2353 Accumulator);
2354 break;
2355 }
2356 case ComplexDeinterleavingOperation::Deinterleave:
2357 llvm_unreachable("Deinterleave node should already have ReplacementNode");
2358 break;
2359 case ComplexDeinterleavingOperation::Splat: {
2360 SmallVector<Value *> Ops;
2361 for (auto &V : Node->Vals) {
2362 Ops.push_back(Elt: V.Real);
2363 Ops.push_back(Elt: V.Imag);
2364 }
2365 auto *R = dyn_cast<Instruction>(Val: Node->Vals[0].Real);
2366 auto *I = dyn_cast<Instruction>(Val: Node->Vals[0].Imag);
2367 if (R && I) {
2368 // Splats that are not constant are interleaved where they are located
2369 Instruction *InsertPoint = R;
2370 for (auto V : Node->Vals) {
2371 if (InsertPoint->comesBefore(Other: cast<Instruction>(Val: V.Real)))
2372 InsertPoint = cast<Instruction>(Val: V.Real);
2373 if (InsertPoint->comesBefore(Other: cast<Instruction>(Val: V.Imag)))
2374 InsertPoint = cast<Instruction>(Val: V.Imag);
2375 }
2376 InsertPoint = InsertPoint->getNextNode();
2377 IRBuilder<> IRB(InsertPoint);
2378 ReplacementNode = IRB.CreateVectorInterleave(Ops);
2379 } else {
2380 ReplacementNode = Builder.CreateVectorInterleave(Ops);
2381 }
2382 break;
2383 }
2384 case ComplexDeinterleavingOperation::ReductionPHI: {
2385 // If Operation is ReductionPHI, a new empty PHINode is created.
2386 // It is filled later when the ReductionOperation is processed.
2387 auto *OldPHI = cast<PHINode>(Val: Node->Vals[0].Real);
2388 auto *VTy = cast<VectorType>(Val: Node->Vals[0].Real->getType());
2389 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2390 auto *NewPHI = PHINode::Create(Ty: NewVTy, NumReservedValues: 0, NameStr: "", InsertBefore: BackEdge->getFirstNonPHIIt());
2391 OldToNewPHI[OldPHI] = NewPHI;
2392 ReplacementNode = NewPHI;
2393 break;
2394 }
2395 case ComplexDeinterleavingOperation::ReductionSingle:
2396 ReplacementNode = replaceNode(Builder, Node: Node->Operands[0]);
2397 processReductionSingle(OperationReplacement: ReplacementNode, Node);
2398 break;
2399 case ComplexDeinterleavingOperation::ReductionOperation:
2400 ReplacementNode = replaceNode(Builder, Node: Node->Operands[0]);
2401 processReductionOperation(OperationReplacement: ReplacementNode, Node);
2402 break;
2403 case ComplexDeinterleavingOperation::ReductionSelect: {
2404 auto *MaskReal = cast<Instruction>(Val: Node->Vals[0].Real)->getOperand(i: 0);
2405 auto *MaskImag = cast<Instruction>(Val: Node->Vals[0].Imag)->getOperand(i: 0);
2406 auto *A = replaceNode(Builder, Node: Node->Operands[0]);
2407 auto *B = replaceNode(Builder, Node: Node->Operands[1]);
2408 auto *NewMask = Builder.CreateVectorInterleave(Ops: {MaskReal, MaskImag});
2409 ReplacementNode = Builder.CreateSelect(C: NewMask, True: A, False: B);
2410 break;
2411 }
2412 }
2413
2414 assert(ReplacementNode && "Target failed to create Intrinsic call.");
2415 NumComplexTransformations += 1;
2416 Node->ReplacementNode = ReplacementNode;
2417 return ReplacementNode;
2418}
2419
2420void ComplexDeinterleavingGraph::processReductionSingle(
2421 Value *OperationReplacement, CompositeNode *Node) {
2422 auto *Real = cast<Instruction>(Val: Node->Vals[0].Real);
2423 auto *OldPHI = ReductionInfo[Real].first;
2424 auto *NewPHI = OldToNewPHI[OldPHI];
2425 auto *VTy = cast<VectorType>(Val: Real->getType());
2426 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2427
2428 Value *Init = OldPHI->getIncomingValueForBlock(BB: Incoming);
2429
2430 IRBuilder<> Builder(Incoming->getTerminator());
2431
2432 Value *NewInit = nullptr;
2433 if (auto *C = dyn_cast<Constant>(Val: Init)) {
2434 if (C->isZeroValue())
2435 NewInit = Constant::getNullValue(Ty: NewVTy);
2436 }
2437
2438 if (!NewInit)
2439 NewInit =
2440 Builder.CreateVectorInterleave(Ops: {Init, Constant::getNullValue(Ty: VTy)});
2441
2442 NewPHI->addIncoming(V: NewInit, BB: Incoming);
2443 NewPHI->addIncoming(V: OperationReplacement, BB: BackEdge);
2444
2445 auto *FinalReduction = ReductionInfo[Real].second;
2446 Builder.SetInsertPoint(&*FinalReduction->getParent()->getFirstInsertionPt());
2447
2448 auto *AddReduce = Builder.CreateAddReduce(Src: OperationReplacement);
2449 FinalReduction->replaceAllUsesWith(V: AddReduce);
2450}
2451
2452void ComplexDeinterleavingGraph::processReductionOperation(
2453 Value *OperationReplacement, CompositeNode *Node) {
2454 auto *Real = cast<Instruction>(Val: Node->Vals[0].Real);
2455 auto *Imag = cast<Instruction>(Val: Node->Vals[0].Imag);
2456 auto *OldPHIReal = ReductionInfo[Real].first;
2457 auto *OldPHIImag = ReductionInfo[Imag].first;
2458 auto *NewPHI = OldToNewPHI[OldPHIReal];
2459
2460 // We have to interleave initial origin values coming from IncomingBlock
2461 Value *InitReal = OldPHIReal->getIncomingValueForBlock(BB: Incoming);
2462 Value *InitImag = OldPHIImag->getIncomingValueForBlock(BB: Incoming);
2463
2464 IRBuilder<> Builder(Incoming->getTerminator());
2465 auto *NewInit = Builder.CreateVectorInterleave(Ops: {InitReal, InitImag});
2466
2467 NewPHI->addIncoming(V: NewInit, BB: Incoming);
2468 NewPHI->addIncoming(V: OperationReplacement, BB: BackEdge);
2469
2470 // Deinterleave complex vector outside of loop so that it can be finally
2471 // reduced
2472 auto *FinalReductionReal = ReductionInfo[Real].second;
2473 auto *FinalReductionImag = ReductionInfo[Imag].second;
2474
2475 Builder.SetInsertPoint(
2476 &*FinalReductionReal->getParent()->getFirstInsertionPt());
2477 auto *Deinterleave = Builder.CreateIntrinsic(ID: Intrinsic::vector_deinterleave2,
2478 Types: OperationReplacement->getType(),
2479 Args: OperationReplacement);
2480
2481 auto *NewReal = Builder.CreateExtractValue(Agg: Deinterleave, Idxs: (uint64_t)0);
2482 FinalReductionReal->replaceUsesOfWith(From: Real, To: NewReal);
2483
2484 Builder.SetInsertPoint(FinalReductionImag);
2485 auto *NewImag = Builder.CreateExtractValue(Agg: Deinterleave, Idxs: 1);
2486 FinalReductionImag->replaceUsesOfWith(From: Imag, To: NewImag);
2487}
2488
2489void ComplexDeinterleavingGraph::replaceNodes() {
2490 SmallVector<Instruction *, 16> DeadInstrRoots;
2491 for (auto *RootInstruction : OrderedRoots) {
2492 // Check if this potential root went through check process and we can
2493 // deinterleave it
2494 if (!RootToNode.count(Val: RootInstruction))
2495 continue;
2496
2497 IRBuilder<> Builder(RootInstruction);
2498 auto RootNode = RootToNode[RootInstruction];
2499 Value *R = replaceNode(Builder, Node: RootNode);
2500
2501 if (RootNode->Operation ==
2502 ComplexDeinterleavingOperation::ReductionOperation) {
2503 auto *RootReal = cast<Instruction>(Val: RootNode->Vals[0].Real);
2504 auto *RootImag = cast<Instruction>(Val: RootNode->Vals[0].Imag);
2505 ReductionInfo[RootReal].first->removeIncomingValue(BB: BackEdge);
2506 ReductionInfo[RootImag].first->removeIncomingValue(BB: BackEdge);
2507 DeadInstrRoots.push_back(Elt: RootReal);
2508 DeadInstrRoots.push_back(Elt: RootImag);
2509 } else if (RootNode->Operation ==
2510 ComplexDeinterleavingOperation::ReductionSingle) {
2511 auto *RootInst = cast<Instruction>(Val: RootNode->Vals[0].Real);
2512 auto &Info = ReductionInfo[RootInst];
2513 Info.first->removeIncomingValue(BB: BackEdge);
2514 DeadInstrRoots.push_back(Elt: Info.second);
2515 } else {
2516 assert(R && "Unable to find replacement for RootInstruction");
2517 DeadInstrRoots.push_back(Elt: RootInstruction);
2518 RootInstruction->replaceAllUsesWith(V: R);
2519 }
2520 }
2521
2522 for (auto *I : DeadInstrRoots)
2523 RecursivelyDeleteTriviallyDeadInstructions(V: I, TLI);
2524}
2525