1 | //===- ComplexDeinterleavingPass.cpp --------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Identification: |
10 | // This step is responsible for finding the patterns that can be lowered to |
11 | // complex instructions, and building a graph to represent the complex |
12 | // structures. Starting from the "Converging Shuffle" (a shuffle that |
13 | // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the |
14 | // operands are evaluated and identified as "Composite Nodes" (collections of |
15 | // instructions that can potentially be lowered to a single complex |
16 | // instruction). This is performed by checking the real and imaginary components |
17 | // and tracking the data flow for each component while following the operand |
18 | // pairs. Validity of each node is expected to be done upon creation, and any |
19 | // validation errors should halt traversal and prevent further graph |
20 | // construction. |
21 | // Instead of relying on Shuffle operations, vector interleaving and |
22 | // deinterleaving can be represented by vector.interleave2 and |
23 | // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by |
24 | // these intrinsics, whereas, fixed-width vectors are recognized for both |
25 | // shufflevector instruction and intrinsics. |
26 | // |
27 | // Replacement: |
28 | // This step traverses the graph built up by identification, delegating to the |
29 | // target to validate and generate the correct intrinsics, and plumbs them |
30 | // together connecting each end of the new intrinsics graph to the existing |
31 | // use-def chain. This step is assumed to finish successfully, as all |
32 | // information is expected to be correct by this point. |
33 | // |
34 | // |
35 | // Internal data structure: |
36 | // ComplexDeinterleavingGraph: |
37 | // Keeps references to all the valid CompositeNodes formed as part of the |
38 | // transformation, and every Instruction contained within said nodes. It also |
39 | // holds onto a reference to the root Instruction, and the root node that should |
40 | // replace it. |
41 | // |
42 | // ComplexDeinterleavingCompositeNode: |
43 | // A CompositeNode represents a single transformation point; each node should |
44 | // transform into a single complex instruction (ignoring vector splitting, which |
45 | // would generate more instructions per node). They are identified in a |
46 | // depth-first manner, traversing and identifying the operands of each |
47 | // instruction in the order they appear in the IR. |
48 | // Each node maintains a reference to its Real and Imaginary instructions, |
49 | // as well as any additional instructions that make up the identified operation |
50 | // (Internal instructions should only have uses within their containing node). |
51 | // A Node also contains the rotation and operation type that it represents. |
52 | // Operands contains pointers to other CompositeNodes, acting as the edges in |
53 | // the graph. ReplacementValue is the transformed Value* that has been emitted |
54 | // to the IR. |
55 | // |
56 | // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and |
57 | // ReplacementValue fields of that Node are relevant, where the ReplacementValue |
58 | // should be pre-populated. |
59 | // |
60 | //===----------------------------------------------------------------------===// |
61 | |
62 | #include "llvm/CodeGen/ComplexDeinterleavingPass.h" |
63 | #include "llvm/ADT/MapVector.h" |
64 | #include "llvm/ADT/Statistic.h" |
65 | #include "llvm/Analysis/TargetLibraryInfo.h" |
66 | #include "llvm/Analysis/TargetTransformInfo.h" |
67 | #include "llvm/CodeGen/TargetLowering.h" |
68 | #include "llvm/CodeGen/TargetPassConfig.h" |
69 | #include "llvm/CodeGen/TargetSubtargetInfo.h" |
70 | #include "llvm/IR/IRBuilder.h" |
71 | #include "llvm/IR/PatternMatch.h" |
72 | #include "llvm/InitializePasses.h" |
73 | #include "llvm/Target/TargetMachine.h" |
74 | #include "llvm/Transforms/Utils/Local.h" |
75 | #include <algorithm> |
76 | |
77 | using namespace llvm; |
78 | using namespace PatternMatch; |
79 | |
80 | #define DEBUG_TYPE "complex-deinterleaving" |
81 | |
82 | STATISTIC(, "Amount of complex patterns transformed" ); |
83 | |
84 | static cl::opt<bool> ComplexDeinterleavingEnabled( |
85 | "enable-complex-deinterleaving" , |
86 | cl::desc("Enable generation of complex instructions" ), cl::init(Val: true), |
87 | cl::Hidden); |
88 | |
89 | /// Checks the given mask, and determines whether said mask is interleaving. |
90 | /// |
91 | /// To be interleaving, a mask must alternate between `i` and `i + (Length / |
92 | /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a |
93 | /// 4x vector interleaving mask would be <0, 2, 1, 3>). |
94 | static bool isInterleavingMask(ArrayRef<int> Mask); |
95 | |
96 | /// Checks the given mask, and determines whether said mask is deinterleaving. |
97 | /// |
98 | /// To be deinterleaving, a mask must increment in steps of 2, and either start |
99 | /// with 0 or 1. |
100 | /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or |
101 | /// <1, 3, 5, 7>). |
102 | static bool isDeinterleavingMask(ArrayRef<int> Mask); |
103 | |
104 | /// Returns true if the operation is a negation of V, and it works for both |
105 | /// integers and floats. |
106 | static bool isNeg(Value *V); |
107 | |
108 | /// Returns the operand for negation operation. |
109 | static Value *getNegOperand(Value *V); |
110 | |
111 | namespace { |
112 | |
113 | class ComplexDeinterleavingLegacyPass : public FunctionPass { |
114 | public: |
115 | static char ID; |
116 | |
117 | ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr) |
118 | : FunctionPass(ID), TM(TM) { |
119 | initializeComplexDeinterleavingLegacyPassPass( |
120 | *PassRegistry::getPassRegistry()); |
121 | } |
122 | |
123 | StringRef getPassName() const override { |
124 | return "Complex Deinterleaving Pass" ; |
125 | } |
126 | |
127 | bool runOnFunction(Function &F) override; |
128 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
129 | AU.addRequired<TargetLibraryInfoWrapperPass>(); |
130 | AU.setPreservesCFG(); |
131 | } |
132 | |
133 | private: |
134 | const TargetMachine *TM; |
135 | }; |
136 | |
137 | class ComplexDeinterleavingGraph; |
138 | struct ComplexDeinterleavingCompositeNode { |
139 | |
140 | ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op, |
141 | Value *R, Value *I) |
142 | : Operation(Op), Real(R), Imag(I) {} |
143 | |
144 | private: |
145 | friend class ComplexDeinterleavingGraph; |
146 | using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>; |
147 | using RawNodePtr = ComplexDeinterleavingCompositeNode *; |
148 | |
149 | public: |
150 | ComplexDeinterleavingOperation Operation; |
151 | Value *Real; |
152 | Value *Imag; |
153 | |
154 | // This two members are required exclusively for generating |
155 | // ComplexDeinterleavingOperation::Symmetric operations. |
156 | unsigned Opcode; |
157 | std::optional<FastMathFlags> Flags; |
158 | |
159 | ComplexDeinterleavingRotation Rotation = |
160 | ComplexDeinterleavingRotation::Rotation_0; |
161 | SmallVector<RawNodePtr> Operands; |
162 | Value *ReplacementNode = nullptr; |
163 | |
164 | void addOperand(NodePtr Node) { Operands.push_back(Elt: Node.get()); } |
165 | |
166 | void dump() { dump(OS&: dbgs()); } |
167 | void dump(raw_ostream &OS) { |
168 | auto PrintValue = [&](Value *V) { |
169 | if (V) { |
170 | OS << "\"" ; |
171 | V->print(O&: OS, IsForDebug: true); |
172 | OS << "\"\n" ; |
173 | } else |
174 | OS << "nullptr\n" ; |
175 | }; |
176 | auto PrintNodeRef = [&](RawNodePtr Ptr) { |
177 | if (Ptr) |
178 | OS << Ptr << "\n" ; |
179 | else |
180 | OS << "nullptr\n" ; |
181 | }; |
182 | |
183 | OS << "- CompositeNode: " << this << "\n" ; |
184 | OS << " Real: " ; |
185 | PrintValue(Real); |
186 | OS << " Imag: " ; |
187 | PrintValue(Imag); |
188 | OS << " ReplacementNode: " ; |
189 | PrintValue(ReplacementNode); |
190 | OS << " Operation: " << (int)Operation << "\n" ; |
191 | OS << " Rotation: " << ((int)Rotation * 90) << "\n" ; |
192 | OS << " Operands: \n" ; |
193 | for (const auto &Op : Operands) { |
194 | OS << " - " ; |
195 | PrintNodeRef(Op); |
196 | } |
197 | } |
198 | }; |
199 | |
200 | class ComplexDeinterleavingGraph { |
201 | public: |
202 | struct Product { |
203 | Value *Multiplier; |
204 | Value *Multiplicand; |
205 | bool IsPositive; |
206 | }; |
207 | |
208 | using Addend = std::pair<Value *, bool>; |
209 | using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr; |
210 | using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr; |
211 | |
212 | // Helper struct for holding info about potential partial multiplication |
213 | // candidates |
214 | struct PartialMulCandidate { |
215 | Value *Common; |
216 | NodePtr Node; |
217 | unsigned RealIdx; |
218 | unsigned ImagIdx; |
219 | bool IsNodeInverted; |
220 | }; |
221 | |
222 | explicit ComplexDeinterleavingGraph(const TargetLowering *TL, |
223 | const TargetLibraryInfo *TLI) |
224 | : TL(TL), TLI(TLI) {} |
225 | |
226 | private: |
227 | const TargetLowering *TL = nullptr; |
228 | const TargetLibraryInfo *TLI = nullptr; |
229 | SmallVector<NodePtr> CompositeNodes; |
230 | DenseMap<std::pair<Value *, Value *>, NodePtr> CachedResult; |
231 | |
232 | SmallPtrSet<Instruction *, 16> FinalInstructions; |
233 | |
234 | /// Root instructions are instructions from which complex computation starts |
235 | std::map<Instruction *, NodePtr> RootToNode; |
236 | |
237 | /// Topologically sorted root instructions |
238 | SmallVector<Instruction *, 1> OrderedRoots; |
239 | |
240 | /// When examining a basic block for complex deinterleaving, if it is a simple |
241 | /// one-block loop, then the only incoming block is 'Incoming' and the |
242 | /// 'BackEdge' block is the block itself." |
243 | BasicBlock *BackEdge = nullptr; |
244 | BasicBlock *Incoming = nullptr; |
245 | |
246 | /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction |
247 | /// %OutsideUser as it is shown in the IR: |
248 | /// |
249 | /// vector.body: |
250 | /// %PHInode = phi <vector type> [ zeroinitializer, %entry ], |
251 | /// [ %ReductionOp, %vector.body ] |
252 | /// ... |
253 | /// %ReductionOp = fadd i64 ... |
254 | /// ... |
255 | /// br i1 %condition, label %vector.body, %middle.block |
256 | /// |
257 | /// middle.block: |
258 | /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp) |
259 | /// |
260 | /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding |
261 | /// `llvm.vector.reduce.fadd` when unroll factor isn't one. |
262 | MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo; |
263 | |
264 | /// In the process of detecting a reduction, we consider a pair of |
265 | /// %ReductionOP, which we refer to as real and imag (or vice versa), and |
266 | /// traverse the use-tree to detect complex operations. As this is a reduction |
267 | /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds |
268 | /// to the %ReductionOPs that we suspect to be complex. |
269 | /// RealPHI and ImagPHI are used by the identifyPHINode method. |
270 | PHINode *RealPHI = nullptr; |
271 | PHINode *ImagPHI = nullptr; |
272 | |
273 | /// Set this flag to true if RealPHI and ImagPHI were reached during reduction |
274 | /// detection. |
275 | bool PHIsFound = false; |
276 | |
277 | /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode. |
278 | /// The new PHINode corresponds to a vector of deinterleaved complex numbers. |
279 | /// This mapping is populated during |
280 | /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then |
281 | /// used in the ComplexDeinterleavingOperation::ReductionOperation node |
282 | /// replacement process. |
283 | std::map<PHINode *, PHINode *> OldToNewPHI; |
284 | |
285 | NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation, |
286 | Value *R, Value *I) { |
287 | assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI && |
288 | Operation != ComplexDeinterleavingOperation::ReductionOperation) || |
289 | (R && I)) && |
290 | "Reduction related nodes must have Real and Imaginary parts" ); |
291 | return std::make_shared<ComplexDeinterleavingCompositeNode>(args&: Operation, args&: R, |
292 | args&: I); |
293 | } |
294 | |
295 | NodePtr submitCompositeNode(NodePtr Node) { |
296 | CompositeNodes.push_back(Elt: Node); |
297 | if (Node->Real && Node->Imag) |
298 | CachedResult[{Node->Real, Node->Imag}] = Node; |
299 | return Node; |
300 | } |
301 | |
302 | /// Identifies a complex partial multiply pattern and its rotation, based on |
303 | /// the following patterns |
304 | /// |
305 | /// 0: r: cr + ar * br |
306 | /// i: ci + ar * bi |
307 | /// 90: r: cr - ai * bi |
308 | /// i: ci + ai * br |
309 | /// 180: r: cr - ar * br |
310 | /// i: ci - ar * bi |
311 | /// 270: r: cr + ai * bi |
312 | /// i: ci - ai * br |
313 | NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag); |
314 | |
315 | /// Identify the other branch of a Partial Mul, taking the CommonOperandI that |
316 | /// is partially known from identifyPartialMul, filling in the other half of |
317 | /// the complex pair. |
318 | NodePtr |
319 | identifyNodeWithImplicitAdd(Instruction *I, Instruction *J, |
320 | std::pair<Value *, Value *> &CommonOperandI); |
321 | |
322 | /// Identifies a complex add pattern and its rotation, based on the following |
323 | /// patterns. |
324 | /// |
325 | /// 90: r: ar - bi |
326 | /// i: ai + br |
327 | /// 270: r: ar + bi |
328 | /// i: ai - br |
329 | NodePtr identifyAdd(Instruction *Real, Instruction *Imag); |
330 | NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag); |
331 | |
332 | NodePtr identifyNode(Value *R, Value *I); |
333 | |
334 | /// Determine if a sum of complex numbers can be formed from \p RealAddends |
335 | /// and \p ImagAddens. If \p Accumulator is not null, add the result to it. |
336 | /// Return nullptr if it is not possible to construct a complex number. |
337 | /// \p Flags are needed to generate symmetric Add and Sub operations. |
338 | NodePtr identifyAdditions(std::list<Addend> &RealAddends, |
339 | std::list<Addend> &ImagAddends, |
340 | std::optional<FastMathFlags> Flags, |
341 | NodePtr Accumulator); |
342 | |
343 | /// Extract one addend that have both real and imaginary parts positive. |
344 | NodePtr extractPositiveAddend(std::list<Addend> &RealAddends, |
345 | std::list<Addend> &ImagAddends); |
346 | |
347 | /// Determine if sum of multiplications of complex numbers can be formed from |
348 | /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result |
349 | /// to it. Return nullptr if it is not possible to construct a complex number. |
350 | NodePtr identifyMultiplications(std::vector<Product> &RealMuls, |
351 | std::vector<Product> &ImagMuls, |
352 | NodePtr Accumulator); |
353 | |
354 | /// Go through pairs of multiplication (one Real and one Imag) and find all |
355 | /// possible candidates for partial multiplication and put them into \p |
356 | /// Candidates. Returns true if all Product has pair with common operand |
357 | bool collectPartialMuls(const std::vector<Product> &RealMuls, |
358 | const std::vector<Product> &ImagMuls, |
359 | std::vector<PartialMulCandidate> &Candidates); |
360 | |
361 | /// If the code is compiled with -Ofast or expressions have `reassoc` flag, |
362 | /// the order of complex computation operations may be significantly altered, |
363 | /// and the real and imaginary parts may not be executed in parallel. This |
364 | /// function takes this into consideration and employs a more general approach |
365 | /// to identify complex computations. Initially, it gathers all the addends |
366 | /// and multiplicands and then constructs a complex expression from them. |
367 | NodePtr identifyReassocNodes(Instruction *I, Instruction *J); |
368 | |
369 | NodePtr identifyRoot(Instruction *I); |
370 | |
371 | /// Identifies the Deinterleave operation applied to a vector containing |
372 | /// complex numbers. There are two ways to represent the Deinterleave |
373 | /// operation: |
374 | /// * Using two shufflevectors with even indices for /pReal instruction and |
375 | /// odd indices for /pImag instructions (only for fixed-width vectors) |
376 | /// * Using two extractvalue instructions applied to `vector.deinterleave2` |
377 | /// intrinsic (for both fixed and scalable vectors) |
378 | NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag); |
379 | |
380 | /// identifying the operation that represents a complex number repeated in a |
381 | /// Splat vector. There are two possible types of splats: ConstantExpr with |
382 | /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an |
383 | /// initialization mask with all values set to zero. |
384 | NodePtr identifySplat(Value *Real, Value *Imag); |
385 | |
386 | NodePtr identifyPHINode(Instruction *Real, Instruction *Imag); |
387 | |
388 | /// Identifies SelectInsts in a loop that has reduction with predication masks |
389 | /// and/or predicated tail folding |
390 | NodePtr identifySelectNode(Instruction *Real, Instruction *Imag); |
391 | |
392 | Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node); |
393 | |
394 | /// Complete IR modifications after producing new reduction operation: |
395 | /// * Populate the PHINode generated for |
396 | /// ComplexDeinterleavingOperation::ReductionPHI |
397 | /// * Deinterleave the final value outside of the loop and repurpose original |
398 | /// reduction users |
399 | void processReductionOperation(Value *OperationReplacement, RawNodePtr Node); |
400 | |
401 | public: |
402 | void dump() { dump(OS&: dbgs()); } |
403 | void dump(raw_ostream &OS) { |
404 | for (const auto &Node : CompositeNodes) |
405 | Node->dump(OS); |
406 | } |
407 | |
408 | /// Returns false if the deinterleaving operation should be cancelled for the |
409 | /// current graph. |
410 | bool identifyNodes(Instruction *RootI); |
411 | |
412 | /// In case \pB is one-block loop, this function seeks potential reductions |
413 | /// and populates ReductionInfo. Returns true if any reductions were |
414 | /// identified. |
415 | bool collectPotentialReductions(BasicBlock *B); |
416 | |
417 | void identifyReductionNodes(); |
418 | |
419 | /// Check that every instruction, from the roots to the leaves, has internal |
420 | /// uses. |
421 | bool checkNodes(); |
422 | |
423 | /// Perform the actual replacement of the underlying instruction graph. |
424 | void replaceNodes(); |
425 | }; |
426 | |
427 | class ComplexDeinterleaving { |
428 | public: |
429 | ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli) |
430 | : TL(tl), TLI(tli) {} |
431 | bool runOnFunction(Function &F); |
432 | |
433 | private: |
434 | bool evaluateBasicBlock(BasicBlock *B); |
435 | |
436 | const TargetLowering *TL = nullptr; |
437 | const TargetLibraryInfo *TLI = nullptr; |
438 | }; |
439 | |
440 | } // namespace |
441 | |
442 | char ComplexDeinterleavingLegacyPass::ID = 0; |
443 | |
444 | INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, |
445 | "Complex Deinterleaving" , false, false) |
446 | INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, |
447 | "Complex Deinterleaving" , false, false) |
448 | |
449 | PreservedAnalyses ComplexDeinterleavingPass::run(Function &F, |
450 | FunctionAnalysisManager &AM) { |
451 | const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering(); |
452 | auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(IR&: F); |
453 | if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F)) |
454 | return PreservedAnalyses::all(); |
455 | |
456 | PreservedAnalyses PA; |
457 | PA.preserve<FunctionAnalysisManagerModuleProxy>(); |
458 | return PA; |
459 | } |
460 | |
461 | FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) { |
462 | return new ComplexDeinterleavingLegacyPass(TM); |
463 | } |
464 | |
465 | bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) { |
466 | const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering(); |
467 | auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); |
468 | return ComplexDeinterleaving(TL, &TLI).runOnFunction(F); |
469 | } |
470 | |
471 | bool ComplexDeinterleaving::runOnFunction(Function &F) { |
472 | if (!ComplexDeinterleavingEnabled) { |
473 | LLVM_DEBUG( |
474 | dbgs() << "Complex deinterleaving has been explicitly disabled.\n" ); |
475 | return false; |
476 | } |
477 | |
478 | if (!TL->isComplexDeinterleavingSupported()) { |
479 | LLVM_DEBUG( |
480 | dbgs() << "Complex deinterleaving has been disabled, target does " |
481 | "not support lowering of complex number operations.\n" ); |
482 | return false; |
483 | } |
484 | |
485 | bool Changed = false; |
486 | for (auto &B : F) |
487 | Changed |= evaluateBasicBlock(B: &B); |
488 | |
489 | return Changed; |
490 | } |
491 | |
492 | static bool isInterleavingMask(ArrayRef<int> Mask) { |
493 | // If the size is not even, it's not an interleaving mask |
494 | if ((Mask.size() & 1)) |
495 | return false; |
496 | |
497 | int HalfNumElements = Mask.size() / 2; |
498 | for (int Idx = 0; Idx < HalfNumElements; ++Idx) { |
499 | int MaskIdx = Idx * 2; |
500 | if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements)) |
501 | return false; |
502 | } |
503 | |
504 | return true; |
505 | } |
506 | |
507 | static bool isDeinterleavingMask(ArrayRef<int> Mask) { |
508 | int Offset = Mask[0]; |
509 | int HalfNumElements = Mask.size() / 2; |
510 | |
511 | for (int Idx = 1; Idx < HalfNumElements; ++Idx) { |
512 | if (Mask[Idx] != (Idx * 2) + Offset) |
513 | return false; |
514 | } |
515 | |
516 | return true; |
517 | } |
518 | |
519 | bool isNeg(Value *V) { |
520 | return match(V, P: m_FNeg(X: m_Value())) || match(V, P: m_Neg(V: m_Value())); |
521 | } |
522 | |
523 | Value *getNegOperand(Value *V) { |
524 | assert(isNeg(V)); |
525 | auto *I = cast<Instruction>(Val: V); |
526 | if (I->getOpcode() == Instruction::FNeg) |
527 | return I->getOperand(i: 0); |
528 | |
529 | return I->getOperand(i: 1); |
530 | } |
531 | |
532 | bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) { |
533 | ComplexDeinterleavingGraph Graph(TL, TLI); |
534 | if (Graph.collectPotentialReductions(B)) |
535 | Graph.identifyReductionNodes(); |
536 | |
537 | for (auto &I : *B) |
538 | Graph.identifyNodes(RootI: &I); |
539 | |
540 | if (Graph.checkNodes()) { |
541 | Graph.replaceNodes(); |
542 | return true; |
543 | } |
544 | |
545 | return false; |
546 | } |
547 | |
548 | ComplexDeinterleavingGraph::NodePtr |
549 | ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd( |
550 | Instruction *Real, Instruction *Imag, |
551 | std::pair<Value *, Value *> &PartialMatch) { |
552 | LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag |
553 | << "\n" ); |
554 | |
555 | if (!Real->hasOneUse() || !Imag->hasOneUse()) { |
556 | LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n" ); |
557 | return nullptr; |
558 | } |
559 | |
560 | if ((Real->getOpcode() != Instruction::FMul && |
561 | Real->getOpcode() != Instruction::Mul) || |
562 | (Imag->getOpcode() != Instruction::FMul && |
563 | Imag->getOpcode() != Instruction::Mul)) { |
564 | LLVM_DEBUG( |
565 | dbgs() << " - Real or imaginary instruction is not fmul or mul\n" ); |
566 | return nullptr; |
567 | } |
568 | |
569 | Value *R0 = Real->getOperand(i: 0); |
570 | Value *R1 = Real->getOperand(i: 1); |
571 | Value *I0 = Imag->getOperand(i: 0); |
572 | Value *I1 = Imag->getOperand(i: 1); |
573 | |
574 | // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the |
575 | // rotations and use the operand. |
576 | unsigned Negs = 0; |
577 | Value *Op; |
578 | if (match(V: R0, P: m_Neg(V: m_Value(V&: Op)))) { |
579 | Negs |= 1; |
580 | R0 = Op; |
581 | } else if (match(V: R1, P: m_Neg(V: m_Value(V&: Op)))) { |
582 | Negs |= 1; |
583 | R1 = Op; |
584 | } |
585 | |
586 | if (isNeg(V: I0)) { |
587 | Negs |= 2; |
588 | Negs ^= 1; |
589 | I0 = Op; |
590 | } else if (match(V: I1, P: m_Neg(V: m_Value(V&: Op)))) { |
591 | Negs |= 2; |
592 | Negs ^= 1; |
593 | I1 = Op; |
594 | } |
595 | |
596 | ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs; |
597 | |
598 | Value *CommonOperand; |
599 | Value *UncommonRealOp; |
600 | Value *UncommonImagOp; |
601 | |
602 | if (R0 == I0 || R0 == I1) { |
603 | CommonOperand = R0; |
604 | UncommonRealOp = R1; |
605 | } else if (R1 == I0 || R1 == I1) { |
606 | CommonOperand = R1; |
607 | UncommonRealOp = R0; |
608 | } else { |
609 | LLVM_DEBUG(dbgs() << " - No equal operand\n" ); |
610 | return nullptr; |
611 | } |
612 | |
613 | UncommonImagOp = (CommonOperand == I0) ? I1 : I0; |
614 | if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || |
615 | Rotation == ComplexDeinterleavingRotation::Rotation_270) |
616 | std::swap(a&: UncommonRealOp, b&: UncommonImagOp); |
617 | |
618 | // Between identifyPartialMul and here we need to have found a complete valid |
619 | // pair from the CommonOperand of each part. |
620 | if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || |
621 | Rotation == ComplexDeinterleavingRotation::Rotation_180) |
622 | PartialMatch.first = CommonOperand; |
623 | else |
624 | PartialMatch.second = CommonOperand; |
625 | |
626 | if (!PartialMatch.first || !PartialMatch.second) { |
627 | LLVM_DEBUG(dbgs() << " - Incomplete partial match\n" ); |
628 | return nullptr; |
629 | } |
630 | |
631 | NodePtr CommonNode = identifyNode(R: PartialMatch.first, I: PartialMatch.second); |
632 | if (!CommonNode) { |
633 | LLVM_DEBUG(dbgs() << " - No CommonNode identified\n" ); |
634 | return nullptr; |
635 | } |
636 | |
637 | NodePtr UncommonNode = identifyNode(R: UncommonRealOp, I: UncommonImagOp); |
638 | if (!UncommonNode) { |
639 | LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n" ); |
640 | return nullptr; |
641 | } |
642 | |
643 | NodePtr Node = prepareCompositeNode( |
644 | Operation: ComplexDeinterleavingOperation::CMulPartial, R: Real, I: Imag); |
645 | Node->Rotation = Rotation; |
646 | Node->addOperand(Node: CommonNode); |
647 | Node->addOperand(Node: UncommonNode); |
648 | return submitCompositeNode(Node); |
649 | } |
650 | |
651 | ComplexDeinterleavingGraph::NodePtr |
652 | ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real, |
653 | Instruction *Imag) { |
654 | LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag |
655 | << "\n" ); |
656 | // Determine rotation |
657 | auto IsAdd = [](unsigned Op) { |
658 | return Op == Instruction::FAdd || Op == Instruction::Add; |
659 | }; |
660 | auto IsSub = [](unsigned Op) { |
661 | return Op == Instruction::FSub || Op == Instruction::Sub; |
662 | }; |
663 | ComplexDeinterleavingRotation Rotation; |
664 | if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode())) |
665 | Rotation = ComplexDeinterleavingRotation::Rotation_0; |
666 | else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode())) |
667 | Rotation = ComplexDeinterleavingRotation::Rotation_90; |
668 | else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode())) |
669 | Rotation = ComplexDeinterleavingRotation::Rotation_180; |
670 | else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode())) |
671 | Rotation = ComplexDeinterleavingRotation::Rotation_270; |
672 | else { |
673 | LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n" ); |
674 | return nullptr; |
675 | } |
676 | |
677 | if (isa<FPMathOperator>(Val: Real) && |
678 | (!Real->getFastMathFlags().allowContract() || |
679 | !Imag->getFastMathFlags().allowContract())) { |
680 | LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n" ); |
681 | return nullptr; |
682 | } |
683 | |
684 | Value *CR = Real->getOperand(i: 0); |
685 | Instruction *RealMulI = dyn_cast<Instruction>(Val: Real->getOperand(i: 1)); |
686 | if (!RealMulI) |
687 | return nullptr; |
688 | Value *CI = Imag->getOperand(i: 0); |
689 | Instruction *ImagMulI = dyn_cast<Instruction>(Val: Imag->getOperand(i: 1)); |
690 | if (!ImagMulI) |
691 | return nullptr; |
692 | |
693 | if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) { |
694 | LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n" ); |
695 | return nullptr; |
696 | } |
697 | |
698 | Value *R0 = RealMulI->getOperand(i: 0); |
699 | Value *R1 = RealMulI->getOperand(i: 1); |
700 | Value *I0 = ImagMulI->getOperand(i: 0); |
701 | Value *I1 = ImagMulI->getOperand(i: 1); |
702 | |
703 | Value *CommonOperand; |
704 | Value *UncommonRealOp; |
705 | Value *UncommonImagOp; |
706 | |
707 | if (R0 == I0 || R0 == I1) { |
708 | CommonOperand = R0; |
709 | UncommonRealOp = R1; |
710 | } else if (R1 == I0 || R1 == I1) { |
711 | CommonOperand = R1; |
712 | UncommonRealOp = R0; |
713 | } else { |
714 | LLVM_DEBUG(dbgs() << " - No equal operand\n" ); |
715 | return nullptr; |
716 | } |
717 | |
718 | UncommonImagOp = (CommonOperand == I0) ? I1 : I0; |
719 | if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || |
720 | Rotation == ComplexDeinterleavingRotation::Rotation_270) |
721 | std::swap(a&: UncommonRealOp, b&: UncommonImagOp); |
722 | |
723 | std::pair<Value *, Value *> PartialMatch( |
724 | (Rotation == ComplexDeinterleavingRotation::Rotation_0 || |
725 | Rotation == ComplexDeinterleavingRotation::Rotation_180) |
726 | ? CommonOperand |
727 | : nullptr, |
728 | (Rotation == ComplexDeinterleavingRotation::Rotation_90 || |
729 | Rotation == ComplexDeinterleavingRotation::Rotation_270) |
730 | ? CommonOperand |
731 | : nullptr); |
732 | |
733 | auto *CRInst = dyn_cast<Instruction>(Val: CR); |
734 | auto *CIInst = dyn_cast<Instruction>(Val: CI); |
735 | |
736 | if (!CRInst || !CIInst) { |
737 | LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n" ); |
738 | return nullptr; |
739 | } |
740 | |
741 | NodePtr CNode = identifyNodeWithImplicitAdd(Real: CRInst, Imag: CIInst, PartialMatch); |
742 | if (!CNode) { |
743 | LLVM_DEBUG(dbgs() << " - No cnode identified\n" ); |
744 | return nullptr; |
745 | } |
746 | |
747 | NodePtr UncommonRes = identifyNode(R: UncommonRealOp, I: UncommonImagOp); |
748 | if (!UncommonRes) { |
749 | LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n" ); |
750 | return nullptr; |
751 | } |
752 | |
753 | assert(PartialMatch.first && PartialMatch.second); |
754 | NodePtr CommonRes = identifyNode(R: PartialMatch.first, I: PartialMatch.second); |
755 | if (!CommonRes) { |
756 | LLVM_DEBUG(dbgs() << " - No CommonRes identified\n" ); |
757 | return nullptr; |
758 | } |
759 | |
760 | NodePtr Node = prepareCompositeNode( |
761 | Operation: ComplexDeinterleavingOperation::CMulPartial, R: Real, I: Imag); |
762 | Node->Rotation = Rotation; |
763 | Node->addOperand(Node: CommonRes); |
764 | Node->addOperand(Node: UncommonRes); |
765 | Node->addOperand(Node: CNode); |
766 | return submitCompositeNode(Node); |
767 | } |
768 | |
769 | ComplexDeinterleavingGraph::NodePtr |
770 | ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) { |
771 | LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n" ); |
772 | |
773 | // Determine rotation |
774 | ComplexDeinterleavingRotation Rotation; |
775 | if ((Real->getOpcode() == Instruction::FSub && |
776 | Imag->getOpcode() == Instruction::FAdd) || |
777 | (Real->getOpcode() == Instruction::Sub && |
778 | Imag->getOpcode() == Instruction::Add)) |
779 | Rotation = ComplexDeinterleavingRotation::Rotation_90; |
780 | else if ((Real->getOpcode() == Instruction::FAdd && |
781 | Imag->getOpcode() == Instruction::FSub) || |
782 | (Real->getOpcode() == Instruction::Add && |
783 | Imag->getOpcode() == Instruction::Sub)) |
784 | Rotation = ComplexDeinterleavingRotation::Rotation_270; |
785 | else { |
786 | LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n" ); |
787 | return nullptr; |
788 | } |
789 | |
790 | auto *AR = dyn_cast<Instruction>(Val: Real->getOperand(i: 0)); |
791 | auto *BI = dyn_cast<Instruction>(Val: Real->getOperand(i: 1)); |
792 | auto *AI = dyn_cast<Instruction>(Val: Imag->getOperand(i: 0)); |
793 | auto *BR = dyn_cast<Instruction>(Val: Imag->getOperand(i: 1)); |
794 | |
795 | if (!AR || !AI || !BR || !BI) { |
796 | LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n" ); |
797 | return nullptr; |
798 | } |
799 | |
800 | NodePtr ResA = identifyNode(R: AR, I: AI); |
801 | if (!ResA) { |
802 | LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n" ); |
803 | return nullptr; |
804 | } |
805 | NodePtr ResB = identifyNode(R: BR, I: BI); |
806 | if (!ResB) { |
807 | LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n" ); |
808 | return nullptr; |
809 | } |
810 | |
811 | NodePtr Node = |
812 | prepareCompositeNode(Operation: ComplexDeinterleavingOperation::CAdd, R: Real, I: Imag); |
813 | Node->Rotation = Rotation; |
814 | Node->addOperand(Node: ResA); |
815 | Node->addOperand(Node: ResB); |
816 | return submitCompositeNode(Node); |
817 | } |
818 | |
819 | static bool isInstructionPairAdd(Instruction *A, Instruction *B) { |
820 | unsigned OpcA = A->getOpcode(); |
821 | unsigned OpcB = B->getOpcode(); |
822 | |
823 | return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) || |
824 | (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) || |
825 | (OpcA == Instruction::Sub && OpcB == Instruction::Add) || |
826 | (OpcA == Instruction::Add && OpcB == Instruction::Sub); |
827 | } |
828 | |
829 | static bool isInstructionPairMul(Instruction *A, Instruction *B) { |
830 | auto Pattern = |
831 | m_BinOp(L: m_FMul(L: m_Value(), R: m_Value()), R: m_FMul(L: m_Value(), R: m_Value())); |
832 | |
833 | return match(V: A, P: Pattern) && match(V: B, P: Pattern); |
834 | } |
835 | |
836 | static bool isInstructionPotentiallySymmetric(Instruction *I) { |
837 | switch (I->getOpcode()) { |
838 | case Instruction::FAdd: |
839 | case Instruction::FSub: |
840 | case Instruction::FMul: |
841 | case Instruction::FNeg: |
842 | case Instruction::Add: |
843 | case Instruction::Sub: |
844 | case Instruction::Mul: |
845 | return true; |
846 | default: |
847 | return false; |
848 | } |
849 | } |
850 | |
851 | ComplexDeinterleavingGraph::NodePtr |
852 | ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real, |
853 | Instruction *Imag) { |
854 | if (Real->getOpcode() != Imag->getOpcode()) |
855 | return nullptr; |
856 | |
857 | if (!isInstructionPotentiallySymmetric(I: Real) || |
858 | !isInstructionPotentiallySymmetric(I: Imag)) |
859 | return nullptr; |
860 | |
861 | auto *R0 = Real->getOperand(i: 0); |
862 | auto *I0 = Imag->getOperand(i: 0); |
863 | |
864 | NodePtr Op0 = identifyNode(R: R0, I: I0); |
865 | NodePtr Op1 = nullptr; |
866 | if (Op0 == nullptr) |
867 | return nullptr; |
868 | |
869 | if (Real->isBinaryOp()) { |
870 | auto *R1 = Real->getOperand(i: 1); |
871 | auto *I1 = Imag->getOperand(i: 1); |
872 | Op1 = identifyNode(R: R1, I: I1); |
873 | if (Op1 == nullptr) |
874 | return nullptr; |
875 | } |
876 | |
877 | if (isa<FPMathOperator>(Val: Real) && |
878 | Real->getFastMathFlags() != Imag->getFastMathFlags()) |
879 | return nullptr; |
880 | |
881 | auto Node = prepareCompositeNode(Operation: ComplexDeinterleavingOperation::Symmetric, |
882 | R: Real, I: Imag); |
883 | Node->Opcode = Real->getOpcode(); |
884 | if (isa<FPMathOperator>(Val: Real)) |
885 | Node->Flags = Real->getFastMathFlags(); |
886 | |
887 | Node->addOperand(Node: Op0); |
888 | if (Real->isBinaryOp()) |
889 | Node->addOperand(Node: Op1); |
890 | |
891 | return submitCompositeNode(Node); |
892 | } |
893 | |
894 | ComplexDeinterleavingGraph::NodePtr |
895 | ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) { |
896 | LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n" ); |
897 | assert(R->getType() == I->getType() && |
898 | "Real and imaginary parts should not have different types" ); |
899 | |
900 | auto It = CachedResult.find(Val: {R, I}); |
901 | if (It != CachedResult.end()) { |
902 | LLVM_DEBUG(dbgs() << " - Folding to existing node\n" ); |
903 | return It->second; |
904 | } |
905 | |
906 | if (NodePtr CN = identifySplat(Real: R, Imag: I)) |
907 | return CN; |
908 | |
909 | auto *Real = dyn_cast<Instruction>(Val: R); |
910 | auto *Imag = dyn_cast<Instruction>(Val: I); |
911 | if (!Real || !Imag) |
912 | return nullptr; |
913 | |
914 | if (NodePtr CN = identifyDeinterleave(Real, Imag)) |
915 | return CN; |
916 | |
917 | if (NodePtr CN = identifyPHINode(Real, Imag)) |
918 | return CN; |
919 | |
920 | if (NodePtr CN = identifySelectNode(Real, Imag)) |
921 | return CN; |
922 | |
923 | auto *VTy = cast<VectorType>(Val: Real->getType()); |
924 | auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); |
925 | |
926 | bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported( |
927 | Operation: ComplexDeinterleavingOperation::CMulPartial, Ty: NewVTy); |
928 | bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported( |
929 | Operation: ComplexDeinterleavingOperation::CAdd, Ty: NewVTy); |
930 | |
931 | if (HasCMulSupport && isInstructionPairMul(A: Real, B: Imag)) { |
932 | if (NodePtr CN = identifyPartialMul(Real, Imag)) |
933 | return CN; |
934 | } |
935 | |
936 | if (HasCAddSupport && isInstructionPairAdd(A: Real, B: Imag)) { |
937 | if (NodePtr CN = identifyAdd(Real, Imag)) |
938 | return CN; |
939 | } |
940 | |
941 | if (HasCMulSupport && HasCAddSupport) { |
942 | if (NodePtr CN = identifyReassocNodes(I: Real, J: Imag)) |
943 | return CN; |
944 | } |
945 | |
946 | if (NodePtr CN = identifySymmetricOperation(Real, Imag)) |
947 | return CN; |
948 | |
949 | LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n" ); |
950 | CachedResult[{R, I}] = nullptr; |
951 | return nullptr; |
952 | } |
953 | |
954 | ComplexDeinterleavingGraph::NodePtr |
955 | ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real, |
956 | Instruction *Imag) { |
957 | auto IsOperationSupported = [](unsigned Opcode) -> bool { |
958 | return Opcode == Instruction::FAdd || Opcode == Instruction::FSub || |
959 | Opcode == Instruction::FNeg || Opcode == Instruction::Add || |
960 | Opcode == Instruction::Sub; |
961 | }; |
962 | |
963 | if (!IsOperationSupported(Real->getOpcode()) || |
964 | !IsOperationSupported(Imag->getOpcode())) |
965 | return nullptr; |
966 | |
967 | std::optional<FastMathFlags> Flags; |
968 | if (isa<FPMathOperator>(Val: Real)) { |
969 | if (Real->getFastMathFlags() != Imag->getFastMathFlags()) { |
970 | LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are " |
971 | "not identical\n" ); |
972 | return nullptr; |
973 | } |
974 | |
975 | Flags = Real->getFastMathFlags(); |
976 | if (!Flags->allowReassoc()) { |
977 | LLVM_DEBUG( |
978 | dbgs() |
979 | << "the 'Reassoc' attribute is missing in the FastMath flags\n" ); |
980 | return nullptr; |
981 | } |
982 | } |
983 | |
984 | // Collect multiplications and addend instructions from the given instruction |
985 | // while traversing it operands. Additionally, verify that all instructions |
986 | // have the same fast math flags. |
987 | auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls, |
988 | std::list<Addend> &Addends) -> bool { |
989 | SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}}; |
990 | SmallPtrSet<Value *, 8> Visited; |
991 | while (!Worklist.empty()) { |
992 | auto [V, IsPositive] = Worklist.back(); |
993 | Worklist.pop_back(); |
994 | if (!Visited.insert(Ptr: V).second) |
995 | continue; |
996 | |
997 | Instruction *I = dyn_cast<Instruction>(Val: V); |
998 | if (!I) { |
999 | Addends.emplace_back(args&: V, args&: IsPositive); |
1000 | continue; |
1001 | } |
1002 | |
1003 | // If an instruction has more than one user, it indicates that it either |
1004 | // has an external user, which will be later checked by the checkNodes |
1005 | // function, or it is a subexpression utilized by multiple expressions. In |
1006 | // the latter case, we will attempt to separately identify the complex |
1007 | // operation from here in order to create a shared |
1008 | // ComplexDeinterleavingCompositeNode. |
1009 | if (I != Insn && I->getNumUses() > 1) { |
1010 | LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n" ); |
1011 | Addends.emplace_back(args&: I, args&: IsPositive); |
1012 | continue; |
1013 | } |
1014 | switch (I->getOpcode()) { |
1015 | case Instruction::FAdd: |
1016 | case Instruction::Add: |
1017 | Worklist.emplace_back(Args: I->getOperand(i: 1), Args&: IsPositive); |
1018 | Worklist.emplace_back(Args: I->getOperand(i: 0), Args&: IsPositive); |
1019 | break; |
1020 | case Instruction::FSub: |
1021 | Worklist.emplace_back(Args: I->getOperand(i: 1), Args: !IsPositive); |
1022 | Worklist.emplace_back(Args: I->getOperand(i: 0), Args&: IsPositive); |
1023 | break; |
1024 | case Instruction::Sub: |
1025 | if (isNeg(V: I)) { |
1026 | Worklist.emplace_back(Args: getNegOperand(V: I), Args: !IsPositive); |
1027 | } else { |
1028 | Worklist.emplace_back(Args: I->getOperand(i: 1), Args: !IsPositive); |
1029 | Worklist.emplace_back(Args: I->getOperand(i: 0), Args&: IsPositive); |
1030 | } |
1031 | break; |
1032 | case Instruction::FMul: |
1033 | case Instruction::Mul: { |
1034 | Value *A, *B; |
1035 | if (isNeg(V: I->getOperand(i: 0))) { |
1036 | A = getNegOperand(V: I->getOperand(i: 0)); |
1037 | IsPositive = !IsPositive; |
1038 | } else { |
1039 | A = I->getOperand(i: 0); |
1040 | } |
1041 | |
1042 | if (isNeg(V: I->getOperand(i: 1))) { |
1043 | B = getNegOperand(V: I->getOperand(i: 1)); |
1044 | IsPositive = !IsPositive; |
1045 | } else { |
1046 | B = I->getOperand(i: 1); |
1047 | } |
1048 | Muls.push_back(x: Product{.Multiplier: A, .Multiplicand: B, .IsPositive: IsPositive}); |
1049 | break; |
1050 | } |
1051 | case Instruction::FNeg: |
1052 | Worklist.emplace_back(Args: I->getOperand(i: 0), Args: !IsPositive); |
1053 | break; |
1054 | default: |
1055 | Addends.emplace_back(args&: I, args&: IsPositive); |
1056 | continue; |
1057 | } |
1058 | |
1059 | if (Flags && I->getFastMathFlags() != *Flags) { |
1060 | LLVM_DEBUG(dbgs() << "The instruction's fast math flags are " |
1061 | "inconsistent with the root instructions' flags: " |
1062 | << *I << "\n" ); |
1063 | return false; |
1064 | } |
1065 | } |
1066 | return true; |
1067 | }; |
1068 | |
1069 | std::vector<Product> RealMuls, ImagMuls; |
1070 | std::list<Addend> RealAddends, ImagAddends; |
1071 | if (!Collect(Real, RealMuls, RealAddends) || |
1072 | !Collect(Imag, ImagMuls, ImagAddends)) |
1073 | return nullptr; |
1074 | |
1075 | if (RealAddends.size() != ImagAddends.size()) |
1076 | return nullptr; |
1077 | |
1078 | NodePtr FinalNode; |
1079 | if (!RealMuls.empty() || !ImagMuls.empty()) { |
1080 | // If there are multiplicands, extract positive addend and use it as an |
1081 | // accumulator |
1082 | FinalNode = extractPositiveAddend(RealAddends, ImagAddends); |
1083 | FinalNode = identifyMultiplications(RealMuls, ImagMuls, Accumulator: FinalNode); |
1084 | if (!FinalNode) |
1085 | return nullptr; |
1086 | } |
1087 | |
1088 | // Identify and process remaining additions |
1089 | if (!RealAddends.empty() || !ImagAddends.empty()) { |
1090 | FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, Accumulator: FinalNode); |
1091 | if (!FinalNode) |
1092 | return nullptr; |
1093 | } |
1094 | assert(FinalNode && "FinalNode can not be nullptr here" ); |
1095 | // Set the Real and Imag fields of the final node and submit it |
1096 | FinalNode->Real = Real; |
1097 | FinalNode->Imag = Imag; |
1098 | submitCompositeNode(Node: FinalNode); |
1099 | return FinalNode; |
1100 | } |
1101 | |
1102 | bool ComplexDeinterleavingGraph::collectPartialMuls( |
1103 | const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls, |
1104 | std::vector<PartialMulCandidate> &PartialMulCandidates) { |
1105 | // Helper function to extract a common operand from two products |
1106 | auto FindCommonInstruction = [](const Product &Real, |
1107 | const Product &Imag) -> Value * { |
1108 | if (Real.Multiplicand == Imag.Multiplicand || |
1109 | Real.Multiplicand == Imag.Multiplier) |
1110 | return Real.Multiplicand; |
1111 | |
1112 | if (Real.Multiplier == Imag.Multiplicand || |
1113 | Real.Multiplier == Imag.Multiplier) |
1114 | return Real.Multiplier; |
1115 | |
1116 | return nullptr; |
1117 | }; |
1118 | |
1119 | // Iterating over real and imaginary multiplications to find common operands |
1120 | // If a common operand is found, a partial multiplication candidate is created |
1121 | // and added to the candidates vector The function returns false if no common |
1122 | // operands are found for any product |
1123 | for (unsigned i = 0; i < RealMuls.size(); ++i) { |
1124 | bool FoundCommon = false; |
1125 | for (unsigned j = 0; j < ImagMuls.size(); ++j) { |
1126 | auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]); |
1127 | if (!Common) |
1128 | continue; |
1129 | |
1130 | auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier |
1131 | : RealMuls[i].Multiplicand; |
1132 | auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier |
1133 | : ImagMuls[j].Multiplicand; |
1134 | |
1135 | auto Node = identifyNode(R: A, I: B); |
1136 | if (Node) { |
1137 | FoundCommon = true; |
1138 | PartialMulCandidates.push_back(x: {.Common: Common, .Node: Node, .RealIdx: i, .ImagIdx: j, .IsNodeInverted: false}); |
1139 | } |
1140 | |
1141 | Node = identifyNode(R: B, I: A); |
1142 | if (Node) { |
1143 | FoundCommon = true; |
1144 | PartialMulCandidates.push_back(x: {.Common: Common, .Node: Node, .RealIdx: i, .ImagIdx: j, .IsNodeInverted: true}); |
1145 | } |
1146 | } |
1147 | if (!FoundCommon) |
1148 | return false; |
1149 | } |
1150 | return true; |
1151 | } |
1152 | |
1153 | ComplexDeinterleavingGraph::NodePtr |
1154 | ComplexDeinterleavingGraph::identifyMultiplications( |
1155 | std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls, |
1156 | NodePtr Accumulator = nullptr) { |
1157 | if (RealMuls.size() != ImagMuls.size()) |
1158 | return nullptr; |
1159 | |
1160 | std::vector<PartialMulCandidate> Info; |
1161 | if (!collectPartialMuls(RealMuls, ImagMuls, PartialMulCandidates&: Info)) |
1162 | return nullptr; |
1163 | |
1164 | // Map to store common instruction to node pointers |
1165 | std::map<Value *, NodePtr> CommonToNode; |
1166 | std::vector<bool> Processed(Info.size(), false); |
1167 | for (unsigned I = 0; I < Info.size(); ++I) { |
1168 | if (Processed[I]) |
1169 | continue; |
1170 | |
1171 | PartialMulCandidate &InfoA = Info[I]; |
1172 | for (unsigned J = I + 1; J < Info.size(); ++J) { |
1173 | if (Processed[J]) |
1174 | continue; |
1175 | |
1176 | PartialMulCandidate &InfoB = Info[J]; |
1177 | auto *InfoReal = &InfoA; |
1178 | auto *InfoImag = &InfoB; |
1179 | |
1180 | auto NodeFromCommon = identifyNode(R: InfoReal->Common, I: InfoImag->Common); |
1181 | if (!NodeFromCommon) { |
1182 | std::swap(a&: InfoReal, b&: InfoImag); |
1183 | NodeFromCommon = identifyNode(R: InfoReal->Common, I: InfoImag->Common); |
1184 | } |
1185 | if (!NodeFromCommon) |
1186 | continue; |
1187 | |
1188 | CommonToNode[InfoReal->Common] = NodeFromCommon; |
1189 | CommonToNode[InfoImag->Common] = NodeFromCommon; |
1190 | Processed[I] = true; |
1191 | Processed[J] = true; |
1192 | } |
1193 | } |
1194 | |
1195 | std::vector<bool> ProcessedReal(RealMuls.size(), false); |
1196 | std::vector<bool> ProcessedImag(ImagMuls.size(), false); |
1197 | NodePtr Result = Accumulator; |
1198 | for (auto &PMI : Info) { |
1199 | if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx]) |
1200 | continue; |
1201 | |
1202 | auto It = CommonToNode.find(x: PMI.Common); |
1203 | // TODO: Process independent complex multiplications. Cases like this: |
1204 | // A.real() * B where both A and B are complex numbers. |
1205 | if (It == CommonToNode.end()) { |
1206 | LLVM_DEBUG({ |
1207 | dbgs() << "Unprocessed independent partial multiplication:\n" ; |
1208 | for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]}) |
1209 | dbgs().indent(4) << (Mul->IsPositive ? "+" : "-" ) << *Mul->Multiplier |
1210 | << " multiplied by " << *Mul->Multiplicand << "\n" ; |
1211 | }); |
1212 | return nullptr; |
1213 | } |
1214 | |
1215 | auto &RealMul = RealMuls[PMI.RealIdx]; |
1216 | auto &ImagMul = ImagMuls[PMI.ImagIdx]; |
1217 | |
1218 | auto NodeA = It->second; |
1219 | auto NodeB = PMI.Node; |
1220 | auto IsMultiplicandReal = PMI.Common == NodeA->Real; |
1221 | // The following table illustrates the relationship between multiplications |
1222 | // and rotations. If we consider the multiplication (X + iY) * (U + iV), we |
1223 | // can see: |
1224 | // |
1225 | // Rotation | Real | Imag | |
1226 | // ---------+--------+--------+ |
1227 | // 0 | x * u | x * v | |
1228 | // 90 | -y * v | y * u | |
1229 | // 180 | -x * u | -x * v | |
1230 | // 270 | y * v | -y * u | |
1231 | // |
1232 | // Check if the candidate can indeed be represented by partial |
1233 | // multiplication |
1234 | // TODO: Add support for multiplication by complex one |
1235 | if ((IsMultiplicandReal && PMI.IsNodeInverted) || |
1236 | (!IsMultiplicandReal && !PMI.IsNodeInverted)) |
1237 | continue; |
1238 | |
1239 | // Determine the rotation based on the multiplications |
1240 | ComplexDeinterleavingRotation Rotation; |
1241 | if (IsMultiplicandReal) { |
1242 | // Detect 0 and 180 degrees rotation |
1243 | if (RealMul.IsPositive && ImagMul.IsPositive) |
1244 | Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0; |
1245 | else if (!RealMul.IsPositive && !ImagMul.IsPositive) |
1246 | Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180; |
1247 | else |
1248 | continue; |
1249 | |
1250 | } else { |
1251 | // Detect 90 and 270 degrees rotation |
1252 | if (!RealMul.IsPositive && ImagMul.IsPositive) |
1253 | Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90; |
1254 | else if (RealMul.IsPositive && !ImagMul.IsPositive) |
1255 | Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270; |
1256 | else |
1257 | continue; |
1258 | } |
1259 | |
1260 | LLVM_DEBUG({ |
1261 | dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n" ; |
1262 | dbgs().indent(4) << "X: " << *NodeA->Real << "\n" ; |
1263 | dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n" ; |
1264 | dbgs().indent(4) << "U: " << *NodeB->Real << "\n" ; |
1265 | dbgs().indent(4) << "V: " << *NodeB->Imag << "\n" ; |
1266 | dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n" ; |
1267 | }); |
1268 | |
1269 | NodePtr NodeMul = prepareCompositeNode( |
1270 | Operation: ComplexDeinterleavingOperation::CMulPartial, R: nullptr, I: nullptr); |
1271 | NodeMul->Rotation = Rotation; |
1272 | NodeMul->addOperand(Node: NodeA); |
1273 | NodeMul->addOperand(Node: NodeB); |
1274 | if (Result) |
1275 | NodeMul->addOperand(Node: Result); |
1276 | submitCompositeNode(Node: NodeMul); |
1277 | Result = NodeMul; |
1278 | ProcessedReal[PMI.RealIdx] = true; |
1279 | ProcessedImag[PMI.ImagIdx] = true; |
1280 | } |
1281 | |
1282 | // Ensure all products have been processed, if not return nullptr. |
1283 | if (!all_of(Range&: ProcessedReal, P: [](bool V) { return V; }) || |
1284 | !all_of(Range&: ProcessedImag, P: [](bool V) { return V; })) { |
1285 | |
1286 | // Dump debug information about which partial multiplications are not |
1287 | // processed. |
1288 | LLVM_DEBUG({ |
1289 | dbgs() << "Unprocessed products (Real):\n" ; |
1290 | for (size_t i = 0; i < ProcessedReal.size(); ++i) { |
1291 | if (!ProcessedReal[i]) |
1292 | dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-" ) |
1293 | << *RealMuls[i].Multiplier << " multiplied by " |
1294 | << *RealMuls[i].Multiplicand << "\n" ; |
1295 | } |
1296 | dbgs() << "Unprocessed products (Imag):\n" ; |
1297 | for (size_t i = 0; i < ProcessedImag.size(); ++i) { |
1298 | if (!ProcessedImag[i]) |
1299 | dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-" ) |
1300 | << *ImagMuls[i].Multiplier << " multiplied by " |
1301 | << *ImagMuls[i].Multiplicand << "\n" ; |
1302 | } |
1303 | }); |
1304 | return nullptr; |
1305 | } |
1306 | |
1307 | return Result; |
1308 | } |
1309 | |
1310 | ComplexDeinterleavingGraph::NodePtr |
1311 | ComplexDeinterleavingGraph::identifyAdditions( |
1312 | std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends, |
1313 | std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) { |
1314 | if (RealAddends.size() != ImagAddends.size()) |
1315 | return nullptr; |
1316 | |
1317 | NodePtr Result; |
1318 | // If we have accumulator use it as first addend |
1319 | if (Accumulator) |
1320 | Result = Accumulator; |
1321 | // Otherwise find an element with both positive real and imaginary parts. |
1322 | else |
1323 | Result = extractPositiveAddend(RealAddends, ImagAddends); |
1324 | |
1325 | if (!Result) |
1326 | return nullptr; |
1327 | |
1328 | while (!RealAddends.empty()) { |
1329 | auto ItR = RealAddends.begin(); |
1330 | auto [R, IsPositiveR] = *ItR; |
1331 | |
1332 | bool FoundImag = false; |
1333 | for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) { |
1334 | auto [I, IsPositiveI] = *ItI; |
1335 | ComplexDeinterleavingRotation Rotation; |
1336 | if (IsPositiveR && IsPositiveI) |
1337 | Rotation = ComplexDeinterleavingRotation::Rotation_0; |
1338 | else if (!IsPositiveR && IsPositiveI) |
1339 | Rotation = ComplexDeinterleavingRotation::Rotation_90; |
1340 | else if (!IsPositiveR && !IsPositiveI) |
1341 | Rotation = ComplexDeinterleavingRotation::Rotation_180; |
1342 | else |
1343 | Rotation = ComplexDeinterleavingRotation::Rotation_270; |
1344 | |
1345 | NodePtr AddNode; |
1346 | if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || |
1347 | Rotation == ComplexDeinterleavingRotation::Rotation_180) { |
1348 | AddNode = identifyNode(R, I); |
1349 | } else { |
1350 | AddNode = identifyNode(R: I, I: R); |
1351 | } |
1352 | if (AddNode) { |
1353 | LLVM_DEBUG({ |
1354 | dbgs() << "Identified addition:\n" ; |
1355 | dbgs().indent(4) << "X: " << *R << "\n" ; |
1356 | dbgs().indent(4) << "Y: " << *I << "\n" ; |
1357 | dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n" ; |
1358 | }); |
1359 | |
1360 | NodePtr TmpNode; |
1361 | if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) { |
1362 | TmpNode = prepareCompositeNode( |
1363 | Operation: ComplexDeinterleavingOperation::Symmetric, R: nullptr, I: nullptr); |
1364 | if (Flags) { |
1365 | TmpNode->Opcode = Instruction::FAdd; |
1366 | TmpNode->Flags = *Flags; |
1367 | } else { |
1368 | TmpNode->Opcode = Instruction::Add; |
1369 | } |
1370 | } else if (Rotation == |
1371 | llvm::ComplexDeinterleavingRotation::Rotation_180) { |
1372 | TmpNode = prepareCompositeNode( |
1373 | Operation: ComplexDeinterleavingOperation::Symmetric, R: nullptr, I: nullptr); |
1374 | if (Flags) { |
1375 | TmpNode->Opcode = Instruction::FSub; |
1376 | TmpNode->Flags = *Flags; |
1377 | } else { |
1378 | TmpNode->Opcode = Instruction::Sub; |
1379 | } |
1380 | } else { |
1381 | TmpNode = prepareCompositeNode(Operation: ComplexDeinterleavingOperation::CAdd, |
1382 | R: nullptr, I: nullptr); |
1383 | TmpNode->Rotation = Rotation; |
1384 | } |
1385 | |
1386 | TmpNode->addOperand(Node: Result); |
1387 | TmpNode->addOperand(Node: AddNode); |
1388 | submitCompositeNode(Node: TmpNode); |
1389 | Result = TmpNode; |
1390 | RealAddends.erase(position: ItR); |
1391 | ImagAddends.erase(position: ItI); |
1392 | FoundImag = true; |
1393 | break; |
1394 | } |
1395 | } |
1396 | if (!FoundImag) |
1397 | return nullptr; |
1398 | } |
1399 | return Result; |
1400 | } |
1401 | |
1402 | ComplexDeinterleavingGraph::NodePtr |
1403 | ComplexDeinterleavingGraph::( |
1404 | std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) { |
1405 | for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) { |
1406 | for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) { |
1407 | auto [R, IsPositiveR] = *ItR; |
1408 | auto [I, IsPositiveI] = *ItI; |
1409 | if (IsPositiveR && IsPositiveI) { |
1410 | auto Result = identifyNode(R, I); |
1411 | if (Result) { |
1412 | RealAddends.erase(position: ItR); |
1413 | ImagAddends.erase(position: ItI); |
1414 | return Result; |
1415 | } |
1416 | } |
1417 | } |
1418 | } |
1419 | return nullptr; |
1420 | } |
1421 | |
1422 | bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) { |
1423 | // This potential root instruction might already have been recognized as |
1424 | // reduction. Because RootToNode maps both Real and Imaginary parts to |
1425 | // CompositeNode we should choose only one either Real or Imag instruction to |
1426 | // use as an anchor for generating complex instruction. |
1427 | auto It = RootToNode.find(x: RootI); |
1428 | if (It != RootToNode.end()) { |
1429 | auto RootNode = It->second; |
1430 | assert(RootNode->Operation == |
1431 | ComplexDeinterleavingOperation::ReductionOperation); |
1432 | // Find out which part, Real or Imag, comes later, and only if we come to |
1433 | // the latest part, add it to OrderedRoots. |
1434 | auto *R = cast<Instruction>(Val: RootNode->Real); |
1435 | auto *I = cast<Instruction>(Val: RootNode->Imag); |
1436 | auto *ReplacementAnchor = R->comesBefore(Other: I) ? I : R; |
1437 | if (ReplacementAnchor != RootI) |
1438 | return false; |
1439 | OrderedRoots.push_back(Elt: RootI); |
1440 | return true; |
1441 | } |
1442 | |
1443 | auto RootNode = identifyRoot(I: RootI); |
1444 | if (!RootNode) |
1445 | return false; |
1446 | |
1447 | LLVM_DEBUG({ |
1448 | Function *F = RootI->getFunction(); |
1449 | BasicBlock *B = RootI->getParent(); |
1450 | dbgs() << "Complex deinterleaving graph for " << F->getName() |
1451 | << "::" << B->getName() << ".\n" ; |
1452 | dump(dbgs()); |
1453 | dbgs() << "\n" ; |
1454 | }); |
1455 | RootToNode[RootI] = RootNode; |
1456 | OrderedRoots.push_back(Elt: RootI); |
1457 | return true; |
1458 | } |
1459 | |
1460 | bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) { |
1461 | bool FoundPotentialReduction = false; |
1462 | |
1463 | auto *Br = dyn_cast<BranchInst>(Val: B->getTerminator()); |
1464 | if (!Br || Br->getNumSuccessors() != 2) |
1465 | return false; |
1466 | |
1467 | // Identify simple one-block loop |
1468 | if (Br->getSuccessor(i: 0) != B && Br->getSuccessor(i: 1) != B) |
1469 | return false; |
1470 | |
1471 | SmallVector<PHINode *> PHIs; |
1472 | for (auto &PHI : B->phis()) { |
1473 | if (PHI.getNumIncomingValues() != 2) |
1474 | continue; |
1475 | |
1476 | if (!PHI.getType()->isVectorTy()) |
1477 | continue; |
1478 | |
1479 | auto *ReductionOp = dyn_cast<Instruction>(Val: PHI.getIncomingValueForBlock(BB: B)); |
1480 | if (!ReductionOp) |
1481 | continue; |
1482 | |
1483 | // Check if final instruction is reduced outside of current block |
1484 | Instruction *FinalReduction = nullptr; |
1485 | auto NumUsers = 0u; |
1486 | for (auto *U : ReductionOp->users()) { |
1487 | ++NumUsers; |
1488 | if (U == &PHI) |
1489 | continue; |
1490 | FinalReduction = dyn_cast<Instruction>(Val: U); |
1491 | } |
1492 | |
1493 | if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B || |
1494 | isa<PHINode>(Val: FinalReduction)) |
1495 | continue; |
1496 | |
1497 | ReductionInfo[ReductionOp] = {&PHI, FinalReduction}; |
1498 | BackEdge = B; |
1499 | auto BackEdgeIdx = PHI.getBasicBlockIndex(BB: B); |
1500 | auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0; |
1501 | Incoming = PHI.getIncomingBlock(i: IncomingIdx); |
1502 | FoundPotentialReduction = true; |
1503 | |
1504 | // If the initial value of PHINode is an Instruction, consider it a leaf |
1505 | // value of a complex deinterleaving graph. |
1506 | if (auto *InitPHI = |
1507 | dyn_cast<Instruction>(Val: PHI.getIncomingValueForBlock(BB: Incoming))) |
1508 | FinalInstructions.insert(Ptr: InitPHI); |
1509 | } |
1510 | return FoundPotentialReduction; |
1511 | } |
1512 | |
1513 | void ComplexDeinterleavingGraph::identifyReductionNodes() { |
1514 | SmallVector<bool> Processed(ReductionInfo.size(), false); |
1515 | SmallVector<Instruction *> OperationInstruction; |
1516 | for (auto &P : ReductionInfo) |
1517 | OperationInstruction.push_back(Elt: P.first); |
1518 | |
1519 | // Identify a complex computation by evaluating two reduction operations that |
1520 | // potentially could be involved |
1521 | for (size_t i = 0; i < OperationInstruction.size(); ++i) { |
1522 | if (Processed[i]) |
1523 | continue; |
1524 | for (size_t j = i + 1; j < OperationInstruction.size(); ++j) { |
1525 | if (Processed[j]) |
1526 | continue; |
1527 | |
1528 | auto *Real = OperationInstruction[i]; |
1529 | auto *Imag = OperationInstruction[j]; |
1530 | if (Real->getType() != Imag->getType()) |
1531 | continue; |
1532 | |
1533 | RealPHI = ReductionInfo[Real].first; |
1534 | ImagPHI = ReductionInfo[Imag].first; |
1535 | PHIsFound = false; |
1536 | auto Node = identifyNode(R: Real, I: Imag); |
1537 | if (!Node) { |
1538 | std::swap(a&: Real, b&: Imag); |
1539 | std::swap(a&: RealPHI, b&: ImagPHI); |
1540 | Node = identifyNode(R: Real, I: Imag); |
1541 | } |
1542 | |
1543 | // If a node is identified and reduction PHINode is used in the chain of |
1544 | // operations, mark its operation instructions as used to prevent |
1545 | // re-identification and attach the node to the real part |
1546 | if (Node && PHIsFound) { |
1547 | LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: " |
1548 | << *Real << " / " << *Imag << "\n" ); |
1549 | Processed[i] = true; |
1550 | Processed[j] = true; |
1551 | auto RootNode = prepareCompositeNode( |
1552 | Operation: ComplexDeinterleavingOperation::ReductionOperation, R: Real, I: Imag); |
1553 | RootNode->addOperand(Node); |
1554 | RootToNode[Real] = RootNode; |
1555 | RootToNode[Imag] = RootNode; |
1556 | submitCompositeNode(Node: RootNode); |
1557 | break; |
1558 | } |
1559 | } |
1560 | } |
1561 | |
1562 | RealPHI = nullptr; |
1563 | ImagPHI = nullptr; |
1564 | } |
1565 | |
1566 | bool ComplexDeinterleavingGraph::checkNodes() { |
1567 | // Collect all instructions from roots to leaves |
1568 | SmallPtrSet<Instruction *, 16> AllInstructions; |
1569 | SmallVector<Instruction *, 8> Worklist; |
1570 | for (auto &Pair : RootToNode) |
1571 | Worklist.push_back(Elt: Pair.first); |
1572 | |
1573 | // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG |
1574 | // chains |
1575 | while (!Worklist.empty()) { |
1576 | auto *I = Worklist.back(); |
1577 | Worklist.pop_back(); |
1578 | |
1579 | if (!AllInstructions.insert(Ptr: I).second) |
1580 | continue; |
1581 | |
1582 | for (Value *Op : I->operands()) { |
1583 | if (auto *OpI = dyn_cast<Instruction>(Val: Op)) { |
1584 | if (!FinalInstructions.count(Ptr: I)) |
1585 | Worklist.emplace_back(Args&: OpI); |
1586 | } |
1587 | } |
1588 | } |
1589 | |
1590 | // Find instructions that have users outside of chain |
1591 | SmallVector<Instruction *, 2> OuterInstructions; |
1592 | for (auto *I : AllInstructions) { |
1593 | // Skip root nodes |
1594 | if (RootToNode.count(x: I)) |
1595 | continue; |
1596 | |
1597 | for (User *U : I->users()) { |
1598 | if (AllInstructions.count(Ptr: cast<Instruction>(Val: U))) |
1599 | continue; |
1600 | |
1601 | // Found an instruction that is not used by XCMLA/XCADD chain |
1602 | Worklist.emplace_back(Args&: I); |
1603 | break; |
1604 | } |
1605 | } |
1606 | |
1607 | // If any instructions are found to be used outside, find and remove roots |
1608 | // that somehow connect to those instructions. |
1609 | SmallPtrSet<Instruction *, 16> Visited; |
1610 | while (!Worklist.empty()) { |
1611 | auto *I = Worklist.back(); |
1612 | Worklist.pop_back(); |
1613 | if (!Visited.insert(Ptr: I).second) |
1614 | continue; |
1615 | |
1616 | // Found an impacted root node. Removing it from the nodes to be |
1617 | // deinterleaved |
1618 | if (RootToNode.count(x: I)) { |
1619 | LLVM_DEBUG(dbgs() << "Instruction " << *I |
1620 | << " could be deinterleaved but its chain of complex " |
1621 | "operations have an outside user\n" ); |
1622 | RootToNode.erase(x: I); |
1623 | } |
1624 | |
1625 | if (!AllInstructions.count(Ptr: I) || FinalInstructions.count(Ptr: I)) |
1626 | continue; |
1627 | |
1628 | for (User *U : I->users()) |
1629 | Worklist.emplace_back(Args: cast<Instruction>(Val: U)); |
1630 | |
1631 | for (Value *Op : I->operands()) { |
1632 | if (auto *OpI = dyn_cast<Instruction>(Val: Op)) |
1633 | Worklist.emplace_back(Args&: OpI); |
1634 | } |
1635 | } |
1636 | return !RootToNode.empty(); |
1637 | } |
1638 | |
1639 | ComplexDeinterleavingGraph::NodePtr |
1640 | ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) { |
1641 | if (auto *Intrinsic = dyn_cast<IntrinsicInst>(Val: RootI)) { |
1642 | if (Intrinsic->getIntrinsicID() != Intrinsic::vector_interleave2) |
1643 | return nullptr; |
1644 | |
1645 | auto *Real = dyn_cast<Instruction>(Val: Intrinsic->getOperand(i_nocapture: 0)); |
1646 | auto *Imag = dyn_cast<Instruction>(Val: Intrinsic->getOperand(i_nocapture: 1)); |
1647 | if (!Real || !Imag) |
1648 | return nullptr; |
1649 | |
1650 | return identifyNode(R: Real, I: Imag); |
1651 | } |
1652 | |
1653 | auto *SVI = dyn_cast<ShuffleVectorInst>(Val: RootI); |
1654 | if (!SVI) |
1655 | return nullptr; |
1656 | |
1657 | // Look for a shufflevector that takes separate vectors of the real and |
1658 | // imaginary components and recombines them into a single vector. |
1659 | if (!isInterleavingMask(Mask: SVI->getShuffleMask())) |
1660 | return nullptr; |
1661 | |
1662 | Instruction *Real; |
1663 | Instruction *Imag; |
1664 | if (!match(V: RootI, P: m_Shuffle(v1: m_Instruction(I&: Real), v2: m_Instruction(I&: Imag)))) |
1665 | return nullptr; |
1666 | |
1667 | return identifyNode(R: Real, I: Imag); |
1668 | } |
1669 | |
1670 | ComplexDeinterleavingGraph::NodePtr |
1671 | ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real, |
1672 | Instruction *Imag) { |
1673 | Instruction *I = nullptr; |
1674 | Value *FinalValue = nullptr; |
1675 | if (match(V: Real, P: m_ExtractValue<0>(V: m_Instruction(I))) && |
1676 | match(V: Imag, P: m_ExtractValue<1>(V: m_Specific(V: I))) && |
1677 | match(V: I, P: m_Intrinsic<Intrinsic::vector_deinterleave2>( |
1678 | Op0: m_Value(V&: FinalValue)))) { |
1679 | NodePtr PlaceholderNode = prepareCompositeNode( |
1680 | Operation: llvm::ComplexDeinterleavingOperation::Deinterleave, R: Real, I: Imag); |
1681 | PlaceholderNode->ReplacementNode = FinalValue; |
1682 | FinalInstructions.insert(Ptr: Real); |
1683 | FinalInstructions.insert(Ptr: Imag); |
1684 | return submitCompositeNode(Node: PlaceholderNode); |
1685 | } |
1686 | |
1687 | auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Val: Real); |
1688 | auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Val: Imag); |
1689 | if (!RealShuffle || !ImagShuffle) { |
1690 | if (RealShuffle || ImagShuffle) |
1691 | LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n" ); |
1692 | return nullptr; |
1693 | } |
1694 | |
1695 | Value *RealOp1 = RealShuffle->getOperand(i_nocapture: 1); |
1696 | if (!isa<UndefValue>(Val: RealOp1) && !isa<ConstantAggregateZero>(Val: RealOp1)) { |
1697 | LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n" ); |
1698 | return nullptr; |
1699 | } |
1700 | Value *ImagOp1 = ImagShuffle->getOperand(i_nocapture: 1); |
1701 | if (!isa<UndefValue>(Val: ImagOp1) && !isa<ConstantAggregateZero>(Val: ImagOp1)) { |
1702 | LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n" ); |
1703 | return nullptr; |
1704 | } |
1705 | |
1706 | Value *RealOp0 = RealShuffle->getOperand(i_nocapture: 0); |
1707 | Value *ImagOp0 = ImagShuffle->getOperand(i_nocapture: 0); |
1708 | |
1709 | if (RealOp0 != ImagOp0) { |
1710 | LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n" ); |
1711 | return nullptr; |
1712 | } |
1713 | |
1714 | ArrayRef<int> RealMask = RealShuffle->getShuffleMask(); |
1715 | ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask(); |
1716 | if (!isDeinterleavingMask(Mask: RealMask) || !isDeinterleavingMask(Mask: ImagMask)) { |
1717 | LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n" ); |
1718 | return nullptr; |
1719 | } |
1720 | |
1721 | if (RealMask[0] != 0 || ImagMask[0] != 1) { |
1722 | LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n" ); |
1723 | return nullptr; |
1724 | } |
1725 | |
1726 | // Type checking, the shuffle type should be a vector type of the same |
1727 | // scalar type, but half the size |
1728 | auto CheckType = [&](ShuffleVectorInst *Shuffle) { |
1729 | Value *Op = Shuffle->getOperand(i_nocapture: 0); |
1730 | auto *ShuffleTy = cast<FixedVectorType>(Val: Shuffle->getType()); |
1731 | auto *OpTy = cast<FixedVectorType>(Val: Op->getType()); |
1732 | |
1733 | if (OpTy->getScalarType() != ShuffleTy->getScalarType()) |
1734 | return false; |
1735 | if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements()) |
1736 | return false; |
1737 | |
1738 | return true; |
1739 | }; |
1740 | |
1741 | auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool { |
1742 | if (!CheckType(Shuffle)) |
1743 | return false; |
1744 | |
1745 | ArrayRef<int> Mask = Shuffle->getShuffleMask(); |
1746 | int Last = *Mask.rbegin(); |
1747 | |
1748 | Value *Op = Shuffle->getOperand(i_nocapture: 0); |
1749 | auto *OpTy = cast<FixedVectorType>(Val: Op->getType()); |
1750 | int NumElements = OpTy->getNumElements(); |
1751 | |
1752 | // Ensure that the deinterleaving shuffle only pulls from the first |
1753 | // shuffle operand. |
1754 | return Last < NumElements; |
1755 | }; |
1756 | |
1757 | if (RealShuffle->getType() != ImagShuffle->getType()) { |
1758 | LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n" ); |
1759 | return nullptr; |
1760 | } |
1761 | if (!CheckDeinterleavingShuffle(RealShuffle)) { |
1762 | LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n" ); |
1763 | return nullptr; |
1764 | } |
1765 | if (!CheckDeinterleavingShuffle(ImagShuffle)) { |
1766 | LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n" ); |
1767 | return nullptr; |
1768 | } |
1769 | |
1770 | NodePtr PlaceholderNode = |
1771 | prepareCompositeNode(Operation: llvm::ComplexDeinterleavingOperation::Deinterleave, |
1772 | R: RealShuffle, I: ImagShuffle); |
1773 | PlaceholderNode->ReplacementNode = RealShuffle->getOperand(i_nocapture: 0); |
1774 | FinalInstructions.insert(Ptr: RealShuffle); |
1775 | FinalInstructions.insert(Ptr: ImagShuffle); |
1776 | return submitCompositeNode(Node: PlaceholderNode); |
1777 | } |
1778 | |
1779 | ComplexDeinterleavingGraph::NodePtr |
1780 | ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) { |
1781 | auto IsSplat = [](Value *V) -> bool { |
1782 | // Fixed-width vector with constants |
1783 | if (isa<ConstantDataVector>(Val: V)) |
1784 | return true; |
1785 | |
1786 | VectorType *VTy; |
1787 | ArrayRef<int> Mask; |
1788 | // Splats are represented differently depending on whether the repeated |
1789 | // value is a constant or an Instruction |
1790 | if (auto *Const = dyn_cast<ConstantExpr>(Val: V)) { |
1791 | if (Const->getOpcode() != Instruction::ShuffleVector) |
1792 | return false; |
1793 | VTy = cast<VectorType>(Val: Const->getType()); |
1794 | Mask = Const->getShuffleMask(); |
1795 | } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(Val: V)) { |
1796 | VTy = Shuf->getType(); |
1797 | Mask = Shuf->getShuffleMask(); |
1798 | } else { |
1799 | return false; |
1800 | } |
1801 | |
1802 | // When the data type is <1 x Type>, it's not possible to differentiate |
1803 | // between the ComplexDeinterleaving::Deinterleave and |
1804 | // ComplexDeinterleaving::Splat operations. |
1805 | if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1) |
1806 | return false; |
1807 | |
1808 | return all_equal(Range&: Mask) && Mask[0] == 0; |
1809 | }; |
1810 | |
1811 | if (!IsSplat(R) || !IsSplat(I)) |
1812 | return nullptr; |
1813 | |
1814 | auto *Real = dyn_cast<Instruction>(Val: R); |
1815 | auto *Imag = dyn_cast<Instruction>(Val: I); |
1816 | if ((!Real && Imag) || (Real && !Imag)) |
1817 | return nullptr; |
1818 | |
1819 | if (Real && Imag) { |
1820 | // Non-constant splats should be in the same basic block |
1821 | if (Real->getParent() != Imag->getParent()) |
1822 | return nullptr; |
1823 | |
1824 | FinalInstructions.insert(Ptr: Real); |
1825 | FinalInstructions.insert(Ptr: Imag); |
1826 | } |
1827 | NodePtr PlaceholderNode = |
1828 | prepareCompositeNode(Operation: ComplexDeinterleavingOperation::Splat, R, I); |
1829 | return submitCompositeNode(Node: PlaceholderNode); |
1830 | } |
1831 | |
1832 | ComplexDeinterleavingGraph::NodePtr |
1833 | ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real, |
1834 | Instruction *Imag) { |
1835 | if (Real != RealPHI || Imag != ImagPHI) |
1836 | return nullptr; |
1837 | |
1838 | PHIsFound = true; |
1839 | NodePtr PlaceholderNode = prepareCompositeNode( |
1840 | Operation: ComplexDeinterleavingOperation::ReductionPHI, R: Real, I: Imag); |
1841 | return submitCompositeNode(Node: PlaceholderNode); |
1842 | } |
1843 | |
1844 | ComplexDeinterleavingGraph::NodePtr |
1845 | ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real, |
1846 | Instruction *Imag) { |
1847 | auto *SelectReal = dyn_cast<SelectInst>(Val: Real); |
1848 | auto *SelectImag = dyn_cast<SelectInst>(Val: Imag); |
1849 | if (!SelectReal || !SelectImag) |
1850 | return nullptr; |
1851 | |
1852 | Instruction *MaskA, *MaskB; |
1853 | Instruction *AR, *AI, *RA, *BI; |
1854 | if (!match(V: Real, P: m_Select(C: m_Instruction(I&: MaskA), L: m_Instruction(I&: AR), |
1855 | R: m_Instruction(I&: RA))) || |
1856 | !match(V: Imag, P: m_Select(C: m_Instruction(I&: MaskB), L: m_Instruction(I&: AI), |
1857 | R: m_Instruction(I&: BI)))) |
1858 | return nullptr; |
1859 | |
1860 | if (MaskA != MaskB && !MaskA->isIdenticalTo(I: MaskB)) |
1861 | return nullptr; |
1862 | |
1863 | if (!MaskA->getType()->isVectorTy()) |
1864 | return nullptr; |
1865 | |
1866 | auto NodeA = identifyNode(R: AR, I: AI); |
1867 | if (!NodeA) |
1868 | return nullptr; |
1869 | |
1870 | auto NodeB = identifyNode(R: RA, I: BI); |
1871 | if (!NodeB) |
1872 | return nullptr; |
1873 | |
1874 | NodePtr PlaceholderNode = prepareCompositeNode( |
1875 | Operation: ComplexDeinterleavingOperation::ReductionSelect, R: Real, I: Imag); |
1876 | PlaceholderNode->addOperand(Node: NodeA); |
1877 | PlaceholderNode->addOperand(Node: NodeB); |
1878 | FinalInstructions.insert(Ptr: MaskA); |
1879 | FinalInstructions.insert(Ptr: MaskB); |
1880 | return submitCompositeNode(Node: PlaceholderNode); |
1881 | } |
1882 | |
1883 | static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, |
1884 | std::optional<FastMathFlags> Flags, |
1885 | Value *InputA, Value *InputB) { |
1886 | Value *I; |
1887 | switch (Opcode) { |
1888 | case Instruction::FNeg: |
1889 | I = B.CreateFNeg(V: InputA); |
1890 | break; |
1891 | case Instruction::FAdd: |
1892 | I = B.CreateFAdd(L: InputA, R: InputB); |
1893 | break; |
1894 | case Instruction::Add: |
1895 | I = B.CreateAdd(LHS: InputA, RHS: InputB); |
1896 | break; |
1897 | case Instruction::FSub: |
1898 | I = B.CreateFSub(L: InputA, R: InputB); |
1899 | break; |
1900 | case Instruction::Sub: |
1901 | I = B.CreateSub(LHS: InputA, RHS: InputB); |
1902 | break; |
1903 | case Instruction::FMul: |
1904 | I = B.CreateFMul(L: InputA, R: InputB); |
1905 | break; |
1906 | case Instruction::Mul: |
1907 | I = B.CreateMul(LHS: InputA, RHS: InputB); |
1908 | break; |
1909 | default: |
1910 | llvm_unreachable("Incorrect symmetric opcode" ); |
1911 | } |
1912 | if (Flags) |
1913 | cast<Instruction>(Val: I)->setFastMathFlags(*Flags); |
1914 | return I; |
1915 | } |
1916 | |
1917 | Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder, |
1918 | RawNodePtr Node) { |
1919 | if (Node->ReplacementNode) |
1920 | return Node->ReplacementNode; |
1921 | |
1922 | auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * { |
1923 | return Node->Operands.size() > Idx |
1924 | ? replaceNode(Builder, Node: Node->Operands[Idx]) |
1925 | : nullptr; |
1926 | }; |
1927 | |
1928 | Value *ReplacementNode; |
1929 | switch (Node->Operation) { |
1930 | case ComplexDeinterleavingOperation::CAdd: |
1931 | case ComplexDeinterleavingOperation::CMulPartial: |
1932 | case ComplexDeinterleavingOperation::Symmetric: { |
1933 | Value *Input0 = ReplaceOperandIfExist(Node, 0); |
1934 | Value *Input1 = ReplaceOperandIfExist(Node, 1); |
1935 | Value *Accumulator = ReplaceOperandIfExist(Node, 2); |
1936 | assert(!Input1 || (Input0->getType() == Input1->getType() && |
1937 | "Node inputs need to be of the same type" )); |
1938 | assert(!Accumulator || |
1939 | (Input0->getType() == Accumulator->getType() && |
1940 | "Accumulator and input need to be of the same type" )); |
1941 | if (Node->Operation == ComplexDeinterleavingOperation::Symmetric) |
1942 | ReplacementNode = replaceSymmetricNode(B&: Builder, Opcode: Node->Opcode, Flags: Node->Flags, |
1943 | InputA: Input0, InputB: Input1); |
1944 | else |
1945 | ReplacementNode = TL->createComplexDeinterleavingIR( |
1946 | B&: Builder, OperationType: Node->Operation, Rotation: Node->Rotation, InputA: Input0, InputB: Input1, |
1947 | Accumulator); |
1948 | break; |
1949 | } |
1950 | case ComplexDeinterleavingOperation::Deinterleave: |
1951 | llvm_unreachable("Deinterleave node should already have ReplacementNode" ); |
1952 | break; |
1953 | case ComplexDeinterleavingOperation::Splat: { |
1954 | auto *NewTy = VectorType::getDoubleElementsVectorType( |
1955 | VTy: cast<VectorType>(Val: Node->Real->getType())); |
1956 | auto *R = dyn_cast<Instruction>(Val: Node->Real); |
1957 | auto *I = dyn_cast<Instruction>(Val: Node->Imag); |
1958 | if (R && I) { |
1959 | // Splats that are not constant are interleaved where they are located |
1960 | Instruction *InsertPoint = (I->comesBefore(Other: R) ? R : I)->getNextNode(); |
1961 | IRBuilder<> IRB(InsertPoint); |
1962 | ReplacementNode = IRB.CreateIntrinsic(ID: Intrinsic::vector_interleave2, |
1963 | Types: NewTy, Args: {Node->Real, Node->Imag}); |
1964 | } else { |
1965 | ReplacementNode = Builder.CreateIntrinsic( |
1966 | ID: Intrinsic::vector_interleave2, Types: NewTy, Args: {Node->Real, Node->Imag}); |
1967 | } |
1968 | break; |
1969 | } |
1970 | case ComplexDeinterleavingOperation::ReductionPHI: { |
1971 | // If Operation is ReductionPHI, a new empty PHINode is created. |
1972 | // It is filled later when the ReductionOperation is processed. |
1973 | auto *VTy = cast<VectorType>(Val: Node->Real->getType()); |
1974 | auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); |
1975 | auto *NewPHI = PHINode::Create(Ty: NewVTy, NumReservedValues: 0, NameStr: "" , InsertBefore: BackEdge->getFirstNonPHIIt()); |
1976 | OldToNewPHI[dyn_cast<PHINode>(Val: Node->Real)] = NewPHI; |
1977 | ReplacementNode = NewPHI; |
1978 | break; |
1979 | } |
1980 | case ComplexDeinterleavingOperation::ReductionOperation: |
1981 | ReplacementNode = replaceNode(Builder, Node: Node->Operands[0]); |
1982 | processReductionOperation(OperationReplacement: ReplacementNode, Node); |
1983 | break; |
1984 | case ComplexDeinterleavingOperation::ReductionSelect: { |
1985 | auto *MaskReal = cast<Instruction>(Val: Node->Real)->getOperand(i: 0); |
1986 | auto *MaskImag = cast<Instruction>(Val: Node->Imag)->getOperand(i: 0); |
1987 | auto *A = replaceNode(Builder, Node: Node->Operands[0]); |
1988 | auto *B = replaceNode(Builder, Node: Node->Operands[1]); |
1989 | auto *NewMaskTy = VectorType::getDoubleElementsVectorType( |
1990 | VTy: cast<VectorType>(Val: MaskReal->getType())); |
1991 | auto *NewMask = Builder.CreateIntrinsic(ID: Intrinsic::vector_interleave2, |
1992 | Types: NewMaskTy, Args: {MaskReal, MaskImag}); |
1993 | ReplacementNode = Builder.CreateSelect(C: NewMask, True: A, False: B); |
1994 | break; |
1995 | } |
1996 | } |
1997 | |
1998 | assert(ReplacementNode && "Target failed to create Intrinsic call." ); |
1999 | NumComplexTransformations += 1; |
2000 | Node->ReplacementNode = ReplacementNode; |
2001 | return ReplacementNode; |
2002 | } |
2003 | |
2004 | void ComplexDeinterleavingGraph::processReductionOperation( |
2005 | Value *OperationReplacement, RawNodePtr Node) { |
2006 | auto *Real = cast<Instruction>(Val: Node->Real); |
2007 | auto *Imag = cast<Instruction>(Val: Node->Imag); |
2008 | auto *OldPHIReal = ReductionInfo[Real].first; |
2009 | auto *OldPHIImag = ReductionInfo[Imag].first; |
2010 | auto *NewPHI = OldToNewPHI[OldPHIReal]; |
2011 | |
2012 | auto *VTy = cast<VectorType>(Val: Real->getType()); |
2013 | auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); |
2014 | |
2015 | // We have to interleave initial origin values coming from IncomingBlock |
2016 | Value *InitReal = OldPHIReal->getIncomingValueForBlock(BB: Incoming); |
2017 | Value *InitImag = OldPHIImag->getIncomingValueForBlock(BB: Incoming); |
2018 | |
2019 | IRBuilder<> Builder(Incoming->getTerminator()); |
2020 | auto *NewInit = Builder.CreateIntrinsic(ID: Intrinsic::vector_interleave2, Types: NewVTy, |
2021 | Args: {InitReal, InitImag}); |
2022 | |
2023 | NewPHI->addIncoming(V: NewInit, BB: Incoming); |
2024 | NewPHI->addIncoming(V: OperationReplacement, BB: BackEdge); |
2025 | |
2026 | // Deinterleave complex vector outside of loop so that it can be finally |
2027 | // reduced |
2028 | auto *FinalReductionReal = ReductionInfo[Real].second; |
2029 | auto *FinalReductionImag = ReductionInfo[Imag].second; |
2030 | |
2031 | Builder.SetInsertPoint( |
2032 | &*FinalReductionReal->getParent()->getFirstInsertionPt()); |
2033 | auto *Deinterleave = Builder.CreateIntrinsic(ID: Intrinsic::vector_deinterleave2, |
2034 | Types: OperationReplacement->getType(), |
2035 | Args: OperationReplacement); |
2036 | |
2037 | auto *NewReal = Builder.CreateExtractValue(Agg: Deinterleave, Idxs: (uint64_t)0); |
2038 | FinalReductionReal->replaceUsesOfWith(From: Real, To: NewReal); |
2039 | |
2040 | Builder.SetInsertPoint(FinalReductionImag); |
2041 | auto *NewImag = Builder.CreateExtractValue(Agg: Deinterleave, Idxs: 1); |
2042 | FinalReductionImag->replaceUsesOfWith(From: Imag, To: NewImag); |
2043 | } |
2044 | |
2045 | void ComplexDeinterleavingGraph::replaceNodes() { |
2046 | SmallVector<Instruction *, 16> DeadInstrRoots; |
2047 | for (auto *RootInstruction : OrderedRoots) { |
2048 | // Check if this potential root went through check process and we can |
2049 | // deinterleave it |
2050 | if (!RootToNode.count(x: RootInstruction)) |
2051 | continue; |
2052 | |
2053 | IRBuilder<> Builder(RootInstruction); |
2054 | auto RootNode = RootToNode[RootInstruction]; |
2055 | Value *R = replaceNode(Builder, Node: RootNode.get()); |
2056 | |
2057 | if (RootNode->Operation == |
2058 | ComplexDeinterleavingOperation::ReductionOperation) { |
2059 | auto *RootReal = cast<Instruction>(Val: RootNode->Real); |
2060 | auto *RootImag = cast<Instruction>(Val: RootNode->Imag); |
2061 | ReductionInfo[RootReal].first->removeIncomingValue(BB: BackEdge); |
2062 | ReductionInfo[RootImag].first->removeIncomingValue(BB: BackEdge); |
2063 | DeadInstrRoots.push_back(Elt: cast<Instruction>(Val: RootReal)); |
2064 | DeadInstrRoots.push_back(Elt: cast<Instruction>(Val: RootImag)); |
2065 | } else { |
2066 | assert(R && "Unable to find replacement for RootInstruction" ); |
2067 | DeadInstrRoots.push_back(Elt: RootInstruction); |
2068 | RootInstruction->replaceAllUsesWith(V: R); |
2069 | } |
2070 | } |
2071 | |
2072 | for (auto *I : DeadInstrRoots) |
2073 | RecursivelyDeleteTriviallyDeadInstructions(V: I, TLI); |
2074 | } |
2075 | |