1//===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Identification:
10// This step is responsible for finding the patterns that can be lowered to
11// complex instructions, and building a graph to represent the complex
12// structures. Starting from the "Converging Shuffle" (a shuffle that
13// reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14// operands are evaluated and identified as "Composite Nodes" (collections of
15// instructions that can potentially be lowered to a single complex
16// instruction). This is performed by checking the real and imaginary components
17// and tracking the data flow for each component while following the operand
18// pairs. Validity of each node is expected to be done upon creation, and any
19// validation errors should halt traversal and prevent further graph
20// construction.
21// Instead of relying on Shuffle operations, vector interleaving and
22// deinterleaving can be represented by vector.interleave2 and
23// vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24// these intrinsics, whereas, fixed-width vectors are recognized for both
25// shufflevector instruction and intrinsics.
26//
27// Replacement:
28// This step traverses the graph built up by identification, delegating to the
29// target to validate and generate the correct intrinsics, and plumbs them
30// together connecting each end of the new intrinsics graph to the existing
31// use-def chain. This step is assumed to finish successfully, as all
32// information is expected to be correct by this point.
33//
34//
35// Internal data structure:
36// ComplexDeinterleavingGraph:
37// Keeps references to all the valid CompositeNodes formed as part of the
38// transformation, and every Instruction contained within said nodes. It also
39// holds onto a reference to the root Instruction, and the root node that should
40// replace it.
41//
42// ComplexDeinterleavingCompositeNode:
43// A CompositeNode represents a single transformation point; each node should
44// transform into a single complex instruction (ignoring vector splitting, which
45// would generate more instructions per node). They are identified in a
46// depth-first manner, traversing and identifying the operands of each
47// instruction in the order they appear in the IR.
48// Each node maintains a reference to its Real and Imaginary instructions,
49// as well as any additional instructions that make up the identified operation
50// (Internal instructions should only have uses within their containing node).
51// A Node also contains the rotation and operation type that it represents.
52// Operands contains pointers to other CompositeNodes, acting as the edges in
53// the graph. ReplacementValue is the transformed Value* that has been emitted
54// to the IR.
55//
56// Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57// ReplacementValue fields of that Node are relevant, where the ReplacementValue
58// should be pre-populated.
59//
60//===----------------------------------------------------------------------===//
61
62#include "llvm/CodeGen/ComplexDeinterleavingPass.h"
63#include "llvm/ADT/AllocatorList.h"
64#include "llvm/ADT/MapVector.h"
65#include "llvm/ADT/Statistic.h"
66#include "llvm/Analysis/TargetLibraryInfo.h"
67#include "llvm/Analysis/TargetTransformInfo.h"
68#include "llvm/CodeGen/TargetLowering.h"
69#include "llvm/CodeGen/TargetSubtargetInfo.h"
70#include "llvm/IR/IRBuilder.h"
71#include "llvm/IR/Intrinsics.h"
72#include "llvm/IR/PatternMatch.h"
73#include "llvm/InitializePasses.h"
74#include "llvm/Support/Allocator.h"
75#include "llvm/Target/TargetMachine.h"
76#include "llvm/Transforms/Utils/Local.h"
77#include <algorithm>
78
79using namespace llvm;
80using namespace PatternMatch;
81
82#define DEBUG_TYPE "complex-deinterleaving"
83
84STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
85
86static cl::opt<bool> ComplexDeinterleavingEnabled(
87 "enable-complex-deinterleaving",
88 cl::desc("Enable generation of complex instructions"), cl::init(Val: true),
89 cl::Hidden);
90
91/// Checks the given mask, and determines whether said mask is interleaving.
92///
93/// To be interleaving, a mask must alternate between `i` and `i + (Length /
94/// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
95/// 4x vector interleaving mask would be <0, 2, 1, 3>).
96static bool isInterleavingMask(ArrayRef<int> Mask);
97
98/// Checks the given mask, and determines whether said mask is deinterleaving.
99///
100/// To be deinterleaving, a mask must increment in steps of 2, and either start
101/// with 0 or 1.
102/// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
103/// <1, 3, 5, 7>).
104static bool isDeinterleavingMask(ArrayRef<int> Mask);
105
106/// Returns true if the operation is a negation of V, and it works for both
107/// integers and floats.
108static bool isNeg(Value *V);
109
110/// Returns the operand for negation operation.
111static Value *getNegOperand(Value *V);
112
113namespace {
114struct ComplexValue {
115 Value *Real = nullptr;
116 Value *Imag = nullptr;
117
118 bool operator==(const ComplexValue &Other) const {
119 return Real == Other.Real && Imag == Other.Imag;
120 }
121};
122hash_code hash_value(const ComplexValue &Arg) {
123 return hash_combine(args: DenseMapInfo<Value *>::getHashValue(PtrVal: Arg.Real),
124 args: DenseMapInfo<Value *>::getHashValue(PtrVal: Arg.Imag));
125}
126} // end namespace
127typedef SmallVector<struct ComplexValue, 2> ComplexValues;
128
129template <> struct llvm::DenseMapInfo<ComplexValue> {
130 static unsigned getHashValue(const ComplexValue &Val) {
131 return hash_combine(args: DenseMapInfo<Value *>::getHashValue(PtrVal: Val.Real),
132 args: DenseMapInfo<Value *>::getHashValue(PtrVal: Val.Imag));
133 }
134 static bool isEqual(const ComplexValue &LHS, const ComplexValue &RHS) {
135 return LHS.Real == RHS.Real && LHS.Imag == RHS.Imag;
136 }
137};
138
139namespace {
140template <typename T, typename IterT>
141std::optional<T> findCommonBetweenCollections(IterT A, IterT B) {
142 auto Common = llvm::find_if(A, [B](T I) { return llvm::is_contained(B, I); });
143 if (Common != A.end())
144 return std::make_optional(*Common);
145 return std::nullopt;
146}
147
148class ComplexDeinterleavingLegacyPass : public FunctionPass {
149public:
150 static char ID;
151
152 ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
153 : FunctionPass(ID), TM(TM) {}
154
155 StringRef getPassName() const override {
156 return "Complex Deinterleaving Pass";
157 }
158
159 bool runOnFunction(Function &F) override;
160 void getAnalysisUsage(AnalysisUsage &AU) const override {
161 AU.addRequired<TargetLibraryInfoWrapperPass>();
162 AU.setPreservesCFG();
163 }
164
165private:
166 const TargetMachine *TM;
167};
168
169class ComplexDeinterleavingGraph;
170struct ComplexDeinterleavingCompositeNode {
171
172 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
173 Value *R, Value *I)
174 : Operation(Op) {
175 Vals.push_back(Elt: {.Real: R, .Imag: I});
176 }
177
178 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
179 ComplexValues &Other)
180 : Operation(Op), Vals(Other) {}
181
182private:
183 friend class ComplexDeinterleavingGraph;
184 using CompositeNode = ComplexDeinterleavingCompositeNode;
185 bool OperandsValid = true;
186
187public:
188 ComplexDeinterleavingOperation Operation;
189 ComplexValues Vals;
190
191 // This two members are required exclusively for generating
192 // ComplexDeinterleavingOperation::Symmetric operations.
193 unsigned Opcode;
194 std::optional<FastMathFlags> Flags;
195
196 ComplexDeinterleavingRotation Rotation =
197 ComplexDeinterleavingRotation::Rotation_0;
198 SmallVector<CompositeNode *> Operands;
199 Value *ReplacementNode = nullptr;
200
201 void addOperand(CompositeNode *Node) {
202 if (!Node)
203 OperandsValid = false;
204 Operands.push_back(Elt: Node);
205 }
206
207 void dump() { dump(OS&: dbgs()); }
208 void dump(raw_ostream &OS) {
209 auto PrintValue = [&](Value *V) {
210 if (V) {
211 OS << "\"";
212 V->print(O&: OS, IsForDebug: true);
213 OS << "\"\n";
214 } else
215 OS << "nullptr\n";
216 };
217 auto PrintNodeRef = [&](CompositeNode *Ptr) {
218 if (Ptr)
219 OS << Ptr << "\n";
220 else
221 OS << "nullptr\n";
222 };
223
224 OS << "- CompositeNode: " << this << "\n";
225 for (unsigned I = 0; I < Vals.size(); I++) {
226 OS << " Real(" << I << ") : ";
227 PrintValue(Vals[I].Real);
228 OS << " Imag(" << I << ") : ";
229 PrintValue(Vals[I].Imag);
230 }
231 OS << " ReplacementNode: ";
232 PrintValue(ReplacementNode);
233 OS << " Operation: " << (int)Operation << "\n";
234 OS << " Rotation: " << ((int)Rotation * 90) << "\n";
235 OS << " Operands: \n";
236 for (const auto &Op : Operands) {
237 OS << " - ";
238 PrintNodeRef(Op);
239 }
240 }
241
242 bool areOperandsValid() { return OperandsValid; }
243};
244
245class ComplexDeinterleavingGraph {
246public:
247 struct Product {
248 Value *Multiplier;
249 Value *Multiplicand;
250 bool IsPositive;
251 };
252
253 using Addend = std::pair<Value *, bool>;
254 using AddendList = BumpPtrList<Addend>;
255 using CompositeNode = ComplexDeinterleavingCompositeNode::CompositeNode;
256
257 // Helper struct for holding info about potential partial multiplication
258 // candidates
259 struct PartialMulCandidate {
260 Value *Common;
261 CompositeNode *Node;
262 unsigned RealIdx;
263 unsigned ImagIdx;
264 bool IsNodeInverted;
265 };
266
267 explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
268 const TargetLibraryInfo *TLI,
269 unsigned Factor)
270 : TL(TL), TLI(TLI), Factor(Factor) {}
271
272private:
273 const TargetLowering *TL = nullptr;
274 const TargetLibraryInfo *TLI = nullptr;
275 unsigned Factor;
276 SmallVector<CompositeNode *> CompositeNodes;
277 DenseMap<ComplexValues, CompositeNode *> CachedResult;
278 SpecificBumpPtrAllocator<ComplexDeinterleavingCompositeNode> Allocator;
279
280 SmallPtrSet<Instruction *, 16> FinalInstructions;
281
282 /// Root instructions are instructions from which complex computation starts
283 DenseMap<Instruction *, CompositeNode *> RootToNode;
284
285 /// Topologically sorted root instructions
286 SmallVector<Instruction *, 1> OrderedRoots;
287
288 /// When examining a basic block for complex deinterleaving, if it is a simple
289 /// one-block loop, then the only incoming block is 'Incoming' and the
290 /// 'BackEdge' block is the block itself."
291 BasicBlock *BackEdge = nullptr;
292 BasicBlock *Incoming = nullptr;
293
294 /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
295 /// %OutsideUser as it is shown in the IR:
296 ///
297 /// vector.body:
298 /// %PHInode = phi <vector type> [ zeroinitializer, %entry ],
299 /// [ %ReductionOp, %vector.body ]
300 /// ...
301 /// %ReductionOp = fadd i64 ...
302 /// ...
303 /// br i1 %condition, label %vector.body, %middle.block
304 ///
305 /// middle.block:
306 /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
307 ///
308 /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
309 /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
310 MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
311
312 /// In the process of detecting a reduction, we consider a pair of
313 /// %ReductionOP, which we refer to as real and imag (or vice versa), and
314 /// traverse the use-tree to detect complex operations. As this is a reduction
315 /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
316 /// to the %ReductionOPs that we suspect to be complex.
317 /// RealPHI and ImagPHI are used by the identifyPHINode method.
318 PHINode *RealPHI = nullptr;
319 PHINode *ImagPHI = nullptr;
320
321 /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
322 /// detection.
323 bool PHIsFound = false;
324
325 /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
326 /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
327 /// This mapping is populated during
328 /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
329 /// used in the ComplexDeinterleavingOperation::ReductionOperation node
330 /// replacement process.
331 DenseMap<PHINode *, PHINode *> OldToNewPHI;
332
333 CompositeNode *prepareCompositeNode(ComplexDeinterleavingOperation Operation,
334 Value *R, Value *I) {
335 assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
336 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
337 (R && I)) &&
338 "Reduction related nodes must have Real and Imaginary parts");
339 return new (Allocator.Allocate())
340 ComplexDeinterleavingCompositeNode(Operation, R, I);
341 }
342
343 CompositeNode *prepareCompositeNode(ComplexDeinterleavingOperation Operation,
344 ComplexValues &Vals) {
345#ifndef NDEBUG
346 for (auto &V : Vals) {
347 assert(
348 ((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
349 Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
350 (V.Real && V.Imag)) &&
351 "Reduction related nodes must have Real and Imaginary parts");
352 }
353#endif
354 return new (Allocator.Allocate())
355 ComplexDeinterleavingCompositeNode(Operation, Vals);
356 }
357
358 CompositeNode *submitCompositeNode(CompositeNode *Node) {
359 CompositeNodes.push_back(Elt: Node);
360 if (Node->Vals[0].Real)
361 CachedResult[Node->Vals] = Node;
362 return Node;
363 }
364
365 /// Identifies a complex partial multiply pattern and its rotation, based on
366 /// the following patterns
367 ///
368 /// 0: r: cr + ar * br
369 /// i: ci + ar * bi
370 /// 90: r: cr - ai * bi
371 /// i: ci + ai * br
372 /// 180: r: cr - ar * br
373 /// i: ci - ar * bi
374 /// 270: r: cr + ai * bi
375 /// i: ci - ai * br
376 CompositeNode *identifyPartialMul(Instruction *Real, Instruction *Imag);
377
378 /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
379 /// is partially known from identifyPartialMul, filling in the other half of
380 /// the complex pair.
381 CompositeNode *
382 identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
383 std::pair<Value *, Value *> &CommonOperandI);
384
385 /// Identifies a complex add pattern and its rotation, based on the following
386 /// patterns.
387 ///
388 /// 90: r: ar - bi
389 /// i: ai + br
390 /// 270: r: ar + bi
391 /// i: ai - br
392 CompositeNode *identifyAdd(Instruction *Real, Instruction *Imag);
393 CompositeNode *identifySymmetricOperation(ComplexValues &Vals);
394 CompositeNode *identifyPartialReduction(Value *R, Value *I);
395 CompositeNode *identifyDotProduct(Value *Inst);
396
397 CompositeNode *identifyNode(ComplexValues &Vals);
398
399 CompositeNode *identifyNode(Value *R, Value *I) {
400 ComplexValues Vals;
401 Vals.push_back(Elt: {.Real: R, .Imag: I});
402 return identifyNode(Vals);
403 }
404
405 /// Determine if a sum of complex numbers can be formed from \p RealAddends
406 /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
407 /// Return nullptr if it is not possible to construct a complex number.
408 /// \p Flags are needed to generate symmetric Add and Sub operations.
409 CompositeNode *identifyAdditions(AddendList &RealAddends,
410 AddendList &ImagAddends,
411 std::optional<FastMathFlags> Flags,
412 CompositeNode *Accumulator);
413
414 /// Extract one addend that have both real and imaginary parts positive.
415 CompositeNode *extractPositiveAddend(AddendList &RealAddends,
416 AddendList &ImagAddends);
417
418 /// Determine if sum of multiplications of complex numbers can be formed from
419 /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
420 /// to it. Return nullptr if it is not possible to construct a complex number.
421 CompositeNode *identifyMultiplications(SmallVectorImpl<Product> &RealMuls,
422 SmallVectorImpl<Product> &ImagMuls,
423 CompositeNode *Accumulator);
424
425 /// Go through pairs of multiplication (one Real and one Imag) and find all
426 /// possible candidates for partial multiplication and put them into \p
427 /// Candidates. Returns true if all Product has pair with common operand
428 bool collectPartialMuls(ArrayRef<Product> RealMuls,
429 ArrayRef<Product> ImagMuls,
430 SmallVectorImpl<PartialMulCandidate> &Candidates);
431
432 /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
433 /// the order of complex computation operations may be significantly altered,
434 /// and the real and imaginary parts may not be executed in parallel. This
435 /// function takes this into consideration and employs a more general approach
436 /// to identify complex computations. Initially, it gathers all the addends
437 /// and multiplicands and then constructs a complex expression from them.
438 CompositeNode *identifyReassocNodes(Instruction *I, Instruction *J);
439
440 CompositeNode *identifyRoot(Instruction *I);
441
442 /// Identifies the Deinterleave operation applied to a vector containing
443 /// complex numbers. There are two ways to represent the Deinterleave
444 /// operation:
445 /// * Using two shufflevectors with even indices for /pReal instruction and
446 /// odd indices for /pImag instructions (only for fixed-width vectors)
447 /// * Using N extractvalue instructions applied to `vector.deinterleaveN`
448 /// intrinsics (for both fixed and scalable vectors) where N is a multiple of
449 /// 2.
450 CompositeNode *identifyDeinterleave(ComplexValues &Vals);
451
452 /// identifying the operation that represents a complex number repeated in a
453 /// Splat vector. There are two possible types of splats: ConstantExpr with
454 /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
455 /// initialization mask with all values set to zero.
456 CompositeNode *identifySplat(ComplexValues &Vals);
457
458 CompositeNode *identifyPHINode(Instruction *Real, Instruction *Imag);
459
460 /// Identifies SelectInsts in a loop that has reduction with predication masks
461 /// and/or predicated tail folding
462 CompositeNode *identifySelectNode(Instruction *Real, Instruction *Imag);
463
464 Value *replaceNode(IRBuilderBase &Builder, CompositeNode *Node);
465
466 /// Complete IR modifications after producing new reduction operation:
467 /// * Populate the PHINode generated for
468 /// ComplexDeinterleavingOperation::ReductionPHI
469 /// * Deinterleave the final value outside of the loop and repurpose original
470 /// reduction users
471 void processReductionOperation(Value *OperationReplacement,
472 CompositeNode *Node);
473 void processReductionSingle(Value *OperationReplacement, CompositeNode *Node);
474
475public:
476 void dump() { dump(OS&: dbgs()); }
477 void dump(raw_ostream &OS) {
478 for (const auto &Node : CompositeNodes)
479 Node->dump(OS);
480 }
481
482 /// Returns false if the deinterleaving operation should be cancelled for the
483 /// current graph.
484 bool identifyNodes(Instruction *RootI);
485
486 /// In case \pB is one-block loop, this function seeks potential reductions
487 /// and populates ReductionInfo. Returns true if any reductions were
488 /// identified.
489 bool collectPotentialReductions(BasicBlock *B);
490
491 void identifyReductionNodes();
492
493 /// Check that every instruction, from the roots to the leaves, has internal
494 /// uses.
495 bool checkNodes();
496
497 /// Perform the actual replacement of the underlying instruction graph.
498 void replaceNodes();
499};
500
501class ComplexDeinterleaving {
502public:
503 ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
504 : TL(tl), TLI(tli) {}
505 bool runOnFunction(Function &F);
506
507private:
508 bool evaluateBasicBlock(BasicBlock *B, unsigned Factor);
509
510 const TargetLowering *TL = nullptr;
511 const TargetLibraryInfo *TLI = nullptr;
512};
513
514} // namespace
515
516char ComplexDeinterleavingLegacyPass::ID = 0;
517
518INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
519 "Complex Deinterleaving", false, false)
520INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
521 "Complex Deinterleaving", false, false)
522
523PreservedAnalyses ComplexDeinterleavingPass::run(Function &F,
524 FunctionAnalysisManager &AM) {
525 const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
526 auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(IR&: F);
527 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
528 return PreservedAnalyses::all();
529
530 PreservedAnalyses PA;
531 PA.preserve<FunctionAnalysisManagerModuleProxy>();
532 return PA;
533}
534
535FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) {
536 return new ComplexDeinterleavingLegacyPass(TM);
537}
538
539bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
540 const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
541 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
542 return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
543}
544
545bool ComplexDeinterleaving::runOnFunction(Function &F) {
546 if (!ComplexDeinterleavingEnabled) {
547 LLVM_DEBUG(
548 dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
549 return false;
550 }
551
552 if (!TL->isComplexDeinterleavingSupported()) {
553 LLVM_DEBUG(
554 dbgs() << "Complex deinterleaving has been disabled, target does "
555 "not support lowering of complex number operations.\n");
556 return false;
557 }
558
559 bool Changed = false;
560 for (auto &B : F)
561 Changed |= evaluateBasicBlock(B: &B, Factor: 2);
562
563 // TODO: Permit changes for both interleave factors in the same function.
564 if (!Changed) {
565 for (auto &B : F)
566 Changed |= evaluateBasicBlock(B: &B, Factor: 4);
567 }
568
569 // TODO: We can also support interleave factors of 6 and 8 if needed.
570
571 return Changed;
572}
573
574static bool isInterleavingMask(ArrayRef<int> Mask) {
575 // If the size is not even, it's not an interleaving mask
576 if ((Mask.size() & 1))
577 return false;
578
579 int HalfNumElements = Mask.size() / 2;
580 for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
581 int MaskIdx = Idx * 2;
582 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
583 return false;
584 }
585
586 return true;
587}
588
589static bool isDeinterleavingMask(ArrayRef<int> Mask) {
590 int Offset = Mask[0];
591 int HalfNumElements = Mask.size() / 2;
592
593 for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
594 if (Mask[Idx] != (Idx * 2) + Offset)
595 return false;
596 }
597
598 return true;
599}
600
601bool isNeg(Value *V) {
602 return match(V, P: m_FNeg(X: m_Value())) || match(V, P: m_Neg(V: m_Value()));
603}
604
605Value *getNegOperand(Value *V) {
606 assert(isNeg(V));
607 auto *I = cast<Instruction>(Val: V);
608 if (I->getOpcode() == Instruction::FNeg)
609 return I->getOperand(i: 0);
610
611 return I->getOperand(i: 1);
612}
613
614bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B, unsigned Factor) {
615 ComplexDeinterleavingGraph Graph(TL, TLI, Factor);
616 if (Graph.collectPotentialReductions(B))
617 Graph.identifyReductionNodes();
618
619 for (auto &I : *B)
620 Graph.identifyNodes(RootI: &I);
621
622 if (Graph.checkNodes()) {
623 Graph.replaceNodes();
624 return true;
625 }
626
627 return false;
628}
629
630ComplexDeinterleavingGraph::CompositeNode *
631ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
632 Instruction *Real, Instruction *Imag,
633 std::pair<Value *, Value *> &PartialMatch) {
634 LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
635 << "\n");
636
637 if (!Real->hasOneUse() || !Imag->hasOneUse()) {
638 LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n");
639 return nullptr;
640 }
641
642 if ((Real->getOpcode() != Instruction::FMul &&
643 Real->getOpcode() != Instruction::Mul) ||
644 (Imag->getOpcode() != Instruction::FMul &&
645 Imag->getOpcode() != Instruction::Mul)) {
646 LLVM_DEBUG(
647 dbgs() << " - Real or imaginary instruction is not fmul or mul\n");
648 return nullptr;
649 }
650
651 Value *R0 = Real->getOperand(i: 0);
652 Value *R1 = Real->getOperand(i: 1);
653 Value *I0 = Imag->getOperand(i: 0);
654 Value *I1 = Imag->getOperand(i: 1);
655
656 // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
657 // rotations and use the operand.
658 unsigned Negs = 0;
659 Value *Op;
660 if (match(V: R0, P: m_Neg(V: m_Value(V&: Op)))) {
661 Negs |= 1;
662 R0 = Op;
663 } else if (match(V: R1, P: m_Neg(V: m_Value(V&: Op)))) {
664 Negs |= 1;
665 R1 = Op;
666 }
667
668 if (isNeg(V: I0)) {
669 Negs |= 2;
670 Negs ^= 1;
671 I0 = Op;
672 } else if (match(V: I1, P: m_Neg(V: m_Value(V&: Op)))) {
673 Negs |= 2;
674 Negs ^= 1;
675 I1 = Op;
676 }
677
678 ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;
679
680 Value *CommonOperand;
681 Value *UncommonRealOp;
682 Value *UncommonImagOp;
683
684 if (R0 == I0 || R0 == I1) {
685 CommonOperand = R0;
686 UncommonRealOp = R1;
687 } else if (R1 == I0 || R1 == I1) {
688 CommonOperand = R1;
689 UncommonRealOp = R0;
690 } else {
691 LLVM_DEBUG(dbgs() << " - No equal operand\n");
692 return nullptr;
693 }
694
695 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
696 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
697 Rotation == ComplexDeinterleavingRotation::Rotation_270)
698 std::swap(a&: UncommonRealOp, b&: UncommonImagOp);
699
700 // Between identifyPartialMul and here we need to have found a complete valid
701 // pair from the CommonOperand of each part.
702 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
703 Rotation == ComplexDeinterleavingRotation::Rotation_180)
704 PartialMatch.first = CommonOperand;
705 else
706 PartialMatch.second = CommonOperand;
707
708 if (!PartialMatch.first || !PartialMatch.second) {
709 LLVM_DEBUG(dbgs() << " - Incomplete partial match\n");
710 return nullptr;
711 }
712
713 CompositeNode *CommonNode =
714 identifyNode(R: PartialMatch.first, I: PartialMatch.second);
715 if (!CommonNode) {
716 LLVM_DEBUG(dbgs() << " - No CommonNode identified\n");
717 return nullptr;
718 }
719
720 CompositeNode *UncommonNode = identifyNode(R: UncommonRealOp, I: UncommonImagOp);
721 if (!UncommonNode) {
722 LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n");
723 return nullptr;
724 }
725
726 CompositeNode *Node = prepareCompositeNode(
727 Operation: ComplexDeinterleavingOperation::CMulPartial, R: Real, I: Imag);
728 Node->Rotation = Rotation;
729 Node->addOperand(Node: CommonNode);
730 Node->addOperand(Node: UncommonNode);
731 return submitCompositeNode(Node);
732}
733
734ComplexDeinterleavingGraph::CompositeNode *
735ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
736 Instruction *Imag) {
737 LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
738 << "\n");
739
740 // Determine rotation
741 auto IsAdd = [](unsigned Op) {
742 return Op == Instruction::FAdd || Op == Instruction::Add;
743 };
744 auto IsSub = [](unsigned Op) {
745 return Op == Instruction::FSub || Op == Instruction::Sub;
746 };
747 ComplexDeinterleavingRotation Rotation;
748 if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
749 Rotation = ComplexDeinterleavingRotation::Rotation_0;
750 else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
751 Rotation = ComplexDeinterleavingRotation::Rotation_90;
752 else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
753 Rotation = ComplexDeinterleavingRotation::Rotation_180;
754 else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
755 Rotation = ComplexDeinterleavingRotation::Rotation_270;
756 else {
757 LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n");
758 return nullptr;
759 }
760
761 if (isa<FPMathOperator>(Val: Real) &&
762 (!Real->getFastMathFlags().allowContract() ||
763 !Imag->getFastMathFlags().allowContract())) {
764 LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n");
765 return nullptr;
766 }
767
768 Value *CR = Real->getOperand(i: 0);
769 Instruction *RealMulI = dyn_cast<Instruction>(Val: Real->getOperand(i: 1));
770 if (!RealMulI)
771 return nullptr;
772 Value *CI = Imag->getOperand(i: 0);
773 Instruction *ImagMulI = dyn_cast<Instruction>(Val: Imag->getOperand(i: 1));
774 if (!ImagMulI)
775 return nullptr;
776
777 if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
778 LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n");
779 return nullptr;
780 }
781
782 Value *R0 = RealMulI->getOperand(i: 0);
783 Value *R1 = RealMulI->getOperand(i: 1);
784 Value *I0 = ImagMulI->getOperand(i: 0);
785 Value *I1 = ImagMulI->getOperand(i: 1);
786
787 Value *CommonOperand;
788 Value *UncommonRealOp;
789 Value *UncommonImagOp;
790
791 if (R0 == I0 || R0 == I1) {
792 CommonOperand = R0;
793 UncommonRealOp = R1;
794 } else if (R1 == I0 || R1 == I1) {
795 CommonOperand = R1;
796 UncommonRealOp = R0;
797 } else {
798 LLVM_DEBUG(dbgs() << " - No equal operand\n");
799 return nullptr;
800 }
801
802 UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
803 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
804 Rotation == ComplexDeinterleavingRotation::Rotation_270)
805 std::swap(a&: UncommonRealOp, b&: UncommonImagOp);
806
807 std::pair<Value *, Value *> PartialMatch(
808 (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
809 Rotation == ComplexDeinterleavingRotation::Rotation_180)
810 ? CommonOperand
811 : nullptr,
812 (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
813 Rotation == ComplexDeinterleavingRotation::Rotation_270)
814 ? CommonOperand
815 : nullptr);
816
817 auto *CRInst = dyn_cast<Instruction>(Val: CR);
818 auto *CIInst = dyn_cast<Instruction>(Val: CI);
819
820 if (!CRInst || !CIInst) {
821 LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n");
822 return nullptr;
823 }
824
825 CompositeNode *CNode =
826 identifyNodeWithImplicitAdd(Real: CRInst, Imag: CIInst, PartialMatch);
827 if (!CNode) {
828 LLVM_DEBUG(dbgs() << " - No cnode identified\n");
829 return nullptr;
830 }
831
832 CompositeNode *UncommonRes = identifyNode(R: UncommonRealOp, I: UncommonImagOp);
833 if (!UncommonRes) {
834 LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n");
835 return nullptr;
836 }
837
838 assert(PartialMatch.first && PartialMatch.second);
839 CompositeNode *CommonRes =
840 identifyNode(R: PartialMatch.first, I: PartialMatch.second);
841 if (!CommonRes) {
842 LLVM_DEBUG(dbgs() << " - No CommonRes identified\n");
843 return nullptr;
844 }
845
846 CompositeNode *Node = prepareCompositeNode(
847 Operation: ComplexDeinterleavingOperation::CMulPartial, R: Real, I: Imag);
848 Node->Rotation = Rotation;
849 Node->addOperand(Node: CommonRes);
850 Node->addOperand(Node: UncommonRes);
851 Node->addOperand(Node: CNode);
852 return submitCompositeNode(Node);
853}
854
855ComplexDeinterleavingGraph::CompositeNode *
856ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
857 LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
858
859 // Determine rotation
860 ComplexDeinterleavingRotation Rotation;
861 if ((Real->getOpcode() == Instruction::FSub &&
862 Imag->getOpcode() == Instruction::FAdd) ||
863 (Real->getOpcode() == Instruction::Sub &&
864 Imag->getOpcode() == Instruction::Add))
865 Rotation = ComplexDeinterleavingRotation::Rotation_90;
866 else if ((Real->getOpcode() == Instruction::FAdd &&
867 Imag->getOpcode() == Instruction::FSub) ||
868 (Real->getOpcode() == Instruction::Add &&
869 Imag->getOpcode() == Instruction::Sub))
870 Rotation = ComplexDeinterleavingRotation::Rotation_270;
871 else {
872 LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
873 return nullptr;
874 }
875
876 auto *AR = dyn_cast<Instruction>(Val: Real->getOperand(i: 0));
877 auto *BI = dyn_cast<Instruction>(Val: Real->getOperand(i: 1));
878 auto *AI = dyn_cast<Instruction>(Val: Imag->getOperand(i: 0));
879 auto *BR = dyn_cast<Instruction>(Val: Imag->getOperand(i: 1));
880
881 if (!AR || !AI || !BR || !BI) {
882 LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
883 return nullptr;
884 }
885
886 CompositeNode *ResA = identifyNode(R: AR, I: AI);
887 if (!ResA) {
888 LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
889 return nullptr;
890 }
891 CompositeNode *ResB = identifyNode(R: BR, I: BI);
892 if (!ResB) {
893 LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
894 return nullptr;
895 }
896
897 CompositeNode *Node =
898 prepareCompositeNode(Operation: ComplexDeinterleavingOperation::CAdd, R: Real, I: Imag);
899 Node->Rotation = Rotation;
900 Node->addOperand(Node: ResA);
901 Node->addOperand(Node: ResB);
902 return submitCompositeNode(Node);
903}
904
905static bool isInstructionPairAdd(Instruction *A, Instruction *B) {
906 unsigned OpcA = A->getOpcode();
907 unsigned OpcB = B->getOpcode();
908
909 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
910 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
911 (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
912 (OpcA == Instruction::Add && OpcB == Instruction::Sub);
913}
914
915static bool isInstructionPairMul(Instruction *A, Instruction *B) {
916 auto Pattern =
917 m_BinOp(L: m_FMul(L: m_Value(), R: m_Value()), R: m_FMul(L: m_Value(), R: m_Value()));
918
919 return match(V: A, P: Pattern) && match(V: B, P: Pattern);
920}
921
922static bool isInstructionPotentiallySymmetric(Instruction *I) {
923 switch (I->getOpcode()) {
924 case Instruction::FAdd:
925 case Instruction::FSub:
926 case Instruction::FMul:
927 case Instruction::FNeg:
928 case Instruction::Add:
929 case Instruction::Sub:
930 case Instruction::Mul:
931 return true;
932 default:
933 return false;
934 }
935}
936
937ComplexDeinterleavingGraph::CompositeNode *
938ComplexDeinterleavingGraph::identifySymmetricOperation(ComplexValues &Vals) {
939 auto *FirstReal = cast<Instruction>(Val: Vals[0].Real);
940 unsigned FirstOpc = FirstReal->getOpcode();
941 for (auto &V : Vals) {
942 auto *Real = cast<Instruction>(Val: V.Real);
943 auto *Imag = cast<Instruction>(Val: V.Imag);
944 if (Real->getOpcode() != FirstOpc || Imag->getOpcode() != FirstOpc)
945 return nullptr;
946
947 if (!isInstructionPotentiallySymmetric(I: Real) ||
948 !isInstructionPotentiallySymmetric(I: Imag))
949 return nullptr;
950
951 if (isa<FPMathOperator>(Val: FirstReal))
952 if (Real->getFastMathFlags() != FirstReal->getFastMathFlags() ||
953 Imag->getFastMathFlags() != FirstReal->getFastMathFlags())
954 return nullptr;
955 }
956
957 ComplexValues OpVals;
958 for (auto &V : Vals) {
959 auto *R0 = cast<Instruction>(Val: V.Real)->getOperand(i: 0);
960 auto *I0 = cast<Instruction>(Val: V.Imag)->getOperand(i: 0);
961 OpVals.push_back(Elt: {.Real: R0, .Imag: I0});
962 }
963
964 CompositeNode *Op0 = identifyNode(Vals&: OpVals);
965 CompositeNode *Op1 = nullptr;
966 if (Op0 == nullptr)
967 return nullptr;
968
969 if (FirstReal->isBinaryOp()) {
970 OpVals.clear();
971 for (auto &V : Vals) {
972 auto *R1 = cast<Instruction>(Val: V.Real)->getOperand(i: 1);
973 auto *I1 = cast<Instruction>(Val: V.Imag)->getOperand(i: 1);
974 OpVals.push_back(Elt: {.Real: R1, .Imag: I1});
975 }
976 Op1 = identifyNode(Vals&: OpVals);
977 if (Op1 == nullptr)
978 return nullptr;
979 }
980
981 auto Node =
982 prepareCompositeNode(Operation: ComplexDeinterleavingOperation::Symmetric, Vals);
983 Node->Opcode = FirstReal->getOpcode();
984 if (isa<FPMathOperator>(Val: FirstReal))
985 Node->Flags = FirstReal->getFastMathFlags();
986
987 Node->addOperand(Node: Op0);
988 if (FirstReal->isBinaryOp())
989 Node->addOperand(Node: Op1);
990
991 return submitCompositeNode(Node);
992}
993
994ComplexDeinterleavingGraph::CompositeNode *
995ComplexDeinterleavingGraph::identifyDotProduct(Value *V) {
996 if (!TL->isComplexDeinterleavingOperationSupported(
997 Operation: ComplexDeinterleavingOperation::CDot, Ty: V->getType())) {
998 LLVM_DEBUG(dbgs() << "Target doesn't support complex deinterleaving "
999 "operation CDot with the type "
1000 << *V->getType() << "\n");
1001 return nullptr;
1002 }
1003
1004 auto *Inst = cast<Instruction>(Val: V);
1005 auto *RealUser = cast<Instruction>(Val: *Inst->user_begin());
1006
1007 CompositeNode *CN =
1008 prepareCompositeNode(Operation: ComplexDeinterleavingOperation::CDot, R: Inst, I: nullptr);
1009
1010 CompositeNode *ANode = nullptr;
1011
1012 const Intrinsic::ID PartialReduceInt = Intrinsic::vector_partial_reduce_add;
1013
1014 Value *AReal = nullptr;
1015 Value *AImag = nullptr;
1016 Value *BReal = nullptr;
1017 Value *BImag = nullptr;
1018 Value *Phi = nullptr;
1019
1020 auto UnwrapCast = [](Value *V) -> Value * {
1021 if (auto *CI = dyn_cast<CastInst>(Val: V))
1022 return CI->getOperand(i_nocapture: 0);
1023 return V;
1024 };
1025
1026 auto PatternRot0 = m_Intrinsic<PartialReduceInt>(
1027 Op0: m_Intrinsic<PartialReduceInt>(Op0: m_Value(V&: Phi),
1028 Op1: m_Mul(L: m_Value(V&: BReal), R: m_Value(V&: AReal))),
1029 Op1: m_Neg(V: m_Mul(L: m_Value(V&: BImag), R: m_Value(V&: AImag))));
1030
1031 auto PatternRot270 = m_Intrinsic<PartialReduceInt>(
1032 Op0: m_Intrinsic<PartialReduceInt>(
1033 Op0: m_Value(V&: Phi), Op1: m_Neg(V: m_Mul(L: m_Value(V&: BReal), R: m_Value(V&: AImag)))),
1034 Op1: m_Mul(L: m_Value(V&: BImag), R: m_Value(V&: AReal)));
1035
1036 if (match(V: Inst, P: PatternRot0)) {
1037 CN->Rotation = ComplexDeinterleavingRotation::Rotation_0;
1038 } else if (match(V: Inst, P: PatternRot270)) {
1039 CN->Rotation = ComplexDeinterleavingRotation::Rotation_270;
1040 } else {
1041 Value *A0, *A1;
1042 // The rotations 90 and 180 share the same operation pattern, so inspect the
1043 // order of the operands, identifying where the real and imaginary
1044 // components of A go, to discern between the aforementioned rotations.
1045 auto PatternRot90Rot180 = m_Intrinsic<PartialReduceInt>(
1046 Op0: m_Intrinsic<PartialReduceInt>(Op0: m_Value(V&: Phi),
1047 Op1: m_Mul(L: m_Value(V&: BReal), R: m_Value(V&: A0))),
1048 Op1: m_Mul(L: m_Value(V&: BImag), R: m_Value(V&: A1)));
1049
1050 if (!match(V: Inst, P: PatternRot90Rot180))
1051 return nullptr;
1052
1053 A0 = UnwrapCast(A0);
1054 A1 = UnwrapCast(A1);
1055
1056 // Test if A0 is real/A1 is imag
1057 ANode = identifyNode(R: A0, I: A1);
1058 if (!ANode) {
1059 // Test if A0 is imag/A1 is real
1060 ANode = identifyNode(R: A1, I: A0);
1061 // Unable to identify operand components, thus unable to identify rotation
1062 if (!ANode)
1063 return nullptr;
1064 CN->Rotation = ComplexDeinterleavingRotation::Rotation_90;
1065 AReal = A1;
1066 AImag = A0;
1067 } else {
1068 AReal = A0;
1069 AImag = A1;
1070 CN->Rotation = ComplexDeinterleavingRotation::Rotation_180;
1071 }
1072 }
1073
1074 AReal = UnwrapCast(AReal);
1075 AImag = UnwrapCast(AImag);
1076 BReal = UnwrapCast(BReal);
1077 BImag = UnwrapCast(BImag);
1078
1079 VectorType *VTy = cast<VectorType>(Val: V->getType());
1080 Type *ExpectedOperandTy = VectorType::getSubdividedVectorType(VTy, NumSubdivs: 2);
1081 if (AReal->getType() != ExpectedOperandTy)
1082 return nullptr;
1083 if (AImag->getType() != ExpectedOperandTy)
1084 return nullptr;
1085 if (BReal->getType() != ExpectedOperandTy)
1086 return nullptr;
1087 if (BImag->getType() != ExpectedOperandTy)
1088 return nullptr;
1089
1090 if (Phi->getType() != VTy && RealUser->getType() != VTy)
1091 return nullptr;
1092
1093 CompositeNode *Node = identifyNode(R: AReal, I: AImag);
1094
1095 // In the case that a node was identified to figure out the rotation, ensure
1096 // that trying to identify a node with AReal and AImag post-unwrap results in
1097 // the same node
1098 if (ANode && Node != ANode) {
1099 LLVM_DEBUG(
1100 dbgs()
1101 << "Identified node is different from previously identified node. "
1102 "Unable to confidently generate a complex operation node\n");
1103 return nullptr;
1104 }
1105
1106 CN->addOperand(Node);
1107 CN->addOperand(Node: identifyNode(R: BReal, I: BImag));
1108 CN->addOperand(Node: identifyNode(R: Phi, I: RealUser));
1109
1110 return submitCompositeNode(Node: CN);
1111}
1112
1113ComplexDeinterleavingGraph::CompositeNode *
1114ComplexDeinterleavingGraph::identifyPartialReduction(Value *R, Value *I) {
1115 // Partial reductions don't support non-vector types, so check these first
1116 if (!isa<VectorType>(Val: R->getType()) || !isa<VectorType>(Val: I->getType()))
1117 return nullptr;
1118
1119 if (!R->hasUseList() || !I->hasUseList())
1120 return nullptr;
1121
1122 auto CommonUser =
1123 findCommonBetweenCollections<Value *>(A: R->users(), B: I->users());
1124 if (!CommonUser)
1125 return nullptr;
1126
1127 auto *IInst = dyn_cast<IntrinsicInst>(Val: *CommonUser);
1128 if (!IInst || IInst->getIntrinsicID() != Intrinsic::vector_partial_reduce_add)
1129 return nullptr;
1130
1131 if (CompositeNode *CN = identifyDotProduct(V: IInst))
1132 return CN;
1133
1134 return nullptr;
1135}
1136
1137ComplexDeinterleavingGraph::CompositeNode *
1138ComplexDeinterleavingGraph::identifyNode(ComplexValues &Vals) {
1139 auto It = CachedResult.find(Val: Vals);
1140 if (It != CachedResult.end()) {
1141 LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
1142 return It->second;
1143 }
1144
1145 if (Vals.size() == 1) {
1146 assert(Factor == 2 && "Can only handle interleave factors of 2");
1147 Value *R = Vals[0].Real;
1148 Value *I = Vals[0].Imag;
1149 if (CompositeNode *CN = identifyPartialReduction(R, I))
1150 return CN;
1151 bool IsReduction = RealPHI == R && (!ImagPHI || ImagPHI == I);
1152 if (!IsReduction && R->getType() != I->getType())
1153 return nullptr;
1154 }
1155
1156 if (CompositeNode *CN = identifySplat(Vals))
1157 return CN;
1158
1159 for (auto &V : Vals) {
1160 auto *Real = dyn_cast<Instruction>(Val: V.Real);
1161 auto *Imag = dyn_cast<Instruction>(Val: V.Imag);
1162 if (!Real || !Imag)
1163 return nullptr;
1164 }
1165
1166 if (CompositeNode *CN = identifyDeinterleave(Vals))
1167 return CN;
1168
1169 if (Vals.size() == 1) {
1170 assert(Factor == 2 && "Can only handle interleave factors of 2");
1171 auto *Real = dyn_cast<Instruction>(Val: Vals[0].Real);
1172 auto *Imag = dyn_cast<Instruction>(Val: Vals[0].Imag);
1173 if (CompositeNode *CN = identifyPHINode(Real, Imag))
1174 return CN;
1175
1176 if (CompositeNode *CN = identifySelectNode(Real, Imag))
1177 return CN;
1178
1179 auto *VTy = cast<VectorType>(Val: Real->getType());
1180 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1181
1182 bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
1183 Operation: ComplexDeinterleavingOperation::CMulPartial, Ty: NewVTy);
1184 bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
1185 Operation: ComplexDeinterleavingOperation::CAdd, Ty: NewVTy);
1186
1187 if (HasCMulSupport && isInstructionPairMul(A: Real, B: Imag)) {
1188 if (CompositeNode *CN = identifyPartialMul(Real, Imag))
1189 return CN;
1190 }
1191
1192 if (HasCAddSupport && isInstructionPairAdd(A: Real, B: Imag)) {
1193 if (CompositeNode *CN = identifyAdd(Real, Imag))
1194 return CN;
1195 }
1196
1197 if (HasCMulSupport && HasCAddSupport) {
1198 if (CompositeNode *CN = identifyReassocNodes(I: Real, J: Imag)) {
1199 return CN;
1200 }
1201 }
1202 }
1203
1204 if (CompositeNode *CN = identifySymmetricOperation(Vals))
1205 return CN;
1206
1207 LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n");
1208 CachedResult[Vals] = nullptr;
1209 return nullptr;
1210}
1211
1212ComplexDeinterleavingGraph::CompositeNode *
1213ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
1214 Instruction *Imag) {
1215 auto IsOperationSupported = [](unsigned Opcode) -> bool {
1216 return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
1217 Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
1218 Opcode == Instruction::Sub;
1219 };
1220
1221 if (!IsOperationSupported(Real->getOpcode()) ||
1222 !IsOperationSupported(Imag->getOpcode()))
1223 return nullptr;
1224
1225 std::optional<FastMathFlags> Flags;
1226 if (isa<FPMathOperator>(Val: Real)) {
1227 if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
1228 LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
1229 "not identical\n");
1230 return nullptr;
1231 }
1232
1233 Flags = Real->getFastMathFlags();
1234 if (!Flags->allowReassoc()) {
1235 LLVM_DEBUG(
1236 dbgs()
1237 << "the 'Reassoc' attribute is missing in the FastMath flags\n");
1238 return nullptr;
1239 }
1240 }
1241
1242 // Collect multiplications and addend instructions from the given instruction
1243 // while traversing it operands. Additionally, verify that all instructions
1244 // have the same fast math flags.
1245 auto Collect = [&Flags](Instruction *Insn, SmallVectorImpl<Product> &Muls,
1246 AddendList &Addends) -> bool {
1247 SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
1248 while (!Worklist.empty()) {
1249 auto [V, IsPositive] = Worklist.pop_back_val();
1250
1251 Instruction *I = dyn_cast<Instruction>(Val: V);
1252 if (!I) {
1253 Addends.emplace_back(Vs&: V, Vs&: IsPositive);
1254 continue;
1255 }
1256
1257 // If an instruction has more than one user, it indicates that it either
1258 // has an external user, which will be later checked by the checkNodes
1259 // function, or it is a subexpression utilized by multiple expressions. In
1260 // the latter case, we will attempt to separately identify the complex
1261 // operation from here in order to create a shared
1262 // ComplexDeinterleavingCompositeNode.
1263 if (I != Insn && I->hasNUsesOrMore(N: 2)) {
1264 LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
1265 Addends.emplace_back(Vs&: I, Vs&: IsPositive);
1266 continue;
1267 }
1268 switch (I->getOpcode()) {
1269 case Instruction::FAdd:
1270 case Instruction::Add:
1271 Worklist.emplace_back(Args: I->getOperand(i: 1), Args&: IsPositive);
1272 Worklist.emplace_back(Args: I->getOperand(i: 0), Args&: IsPositive);
1273 break;
1274 case Instruction::FSub:
1275 Worklist.emplace_back(Args: I->getOperand(i: 1), Args: !IsPositive);
1276 Worklist.emplace_back(Args: I->getOperand(i: 0), Args&: IsPositive);
1277 break;
1278 case Instruction::Sub:
1279 if (isNeg(V: I)) {
1280 Worklist.emplace_back(Args: getNegOperand(V: I), Args: !IsPositive);
1281 } else {
1282 Worklist.emplace_back(Args: I->getOperand(i: 1), Args: !IsPositive);
1283 Worklist.emplace_back(Args: I->getOperand(i: 0), Args&: IsPositive);
1284 }
1285 break;
1286 case Instruction::FMul:
1287 case Instruction::Mul: {
1288 Value *A, *B;
1289 if (isNeg(V: I->getOperand(i: 0))) {
1290 A = getNegOperand(V: I->getOperand(i: 0));
1291 IsPositive = !IsPositive;
1292 } else {
1293 A = I->getOperand(i: 0);
1294 }
1295
1296 if (isNeg(V: I->getOperand(i: 1))) {
1297 B = getNegOperand(V: I->getOperand(i: 1));
1298 IsPositive = !IsPositive;
1299 } else {
1300 B = I->getOperand(i: 1);
1301 }
1302 Muls.push_back(Elt: Product{.Multiplier: A, .Multiplicand: B, .IsPositive: IsPositive});
1303 break;
1304 }
1305 case Instruction::FNeg:
1306 Worklist.emplace_back(Args: I->getOperand(i: 0), Args: !IsPositive);
1307 break;
1308 default:
1309 Addends.emplace_back(Vs&: I, Vs&: IsPositive);
1310 continue;
1311 }
1312
1313 if (Flags && I->getFastMathFlags() != *Flags) {
1314 LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
1315 "inconsistent with the root instructions' flags: "
1316 << *I << "\n");
1317 return false;
1318 }
1319 }
1320 return true;
1321 };
1322
1323 SmallVector<Product> RealMuls, ImagMuls;
1324 AddendList RealAddends, ImagAddends;
1325 if (!Collect(Real, RealMuls, RealAddends) ||
1326 !Collect(Imag, ImagMuls, ImagAddends))
1327 return nullptr;
1328
1329 if (RealAddends.size() != ImagAddends.size())
1330 return nullptr;
1331
1332 CompositeNode *FinalNode = nullptr;
1333 if (!RealMuls.empty() || !ImagMuls.empty()) {
1334 // If there are multiplicands, extract positive addend and use it as an
1335 // accumulator
1336 FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1337 FinalNode = identifyMultiplications(RealMuls, ImagMuls, Accumulator: FinalNode);
1338 if (!FinalNode)
1339 return nullptr;
1340 }
1341
1342 // Identify and process remaining additions
1343 if (!RealAddends.empty() || !ImagAddends.empty()) {
1344 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, Accumulator: FinalNode);
1345 if (!FinalNode)
1346 return nullptr;
1347 }
1348 assert(FinalNode && "FinalNode can not be nullptr here");
1349 assert(FinalNode->Vals.size() == 1);
1350 // Set the Real and Imag fields of the final node and submit it
1351 FinalNode->Vals[0].Real = Real;
1352 FinalNode->Vals[0].Imag = Imag;
1353 submitCompositeNode(Node: FinalNode);
1354 return FinalNode;
1355}
1356
1357bool ComplexDeinterleavingGraph::collectPartialMuls(
1358 ArrayRef<Product> RealMuls, ArrayRef<Product> ImagMuls,
1359 SmallVectorImpl<PartialMulCandidate> &PartialMulCandidates) {
1360 // Helper function to extract a common operand from two products
1361 auto FindCommonInstruction = [](const Product &Real,
1362 const Product &Imag) -> Value * {
1363 if (Real.Multiplicand == Imag.Multiplicand ||
1364 Real.Multiplicand == Imag.Multiplier)
1365 return Real.Multiplicand;
1366
1367 if (Real.Multiplier == Imag.Multiplicand ||
1368 Real.Multiplier == Imag.Multiplier)
1369 return Real.Multiplier;
1370
1371 return nullptr;
1372 };
1373
1374 // Iterating over real and imaginary multiplications to find common operands
1375 // If a common operand is found, a partial multiplication candidate is created
1376 // and added to the candidates vector The function returns false if no common
1377 // operands are found for any product
1378 for (unsigned i = 0; i < RealMuls.size(); ++i) {
1379 bool FoundCommon = false;
1380 for (unsigned j = 0; j < ImagMuls.size(); ++j) {
1381 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1382 if (!Common)
1383 continue;
1384
1385 auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1386 : RealMuls[i].Multiplicand;
1387 auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
1388 : ImagMuls[j].Multiplicand;
1389
1390 auto Node = identifyNode(R: A, I: B);
1391 if (Node) {
1392 FoundCommon = true;
1393 PartialMulCandidates.push_back(Elt: {.Common: Common, .Node: Node, .RealIdx: i, .ImagIdx: j, .IsNodeInverted: false});
1394 }
1395
1396 Node = identifyNode(R: B, I: A);
1397 if (Node) {
1398 FoundCommon = true;
1399 PartialMulCandidates.push_back(Elt: {.Common: Common, .Node: Node, .RealIdx: i, .ImagIdx: j, .IsNodeInverted: true});
1400 }
1401 }
1402 if (!FoundCommon)
1403 return false;
1404 }
1405 return true;
1406}
1407
1408ComplexDeinterleavingGraph::CompositeNode *
1409ComplexDeinterleavingGraph::identifyMultiplications(
1410 SmallVectorImpl<Product> &RealMuls, SmallVectorImpl<Product> &ImagMuls,
1411 CompositeNode *Accumulator = nullptr) {
1412 if (RealMuls.size() != ImagMuls.size())
1413 return nullptr;
1414
1415 SmallVector<PartialMulCandidate> Info;
1416 if (!collectPartialMuls(RealMuls, ImagMuls, PartialMulCandidates&: Info))
1417 return nullptr;
1418
1419 // Map to store common instruction to node pointers
1420 DenseMap<Value *, CompositeNode *> CommonToNode;
1421 SmallVector<bool> Processed(Info.size(), false);
1422 for (unsigned I = 0; I < Info.size(); ++I) {
1423 if (Processed[I])
1424 continue;
1425
1426 PartialMulCandidate &InfoA = Info[I];
1427 for (unsigned J = I + 1; J < Info.size(); ++J) {
1428 if (Processed[J])
1429 continue;
1430
1431 PartialMulCandidate &InfoB = Info[J];
1432 auto *InfoReal = &InfoA;
1433 auto *InfoImag = &InfoB;
1434
1435 auto NodeFromCommon = identifyNode(R: InfoReal->Common, I: InfoImag->Common);
1436 if (!NodeFromCommon) {
1437 std::swap(a&: InfoReal, b&: InfoImag);
1438 NodeFromCommon = identifyNode(R: InfoReal->Common, I: InfoImag->Common);
1439 }
1440 if (!NodeFromCommon)
1441 continue;
1442
1443 CommonToNode[InfoReal->Common] = NodeFromCommon;
1444 CommonToNode[InfoImag->Common] = NodeFromCommon;
1445 Processed[I] = true;
1446 Processed[J] = true;
1447 }
1448 }
1449
1450 SmallVector<bool> ProcessedReal(RealMuls.size(), false);
1451 SmallVector<bool> ProcessedImag(ImagMuls.size(), false);
1452 CompositeNode *Result = Accumulator;
1453 for (auto &PMI : Info) {
1454 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1455 continue;
1456
1457 auto It = CommonToNode.find(Val: PMI.Common);
1458 // TODO: Process independent complex multiplications. Cases like this:
1459 // A.real() * B where both A and B are complex numbers.
1460 if (It == CommonToNode.end()) {
1461 LLVM_DEBUG({
1462 dbgs() << "Unprocessed independent partial multiplication:\n";
1463 for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1464 dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
1465 << " multiplied by " << *Mul->Multiplicand << "\n";
1466 });
1467 return nullptr;
1468 }
1469
1470 auto &RealMul = RealMuls[PMI.RealIdx];
1471 auto &ImagMul = ImagMuls[PMI.ImagIdx];
1472
1473 auto NodeA = It->second;
1474 auto NodeB = PMI.Node;
1475 auto IsMultiplicandReal = PMI.Common == NodeA->Vals[0].Real;
1476 // The following table illustrates the relationship between multiplications
1477 // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1478 // can see:
1479 //
1480 // Rotation | Real | Imag |
1481 // ---------+--------+--------+
1482 // 0 | x * u | x * v |
1483 // 90 | -y * v | y * u |
1484 // 180 | -x * u | -x * v |
1485 // 270 | y * v | -y * u |
1486 //
1487 // Check if the candidate can indeed be represented by partial
1488 // multiplication
1489 // TODO: Add support for multiplication by complex one
1490 if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1491 (!IsMultiplicandReal && !PMI.IsNodeInverted))
1492 continue;
1493
1494 // Determine the rotation based on the multiplications
1495 ComplexDeinterleavingRotation Rotation;
1496 if (IsMultiplicandReal) {
1497 // Detect 0 and 180 degrees rotation
1498 if (RealMul.IsPositive && ImagMul.IsPositive)
1499 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0;
1500 else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1501 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180;
1502 else
1503 continue;
1504
1505 } else {
1506 // Detect 90 and 270 degrees rotation
1507 if (!RealMul.IsPositive && ImagMul.IsPositive)
1508 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90;
1509 else if (RealMul.IsPositive && !ImagMul.IsPositive)
1510 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270;
1511 else
1512 continue;
1513 }
1514
1515 LLVM_DEBUG({
1516 dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1517 dbgs().indent(4) << "X: " << *NodeA->Vals[0].Real << "\n";
1518 dbgs().indent(4) << "Y: " << *NodeA->Vals[0].Imag << "\n";
1519 dbgs().indent(4) << "U: " << *NodeB->Vals[0].Real << "\n";
1520 dbgs().indent(4) << "V: " << *NodeB->Vals[0].Imag << "\n";
1521 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1522 });
1523
1524 CompositeNode *NodeMul = prepareCompositeNode(
1525 Operation: ComplexDeinterleavingOperation::CMulPartial, R: nullptr, I: nullptr);
1526 NodeMul->Rotation = Rotation;
1527 NodeMul->addOperand(Node: NodeA);
1528 NodeMul->addOperand(Node: NodeB);
1529 if (Result)
1530 NodeMul->addOperand(Node: Result);
1531 submitCompositeNode(Node: NodeMul);
1532 Result = NodeMul;
1533 ProcessedReal[PMI.RealIdx] = true;
1534 ProcessedImag[PMI.ImagIdx] = true;
1535 }
1536
1537 // Ensure all products have been processed, if not return nullptr.
1538 if (!all_of(Range&: ProcessedReal, P: [](bool V) { return V; }) ||
1539 !all_of(Range&: ProcessedImag, P: [](bool V) { return V; })) {
1540
1541 // Dump debug information about which partial multiplications are not
1542 // processed.
1543 LLVM_DEBUG({
1544 dbgs() << "Unprocessed products (Real):\n";
1545 for (size_t i = 0; i < ProcessedReal.size(); ++i) {
1546 if (!ProcessedReal[i])
1547 dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
1548 << *RealMuls[i].Multiplier << " multiplied by "
1549 << *RealMuls[i].Multiplicand << "\n";
1550 }
1551 dbgs() << "Unprocessed products (Imag):\n";
1552 for (size_t i = 0; i < ProcessedImag.size(); ++i) {
1553 if (!ProcessedImag[i])
1554 dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
1555 << *ImagMuls[i].Multiplier << " multiplied by "
1556 << *ImagMuls[i].Multiplicand << "\n";
1557 }
1558 });
1559 return nullptr;
1560 }
1561
1562 return Result;
1563}
1564
1565ComplexDeinterleavingGraph::CompositeNode *
1566ComplexDeinterleavingGraph::identifyAdditions(
1567 AddendList &RealAddends, AddendList &ImagAddends,
1568 std::optional<FastMathFlags> Flags, CompositeNode *Accumulator = nullptr) {
1569 if (RealAddends.size() != ImagAddends.size())
1570 return nullptr;
1571
1572 CompositeNode *Result = nullptr;
1573 // If we have accumulator use it as first addend
1574 if (Accumulator)
1575 Result = Accumulator;
1576 // Otherwise find an element with both positive real and imaginary parts.
1577 else
1578 Result = extractPositiveAddend(RealAddends, ImagAddends);
1579
1580 if (!Result)
1581 return nullptr;
1582
1583 while (!RealAddends.empty()) {
1584 auto ItR = RealAddends.begin();
1585 auto [R, IsPositiveR] = *ItR;
1586
1587 bool FoundImag = false;
1588 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1589 auto [I, IsPositiveI] = *ItI;
1590 ComplexDeinterleavingRotation Rotation;
1591 if (IsPositiveR && IsPositiveI)
1592 Rotation = ComplexDeinterleavingRotation::Rotation_0;
1593 else if (!IsPositiveR && IsPositiveI)
1594 Rotation = ComplexDeinterleavingRotation::Rotation_90;
1595 else if (!IsPositiveR && !IsPositiveI)
1596 Rotation = ComplexDeinterleavingRotation::Rotation_180;
1597 else
1598 Rotation = ComplexDeinterleavingRotation::Rotation_270;
1599
1600 CompositeNode *AddNode = nullptr;
1601 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1602 Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1603 AddNode = identifyNode(R, I);
1604 } else {
1605 AddNode = identifyNode(R: I, I: R);
1606 }
1607 if (AddNode) {
1608 LLVM_DEBUG({
1609 dbgs() << "Identified addition:\n";
1610 dbgs().indent(4) << "X: " << *R << "\n";
1611 dbgs().indent(4) << "Y: " << *I << "\n";
1612 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1613 });
1614
1615 CompositeNode *TmpNode = nullptr;
1616 if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) {
1617 TmpNode = prepareCompositeNode(
1618 Operation: ComplexDeinterleavingOperation::Symmetric, R: nullptr, I: nullptr);
1619 if (Flags) {
1620 TmpNode->Opcode = Instruction::FAdd;
1621 TmpNode->Flags = *Flags;
1622 } else {
1623 TmpNode->Opcode = Instruction::Add;
1624 }
1625 } else if (Rotation ==
1626 llvm::ComplexDeinterleavingRotation::Rotation_180) {
1627 TmpNode = prepareCompositeNode(
1628 Operation: ComplexDeinterleavingOperation::Symmetric, R: nullptr, I: nullptr);
1629 if (Flags) {
1630 TmpNode->Opcode = Instruction::FSub;
1631 TmpNode->Flags = *Flags;
1632 } else {
1633 TmpNode->Opcode = Instruction::Sub;
1634 }
1635 } else {
1636 TmpNode = prepareCompositeNode(Operation: ComplexDeinterleavingOperation::CAdd,
1637 R: nullptr, I: nullptr);
1638 TmpNode->Rotation = Rotation;
1639 }
1640
1641 TmpNode->addOperand(Node: Result);
1642 TmpNode->addOperand(Node: AddNode);
1643 submitCompositeNode(Node: TmpNode);
1644 Result = TmpNode;
1645 RealAddends.erase(I: ItR);
1646 ImagAddends.erase(I: ItI);
1647 FoundImag = true;
1648 break;
1649 }
1650 }
1651 if (!FoundImag)
1652 return nullptr;
1653 }
1654 return Result;
1655}
1656
1657ComplexDeinterleavingGraph::CompositeNode *
1658ComplexDeinterleavingGraph::extractPositiveAddend(AddendList &RealAddends,
1659 AddendList &ImagAddends) {
1660 for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1661 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1662 auto [R, IsPositiveR] = *ItR;
1663 auto [I, IsPositiveI] = *ItI;
1664 if (IsPositiveR && IsPositiveI) {
1665 auto Result = identifyNode(R, I);
1666 if (Result) {
1667 RealAddends.erase(I: ItR);
1668 ImagAddends.erase(I: ItI);
1669 return Result;
1670 }
1671 }
1672 }
1673 }
1674 return nullptr;
1675}
1676
1677bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1678 // This potential root instruction might already have been recognized as
1679 // reduction. Because RootToNode maps both Real and Imaginary parts to
1680 // CompositeNode we should choose only one either Real or Imag instruction to
1681 // use as an anchor for generating complex instruction.
1682 auto It = RootToNode.find(Val: RootI);
1683 if (It != RootToNode.end()) {
1684 auto RootNode = It->second;
1685 assert(RootNode->Operation ==
1686 ComplexDeinterleavingOperation::ReductionOperation ||
1687 RootNode->Operation ==
1688 ComplexDeinterleavingOperation::ReductionSingle);
1689 assert(RootNode->Vals.size() == 1 &&
1690 "Cannot handle reductions involving multiple complex values");
1691 // Find out which part, Real or Imag, comes later, and only if we come to
1692 // the latest part, add it to OrderedRoots.
1693 auto *R = cast<Instruction>(Val: RootNode->Vals[0].Real);
1694 auto *I = RootNode->Vals[0].Imag ? cast<Instruction>(Val: RootNode->Vals[0].Imag)
1695 : nullptr;
1696
1697 Instruction *ReplacementAnchor;
1698 if (I)
1699 ReplacementAnchor = R->comesBefore(Other: I) ? I : R;
1700 else
1701 ReplacementAnchor = R;
1702
1703 if (ReplacementAnchor != RootI)
1704 return false;
1705 OrderedRoots.push_back(Elt: RootI);
1706 return true;
1707 }
1708
1709 auto RootNode = identifyRoot(I: RootI);
1710 if (!RootNode)
1711 return false;
1712
1713 LLVM_DEBUG({
1714 Function *F = RootI->getFunction();
1715 BasicBlock *B = RootI->getParent();
1716 dbgs() << "Complex deinterleaving graph for " << F->getName()
1717 << "::" << B->getName() << ".\n";
1718 dump(dbgs());
1719 dbgs() << "\n";
1720 });
1721 RootToNode[RootI] = RootNode;
1722 OrderedRoots.push_back(Elt: RootI);
1723 return true;
1724}
1725
1726bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
1727 bool FoundPotentialReduction = false;
1728 if (Factor != 2)
1729 return false;
1730
1731 auto *Br = dyn_cast<CondBrInst>(Val: B->getTerminator());
1732 if (!Br)
1733 return false;
1734
1735 // Identify simple one-block loop
1736 if (Br->getSuccessor(i: 0) != B && Br->getSuccessor(i: 1) != B)
1737 return false;
1738
1739 for (auto &PHI : B->phis()) {
1740 if (PHI.getNumIncomingValues() != 2)
1741 continue;
1742
1743 if (!PHI.getType()->isVectorTy())
1744 continue;
1745
1746 auto *ReductionOp = dyn_cast<Instruction>(Val: PHI.getIncomingValueForBlock(BB: B));
1747 if (!ReductionOp)
1748 continue;
1749
1750 // Check if final instruction is reduced outside of current block
1751 Instruction *FinalReduction = nullptr;
1752 auto NumUsers = 0u;
1753 for (auto *U : ReductionOp->users()) {
1754 ++NumUsers;
1755 if (U == &PHI)
1756 continue;
1757 FinalReduction = dyn_cast<Instruction>(Val: U);
1758 }
1759
1760 if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B ||
1761 isa<PHINode>(Val: FinalReduction))
1762 continue;
1763
1764 ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
1765 BackEdge = B;
1766 auto BackEdgeIdx = PHI.getBasicBlockIndex(BB: B);
1767 auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1768 Incoming = PHI.getIncomingBlock(i: IncomingIdx);
1769 FoundPotentialReduction = true;
1770
1771 // If the initial value of PHINode is an Instruction, consider it a leaf
1772 // value of a complex deinterleaving graph.
1773 if (auto *InitPHI =
1774 dyn_cast<Instruction>(Val: PHI.getIncomingValueForBlock(BB: Incoming)))
1775 FinalInstructions.insert(Ptr: InitPHI);
1776 }
1777 return FoundPotentialReduction;
1778}
1779
1780void ComplexDeinterleavingGraph::identifyReductionNodes() {
1781 assert(Factor == 2 && "Cannot handle multiple complex values");
1782
1783 SmallVector<bool> Processed(ReductionInfo.size(), false);
1784 SmallVector<Instruction *> OperationInstruction;
1785 for (auto &P : ReductionInfo)
1786 OperationInstruction.push_back(Elt: P.first);
1787
1788 // Identify a complex computation by evaluating two reduction operations that
1789 // potentially could be involved
1790 for (size_t i = 0; i < OperationInstruction.size(); ++i) {
1791 if (Processed[i])
1792 continue;
1793 for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
1794 if (Processed[j])
1795 continue;
1796 auto *Real = OperationInstruction[i];
1797 auto *Imag = OperationInstruction[j];
1798 if (Real->getType() != Imag->getType())
1799 continue;
1800
1801 RealPHI = ReductionInfo[Real].first;
1802 ImagPHI = ReductionInfo[Imag].first;
1803 PHIsFound = false;
1804 auto Node = identifyNode(R: Real, I: Imag);
1805 if (!Node) {
1806 std::swap(a&: Real, b&: Imag);
1807 std::swap(a&: RealPHI, b&: ImagPHI);
1808 Node = identifyNode(R: Real, I: Imag);
1809 }
1810
1811 // If a node is identified and reduction PHINode is used in the chain of
1812 // operations, mark its operation instructions as used to prevent
1813 // re-identification and attach the node to the real part
1814 if (Node && PHIsFound) {
1815 LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
1816 << *Real << " / " << *Imag << "\n");
1817 Processed[i] = true;
1818 Processed[j] = true;
1819 auto RootNode = prepareCompositeNode(
1820 Operation: ComplexDeinterleavingOperation::ReductionOperation, R: Real, I: Imag);
1821 RootNode->addOperand(Node);
1822 RootToNode[Real] = RootNode;
1823 RootToNode[Imag] = RootNode;
1824 submitCompositeNode(Node: RootNode);
1825 break;
1826 }
1827 }
1828
1829 auto *Real = OperationInstruction[i];
1830 // We want to check that we have 2 operands, but the function attributes
1831 // being counted as operands bloats this value.
1832 if (Processed[i] || Real->getNumOperands() < 2)
1833 continue;
1834
1835 // Can only combined integer reductions at the moment.
1836 if (!ReductionInfo[Real].second->getType()->isIntegerTy())
1837 continue;
1838
1839 RealPHI = ReductionInfo[Real].first;
1840 ImagPHI = nullptr;
1841 PHIsFound = false;
1842 auto Node = identifyNode(R: Real->getOperand(i: 0), I: Real->getOperand(i: 1));
1843 if (Node && PHIsFound) {
1844 LLVM_DEBUG(
1845 dbgs() << "Identified single reduction starting from instruction: "
1846 << *Real << "/" << *ReductionInfo[Real].second << "\n");
1847
1848 // Reducing to a single vector is not supported, only permit reducing down
1849 // to scalar values.
1850 // Doing this here will leave the prior node in the graph,
1851 // however with no uses the node will be unreachable by the replacement
1852 // process. That along with the usage outside the graph should prevent the
1853 // replacement process from kicking off at all for this graph.
1854 // TODO Add support for reducing to a single vector value
1855 if (ReductionInfo[Real].second->getType()->isVectorTy())
1856 continue;
1857
1858 Processed[i] = true;
1859 auto RootNode = prepareCompositeNode(
1860 Operation: ComplexDeinterleavingOperation::ReductionSingle, R: Real, I: nullptr);
1861 RootNode->addOperand(Node);
1862 RootToNode[Real] = RootNode;
1863 submitCompositeNode(Node: RootNode);
1864 }
1865 }
1866
1867 RealPHI = nullptr;
1868 ImagPHI = nullptr;
1869}
1870
1871bool ComplexDeinterleavingGraph::checkNodes() {
1872 bool FoundDeinterleaveNode = false;
1873 for (CompositeNode *N : CompositeNodes) {
1874 if (!N->areOperandsValid())
1875 return false;
1876
1877 if (N->Operation == ComplexDeinterleavingOperation::Deinterleave)
1878 FoundDeinterleaveNode = true;
1879 }
1880
1881 // We need a deinterleave node in order to guarantee that we're working with
1882 // complex numbers.
1883 if (!FoundDeinterleaveNode) {
1884 LLVM_DEBUG(
1885 dbgs() << "Couldn't find a deinterleave node within the graph, cannot "
1886 "guarantee safety during graph transformation.\n");
1887 return false;
1888 }
1889
1890 // Collect all instructions from roots to leaves
1891 SmallPtrSet<Instruction *, 16> AllInstructions;
1892 SmallVector<Instruction *, 8> Worklist;
1893 for (auto &Pair : RootToNode)
1894 Worklist.push_back(Elt: Pair.first);
1895
1896 // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1897 // chains
1898 while (!Worklist.empty()) {
1899 auto *I = Worklist.pop_back_val();
1900
1901 if (!AllInstructions.insert(Ptr: I).second)
1902 continue;
1903
1904 for (Value *Op : I->operands()) {
1905 if (auto *OpI = dyn_cast<Instruction>(Val: Op)) {
1906 if (!FinalInstructions.count(Ptr: I))
1907 Worklist.emplace_back(Args&: OpI);
1908 }
1909 }
1910 }
1911
1912 // Find instructions that have users outside of chain
1913 for (auto *I : AllInstructions) {
1914 // Skip root nodes
1915 if (RootToNode.count(Val: I))
1916 continue;
1917
1918 for (User *U : I->users()) {
1919 if (AllInstructions.count(Ptr: cast<Instruction>(Val: U)))
1920 continue;
1921
1922 // Found an instruction that is not used by XCMLA/XCADD chain
1923 Worklist.emplace_back(Args&: I);
1924 break;
1925 }
1926 }
1927
1928 // If any instructions are found to be used outside, find and remove roots
1929 // that somehow connect to those instructions.
1930 SmallPtrSet<Instruction *, 16> Visited;
1931 while (!Worklist.empty()) {
1932 auto *I = Worklist.pop_back_val();
1933 if (!Visited.insert(Ptr: I).second)
1934 continue;
1935
1936 // Found an impacted root node. Removing it from the nodes to be
1937 // deinterleaved
1938 if (RootToNode.count(Val: I)) {
1939 LLVM_DEBUG(dbgs() << "Instruction " << *I
1940 << " could be deinterleaved but its chain of complex "
1941 "operations have an outside user\n");
1942 RootToNode.erase(Val: I);
1943 }
1944
1945 if (!AllInstructions.count(Ptr: I) || FinalInstructions.count(Ptr: I))
1946 continue;
1947
1948 for (User *U : I->users())
1949 Worklist.emplace_back(Args: cast<Instruction>(Val: U));
1950
1951 for (Value *Op : I->operands()) {
1952 if (auto *OpI = dyn_cast<Instruction>(Val: Op))
1953 Worklist.emplace_back(Args&: OpI);
1954 }
1955 }
1956 return !RootToNode.empty();
1957}
1958
1959ComplexDeinterleavingGraph::CompositeNode *
1960ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1961 if (auto *Intrinsic = dyn_cast<IntrinsicInst>(Val: RootI)) {
1962 if (Intrinsic::getInterleaveIntrinsicID(Factor) !=
1963 Intrinsic->getIntrinsicID())
1964 return nullptr;
1965
1966 ComplexValues Vals;
1967 for (unsigned I = 0; I < Factor; I += 2) {
1968 auto *Real = dyn_cast<Instruction>(Val: Intrinsic->getOperand(i_nocapture: I));
1969 auto *Imag = dyn_cast<Instruction>(Val: Intrinsic->getOperand(i_nocapture: I + 1));
1970 if (!Real || !Imag)
1971 return nullptr;
1972 Vals.push_back(Elt: {.Real: Real, .Imag: Imag});
1973 }
1974
1975 ComplexDeinterleavingGraph::CompositeNode *Node1 = identifyNode(Vals);
1976 if (!Node1)
1977 return nullptr;
1978 return Node1;
1979 }
1980
1981 // TODO: We could also add support for fixed-width interleave factors of 4
1982 // and above, but currently for symmetric operations the interleaves and
1983 // deinterleaves are already removed by VectorCombine. If we extend this to
1984 // permit complex multiplications, reductions, etc. then we should also add
1985 // support for fixed-width here.
1986 if (Factor != 2)
1987 return nullptr;
1988
1989 auto *SVI = dyn_cast<ShuffleVectorInst>(Val: RootI);
1990 if (!SVI)
1991 return nullptr;
1992
1993 // Look for a shufflevector that takes separate vectors of the real and
1994 // imaginary components and recombines them into a single vector.
1995 if (!isInterleavingMask(Mask: SVI->getShuffleMask()))
1996 return nullptr;
1997
1998 Instruction *Real;
1999 Instruction *Imag;
2000 if (!match(V: RootI, P: m_Shuffle(v1: m_Instruction(I&: Real), v2: m_Instruction(I&: Imag))))
2001 return nullptr;
2002
2003 return identifyNode(R: Real, I: Imag);
2004}
2005
2006ComplexDeinterleavingGraph::CompositeNode *
2007ComplexDeinterleavingGraph::identifyDeinterleave(ComplexValues &Vals) {
2008 Instruction *II = nullptr;
2009
2010 // Must be at least one complex value.
2011 auto CheckExtract = [&](Value *V, unsigned ExpectedIdx,
2012 Instruction *ExpectedInsn) -> ExtractValueInst * {
2013 auto *EVI = dyn_cast<ExtractValueInst>(Val: V);
2014 if (!EVI || EVI->getNumIndices() != 1 ||
2015 EVI->getIndices()[0] != ExpectedIdx ||
2016 !isa<Instruction>(Val: EVI->getAggregateOperand()) ||
2017 (ExpectedInsn && ExpectedInsn != EVI->getAggregateOperand()))
2018 return nullptr;
2019 return EVI;
2020 };
2021
2022 for (unsigned Idx = 0; Idx < Vals.size(); Idx++) {
2023 ExtractValueInst *RealEVI = CheckExtract(Vals[Idx].Real, Idx * 2, II);
2024 if (RealEVI && Idx == 0)
2025 II = cast<Instruction>(Val: RealEVI->getAggregateOperand());
2026 if (!RealEVI || !CheckExtract(Vals[Idx].Imag, (Idx * 2) + 1, II)) {
2027 II = nullptr;
2028 break;
2029 }
2030 }
2031
2032 if (auto *IntrinsicII = dyn_cast_or_null<IntrinsicInst>(Val: II)) {
2033 if (IntrinsicII->getIntrinsicID() !=
2034 Intrinsic::getDeinterleaveIntrinsicID(Factor: 2 * Vals.size()))
2035 return nullptr;
2036
2037 // The remaining should match too.
2038 CompositeNode *PlaceholderNode = prepareCompositeNode(
2039 Operation: llvm::ComplexDeinterleavingOperation::Deinterleave, Vals);
2040 PlaceholderNode->ReplacementNode = II->getOperand(i: 0);
2041 for (auto &V : Vals) {
2042 FinalInstructions.insert(Ptr: cast<Instruction>(Val: V.Real));
2043 FinalInstructions.insert(Ptr: cast<Instruction>(Val: V.Imag));
2044 }
2045 return submitCompositeNode(Node: PlaceholderNode);
2046 }
2047
2048 if (Vals.size() != 1)
2049 return nullptr;
2050
2051 Value *Real = Vals[0].Real;
2052 Value *Imag = Vals[0].Imag;
2053 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Val: Real);
2054 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Val: Imag);
2055 if (!RealShuffle || !ImagShuffle) {
2056 if (RealShuffle || ImagShuffle)
2057 LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
2058 return nullptr;
2059 }
2060
2061 Value *RealOp1 = RealShuffle->getOperand(i_nocapture: 1);
2062 if (!isa<UndefValue>(Val: RealOp1) && !match(V: RealOp1, P: m_Zero())) {
2063 LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
2064 return nullptr;
2065 }
2066 Value *ImagOp1 = ImagShuffle->getOperand(i_nocapture: 1);
2067 if (!isa<UndefValue>(Val: ImagOp1) && !match(V: ImagOp1, P: m_Zero())) {
2068 LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
2069 return nullptr;
2070 }
2071
2072 Value *RealOp0 = RealShuffle->getOperand(i_nocapture: 0);
2073 Value *ImagOp0 = ImagShuffle->getOperand(i_nocapture: 0);
2074
2075 if (RealOp0 != ImagOp0) {
2076 LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
2077 return nullptr;
2078 }
2079
2080 ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
2081 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
2082 if (!isDeinterleavingMask(Mask: RealMask) || !isDeinterleavingMask(Mask: ImagMask)) {
2083 LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
2084 return nullptr;
2085 }
2086
2087 if (RealMask[0] != 0 || ImagMask[0] != 1) {
2088 LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
2089 return nullptr;
2090 }
2091
2092 // Type checking, the shuffle type should be a vector type of the same
2093 // scalar type, but half the size
2094 auto CheckType = [&](ShuffleVectorInst *Shuffle) {
2095 Value *Op = Shuffle->getOperand(i_nocapture: 0);
2096 auto *ShuffleTy = cast<FixedVectorType>(Val: Shuffle->getType());
2097 auto *OpTy = cast<FixedVectorType>(Val: Op->getType());
2098
2099 if (OpTy->getScalarType() != ShuffleTy->getScalarType())
2100 return false;
2101 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
2102 return false;
2103
2104 return true;
2105 };
2106
2107 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
2108 if (!CheckType(Shuffle))
2109 return false;
2110
2111 ArrayRef<int> Mask = Shuffle->getShuffleMask();
2112 int Last = *Mask.rbegin();
2113
2114 Value *Op = Shuffle->getOperand(i_nocapture: 0);
2115 auto *OpTy = cast<FixedVectorType>(Val: Op->getType());
2116 int NumElements = OpTy->getNumElements();
2117
2118 // Ensure that the deinterleaving shuffle only pulls from the first
2119 // shuffle operand.
2120 return Last < NumElements;
2121 };
2122
2123 if (RealShuffle->getType() != ImagShuffle->getType()) {
2124 LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
2125 return nullptr;
2126 }
2127 if (!CheckDeinterleavingShuffle(RealShuffle)) {
2128 LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
2129 return nullptr;
2130 }
2131 if (!CheckDeinterleavingShuffle(ImagShuffle)) {
2132 LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
2133 return nullptr;
2134 }
2135
2136 CompositeNode *PlaceholderNode =
2137 prepareCompositeNode(Operation: llvm::ComplexDeinterleavingOperation::Deinterleave,
2138 R: RealShuffle, I: ImagShuffle);
2139 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(i_nocapture: 0);
2140 FinalInstructions.insert(Ptr: RealShuffle);
2141 FinalInstructions.insert(Ptr: ImagShuffle);
2142 return submitCompositeNode(Node: PlaceholderNode);
2143}
2144
2145ComplexDeinterleavingGraph::CompositeNode *
2146ComplexDeinterleavingGraph::identifySplat(ComplexValues &Vals) {
2147 auto IsSplat = [](Value *V) -> bool {
2148 // Fixed-width vector with constants
2149 if (isa<ConstantDataVector>(Val: V))
2150 return true;
2151
2152 if (isa<ConstantInt>(Val: V) || isa<ConstantFP>(Val: V))
2153 return isa<VectorType>(Val: V->getType());
2154
2155 VectorType *VTy;
2156 ArrayRef<int> Mask;
2157 // Splats are represented differently depending on whether the repeated
2158 // value is a constant or an Instruction
2159 if (auto *Const = dyn_cast<ConstantExpr>(Val: V)) {
2160 if (Const->getOpcode() != Instruction::ShuffleVector)
2161 return false;
2162 VTy = cast<VectorType>(Val: Const->getType());
2163 Mask = Const->getShuffleMask();
2164 } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(Val: V)) {
2165 VTy = Shuf->getType();
2166 Mask = Shuf->getShuffleMask();
2167 } else {
2168 return false;
2169 }
2170
2171 // When the data type is <1 x Type>, it's not possible to differentiate
2172 // between the ComplexDeinterleaving::Deinterleave and
2173 // ComplexDeinterleaving::Splat operations.
2174 if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
2175 return false;
2176
2177 return all_equal(Range&: Mask) && Mask[0] == 0;
2178 };
2179
2180 // The splats must meet the following requirements:
2181 // 1. Must either be all instructions or all values.
2182 // 2. Non-constant splats must live in the same block.
2183 if (auto *FirstValAsInstruction = dyn_cast<Instruction>(Val: Vals[0].Real)) {
2184 BasicBlock *FirstBB = FirstValAsInstruction->getParent();
2185 for (auto &V : Vals) {
2186 if (!IsSplat(V.Real) || !IsSplat(V.Imag))
2187 return nullptr;
2188
2189 auto *Real = dyn_cast<Instruction>(Val: V.Real);
2190 auto *Imag = dyn_cast<Instruction>(Val: V.Imag);
2191 if (!Real || !Imag || Real->getParent() != FirstBB ||
2192 Imag->getParent() != FirstBB)
2193 return nullptr;
2194 }
2195 } else {
2196 for (auto &V : Vals) {
2197 if (!IsSplat(V.Real) || !IsSplat(V.Imag) || isa<Instruction>(Val: V.Real) ||
2198 isa<Instruction>(Val: V.Imag))
2199 return nullptr;
2200 }
2201 }
2202
2203 for (auto &V : Vals) {
2204 auto *Real = dyn_cast<Instruction>(Val: V.Real);
2205 auto *Imag = dyn_cast<Instruction>(Val: V.Imag);
2206 if (Real && Imag) {
2207 FinalInstructions.insert(Ptr: Real);
2208 FinalInstructions.insert(Ptr: Imag);
2209 }
2210 }
2211 CompositeNode *PlaceholderNode =
2212 prepareCompositeNode(Operation: ComplexDeinterleavingOperation::Splat, Vals);
2213 return submitCompositeNode(Node: PlaceholderNode);
2214}
2215
2216ComplexDeinterleavingGraph::CompositeNode *
2217ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
2218 Instruction *Imag) {
2219 if (Real != RealPHI || (ImagPHI && Imag != ImagPHI))
2220 return nullptr;
2221
2222 PHIsFound = true;
2223 CompositeNode *PlaceholderNode = prepareCompositeNode(
2224 Operation: ComplexDeinterleavingOperation::ReductionPHI, R: Real, I: Imag);
2225 return submitCompositeNode(Node: PlaceholderNode);
2226}
2227
2228ComplexDeinterleavingGraph::CompositeNode *
2229ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
2230 Instruction *Imag) {
2231 auto *SelectReal = dyn_cast<SelectInst>(Val: Real);
2232 auto *SelectImag = dyn_cast<SelectInst>(Val: Imag);
2233 if (!SelectReal || !SelectImag)
2234 return nullptr;
2235
2236 Instruction *MaskA, *MaskB;
2237 Instruction *AR, *AI, *RA, *BI;
2238 if (!match(V: Real, P: m_Select(C: m_Instruction(I&: MaskA), L: m_Instruction(I&: AR),
2239 R: m_Instruction(I&: RA))) ||
2240 !match(V: Imag, P: m_Select(C: m_Instruction(I&: MaskB), L: m_Instruction(I&: AI),
2241 R: m_Instruction(I&: BI))))
2242 return nullptr;
2243
2244 if (MaskA != MaskB && !MaskA->isIdenticalTo(I: MaskB))
2245 return nullptr;
2246
2247 if (!MaskA->getType()->isVectorTy())
2248 return nullptr;
2249
2250 auto NodeA = identifyNode(R: AR, I: AI);
2251 if (!NodeA)
2252 return nullptr;
2253
2254 auto NodeB = identifyNode(R: RA, I: BI);
2255 if (!NodeB)
2256 return nullptr;
2257
2258 CompositeNode *PlaceholderNode = prepareCompositeNode(
2259 Operation: ComplexDeinterleavingOperation::ReductionSelect, R: Real, I: Imag);
2260 PlaceholderNode->addOperand(Node: NodeA);
2261 PlaceholderNode->addOperand(Node: NodeB);
2262 FinalInstructions.insert(Ptr: MaskA);
2263 FinalInstructions.insert(Ptr: MaskB);
2264 return submitCompositeNode(Node: PlaceholderNode);
2265}
2266
2267static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
2268 std::optional<FastMathFlags> Flags,
2269 Value *InputA, Value *InputB) {
2270 Value *I;
2271 switch (Opcode) {
2272 case Instruction::FNeg:
2273 I = B.CreateFNeg(V: InputA);
2274 break;
2275 case Instruction::FAdd:
2276 I = B.CreateFAdd(L: InputA, R: InputB);
2277 break;
2278 case Instruction::Add:
2279 I = B.CreateAdd(LHS: InputA, RHS: InputB);
2280 break;
2281 case Instruction::FSub:
2282 I = B.CreateFSub(L: InputA, R: InputB);
2283 break;
2284 case Instruction::Sub:
2285 I = B.CreateSub(LHS: InputA, RHS: InputB);
2286 break;
2287 case Instruction::FMul:
2288 I = B.CreateFMul(L: InputA, R: InputB);
2289 break;
2290 case Instruction::Mul:
2291 I = B.CreateMul(LHS: InputA, RHS: InputB);
2292 break;
2293 default:
2294 llvm_unreachable("Incorrect symmetric opcode");
2295 }
2296 if (Flags)
2297 cast<Instruction>(Val: I)->setFastMathFlags(*Flags);
2298 return I;
2299}
2300
2301Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
2302 CompositeNode *Node) {
2303 if (Node->ReplacementNode)
2304 return Node->ReplacementNode;
2305
2306 auto ReplaceOperandIfExist = [&](CompositeNode *Node,
2307 unsigned Idx) -> Value * {
2308 return Node->Operands.size() > Idx
2309 ? replaceNode(Builder, Node: Node->Operands[Idx])
2310 : nullptr;
2311 };
2312
2313 Value *ReplacementNode = nullptr;
2314 switch (Node->Operation) {
2315 case ComplexDeinterleavingOperation::CDot: {
2316 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2317 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2318 Value *Accumulator = ReplaceOperandIfExist(Node, 2);
2319 assert(!Input1 || (Input0->getType() == Input1->getType() &&
2320 "Node inputs need to be of the same type"));
2321 ReplacementNode = TL->createComplexDeinterleavingIR(
2322 B&: Builder, OperationType: Node->Operation, Rotation: Node->Rotation, InputA: Input0, InputB: Input1, Accumulator);
2323 break;
2324 }
2325 case ComplexDeinterleavingOperation::CAdd:
2326 case ComplexDeinterleavingOperation::CMulPartial:
2327 case ComplexDeinterleavingOperation::Symmetric: {
2328 Value *Input0 = ReplaceOperandIfExist(Node, 0);
2329 Value *Input1 = ReplaceOperandIfExist(Node, 1);
2330 Value *Accumulator = ReplaceOperandIfExist(Node, 2);
2331 assert(!Input1 || (Input0->getType() == Input1->getType() &&
2332 "Node inputs need to be of the same type"));
2333 assert(!Accumulator ||
2334 (Input0->getType() == Accumulator->getType() &&
2335 "Accumulator and input need to be of the same type"));
2336 if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
2337 ReplacementNode = replaceSymmetricNode(B&: Builder, Opcode: Node->Opcode, Flags: Node->Flags,
2338 InputA: Input0, InputB: Input1);
2339 else
2340 ReplacementNode = TL->createComplexDeinterleavingIR(
2341 B&: Builder, OperationType: Node->Operation, Rotation: Node->Rotation, InputA: Input0, InputB: Input1,
2342 Accumulator);
2343 break;
2344 }
2345 case ComplexDeinterleavingOperation::Deinterleave:
2346 llvm_unreachable("Deinterleave node should already have ReplacementNode");
2347 break;
2348 case ComplexDeinterleavingOperation::Splat: {
2349 SmallVector<Value *> Ops;
2350 for (auto &V : Node->Vals) {
2351 Ops.push_back(Elt: V.Real);
2352 Ops.push_back(Elt: V.Imag);
2353 }
2354 auto *R = dyn_cast<Instruction>(Val: Node->Vals[0].Real);
2355 auto *I = dyn_cast<Instruction>(Val: Node->Vals[0].Imag);
2356 if (R && I) {
2357 // Splats that are not constant are interleaved where they are located
2358 Instruction *InsertPoint = R;
2359 for (auto V : Node->Vals) {
2360 if (InsertPoint->comesBefore(Other: cast<Instruction>(Val: V.Real)))
2361 InsertPoint = cast<Instruction>(Val: V.Real);
2362 if (InsertPoint->comesBefore(Other: cast<Instruction>(Val: V.Imag)))
2363 InsertPoint = cast<Instruction>(Val: V.Imag);
2364 }
2365 InsertPoint = InsertPoint->getNextNode();
2366 IRBuilder<> IRB(InsertPoint);
2367 ReplacementNode = IRB.CreateVectorInterleave(Ops);
2368 } else {
2369 ReplacementNode = Builder.CreateVectorInterleave(Ops);
2370 }
2371 break;
2372 }
2373 case ComplexDeinterleavingOperation::ReductionPHI: {
2374 // If Operation is ReductionPHI, a new empty PHINode is created.
2375 // It is filled later when the ReductionOperation is processed.
2376 auto *OldPHI = cast<PHINode>(Val: Node->Vals[0].Real);
2377 auto *VTy = cast<VectorType>(Val: Node->Vals[0].Real->getType());
2378 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2379 auto *NewPHI = PHINode::Create(Ty: NewVTy, NumReservedValues: 0, NameStr: "", InsertBefore: BackEdge->getFirstNonPHIIt());
2380 OldToNewPHI[OldPHI] = NewPHI;
2381 ReplacementNode = NewPHI;
2382 break;
2383 }
2384 case ComplexDeinterleavingOperation::ReductionSingle:
2385 ReplacementNode = replaceNode(Builder, Node: Node->Operands[0]);
2386 processReductionSingle(OperationReplacement: ReplacementNode, Node);
2387 break;
2388 case ComplexDeinterleavingOperation::ReductionOperation:
2389 ReplacementNode = replaceNode(Builder, Node: Node->Operands[0]);
2390 processReductionOperation(OperationReplacement: ReplacementNode, Node);
2391 break;
2392 case ComplexDeinterleavingOperation::ReductionSelect: {
2393 auto *MaskReal = cast<Instruction>(Val: Node->Vals[0].Real)->getOperand(i: 0);
2394 auto *MaskImag = cast<Instruction>(Val: Node->Vals[0].Imag)->getOperand(i: 0);
2395 auto *A = replaceNode(Builder, Node: Node->Operands[0]);
2396 auto *B = replaceNode(Builder, Node: Node->Operands[1]);
2397 auto *NewMask = Builder.CreateVectorInterleave(Ops: {MaskReal, MaskImag});
2398 ReplacementNode = Builder.CreateSelect(C: NewMask, True: A, False: B);
2399 break;
2400 }
2401 }
2402
2403 assert(ReplacementNode && "Target failed to create Intrinsic call.");
2404 NumComplexTransformations += 1;
2405 Node->ReplacementNode = ReplacementNode;
2406 return ReplacementNode;
2407}
2408
2409void ComplexDeinterleavingGraph::processReductionSingle(
2410 Value *OperationReplacement, CompositeNode *Node) {
2411 auto *Real = cast<Instruction>(Val: Node->Vals[0].Real);
2412 auto *OldPHI = ReductionInfo[Real].first;
2413 auto *NewPHI = OldToNewPHI[OldPHI];
2414 auto *VTy = cast<VectorType>(Val: Real->getType());
2415 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2416
2417 Value *Init = OldPHI->getIncomingValueForBlock(BB: Incoming);
2418
2419 IRBuilder<> Builder(Incoming->getTerminator());
2420
2421 Value *NewInit = nullptr;
2422 if (auto *C = dyn_cast<Constant>(Val: Init)) {
2423 if (C->isNullValue())
2424 NewInit = Constant::getNullValue(Ty: NewVTy);
2425 }
2426
2427 if (!NewInit)
2428 NewInit =
2429 Builder.CreateVectorInterleave(Ops: {Init, Constant::getNullValue(Ty: VTy)});
2430
2431 NewPHI->addIncoming(V: NewInit, BB: Incoming);
2432 NewPHI->addIncoming(V: OperationReplacement, BB: BackEdge);
2433
2434 auto *FinalReduction = ReductionInfo[Real].second;
2435 Builder.SetInsertPoint(&*FinalReduction->getParent()->getFirstInsertionPt());
2436
2437 auto *AddReduce = Builder.CreateAddReduce(Src: OperationReplacement);
2438 FinalReduction->replaceAllUsesWith(V: AddReduce);
2439}
2440
2441void ComplexDeinterleavingGraph::processReductionOperation(
2442 Value *OperationReplacement, CompositeNode *Node) {
2443 auto *Real = cast<Instruction>(Val: Node->Vals[0].Real);
2444 auto *Imag = cast<Instruction>(Val: Node->Vals[0].Imag);
2445 auto *OldPHIReal = ReductionInfo[Real].first;
2446 auto *OldPHIImag = ReductionInfo[Imag].first;
2447 auto *NewPHI = OldToNewPHI[OldPHIReal];
2448
2449 // We have to interleave initial origin values coming from IncomingBlock
2450 Value *InitReal = OldPHIReal->getIncomingValueForBlock(BB: Incoming);
2451 Value *InitImag = OldPHIImag->getIncomingValueForBlock(BB: Incoming);
2452
2453 IRBuilder<> Builder(Incoming->getTerminator());
2454 auto *NewInit = Builder.CreateVectorInterleave(Ops: {InitReal, InitImag});
2455
2456 NewPHI->addIncoming(V: NewInit, BB: Incoming);
2457 NewPHI->addIncoming(V: OperationReplacement, BB: BackEdge);
2458
2459 // Deinterleave complex vector outside of loop so that it can be finally
2460 // reduced
2461 auto *FinalReductionReal = ReductionInfo[Real].second;
2462 auto *FinalReductionImag = ReductionInfo[Imag].second;
2463
2464 auto *Br = cast<CondBrInst>(Val: BackEdge->getTerminator());
2465 BasicBlock *ExitBB = Br->getSuccessor(i: Br->getSuccessor(i: 0) == BackEdge);
2466 Builder.SetInsertPoint(&*ExitBB->getFirstInsertionPt());
2467
2468 auto *Deinterleave = Builder.CreateIntrinsic(ID: Intrinsic::vector_deinterleave2,
2469 OverloadTypes: OperationReplacement->getType(),
2470 Args: OperationReplacement);
2471
2472 auto *NewReal = Builder.CreateExtractValue(Agg: Deinterleave, Idxs: (uint64_t)0);
2473 FinalReductionReal->replaceUsesOfWith(From: Real, To: NewReal);
2474
2475 Builder.SetInsertPoint(FinalReductionImag);
2476 auto *NewImag = Builder.CreateExtractValue(Agg: Deinterleave, Idxs: 1);
2477 FinalReductionImag->replaceUsesOfWith(From: Imag, To: NewImag);
2478}
2479
2480void ComplexDeinterleavingGraph::replaceNodes() {
2481 SmallVector<Instruction *, 16> DeadInstrRoots;
2482 for (auto *RootInstruction : OrderedRoots) {
2483 // Check if this potential root went through check process and we can
2484 // deinterleave it
2485 if (!RootToNode.count(Val: RootInstruction))
2486 continue;
2487
2488 IRBuilder<> Builder(RootInstruction);
2489 auto RootNode = RootToNode[RootInstruction];
2490 Value *R = replaceNode(Builder, Node: RootNode);
2491
2492 if (RootNode->Operation ==
2493 ComplexDeinterleavingOperation::ReductionOperation) {
2494 auto *RootReal = cast<Instruction>(Val: RootNode->Vals[0].Real);
2495 auto *RootImag = cast<Instruction>(Val: RootNode->Vals[0].Imag);
2496 ReductionInfo[RootReal].first->removeIncomingValue(BB: BackEdge);
2497 ReductionInfo[RootImag].first->removeIncomingValue(BB: BackEdge);
2498 DeadInstrRoots.push_back(Elt: RootReal);
2499 DeadInstrRoots.push_back(Elt: RootImag);
2500 } else if (RootNode->Operation ==
2501 ComplexDeinterleavingOperation::ReductionSingle) {
2502 auto *RootInst = cast<Instruction>(Val: RootNode->Vals[0].Real);
2503 auto &Info = ReductionInfo[RootInst];
2504 Info.first->removeIncomingValue(BB: BackEdge);
2505 DeadInstrRoots.push_back(Elt: Info.second);
2506 } else {
2507 assert(R && "Unable to find replacement for RootInstruction");
2508 DeadInstrRoots.push_back(Elt: RootInstruction);
2509 RootInstruction->replaceAllUsesWith(V: R);
2510 }
2511 }
2512
2513 for (auto *I : DeadInstrRoots)
2514 RecursivelyDeleteTriviallyDeadInstructions(V: I, TLI);
2515}
2516