1//===- DAGCombiner.cpp - Implement a DAG node combiner --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass combines dag nodes to form fewer, simpler DAG nodes. It can be run
10// both before and after the DAG is legalized.
11//
12// This pass is not a substitute for the LLVM IR instcombine pass. This pass is
13// primarily intended to handle simplification opportunities that are implicit
14// in the LLVM IR and exposed by the various codegen lowering phases.
15//
16//===----------------------------------------------------------------------===//
17
18#include "llvm/ADT/APFloat.h"
19#include "llvm/ADT/APInt.h"
20#include "llvm/ADT/ArrayRef.h"
21#include "llvm/ADT/DenseMap.h"
22#include "llvm/ADT/IntervalMap.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/SetVector.h"
25#include "llvm/ADT/SmallBitVector.h"
26#include "llvm/ADT/SmallPtrSet.h"
27#include "llvm/ADT/SmallSet.h"
28#include "llvm/ADT/SmallVector.h"
29#include "llvm/ADT/Statistic.h"
30#include "llvm/Analysis/AliasAnalysis.h"
31#include "llvm/Analysis/MemoryLocation.h"
32#include "llvm/Analysis/TargetLibraryInfo.h"
33#include "llvm/Analysis/ValueTracking.h"
34#include "llvm/Analysis/VectorUtils.h"
35#include "llvm/CodeGen/ByteProvider.h"
36#include "llvm/CodeGen/DAGCombine.h"
37#include "llvm/CodeGen/ISDOpcodes.h"
38#include "llvm/CodeGen/MachineFunction.h"
39#include "llvm/CodeGen/MachineMemOperand.h"
40#include "llvm/CodeGen/RuntimeLibcallUtil.h"
41#include "llvm/CodeGen/SDPatternMatch.h"
42#include "llvm/CodeGen/SelectionDAG.h"
43#include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
44#include "llvm/CodeGen/SelectionDAGNodes.h"
45#include "llvm/CodeGen/SelectionDAGTargetInfo.h"
46#include "llvm/CodeGen/TargetLowering.h"
47#include "llvm/CodeGen/TargetRegisterInfo.h"
48#include "llvm/CodeGen/TargetSubtargetInfo.h"
49#include "llvm/CodeGen/ValueTypes.h"
50#include "llvm/CodeGenTypes/MachineValueType.h"
51#include "llvm/IR/Attributes.h"
52#include "llvm/IR/Constant.h"
53#include "llvm/IR/DataLayout.h"
54#include "llvm/IR/DerivedTypes.h"
55#include "llvm/IR/Function.h"
56#include "llvm/IR/Metadata.h"
57#include "llvm/Support/Casting.h"
58#include "llvm/Support/CodeGen.h"
59#include "llvm/Support/CommandLine.h"
60#include "llvm/Support/Compiler.h"
61#include "llvm/Support/Debug.h"
62#include "llvm/Support/DebugCounter.h"
63#include "llvm/Support/ErrorHandling.h"
64#include "llvm/Support/KnownBits.h"
65#include "llvm/Support/MathExtras.h"
66#include "llvm/Support/raw_ostream.h"
67#include "llvm/Target/TargetMachine.h"
68#include "llvm/Target/TargetOptions.h"
69#include <algorithm>
70#include <cassert>
71#include <cstdint>
72#include <functional>
73#include <iterator>
74#include <optional>
75#include <string>
76#include <tuple>
77#include <utility>
78#include <variant>
79
80#include "MatchContext.h"
81
82using namespace llvm;
83using namespace llvm::SDPatternMatch;
84
85#define DEBUG_TYPE "dagcombine"
86
87STATISTIC(NodesCombined , "Number of dag nodes combined");
88STATISTIC(PreIndexedNodes , "Number of pre-indexed nodes created");
89STATISTIC(PostIndexedNodes, "Number of post-indexed nodes created");
90STATISTIC(OpsNarrowed , "Number of load/op/store narrowed");
91STATISTIC(LdStFP2Int , "Number of fp load/store pairs transformed to int");
92STATISTIC(SlicedLoads, "Number of load sliced");
93STATISTIC(NumFPLogicOpsConv, "Number of logic ops converted to fp ops");
94
95DEBUG_COUNTER(DAGCombineCounter, "dagcombine",
96 "Controls whether a DAG combine is performed for a node");
97
98static cl::opt<bool>
99CombinerGlobalAA("combiner-global-alias-analysis", cl::Hidden,
100 cl::desc("Enable DAG combiner's use of IR alias analysis"));
101
102static cl::opt<bool>
103UseTBAA("combiner-use-tbaa", cl::Hidden, cl::init(Val: true),
104 cl::desc("Enable DAG combiner's use of TBAA"));
105
106#ifndef NDEBUG
107static cl::opt<std::string>
108CombinerAAOnlyFunc("combiner-aa-only-func", cl::Hidden,
109 cl::desc("Only use DAG-combiner alias analysis in this"
110 " function"));
111#endif
112
113/// Hidden option to stress test load slicing, i.e., when this option
114/// is enabled, load slicing bypasses most of its profitability guards.
115static cl::opt<bool>
116StressLoadSlicing("combiner-stress-load-slicing", cl::Hidden,
117 cl::desc("Bypass the profitability model of load slicing"),
118 cl::init(Val: false));
119
120static cl::opt<bool>
121 MaySplitLoadIndex("combiner-split-load-index", cl::Hidden, cl::init(Val: true),
122 cl::desc("DAG combiner may split indexing from loads"));
123
124static cl::opt<bool>
125 EnableStoreMerging("combiner-store-merging", cl::Hidden, cl::init(Val: true),
126 cl::desc("DAG combiner enable merging multiple stores "
127 "into a wider store"));
128
129static cl::opt<unsigned> TokenFactorInlineLimit(
130 "combiner-tokenfactor-inline-limit", cl::Hidden, cl::init(Val: 2048),
131 cl::desc("Limit the number of operands to inline for Token Factors"));
132
133static cl::opt<unsigned> StoreMergeDependenceLimit(
134 "combiner-store-merge-dependence-limit", cl::Hidden, cl::init(Val: 10),
135 cl::desc("Limit the number of times for the same StoreNode and RootNode "
136 "to bail out in store merging dependence check"));
137
138static cl::opt<bool> EnableReduceLoadOpStoreWidth(
139 "combiner-reduce-load-op-store-width", cl::Hidden, cl::init(Val: true),
140 cl::desc("DAG combiner enable reducing the width of load/op/store "
141 "sequence"));
142
143static cl::opt<bool> EnableShrinkLoadReplaceStoreWithStore(
144 "combiner-shrink-load-replace-store-with-store", cl::Hidden, cl::init(Val: true),
145 cl::desc("DAG combiner enable load/<replace bytes>/store with "
146 "a narrower store"));
147
148static cl::opt<bool> EnableVectorFCopySignExtendRound(
149 "combiner-vector-fcopysign-extend-round", cl::Hidden, cl::init(Val: false),
150 cl::desc(
151 "Enable merging extends and rounds into FCOPYSIGN on vector types"));
152
153namespace {
154
155 class DAGCombiner {
156 SelectionDAG &DAG;
157 const TargetLowering &TLI;
158 const SelectionDAGTargetInfo *STI;
159 CombineLevel Level = BeforeLegalizeTypes;
160 CodeGenOptLevel OptLevel;
161 bool LegalDAG = false;
162 bool LegalOperations = false;
163 bool LegalTypes = false;
164 bool ForCodeSize;
165 bool DisableGenericCombines;
166
167 /// Worklist of all of the nodes that need to be simplified.
168 ///
169 /// This must behave as a stack -- new nodes to process are pushed onto the
170 /// back and when processing we pop off of the back.
171 ///
172 /// The worklist will not contain duplicates but may contain null entries
173 /// due to nodes being deleted from the underlying DAG. For fast lookup and
174 /// deduplication, the index of the node in this vector is stored in the
175 /// node in SDNode::CombinerWorklistIndex.
176 SmallVector<SDNode *, 64> Worklist;
177
178 /// This records all nodes attempted to be added to the worklist since we
179 /// considered a new worklist entry. As we keep do not add duplicate nodes
180 /// in the worklist, this is different from the tail of the worklist.
181 SmallSetVector<SDNode *, 32> PruningList;
182
183 /// Map from candidate StoreNode to the pair of RootNode and count.
184 /// The count is used to track how many times we have seen the StoreNode
185 /// with the same RootNode bail out in dependence check. If we have seen
186 /// the bail out for the same pair many times over a limit, we won't
187 /// consider the StoreNode with the same RootNode as store merging
188 /// candidate again.
189 DenseMap<SDNode *, std::pair<SDNode *, unsigned>> StoreRootCountMap;
190
191 // AA - Used for DAG load/store alias analysis.
192 AliasAnalysis *AA;
193
194 /// This caches all chains that have already been processed in
195 /// DAGCombiner::getStoreMergeCandidates() and found to have no mergeable
196 /// stores candidates.
197 SmallPtrSet<SDNode *, 4> ChainsWithoutMergeableStores;
198
199 /// When an instruction is simplified, add all users of the instruction to
200 /// the work lists because they might get more simplified now.
201 void AddUsersToWorklist(SDNode *N) {
202 for (SDNode *Node : N->uses())
203 AddToWorklist(N: Node);
204 }
205
206 /// Convenient shorthand to add a node and all of its user to the worklist.
207 void AddToWorklistWithUsers(SDNode *N) {
208 AddUsersToWorklist(N);
209 AddToWorklist(N);
210 }
211
212 // Prune potentially dangling nodes. This is called after
213 // any visit to a node, but should also be called during a visit after any
214 // failed combine which may have created a DAG node.
215 void clearAddedDanglingWorklistEntries() {
216 // Check any nodes added to the worklist to see if they are prunable.
217 while (!PruningList.empty()) {
218 auto *N = PruningList.pop_back_val();
219 if (N->use_empty())
220 recursivelyDeleteUnusedNodes(N);
221 }
222 }
223
224 SDNode *getNextWorklistEntry() {
225 // Before we do any work, remove nodes that are not in use.
226 clearAddedDanglingWorklistEntries();
227 SDNode *N = nullptr;
228 // The Worklist holds the SDNodes in order, but it may contain null
229 // entries.
230 while (!N && !Worklist.empty()) {
231 N = Worklist.pop_back_val();
232 }
233
234 if (N) {
235 assert(N->getCombinerWorklistIndex() >= 0 &&
236 "Found a worklist entry without a corresponding map entry!");
237 // Set to -2 to indicate that we combined the node.
238 N->setCombinerWorklistIndex(-2);
239 }
240 return N;
241 }
242
243 /// Call the node-specific routine that folds each particular type of node.
244 SDValue visit(SDNode *N);
245
246 public:
247 DAGCombiner(SelectionDAG &D, AliasAnalysis *AA, CodeGenOptLevel OL)
248 : DAG(D), TLI(D.getTargetLoweringInfo()),
249 STI(D.getSubtarget().getSelectionDAGInfo()), OptLevel(OL), AA(AA) {
250 ForCodeSize = DAG.shouldOptForSize();
251 DisableGenericCombines = STI && STI->disableGenericCombines(OptLevel);
252
253 MaximumLegalStoreInBits = 0;
254 // We use the minimum store size here, since that's all we can guarantee
255 // for the scalable vector types.
256 for (MVT VT : MVT::all_valuetypes())
257 if (EVT(VT).isSimple() && VT != MVT::Other &&
258 TLI.isTypeLegal(VT: EVT(VT)) &&
259 VT.getSizeInBits().getKnownMinValue() >= MaximumLegalStoreInBits)
260 MaximumLegalStoreInBits = VT.getSizeInBits().getKnownMinValue();
261 }
262
263 void ConsiderForPruning(SDNode *N) {
264 // Mark this for potential pruning.
265 PruningList.insert(X: N);
266 }
267
268 /// Add to the worklist making sure its instance is at the back (next to be
269 /// processed.)
270 void AddToWorklist(SDNode *N, bool IsCandidateForPruning = true,
271 bool SkipIfCombinedBefore = false) {
272 assert(N->getOpcode() != ISD::DELETED_NODE &&
273 "Deleted Node added to Worklist");
274
275 // Skip handle nodes as they can't usefully be combined and confuse the
276 // zero-use deletion strategy.
277 if (N->getOpcode() == ISD::HANDLENODE)
278 return;
279
280 if (SkipIfCombinedBefore && N->getCombinerWorklistIndex() == -2)
281 return;
282
283 if (IsCandidateForPruning)
284 ConsiderForPruning(N);
285
286 if (N->getCombinerWorklistIndex() < 0) {
287 N->setCombinerWorklistIndex(Worklist.size());
288 Worklist.push_back(Elt: N);
289 }
290 }
291
292 /// Remove all instances of N from the worklist.
293 void removeFromWorklist(SDNode *N) {
294 PruningList.remove(X: N);
295 StoreRootCountMap.erase(Val: N);
296
297 int WorklistIndex = N->getCombinerWorklistIndex();
298 // If not in the worklist, the index might be -1 or -2 (was combined
299 // before). As the node gets deleted anyway, there's no need to update
300 // the index.
301 if (WorklistIndex < 0)
302 return; // Not in the worklist.
303
304 // Null out the entry rather than erasing it to avoid a linear operation.
305 Worklist[WorklistIndex] = nullptr;
306 N->setCombinerWorklistIndex(-1);
307 }
308
309 void deleteAndRecombine(SDNode *N);
310 bool recursivelyDeleteUnusedNodes(SDNode *N);
311
312 /// Replaces all uses of the results of one DAG node with new values.
313 SDValue CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
314 bool AddTo = true);
315
316 /// Replaces all uses of the results of one DAG node with new values.
317 SDValue CombineTo(SDNode *N, SDValue Res, bool AddTo = true) {
318 return CombineTo(N, To: &Res, NumTo: 1, AddTo);
319 }
320
321 /// Replaces all uses of the results of one DAG node with new values.
322 SDValue CombineTo(SDNode *N, SDValue Res0, SDValue Res1,
323 bool AddTo = true) {
324 SDValue To[] = { Res0, Res1 };
325 return CombineTo(N, To, NumTo: 2, AddTo);
326 }
327
328 void CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO);
329
330 private:
331 unsigned MaximumLegalStoreInBits;
332
333 /// Check the specified integer node value to see if it can be simplified or
334 /// if things it uses can be simplified by bit propagation.
335 /// If so, return true.
336 bool SimplifyDemandedBits(SDValue Op) {
337 unsigned BitWidth = Op.getScalarValueSizeInBits();
338 APInt DemandedBits = APInt::getAllOnes(numBits: BitWidth);
339 return SimplifyDemandedBits(Op, DemandedBits);
340 }
341
342 bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits) {
343 EVT VT = Op.getValueType();
344 APInt DemandedElts = VT.isFixedLengthVector()
345 ? APInt::getAllOnes(numBits: VT.getVectorNumElements())
346 : APInt(1, 1);
347 return SimplifyDemandedBits(Op, DemandedBits, DemandedElts, AssumeSingleUse: false);
348 }
349
350 /// Check the specified vector node value to see if it can be simplified or
351 /// if things it uses can be simplified as it only uses some of the
352 /// elements. If so, return true.
353 bool SimplifyDemandedVectorElts(SDValue Op) {
354 // TODO: For now just pretend it cannot be simplified.
355 if (Op.getValueType().isScalableVector())
356 return false;
357
358 unsigned NumElts = Op.getValueType().getVectorNumElements();
359 APInt DemandedElts = APInt::getAllOnes(numBits: NumElts);
360 return SimplifyDemandedVectorElts(Op, DemandedElts);
361 }
362
363 bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
364 const APInt &DemandedElts,
365 bool AssumeSingleUse = false);
366 bool SimplifyDemandedVectorElts(SDValue Op, const APInt &DemandedElts,
367 bool AssumeSingleUse = false);
368
369 bool CombineToPreIndexedLoadStore(SDNode *N);
370 bool CombineToPostIndexedLoadStore(SDNode *N);
371 SDValue SplitIndexingFromLoad(LoadSDNode *LD);
372 bool SliceUpLoad(SDNode *N);
373
374 // Looks up the chain to find a unique (unaliased) store feeding the passed
375 // load. If no such store is found, returns a nullptr.
376 // Note: This will look past a CALLSEQ_START if the load is chained to it so
377 // so that it can find stack stores for byval params.
378 StoreSDNode *getUniqueStoreFeeding(LoadSDNode *LD, int64_t &Offset);
379 // Scalars have size 0 to distinguish from singleton vectors.
380 SDValue ForwardStoreValueToDirectLoad(LoadSDNode *LD);
381 bool getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val);
382 bool extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val);
383
384 /// Replace an ISD::EXTRACT_VECTOR_ELT of a load with a narrowed
385 /// load.
386 ///
387 /// \param EVE ISD::EXTRACT_VECTOR_ELT to be replaced.
388 /// \param InVecVT type of the input vector to EVE with bitcasts resolved.
389 /// \param EltNo index of the vector element to load.
390 /// \param OriginalLoad load that EVE came from to be replaced.
391 /// \returns EVE on success SDValue() on failure.
392 SDValue scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
393 SDValue EltNo,
394 LoadSDNode *OriginalLoad);
395 void ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad);
396 SDValue PromoteOperand(SDValue Op, EVT PVT, bool &Replace);
397 SDValue SExtPromoteOperand(SDValue Op, EVT PVT);
398 SDValue ZExtPromoteOperand(SDValue Op, EVT PVT);
399 SDValue PromoteIntBinOp(SDValue Op);
400 SDValue PromoteIntShiftOp(SDValue Op);
401 SDValue PromoteExtend(SDValue Op);
402 bool PromoteLoad(SDValue Op);
403
404 SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
405 SDValue RHS, SDValue True, SDValue False,
406 ISD::CondCode CC);
407
408 /// Call the node-specific routine that knows how to fold each
409 /// particular type of node. If that doesn't do anything, try the
410 /// target-specific DAG combines.
411 SDValue combine(SDNode *N);
412
413 // Visitation implementation - Implement dag node combining for different
414 // node types. The semantics are as follows:
415 // Return Value:
416 // SDValue.getNode() == 0 - No change was made
417 // SDValue.getNode() == N - N was replaced, is dead and has been handled.
418 // otherwise - N should be replaced by the returned Operand.
419 //
420 SDValue visitTokenFactor(SDNode *N);
421 SDValue visitMERGE_VALUES(SDNode *N);
422 SDValue visitADD(SDNode *N);
423 SDValue visitADDLike(SDNode *N);
424 SDValue visitADDLikeCommutative(SDValue N0, SDValue N1, SDNode *LocReference);
425 SDValue visitSUB(SDNode *N);
426 SDValue visitADDSAT(SDNode *N);
427 SDValue visitSUBSAT(SDNode *N);
428 SDValue visitADDC(SDNode *N);
429 SDValue visitADDO(SDNode *N);
430 SDValue visitUADDOLike(SDValue N0, SDValue N1, SDNode *N);
431 SDValue visitSUBC(SDNode *N);
432 SDValue visitSUBO(SDNode *N);
433 SDValue visitADDE(SDNode *N);
434 SDValue visitUADDO_CARRY(SDNode *N);
435 SDValue visitSADDO_CARRY(SDNode *N);
436 SDValue visitUADDO_CARRYLike(SDValue N0, SDValue N1, SDValue CarryIn,
437 SDNode *N);
438 SDValue visitSADDO_CARRYLike(SDValue N0, SDValue N1, SDValue CarryIn,
439 SDNode *N);
440 SDValue visitSUBE(SDNode *N);
441 SDValue visitUSUBO_CARRY(SDNode *N);
442 SDValue visitSSUBO_CARRY(SDNode *N);
443 template <class MatchContextClass> SDValue visitMUL(SDNode *N);
444 SDValue visitMULFIX(SDNode *N);
445 SDValue useDivRem(SDNode *N);
446 SDValue visitSDIV(SDNode *N);
447 SDValue visitSDIVLike(SDValue N0, SDValue N1, SDNode *N);
448 SDValue visitUDIV(SDNode *N);
449 SDValue visitUDIVLike(SDValue N0, SDValue N1, SDNode *N);
450 SDValue visitREM(SDNode *N);
451 SDValue visitMULHU(SDNode *N);
452 SDValue visitMULHS(SDNode *N);
453 SDValue visitAVG(SDNode *N);
454 SDValue visitABD(SDNode *N);
455 SDValue visitSMUL_LOHI(SDNode *N);
456 SDValue visitUMUL_LOHI(SDNode *N);
457 SDValue visitMULO(SDNode *N);
458 SDValue visitIMINMAX(SDNode *N);
459 SDValue visitAND(SDNode *N);
460 SDValue visitANDLike(SDValue N0, SDValue N1, SDNode *N);
461 SDValue visitOR(SDNode *N);
462 SDValue visitORLike(SDValue N0, SDValue N1, const SDLoc &DL);
463 SDValue visitXOR(SDNode *N);
464 SDValue SimplifyVCastOp(SDNode *N, const SDLoc &DL);
465 SDValue SimplifyVBinOp(SDNode *N, const SDLoc &DL);
466 SDValue visitSHL(SDNode *N);
467 SDValue visitSRA(SDNode *N);
468 SDValue visitSRL(SDNode *N);
469 SDValue visitFunnelShift(SDNode *N);
470 SDValue visitSHLSAT(SDNode *N);
471 SDValue visitRotate(SDNode *N);
472 SDValue visitABS(SDNode *N);
473 SDValue visitBSWAP(SDNode *N);
474 SDValue visitBITREVERSE(SDNode *N);
475 SDValue visitCTLZ(SDNode *N);
476 SDValue visitCTLZ_ZERO_UNDEF(SDNode *N);
477 SDValue visitCTTZ(SDNode *N);
478 SDValue visitCTTZ_ZERO_UNDEF(SDNode *N);
479 SDValue visitCTPOP(SDNode *N);
480 SDValue visitSELECT(SDNode *N);
481 SDValue visitVSELECT(SDNode *N);
482 SDValue visitVP_SELECT(SDNode *N);
483 SDValue visitSELECT_CC(SDNode *N);
484 SDValue visitSETCC(SDNode *N);
485 SDValue visitSETCCCARRY(SDNode *N);
486 SDValue visitSIGN_EXTEND(SDNode *N);
487 SDValue visitZERO_EXTEND(SDNode *N);
488 SDValue visitANY_EXTEND(SDNode *N);
489 SDValue visitAssertExt(SDNode *N);
490 SDValue visitAssertAlign(SDNode *N);
491 SDValue visitSIGN_EXTEND_INREG(SDNode *N);
492 SDValue visitEXTEND_VECTOR_INREG(SDNode *N);
493 SDValue visitTRUNCATE(SDNode *N);
494 SDValue visitBITCAST(SDNode *N);
495 SDValue visitFREEZE(SDNode *N);
496 SDValue visitBUILD_PAIR(SDNode *N);
497 SDValue visitFADD(SDNode *N);
498 SDValue visitVP_FADD(SDNode *N);
499 SDValue visitVP_FSUB(SDNode *N);
500 SDValue visitSTRICT_FADD(SDNode *N);
501 SDValue visitFSUB(SDNode *N);
502 SDValue visitFMUL(SDNode *N);
503 template <class MatchContextClass> SDValue visitFMA(SDNode *N);
504 SDValue visitFMAD(SDNode *N);
505 SDValue visitFDIV(SDNode *N);
506 SDValue visitFREM(SDNode *N);
507 SDValue visitFSQRT(SDNode *N);
508 SDValue visitFCOPYSIGN(SDNode *N);
509 SDValue visitFPOW(SDNode *N);
510 SDValue visitSINT_TO_FP(SDNode *N);
511 SDValue visitUINT_TO_FP(SDNode *N);
512 SDValue visitFP_TO_SINT(SDNode *N);
513 SDValue visitFP_TO_UINT(SDNode *N);
514 SDValue visitXRINT(SDNode *N);
515 SDValue visitFP_ROUND(SDNode *N);
516 SDValue visitFP_EXTEND(SDNode *N);
517 SDValue visitFNEG(SDNode *N);
518 SDValue visitFABS(SDNode *N);
519 SDValue visitFCEIL(SDNode *N);
520 SDValue visitFTRUNC(SDNode *N);
521 SDValue visitFFREXP(SDNode *N);
522 SDValue visitFFLOOR(SDNode *N);
523 SDValue visitFMinMax(SDNode *N);
524 SDValue visitBRCOND(SDNode *N);
525 SDValue visitBR_CC(SDNode *N);
526 SDValue visitLOAD(SDNode *N);
527
528 SDValue replaceStoreChain(StoreSDNode *ST, SDValue BetterChain);
529 SDValue replaceStoreOfFPConstant(StoreSDNode *ST);
530 SDValue replaceStoreOfInsertLoad(StoreSDNode *ST);
531
532 bool refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(SDNode *N);
533
534 SDValue visitSTORE(SDNode *N);
535 SDValue visitATOMIC_STORE(SDNode *N);
536 SDValue visitLIFETIME_END(SDNode *N);
537 SDValue visitINSERT_VECTOR_ELT(SDNode *N);
538 SDValue visitEXTRACT_VECTOR_ELT(SDNode *N);
539 SDValue visitBUILD_VECTOR(SDNode *N);
540 SDValue visitCONCAT_VECTORS(SDNode *N);
541 SDValue visitEXTRACT_SUBVECTOR(SDNode *N);
542 SDValue visitVECTOR_SHUFFLE(SDNode *N);
543 SDValue visitSCALAR_TO_VECTOR(SDNode *N);
544 SDValue visitINSERT_SUBVECTOR(SDNode *N);
545 SDValue visitVECTOR_COMPRESS(SDNode *N);
546 SDValue visitMLOAD(SDNode *N);
547 SDValue visitMSTORE(SDNode *N);
548 SDValue visitMGATHER(SDNode *N);
549 SDValue visitMSCATTER(SDNode *N);
550 SDValue visitVPGATHER(SDNode *N);
551 SDValue visitVPSCATTER(SDNode *N);
552 SDValue visitVP_STRIDED_LOAD(SDNode *N);
553 SDValue visitVP_STRIDED_STORE(SDNode *N);
554 SDValue visitFP_TO_FP16(SDNode *N);
555 SDValue visitFP16_TO_FP(SDNode *N);
556 SDValue visitFP_TO_BF16(SDNode *N);
557 SDValue visitBF16_TO_FP(SDNode *N);
558 SDValue visitVECREDUCE(SDNode *N);
559 SDValue visitVPOp(SDNode *N);
560 SDValue visitGET_FPENV_MEM(SDNode *N);
561 SDValue visitSET_FPENV_MEM(SDNode *N);
562
563 template <class MatchContextClass>
564 SDValue visitFADDForFMACombine(SDNode *N);
565 template <class MatchContextClass>
566 SDValue visitFSUBForFMACombine(SDNode *N);
567 SDValue visitFMULForFMADistributiveCombine(SDNode *N);
568
569 SDValue XformToShuffleWithZero(SDNode *N);
570 bool reassociationCanBreakAddressingModePattern(unsigned Opc,
571 const SDLoc &DL,
572 SDNode *N,
573 SDValue N0,
574 SDValue N1);
575 SDValue reassociateOpsCommutative(unsigned Opc, const SDLoc &DL, SDValue N0,
576 SDValue N1, SDNodeFlags Flags);
577 SDValue reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
578 SDValue N1, SDNodeFlags Flags);
579 SDValue reassociateReduction(unsigned RedOpc, unsigned Opc, const SDLoc &DL,
580 EVT VT, SDValue N0, SDValue N1,
581 SDNodeFlags Flags = SDNodeFlags());
582
583 SDValue visitShiftByConstant(SDNode *N);
584
585 SDValue foldSelectOfConstants(SDNode *N);
586 SDValue foldVSelectOfConstants(SDNode *N);
587 SDValue foldBinOpIntoSelect(SDNode *BO);
588 bool SimplifySelectOps(SDNode *SELECT, SDValue LHS, SDValue RHS);
589 SDValue hoistLogicOpWithSameOpcodeHands(SDNode *N);
590 SDValue SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2);
591 SDValue SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
592 SDValue N2, SDValue N3, ISD::CondCode CC,
593 bool NotExtCompare = false);
594 SDValue convertSelectOfFPConstantsToLoadOffset(
595 const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
596 ISD::CondCode CC);
597 SDValue foldSignChangeInBitcast(SDNode *N);
598 SDValue foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0, SDValue N1,
599 SDValue N2, SDValue N3, ISD::CondCode CC);
600 SDValue foldSelectOfBinops(SDNode *N);
601 SDValue foldSextSetcc(SDNode *N);
602 SDValue foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
603 const SDLoc &DL);
604 SDValue foldSubToUSubSat(EVT DstVT, SDNode *N, const SDLoc &DL);
605 SDValue foldABSToABD(SDNode *N, const SDLoc &DL);
606 SDValue unfoldMaskedMerge(SDNode *N);
607 SDValue unfoldExtremeBitClearingToShifts(SDNode *N);
608 SDValue SimplifySetCC(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
609 const SDLoc &DL, bool foldBooleans);
610 SDValue rebuildSetCC(SDValue N);
611
612 bool isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
613 SDValue &CC, bool MatchStrict = false) const;
614 bool isOneUseSetCC(SDValue N) const;
615
616 SDValue foldAddToAvg(SDNode *N, const SDLoc &DL);
617 SDValue foldSubToAvg(SDNode *N, const SDLoc &DL);
618
619 SDValue SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
620 unsigned HiOp);
621 SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
622 SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
623 const TargetLowering &TLI);
624
625 SDValue CombineExtLoad(SDNode *N);
626 SDValue CombineZExtLogicopShiftLoad(SDNode *N);
627 SDValue combineRepeatedFPDivisors(SDNode *N);
628 SDValue combineFMulOrFDivWithIntPow2(SDNode *N);
629 SDValue mergeInsertEltWithShuffle(SDNode *N, unsigned InsIndex);
630 SDValue combineInsertEltToShuffle(SDNode *N, unsigned InsIndex);
631 SDValue combineInsertEltToLoad(SDNode *N, unsigned InsIndex);
632 SDValue ConstantFoldBITCASTofBUILD_VECTOR(SDNode *, EVT);
633 SDValue BuildSDIV(SDNode *N);
634 SDValue BuildSDIVPow2(SDNode *N);
635 SDValue BuildUDIV(SDNode *N);
636 SDValue BuildSREMPow2(SDNode *N);
637 SDValue buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N);
638 SDValue BuildLogBase2(SDValue V, const SDLoc &DL,
639 bool KnownNeverZero = false,
640 bool InexpensiveOnly = false,
641 std::optional<EVT> OutVT = std::nullopt);
642 SDValue BuildDivEstimate(SDValue N, SDValue Op, SDNodeFlags Flags);
643 SDValue buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags);
644 SDValue buildSqrtEstimate(SDValue Op, SDNodeFlags Flags);
645 SDValue buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags, bool Recip);
646 SDValue buildSqrtNROneConst(SDValue Arg, SDValue Est, unsigned Iterations,
647 SDNodeFlags Flags, bool Reciprocal);
648 SDValue buildSqrtNRTwoConst(SDValue Arg, SDValue Est, unsigned Iterations,
649 SDNodeFlags Flags, bool Reciprocal);
650 SDValue MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
651 bool DemandHighBits = true);
652 SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1);
653 SDValue MatchRotatePosNeg(SDValue Shifted, SDValue Pos, SDValue Neg,
654 SDValue InnerPos, SDValue InnerNeg, bool HasPos,
655 unsigned PosOpcode, unsigned NegOpcode,
656 const SDLoc &DL);
657 SDValue MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos, SDValue Neg,
658 SDValue InnerPos, SDValue InnerNeg, bool HasPos,
659 unsigned PosOpcode, unsigned NegOpcode,
660 const SDLoc &DL);
661 SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL);
662 SDValue MatchLoadCombine(SDNode *N);
663 SDValue mergeTruncStores(StoreSDNode *N);
664 SDValue reduceLoadWidth(SDNode *N);
665 SDValue ReduceLoadOpStoreWidth(SDNode *N);
666 SDValue splitMergedValStore(StoreSDNode *ST);
667 SDValue TransformFPLoadStorePair(SDNode *N);
668 SDValue convertBuildVecZextToZext(SDNode *N);
669 SDValue convertBuildVecZextToBuildVecWithZeros(SDNode *N);
670 SDValue reduceBuildVecExtToExtBuildVec(SDNode *N);
671 SDValue reduceBuildVecTruncToBitCast(SDNode *N);
672 SDValue reduceBuildVecToShuffle(SDNode *N);
673 SDValue createBuildVecShuffle(const SDLoc &DL, SDNode *N,
674 ArrayRef<int> VectorMask, SDValue VecIn1,
675 SDValue VecIn2, unsigned LeftIdx,
676 bool DidSplitVec);
677 SDValue matchVSelectOpSizesWithSetCC(SDNode *Cast);
678
679 /// Walk up chain skipping non-aliasing memory nodes,
680 /// looking for aliasing nodes and adding them to the Aliases vector.
681 void GatherAllAliases(SDNode *N, SDValue OriginalChain,
682 SmallVectorImpl<SDValue> &Aliases);
683
684 /// Return true if there is any possibility that the two addresses overlap.
685 bool mayAlias(SDNode *Op0, SDNode *Op1) const;
686
687 /// Walk up chain skipping non-aliasing memory nodes, looking for a better
688 /// chain (aliasing node.)
689 SDValue FindBetterChain(SDNode *N, SDValue Chain);
690
691 /// Try to replace a store and any possibly adjacent stores on
692 /// consecutive chains with better chains. Return true only if St is
693 /// replaced.
694 ///
695 /// Notice that other chains may still be replaced even if the function
696 /// returns false.
697 bool findBetterNeighborChains(StoreSDNode *St);
698
699 // Helper for findBetterNeighborChains. Walk up store chain add additional
700 // chained stores that do not overlap and can be parallelized.
701 bool parallelizeChainedStores(StoreSDNode *St);
702
703 /// Holds a pointer to an LSBaseSDNode as well as information on where it
704 /// is located in a sequence of memory operations connected by a chain.
705 struct MemOpLink {
706 // Ptr to the mem node.
707 LSBaseSDNode *MemNode;
708
709 // Offset from the base ptr.
710 int64_t OffsetFromBase;
711
712 MemOpLink(LSBaseSDNode *N, int64_t Offset)
713 : MemNode(N), OffsetFromBase(Offset) {}
714 };
715
716 // Classify the origin of a stored value.
717 enum class StoreSource { Unknown, Constant, Extract, Load };
718 StoreSource getStoreSource(SDValue StoreVal) {
719 switch (StoreVal.getOpcode()) {
720 case ISD::Constant:
721 case ISD::ConstantFP:
722 return StoreSource::Constant;
723 case ISD::BUILD_VECTOR:
724 if (ISD::isBuildVectorOfConstantSDNodes(N: StoreVal.getNode()) ||
725 ISD::isBuildVectorOfConstantFPSDNodes(N: StoreVal.getNode()))
726 return StoreSource::Constant;
727 return StoreSource::Unknown;
728 case ISD::EXTRACT_VECTOR_ELT:
729 case ISD::EXTRACT_SUBVECTOR:
730 return StoreSource::Extract;
731 case ISD::LOAD:
732 return StoreSource::Load;
733 default:
734 return StoreSource::Unknown;
735 }
736 }
737
738 /// This is a helper function for visitMUL to check the profitability
739 /// of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
740 /// MulNode is the original multiply, AddNode is (add x, c1),
741 /// and ConstNode is c2.
742 bool isMulAddWithConstProfitable(SDNode *MulNode, SDValue AddNode,
743 SDValue ConstNode);
744
745 /// This is a helper function for visitAND and visitZERO_EXTEND. Returns
746 /// true if the (and (load x) c) pattern matches an extload. ExtVT returns
747 /// the type of the loaded value to be extended.
748 bool isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
749 EVT LoadResultTy, EVT &ExtVT);
750
751 /// Helper function to calculate whether the given Load/Store can have its
752 /// width reduced to ExtVT.
753 bool isLegalNarrowLdSt(LSBaseSDNode *LDSTN, ISD::LoadExtType ExtType,
754 EVT &MemVT, unsigned ShAmt = 0);
755
756 /// Used by BackwardsPropagateMask to find suitable loads.
757 bool SearchForAndLoads(SDNode *N, SmallVectorImpl<LoadSDNode*> &Loads,
758 SmallPtrSetImpl<SDNode*> &NodesWithConsts,
759 ConstantSDNode *Mask, SDNode *&NodeToMask);
760 /// Attempt to propagate a given AND node back to load leaves so that they
761 /// can be combined into narrow loads.
762 bool BackwardsPropagateMask(SDNode *N);
763
764 /// Helper function for mergeConsecutiveStores which merges the component
765 /// store chains.
766 SDValue getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
767 unsigned NumStores);
768
769 /// Helper function for mergeConsecutiveStores which checks if all the store
770 /// nodes have the same underlying object. We can still reuse the first
771 /// store's pointer info if all the stores are from the same object.
772 bool hasSameUnderlyingObj(ArrayRef<MemOpLink> StoreNodes);
773
774 /// This is a helper function for mergeConsecutiveStores. When the source
775 /// elements of the consecutive stores are all constants or all extracted
776 /// vector elements, try to merge them into one larger store introducing
777 /// bitcasts if necessary. \return True if a merged store was created.
778 bool mergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> &StoreNodes,
779 EVT MemVT, unsigned NumStores,
780 bool IsConstantSrc, bool UseVector,
781 bool UseTrunc);
782
783 /// This is a helper function for mergeConsecutiveStores. Stores that
784 /// potentially may be merged with St are placed in StoreNodes. On success,
785 /// returns a chain predecessor to all store candidates.
786 SDNode *getStoreMergeCandidates(StoreSDNode *St,
787 SmallVectorImpl<MemOpLink> &StoreNodes);
788
789 /// Helper function for mergeConsecutiveStores. Checks if candidate stores
790 /// have indirect dependency through their operands. RootNode is the
791 /// predecessor to all stores calculated by getStoreMergeCandidates and is
792 /// used to prune the dependency check. \return True if safe to merge.
793 bool checkMergeStoreCandidatesForDependencies(
794 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
795 SDNode *RootNode);
796
797 /// This is a helper function for mergeConsecutiveStores. Given a list of
798 /// store candidates, find the first N that are consecutive in memory.
799 /// Returns 0 if there are not at least 2 consecutive stores to try merging.
800 unsigned getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
801 int64_t ElementSizeBytes) const;
802
803 /// This is a helper function for mergeConsecutiveStores. It is used for
804 /// store chains that are composed entirely of constant values.
805 bool tryStoreMergeOfConstants(SmallVectorImpl<MemOpLink> &StoreNodes,
806 unsigned NumConsecutiveStores,
807 EVT MemVT, SDNode *Root, bool AllowVectors);
808
809 /// This is a helper function for mergeConsecutiveStores. It is used for
810 /// store chains that are composed entirely of extracted vector elements.
811 /// When extracting multiple vector elements, try to store them in one
812 /// vector store rather than a sequence of scalar stores.
813 bool tryStoreMergeOfExtracts(SmallVectorImpl<MemOpLink> &StoreNodes,
814 unsigned NumConsecutiveStores, EVT MemVT,
815 SDNode *Root);
816
817 /// This is a helper function for mergeConsecutiveStores. It is used for
818 /// store chains that are composed entirely of loaded values.
819 bool tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
820 unsigned NumConsecutiveStores, EVT MemVT,
821 SDNode *Root, bool AllowVectors,
822 bool IsNonTemporalStore, bool IsNonTemporalLoad);
823
824 /// Merge consecutive store operations into a wide store.
825 /// This optimization uses wide integers or vectors when possible.
826 /// \return true if stores were merged.
827 bool mergeConsecutiveStores(StoreSDNode *St);
828
829 /// Try to transform a truncation where C is a constant:
830 /// (trunc (and X, C)) -> (and (trunc X), (trunc C))
831 ///
832 /// \p N needs to be a truncation and its first operand an AND. Other
833 /// requirements are checked by the function (e.g. that trunc is
834 /// single-use) and if missed an empty SDValue is returned.
835 SDValue distributeTruncateThroughAnd(SDNode *N);
836
837 /// Helper function to determine whether the target supports operation
838 /// given by \p Opcode for type \p VT, that is, whether the operation
839 /// is legal or custom before legalizing operations, and whether is
840 /// legal (but not custom) after legalization.
841 bool hasOperation(unsigned Opcode, EVT VT) {
842 return TLI.isOperationLegalOrCustom(Op: Opcode, VT, LegalOnly: LegalOperations);
843 }
844
845 public:
846 /// Runs the dag combiner on all nodes in the work list
847 void Run(CombineLevel AtLevel);
848
849 SelectionDAG &getDAG() const { return DAG; }
850
851 /// Convenience wrapper around TargetLowering::getShiftAmountTy.
852 EVT getShiftAmountTy(EVT LHSTy) {
853 return TLI.getShiftAmountTy(LHSTy, DL: DAG.getDataLayout());
854 }
855
856 /// This method returns true if we are running before type legalization or
857 /// if the specified VT is legal.
858 bool isTypeLegal(const EVT &VT) {
859 if (!LegalTypes) return true;
860 return TLI.isTypeLegal(VT);
861 }
862
863 /// Convenience wrapper around TargetLowering::getSetCCResultType
864 EVT getSetCCResultType(EVT VT) const {
865 return TLI.getSetCCResultType(DL: DAG.getDataLayout(), Context&: *DAG.getContext(), VT);
866 }
867
868 void ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
869 SDValue OrigLoad, SDValue ExtLoad,
870 ISD::NodeType ExtType);
871 };
872
873/// This class is a DAGUpdateListener that removes any deleted
874/// nodes from the worklist.
875class WorklistRemover : public SelectionDAG::DAGUpdateListener {
876 DAGCombiner &DC;
877
878public:
879 explicit WorklistRemover(DAGCombiner &dc)
880 : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
881
882 void NodeDeleted(SDNode *N, SDNode *E) override {
883 DC.removeFromWorklist(N);
884 }
885};
886
887class WorklistInserter : public SelectionDAG::DAGUpdateListener {
888 DAGCombiner &DC;
889
890public:
891 explicit WorklistInserter(DAGCombiner &dc)
892 : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
893
894 // FIXME: Ideally we could add N to the worklist, but this causes exponential
895 // compile time costs in large DAGs, e.g. Halide.
896 void NodeInserted(SDNode *N) override { DC.ConsiderForPruning(N); }
897};
898
899} // end anonymous namespace
900
901//===----------------------------------------------------------------------===//
902// TargetLowering::DAGCombinerInfo implementation
903//===----------------------------------------------------------------------===//
904
905void TargetLowering::DAGCombinerInfo::AddToWorklist(SDNode *N) {
906 ((DAGCombiner*)DC)->AddToWorklist(N);
907}
908
909SDValue TargetLowering::DAGCombinerInfo::
910CombineTo(SDNode *N, ArrayRef<SDValue> To, bool AddTo) {
911 return ((DAGCombiner*)DC)->CombineTo(N, To: &To[0], NumTo: To.size(), AddTo);
912}
913
914SDValue TargetLowering::DAGCombinerInfo::
915CombineTo(SDNode *N, SDValue Res, bool AddTo) {
916 return ((DAGCombiner*)DC)->CombineTo(N, Res, AddTo);
917}
918
919SDValue TargetLowering::DAGCombinerInfo::
920CombineTo(SDNode *N, SDValue Res0, SDValue Res1, bool AddTo) {
921 return ((DAGCombiner*)DC)->CombineTo(N, Res0, Res1, AddTo);
922}
923
924bool TargetLowering::DAGCombinerInfo::
925recursivelyDeleteUnusedNodes(SDNode *N) {
926 return ((DAGCombiner*)DC)->recursivelyDeleteUnusedNodes(N);
927}
928
929void TargetLowering::DAGCombinerInfo::
930CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
931 return ((DAGCombiner*)DC)->CommitTargetLoweringOpt(TLO);
932}
933
934//===----------------------------------------------------------------------===//
935// Helper Functions
936//===----------------------------------------------------------------------===//
937
938void DAGCombiner::deleteAndRecombine(SDNode *N) {
939 removeFromWorklist(N);
940
941 // If the operands of this node are only used by the node, they will now be
942 // dead. Make sure to re-visit them and recursively delete dead nodes.
943 for (const SDValue &Op : N->ops())
944 // For an operand generating multiple values, one of the values may
945 // become dead allowing further simplification (e.g. split index
946 // arithmetic from an indexed load).
947 if (Op->hasOneUse() || Op->getNumValues() > 1)
948 AddToWorklist(N: Op.getNode());
949
950 DAG.DeleteNode(N);
951}
952
953// APInts must be the same size for most operations, this helper
954// function zero extends the shorter of the pair so that they match.
955// We provide an Offset so that we can create bitwidths that won't overflow.
956static void zeroExtendToMatch(APInt &LHS, APInt &RHS, unsigned Offset = 0) {
957 unsigned Bits = Offset + std::max(a: LHS.getBitWidth(), b: RHS.getBitWidth());
958 LHS = LHS.zext(width: Bits);
959 RHS = RHS.zext(width: Bits);
960}
961
962// Return true if this node is a setcc, or is a select_cc
963// that selects between the target values used for true and false, making it
964// equivalent to a setcc. Also, set the incoming LHS, RHS, and CC references to
965// the appropriate nodes based on the type of node we are checking. This
966// simplifies life a bit for the callers.
967bool DAGCombiner::isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
968 SDValue &CC, bool MatchStrict) const {
969 if (N.getOpcode() == ISD::SETCC) {
970 LHS = N.getOperand(i: 0);
971 RHS = N.getOperand(i: 1);
972 CC = N.getOperand(i: 2);
973 return true;
974 }
975
976 if (MatchStrict &&
977 (N.getOpcode() == ISD::STRICT_FSETCC ||
978 N.getOpcode() == ISD::STRICT_FSETCCS)) {
979 LHS = N.getOperand(i: 1);
980 RHS = N.getOperand(i: 2);
981 CC = N.getOperand(i: 3);
982 return true;
983 }
984
985 if (N.getOpcode() != ISD::SELECT_CC || !TLI.isConstTrueVal(N: N.getOperand(i: 2)) ||
986 !TLI.isConstFalseVal(N: N.getOperand(i: 3)))
987 return false;
988
989 if (TLI.getBooleanContents(Type: N.getValueType()) ==
990 TargetLowering::UndefinedBooleanContent)
991 return false;
992
993 LHS = N.getOperand(i: 0);
994 RHS = N.getOperand(i: 1);
995 CC = N.getOperand(i: 4);
996 return true;
997}
998
999/// Return true if this is a SetCC-equivalent operation with only one use.
1000/// If this is true, it allows the users to invert the operation for free when
1001/// it is profitable to do so.
1002bool DAGCombiner::isOneUseSetCC(SDValue N) const {
1003 SDValue N0, N1, N2;
1004 if (isSetCCEquivalent(N, LHS&: N0, RHS&: N1, CC&: N2) && N->hasOneUse())
1005 return true;
1006 return false;
1007}
1008
1009static bool isConstantSplatVectorMaskForType(SDNode *N, EVT ScalarTy) {
1010 if (!ScalarTy.isSimple())
1011 return false;
1012
1013 uint64_t MaskForTy = 0ULL;
1014 switch (ScalarTy.getSimpleVT().SimpleTy) {
1015 case MVT::i8:
1016 MaskForTy = 0xFFULL;
1017 break;
1018 case MVT::i16:
1019 MaskForTy = 0xFFFFULL;
1020 break;
1021 case MVT::i32:
1022 MaskForTy = 0xFFFFFFFFULL;
1023 break;
1024 default:
1025 return false;
1026 break;
1027 }
1028
1029 APInt Val;
1030 if (ISD::isConstantSplatVector(N, SplatValue&: Val))
1031 return Val.getLimitedValue() == MaskForTy;
1032
1033 return false;
1034}
1035
1036// Determines if it is a constant integer or a splat/build vector of constant
1037// integers (and undefs).
1038// Do not permit build vector implicit truncation.
1039static bool isConstantOrConstantVector(SDValue N, bool NoOpaques = false) {
1040 if (ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Val&: N))
1041 return !(Const->isOpaque() && NoOpaques);
1042 if (N.getOpcode() != ISD::BUILD_VECTOR && N.getOpcode() != ISD::SPLAT_VECTOR)
1043 return false;
1044 unsigned BitWidth = N.getScalarValueSizeInBits();
1045 for (const SDValue &Op : N->op_values()) {
1046 if (Op.isUndef())
1047 continue;
1048 ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Val: Op);
1049 if (!Const || Const->getAPIntValue().getBitWidth() != BitWidth ||
1050 (Const->isOpaque() && NoOpaques))
1051 return false;
1052 }
1053 return true;
1054}
1055
1056// Determines if a BUILD_VECTOR is composed of all-constants possibly mixed with
1057// undef's.
1058static bool isAnyConstantBuildVector(SDValue V, bool NoOpaques = false) {
1059 if (V.getOpcode() != ISD::BUILD_VECTOR)
1060 return false;
1061 return isConstantOrConstantVector(N: V, NoOpaques) ||
1062 ISD::isBuildVectorOfConstantFPSDNodes(N: V.getNode());
1063}
1064
1065// Determine if this an indexed load with an opaque target constant index.
1066static bool canSplitIdx(LoadSDNode *LD) {
1067 return MaySplitLoadIndex &&
1068 (LD->getOperand(Num: 2).getOpcode() != ISD::TargetConstant ||
1069 !cast<ConstantSDNode>(Val: LD->getOperand(Num: 2))->isOpaque());
1070}
1071
1072bool DAGCombiner::reassociationCanBreakAddressingModePattern(unsigned Opc,
1073 const SDLoc &DL,
1074 SDNode *N,
1075 SDValue N0,
1076 SDValue N1) {
1077 // Currently this only tries to ensure we don't undo the GEP splits done by
1078 // CodeGenPrepare when shouldConsiderGEPOffsetSplit is true. To ensure this,
1079 // we check if the following transformation would be problematic:
1080 // (load/store (add, (add, x, offset1), offset2)) ->
1081 // (load/store (add, x, offset1+offset2)).
1082
1083 // (load/store (add, (add, x, y), offset2)) ->
1084 // (load/store (add, (add, x, offset2), y)).
1085
1086 if (N0.getOpcode() != ISD::ADD)
1087 return false;
1088
1089 // Check for vscale addressing modes.
1090 // (load/store (add/sub (add x, y), vscale))
1091 // (load/store (add/sub (add x, y), (lsl vscale, C)))
1092 // (load/store (add/sub (add x, y), (mul vscale, C)))
1093 if ((N1.getOpcode() == ISD::VSCALE ||
1094 ((N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::MUL) &&
1095 N1.getOperand(i: 0).getOpcode() == ISD::VSCALE &&
1096 isa<ConstantSDNode>(Val: N1.getOperand(i: 1)))) &&
1097 N1.getValueType().getFixedSizeInBits() <= 64) {
1098 int64_t ScalableOffset = N1.getOpcode() == ISD::VSCALE
1099 ? N1.getConstantOperandVal(i: 0)
1100 : (N1.getOperand(i: 0).getConstantOperandVal(i: 0) *
1101 (N1.getOpcode() == ISD::SHL
1102 ? (1LL << N1.getConstantOperandVal(i: 1))
1103 : N1.getConstantOperandVal(i: 1)));
1104 if (Opc == ISD::SUB)
1105 ScalableOffset = -ScalableOffset;
1106 if (all_of(Range: N->uses(), P: [&](SDNode *Node) {
1107 if (auto *LoadStore = dyn_cast<MemSDNode>(Val: Node);
1108 LoadStore && LoadStore->getBasePtr().getNode() == N) {
1109 TargetLoweringBase::AddrMode AM;
1110 AM.HasBaseReg = true;
1111 AM.ScalableOffset = ScalableOffset;
1112 EVT VT = LoadStore->getMemoryVT();
1113 unsigned AS = LoadStore->getAddressSpace();
1114 Type *AccessTy = VT.getTypeForEVT(Context&: *DAG.getContext());
1115 return TLI.isLegalAddressingMode(DL: DAG.getDataLayout(), AM, Ty: AccessTy,
1116 AddrSpace: AS);
1117 }
1118 return false;
1119 }))
1120 return true;
1121 }
1122
1123 if (Opc != ISD::ADD)
1124 return false;
1125
1126 auto *C2 = dyn_cast<ConstantSDNode>(Val&: N1);
1127 if (!C2)
1128 return false;
1129
1130 const APInt &C2APIntVal = C2->getAPIntValue();
1131 if (C2APIntVal.getSignificantBits() > 64)
1132 return false;
1133
1134 if (auto *C1 = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1))) {
1135 if (N0.hasOneUse())
1136 return false;
1137
1138 const APInt &C1APIntVal = C1->getAPIntValue();
1139 const APInt CombinedValueIntVal = C1APIntVal + C2APIntVal;
1140 if (CombinedValueIntVal.getSignificantBits() > 64)
1141 return false;
1142 const int64_t CombinedValue = CombinedValueIntVal.getSExtValue();
1143
1144 for (SDNode *Node : N->uses()) {
1145 if (auto *LoadStore = dyn_cast<MemSDNode>(Val: Node)) {
1146 // Is x[offset2] already not a legal addressing mode? If so then
1147 // reassociating the constants breaks nothing (we test offset2 because
1148 // that's the one we hope to fold into the load or store).
1149 TargetLoweringBase::AddrMode AM;
1150 AM.HasBaseReg = true;
1151 AM.BaseOffs = C2APIntVal.getSExtValue();
1152 EVT VT = LoadStore->getMemoryVT();
1153 unsigned AS = LoadStore->getAddressSpace();
1154 Type *AccessTy = VT.getTypeForEVT(Context&: *DAG.getContext());
1155 if (!TLI.isLegalAddressingMode(DL: DAG.getDataLayout(), AM, Ty: AccessTy, AddrSpace: AS))
1156 continue;
1157
1158 // Would x[offset1+offset2] still be a legal addressing mode?
1159 AM.BaseOffs = CombinedValue;
1160 if (!TLI.isLegalAddressingMode(DL: DAG.getDataLayout(), AM, Ty: AccessTy, AddrSpace: AS))
1161 return true;
1162 }
1163 }
1164 } else {
1165 if (auto *GA = dyn_cast<GlobalAddressSDNode>(Val: N0.getOperand(i: 1)))
1166 if (GA->getOpcode() == ISD::GlobalAddress && TLI.isOffsetFoldingLegal(GA))
1167 return false;
1168
1169 for (SDNode *Node : N->uses()) {
1170 auto *LoadStore = dyn_cast<MemSDNode>(Val: Node);
1171 if (!LoadStore)
1172 return false;
1173
1174 // Is x[offset2] a legal addressing mode? If so then
1175 // reassociating the constants breaks address pattern
1176 TargetLoweringBase::AddrMode AM;
1177 AM.HasBaseReg = true;
1178 AM.BaseOffs = C2APIntVal.getSExtValue();
1179 EVT VT = LoadStore->getMemoryVT();
1180 unsigned AS = LoadStore->getAddressSpace();
1181 Type *AccessTy = VT.getTypeForEVT(Context&: *DAG.getContext());
1182 if (!TLI.isLegalAddressingMode(DL: DAG.getDataLayout(), AM, Ty: AccessTy, AddrSpace: AS))
1183 return false;
1184 }
1185 return true;
1186 }
1187
1188 return false;
1189}
1190
1191/// Helper for DAGCombiner::reassociateOps. Try to reassociate (Opc N0, N1) if
1192/// \p N0 is the same kind of operation as \p Opc.
1193SDValue DAGCombiner::reassociateOpsCommutative(unsigned Opc, const SDLoc &DL,
1194 SDValue N0, SDValue N1,
1195 SDNodeFlags Flags) {
1196 EVT VT = N0.getValueType();
1197
1198 if (N0.getOpcode() != Opc)
1199 return SDValue();
1200
1201 SDValue N00 = N0.getOperand(i: 0);
1202 SDValue N01 = N0.getOperand(i: 1);
1203
1204 if (DAG.isConstantIntBuildVectorOrConstantInt(N: peekThroughBitcasts(V: N01))) {
1205 SDNodeFlags NewFlags;
1206 if (N0.getOpcode() == ISD::ADD && N0->getFlags().hasNoUnsignedWrap() &&
1207 Flags.hasNoUnsignedWrap())
1208 NewFlags.setNoUnsignedWrap(true);
1209
1210 if (DAG.isConstantIntBuildVectorOrConstantInt(N: peekThroughBitcasts(V: N1))) {
1211 // Reassociate: (op (op x, c1), c2) -> (op x, (op c1, c2))
1212 if (SDValue OpNode = DAG.FoldConstantArithmetic(Opcode: Opc, DL, VT, Ops: {N01, N1}))
1213 return DAG.getNode(Opcode: Opc, DL, VT, N1: N00, N2: OpNode, Flags: NewFlags);
1214 return SDValue();
1215 }
1216 if (TLI.isReassocProfitable(DAG, N0, N1)) {
1217 // Reassociate: (op (op x, c1), y) -> (op (op x, y), c1)
1218 // iff (op x, c1) has one use
1219 SDValue OpNode = DAG.getNode(Opcode: Opc, DL: SDLoc(N0), VT, N1: N00, N2: N1, Flags: NewFlags);
1220 return DAG.getNode(Opcode: Opc, DL, VT, N1: OpNode, N2: N01, Flags: NewFlags);
1221 }
1222 }
1223
1224 // Check for repeated operand logic simplifications.
1225 if (Opc == ISD::AND || Opc == ISD::OR) {
1226 // (N00 & N01) & N00 --> N00 & N01
1227 // (N00 & N01) & N01 --> N00 & N01
1228 // (N00 | N01) | N00 --> N00 | N01
1229 // (N00 | N01) | N01 --> N00 | N01
1230 if (N1 == N00 || N1 == N01)
1231 return N0;
1232 }
1233 if (Opc == ISD::XOR) {
1234 // (N00 ^ N01) ^ N00 --> N01
1235 if (N1 == N00)
1236 return N01;
1237 // (N00 ^ N01) ^ N01 --> N00
1238 if (N1 == N01)
1239 return N00;
1240 }
1241
1242 if (TLI.isReassocProfitable(DAG, N0, N1)) {
1243 if (N1 != N01) {
1244 // Reassociate if (op N00, N1) already exist
1245 if (SDNode *NE = DAG.getNodeIfExists(Opcode: Opc, VTList: DAG.getVTList(VT), Ops: {N00, N1})) {
1246 // if Op (Op N00, N1), N01 already exist
1247 // we need to stop reassciate to avoid dead loop
1248 if (!DAG.doesNodeExist(Opcode: Opc, VTList: DAG.getVTList(VT), Ops: {SDValue(NE, 0), N01}))
1249 return DAG.getNode(Opcode: Opc, DL, VT, N1: SDValue(NE, 0), N2: N01);
1250 }
1251 }
1252
1253 if (N1 != N00) {
1254 // Reassociate if (op N01, N1) already exist
1255 if (SDNode *NE = DAG.getNodeIfExists(Opcode: Opc, VTList: DAG.getVTList(VT), Ops: {N01, N1})) {
1256 // if Op (Op N01, N1), N00 already exist
1257 // we need to stop reassciate to avoid dead loop
1258 if (!DAG.doesNodeExist(Opcode: Opc, VTList: DAG.getVTList(VT), Ops: {SDValue(NE, 0), N00}))
1259 return DAG.getNode(Opcode: Opc, DL, VT, N1: SDValue(NE, 0), N2: N00);
1260 }
1261 }
1262
1263 // Reassociate the operands from (OR/AND (OR/AND(N00, N001)), N1) to (OR/AND
1264 // (OR/AND(N00, N1)), N01) when N00 and N1 are comparisons with the same
1265 // predicate or to (OR/AND (OR/AND(N1, N01)), N00) when N01 and N1 are
1266 // comparisons with the same predicate. This enables optimizations as the
1267 // following one:
1268 // CMP(A,C)||CMP(B,C) => CMP(MIN/MAX(A,B), C)
1269 // CMP(A,C)&&CMP(B,C) => CMP(MIN/MAX(A,B), C)
1270 if (Opc == ISD::AND || Opc == ISD::OR) {
1271 if (N1->getOpcode() == ISD::SETCC && N00->getOpcode() == ISD::SETCC &&
1272 N01->getOpcode() == ISD::SETCC) {
1273 ISD::CondCode CC1 = cast<CondCodeSDNode>(Val: N1.getOperand(i: 2))->get();
1274 ISD::CondCode CC00 = cast<CondCodeSDNode>(Val: N00.getOperand(i: 2))->get();
1275 ISD::CondCode CC01 = cast<CondCodeSDNode>(Val: N01.getOperand(i: 2))->get();
1276 if (CC1 == CC00 && CC1 != CC01) {
1277 SDValue OpNode = DAG.getNode(Opcode: Opc, DL: SDLoc(N0), VT, N1: N00, N2: N1, Flags);
1278 return DAG.getNode(Opcode: Opc, DL, VT, N1: OpNode, N2: N01, Flags);
1279 }
1280 if (CC1 == CC01 && CC1 != CC00) {
1281 SDValue OpNode = DAG.getNode(Opcode: Opc, DL: SDLoc(N0), VT, N1: N01, N2: N1, Flags);
1282 return DAG.getNode(Opcode: Opc, DL, VT, N1: OpNode, N2: N00, Flags);
1283 }
1284 }
1285 }
1286 }
1287
1288 return SDValue();
1289}
1290
1291/// Try to reassociate commutative (Opc N0, N1) if either \p N0 or \p N1 is the
1292/// same kind of operation as \p Opc.
1293SDValue DAGCombiner::reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
1294 SDValue N1, SDNodeFlags Flags) {
1295 assert(TLI.isCommutativeBinOp(Opc) && "Operation not commutative.");
1296
1297 // Floating-point reassociation is not allowed without loose FP math.
1298 if (N0.getValueType().isFloatingPoint() ||
1299 N1.getValueType().isFloatingPoint())
1300 if (!Flags.hasAllowReassociation() || !Flags.hasNoSignedZeros())
1301 return SDValue();
1302
1303 if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N0, N1, Flags))
1304 return Combined;
1305 if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N0: N1, N1: N0, Flags))
1306 return Combined;
1307 return SDValue();
1308}
1309
1310// Try to fold Opc(vecreduce(x), vecreduce(y)) -> vecreduce(Opc(x, y))
1311// Note that we only expect Flags to be passed from FP operations. For integer
1312// operations they need to be dropped.
1313SDValue DAGCombiner::reassociateReduction(unsigned RedOpc, unsigned Opc,
1314 const SDLoc &DL, EVT VT, SDValue N0,
1315 SDValue N1, SDNodeFlags Flags) {
1316 if (N0.getOpcode() == RedOpc && N1.getOpcode() == RedOpc &&
1317 N0.getOperand(i: 0).getValueType() == N1.getOperand(i: 0).getValueType() &&
1318 N0->hasOneUse() && N1->hasOneUse() &&
1319 TLI.isOperationLegalOrCustom(Op: Opc, VT: N0.getOperand(i: 0).getValueType()) &&
1320 TLI.shouldReassociateReduction(RedOpc, VT: N0.getOperand(i: 0).getValueType())) {
1321 SelectionDAG::FlagInserter FlagsInserter(DAG, Flags);
1322 return DAG.getNode(Opcode: RedOpc, DL, VT,
1323 Operand: DAG.getNode(Opcode: Opc, DL, VT: N0.getOperand(i: 0).getValueType(),
1324 N1: N0.getOperand(i: 0), N2: N1.getOperand(i: 0)));
1325 }
1326 return SDValue();
1327}
1328
1329SDValue DAGCombiner::CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
1330 bool AddTo) {
1331 assert(N->getNumValues() == NumTo && "Broken CombineTo call!");
1332 ++NodesCombined;
1333 LLVM_DEBUG(dbgs() << "\nReplacing.1 "; N->dump(&DAG); dbgs() << "\nWith: ";
1334 To[0].dump(&DAG);
1335 dbgs() << " and " << NumTo - 1 << " other values\n");
1336 for (unsigned i = 0, e = NumTo; i != e; ++i)
1337 assert((!To[i].getNode() ||
1338 N->getValueType(i) == To[i].getValueType()) &&
1339 "Cannot combine value to value of different type!");
1340
1341 WorklistRemover DeadNodes(*this);
1342 DAG.ReplaceAllUsesWith(From: N, To);
1343 if (AddTo) {
1344 // Push the new nodes and any users onto the worklist
1345 for (unsigned i = 0, e = NumTo; i != e; ++i) {
1346 if (To[i].getNode())
1347 AddToWorklistWithUsers(N: To[i].getNode());
1348 }
1349 }
1350
1351 // Finally, if the node is now dead, remove it from the graph. The node
1352 // may not be dead if the replacement process recursively simplified to
1353 // something else needing this node.
1354 if (N->use_empty())
1355 deleteAndRecombine(N);
1356 return SDValue(N, 0);
1357}
1358
1359void DAGCombiner::
1360CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
1361 // Replace the old value with the new one.
1362 ++NodesCombined;
1363 LLVM_DEBUG(dbgs() << "\nReplacing.2 "; TLO.Old.dump(&DAG);
1364 dbgs() << "\nWith: "; TLO.New.dump(&DAG); dbgs() << '\n');
1365
1366 // Replace all uses.
1367 DAG.ReplaceAllUsesOfValueWith(From: TLO.Old, To: TLO.New);
1368
1369 // Push the new node and any (possibly new) users onto the worklist.
1370 AddToWorklistWithUsers(N: TLO.New.getNode());
1371
1372 // Finally, if the node is now dead, remove it from the graph.
1373 recursivelyDeleteUnusedNodes(N: TLO.Old.getNode());
1374}
1375
1376/// Check the specified integer node value to see if it can be simplified or if
1377/// things it uses can be simplified by bit propagation. If so, return true.
1378bool DAGCombiner::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
1379 const APInt &DemandedElts,
1380 bool AssumeSingleUse) {
1381 TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1382 KnownBits Known;
1383 if (!TLI.SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO, Depth: 0,
1384 AssumeSingleUse))
1385 return false;
1386
1387 // Revisit the node.
1388 AddToWorklist(N: Op.getNode());
1389
1390 CommitTargetLoweringOpt(TLO);
1391 return true;
1392}
1393
1394/// Check the specified vector node value to see if it can be simplified or
1395/// if things it uses can be simplified as it only uses some of the elements.
1396/// If so, return true.
1397bool DAGCombiner::SimplifyDemandedVectorElts(SDValue Op,
1398 const APInt &DemandedElts,
1399 bool AssumeSingleUse) {
1400 TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1401 APInt KnownUndef, KnownZero;
1402 if (!TLI.SimplifyDemandedVectorElts(Op, DemandedEltMask: DemandedElts, KnownUndef, KnownZero,
1403 TLO, Depth: 0, AssumeSingleUse))
1404 return false;
1405
1406 // Revisit the node.
1407 AddToWorklist(N: Op.getNode());
1408
1409 CommitTargetLoweringOpt(TLO);
1410 return true;
1411}
1412
1413void DAGCombiner::ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad) {
1414 SDLoc DL(Load);
1415 EVT VT = Load->getValueType(ResNo: 0);
1416 SDValue Trunc = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: SDValue(ExtLoad, 0));
1417
1418 LLVM_DEBUG(dbgs() << "\nReplacing.9 "; Load->dump(&DAG); dbgs() << "\nWith: ";
1419 Trunc.dump(&DAG); dbgs() << '\n');
1420
1421 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Load, 0), To: Trunc);
1422 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Load, 1), To: SDValue(ExtLoad, 1));
1423
1424 AddToWorklist(N: Trunc.getNode());
1425 recursivelyDeleteUnusedNodes(N: Load);
1426}
1427
1428SDValue DAGCombiner::PromoteOperand(SDValue Op, EVT PVT, bool &Replace) {
1429 Replace = false;
1430 SDLoc DL(Op);
1431 if (ISD::isUNINDEXEDLoad(N: Op.getNode())) {
1432 LoadSDNode *LD = cast<LoadSDNode>(Val&: Op);
1433 EVT MemVT = LD->getMemoryVT();
1434 ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(N: LD) ? ISD::EXTLOAD
1435 : LD->getExtensionType();
1436 Replace = true;
1437 return DAG.getExtLoad(ExtType, dl: DL, VT: PVT,
1438 Chain: LD->getChain(), Ptr: LD->getBasePtr(),
1439 MemVT, MMO: LD->getMemOperand());
1440 }
1441
1442 unsigned Opc = Op.getOpcode();
1443 switch (Opc) {
1444 default: break;
1445 case ISD::AssertSext:
1446 if (SDValue Op0 = SExtPromoteOperand(Op: Op.getOperand(i: 0), PVT))
1447 return DAG.getNode(Opcode: ISD::AssertSext, DL, VT: PVT, N1: Op0, N2: Op.getOperand(i: 1));
1448 break;
1449 case ISD::AssertZext:
1450 if (SDValue Op0 = ZExtPromoteOperand(Op: Op.getOperand(i: 0), PVT))
1451 return DAG.getNode(Opcode: ISD::AssertZext, DL, VT: PVT, N1: Op0, N2: Op.getOperand(i: 1));
1452 break;
1453 case ISD::Constant: {
1454 unsigned ExtOpc =
1455 Op.getValueType().isByteSized() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
1456 return DAG.getNode(Opcode: ExtOpc, DL, VT: PVT, Operand: Op);
1457 }
1458 }
1459
1460 if (!TLI.isOperationLegal(Op: ISD::ANY_EXTEND, VT: PVT))
1461 return SDValue();
1462 return DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT: PVT, Operand: Op);
1463}
1464
1465SDValue DAGCombiner::SExtPromoteOperand(SDValue Op, EVT PVT) {
1466 if (!TLI.isOperationLegal(Op: ISD::SIGN_EXTEND_INREG, VT: PVT))
1467 return SDValue();
1468 EVT OldVT = Op.getValueType();
1469 SDLoc DL(Op);
1470 bool Replace = false;
1471 SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1472 if (!NewOp.getNode())
1473 return SDValue();
1474 AddToWorklist(N: NewOp.getNode());
1475
1476 if (Replace)
1477 ReplaceLoadWithPromotedLoad(Load: Op.getNode(), ExtLoad: NewOp.getNode());
1478 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL, VT: NewOp.getValueType(), N1: NewOp,
1479 N2: DAG.getValueType(OldVT));
1480}
1481
1482SDValue DAGCombiner::ZExtPromoteOperand(SDValue Op, EVT PVT) {
1483 EVT OldVT = Op.getValueType();
1484 SDLoc DL(Op);
1485 bool Replace = false;
1486 SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1487 if (!NewOp.getNode())
1488 return SDValue();
1489 AddToWorklist(N: NewOp.getNode());
1490
1491 if (Replace)
1492 ReplaceLoadWithPromotedLoad(Load: Op.getNode(), ExtLoad: NewOp.getNode());
1493 return DAG.getZeroExtendInReg(Op: NewOp, DL, VT: OldVT);
1494}
1495
1496/// Promote the specified integer binary operation if the target indicates it is
1497/// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1498/// i32 since i16 instructions are longer.
1499SDValue DAGCombiner::PromoteIntBinOp(SDValue Op) {
1500 if (!LegalOperations)
1501 return SDValue();
1502
1503 EVT VT = Op.getValueType();
1504 if (VT.isVector() || !VT.isInteger())
1505 return SDValue();
1506
1507 // If operation type is 'undesirable', e.g. i16 on x86, consider
1508 // promoting it.
1509 unsigned Opc = Op.getOpcode();
1510 if (TLI.isTypeDesirableForOp(Opc, VT))
1511 return SDValue();
1512
1513 EVT PVT = VT;
1514 // Consult target whether it is a good idea to promote this operation and
1515 // what's the right type to promote it to.
1516 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1517 assert(PVT != VT && "Don't know what type to promote to!");
1518
1519 LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1520
1521 bool Replace0 = false;
1522 SDValue N0 = Op.getOperand(i: 0);
1523 SDValue NN0 = PromoteOperand(Op: N0, PVT, Replace&: Replace0);
1524
1525 bool Replace1 = false;
1526 SDValue N1 = Op.getOperand(i: 1);
1527 SDValue NN1 = PromoteOperand(Op: N1, PVT, Replace&: Replace1);
1528 SDLoc DL(Op);
1529
1530 SDValue RV =
1531 DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: DAG.getNode(Opcode: Opc, DL, VT: PVT, N1: NN0, N2: NN1));
1532
1533 // We are always replacing N0/N1's use in N and only need additional
1534 // replacements if there are additional uses.
1535 // Note: We are checking uses of the *nodes* (SDNode) rather than values
1536 // (SDValue) here because the node may reference multiple values
1537 // (for example, the chain value of a load node).
1538 Replace0 &= !N0->hasOneUse();
1539 Replace1 &= (N0 != N1) && !N1->hasOneUse();
1540
1541 // Combine Op here so it is preserved past replacements.
1542 CombineTo(N: Op.getNode(), Res: RV);
1543
1544 // If operands have a use ordering, make sure we deal with
1545 // predecessor first.
1546 if (Replace0 && Replace1 && N0->isPredecessorOf(N: N1.getNode())) {
1547 std::swap(a&: N0, b&: N1);
1548 std::swap(a&: NN0, b&: NN1);
1549 }
1550
1551 if (Replace0) {
1552 AddToWorklist(N: NN0.getNode());
1553 ReplaceLoadWithPromotedLoad(Load: N0.getNode(), ExtLoad: NN0.getNode());
1554 }
1555 if (Replace1) {
1556 AddToWorklist(N: NN1.getNode());
1557 ReplaceLoadWithPromotedLoad(Load: N1.getNode(), ExtLoad: NN1.getNode());
1558 }
1559 return Op;
1560 }
1561 return SDValue();
1562}
1563
1564/// Promote the specified integer shift operation if the target indicates it is
1565/// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1566/// i32 since i16 instructions are longer.
1567SDValue DAGCombiner::PromoteIntShiftOp(SDValue Op) {
1568 if (!LegalOperations)
1569 return SDValue();
1570
1571 EVT VT = Op.getValueType();
1572 if (VT.isVector() || !VT.isInteger())
1573 return SDValue();
1574
1575 // If operation type is 'undesirable', e.g. i16 on x86, consider
1576 // promoting it.
1577 unsigned Opc = Op.getOpcode();
1578 if (TLI.isTypeDesirableForOp(Opc, VT))
1579 return SDValue();
1580
1581 EVT PVT = VT;
1582 // Consult target whether it is a good idea to promote this operation and
1583 // what's the right type to promote it to.
1584 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1585 assert(PVT != VT && "Don't know what type to promote to!");
1586
1587 LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1588
1589 bool Replace = false;
1590 SDValue N0 = Op.getOperand(i: 0);
1591 if (Opc == ISD::SRA)
1592 N0 = SExtPromoteOperand(Op: N0, PVT);
1593 else if (Opc == ISD::SRL)
1594 N0 = ZExtPromoteOperand(Op: N0, PVT);
1595 else
1596 N0 = PromoteOperand(Op: N0, PVT, Replace);
1597
1598 if (!N0.getNode())
1599 return SDValue();
1600
1601 SDLoc DL(Op);
1602 SDValue N1 = Op.getOperand(i: 1);
1603 SDValue RV =
1604 DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: DAG.getNode(Opcode: Opc, DL, VT: PVT, N1: N0, N2: N1));
1605
1606 if (Replace)
1607 ReplaceLoadWithPromotedLoad(Load: Op.getOperand(i: 0).getNode(), ExtLoad: N0.getNode());
1608
1609 // Deal with Op being deleted.
1610 if (Op && Op.getOpcode() != ISD::DELETED_NODE)
1611 return RV;
1612 }
1613 return SDValue();
1614}
1615
1616SDValue DAGCombiner::PromoteExtend(SDValue Op) {
1617 if (!LegalOperations)
1618 return SDValue();
1619
1620 EVT VT = Op.getValueType();
1621 if (VT.isVector() || !VT.isInteger())
1622 return SDValue();
1623
1624 // If operation type is 'undesirable', e.g. i16 on x86, consider
1625 // promoting it.
1626 unsigned Opc = Op.getOpcode();
1627 if (TLI.isTypeDesirableForOp(Opc, VT))
1628 return SDValue();
1629
1630 EVT PVT = VT;
1631 // Consult target whether it is a good idea to promote this operation and
1632 // what's the right type to promote it to.
1633 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1634 assert(PVT != VT && "Don't know what type to promote to!");
1635 // fold (aext (aext x)) -> (aext x)
1636 // fold (aext (zext x)) -> (zext x)
1637 // fold (aext (sext x)) -> (sext x)
1638 LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1639 return DAG.getNode(Opcode: Op.getOpcode(), DL: SDLoc(Op), VT, Operand: Op.getOperand(i: 0));
1640 }
1641 return SDValue();
1642}
1643
1644bool DAGCombiner::PromoteLoad(SDValue Op) {
1645 if (!LegalOperations)
1646 return false;
1647
1648 if (!ISD::isUNINDEXEDLoad(N: Op.getNode()))
1649 return false;
1650
1651 EVT VT = Op.getValueType();
1652 if (VT.isVector() || !VT.isInteger())
1653 return false;
1654
1655 // If operation type is 'undesirable', e.g. i16 on x86, consider
1656 // promoting it.
1657 unsigned Opc = Op.getOpcode();
1658 if (TLI.isTypeDesirableForOp(Opc, VT))
1659 return false;
1660
1661 EVT PVT = VT;
1662 // Consult target whether it is a good idea to promote this operation and
1663 // what's the right type to promote it to.
1664 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1665 assert(PVT != VT && "Don't know what type to promote to!");
1666
1667 SDLoc DL(Op);
1668 SDNode *N = Op.getNode();
1669 LoadSDNode *LD = cast<LoadSDNode>(Val: N);
1670 EVT MemVT = LD->getMemoryVT();
1671 ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(N: LD) ? ISD::EXTLOAD
1672 : LD->getExtensionType();
1673 SDValue NewLD = DAG.getExtLoad(ExtType, dl: DL, VT: PVT,
1674 Chain: LD->getChain(), Ptr: LD->getBasePtr(),
1675 MemVT, MMO: LD->getMemOperand());
1676 SDValue Result = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: NewLD);
1677
1678 LLVM_DEBUG(dbgs() << "\nPromoting "; N->dump(&DAG); dbgs() << "\nTo: ";
1679 Result.dump(&DAG); dbgs() << '\n');
1680
1681 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Result);
1682 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 1), To: NewLD.getValue(R: 1));
1683
1684 AddToWorklist(N: Result.getNode());
1685 recursivelyDeleteUnusedNodes(N);
1686 return true;
1687 }
1688
1689 return false;
1690}
1691
1692/// Recursively delete a node which has no uses and any operands for
1693/// which it is the only use.
1694///
1695/// Note that this both deletes the nodes and removes them from the worklist.
1696/// It also adds any nodes who have had a user deleted to the worklist as they
1697/// may now have only one use and subject to other combines.
1698bool DAGCombiner::recursivelyDeleteUnusedNodes(SDNode *N) {
1699 if (!N->use_empty())
1700 return false;
1701
1702 SmallSetVector<SDNode *, 16> Nodes;
1703 Nodes.insert(X: N);
1704 do {
1705 N = Nodes.pop_back_val();
1706 if (!N)
1707 continue;
1708
1709 if (N->use_empty()) {
1710 for (const SDValue &ChildN : N->op_values())
1711 Nodes.insert(X: ChildN.getNode());
1712
1713 removeFromWorklist(N);
1714 DAG.DeleteNode(N);
1715 } else {
1716 AddToWorklist(N);
1717 }
1718 } while (!Nodes.empty());
1719 return true;
1720}
1721
1722//===----------------------------------------------------------------------===//
1723// Main DAG Combiner implementation
1724//===----------------------------------------------------------------------===//
1725
1726void DAGCombiner::Run(CombineLevel AtLevel) {
1727 // set the instance variables, so that the various visit routines may use it.
1728 Level = AtLevel;
1729 LegalDAG = Level >= AfterLegalizeDAG;
1730 LegalOperations = Level >= AfterLegalizeVectorOps;
1731 LegalTypes = Level >= AfterLegalizeTypes;
1732
1733 WorklistInserter AddNodes(*this);
1734
1735 // Add all the dag nodes to the worklist.
1736 //
1737 // Note: All nodes are not added to PruningList here, this is because the only
1738 // nodes which can be deleted are those which have no uses and all other nodes
1739 // which would otherwise be added to the worklist by the first call to
1740 // getNextWorklistEntry are already present in it.
1741 for (SDNode &Node : DAG.allnodes())
1742 AddToWorklist(N: &Node, /* IsCandidateForPruning */ Node.use_empty());
1743
1744 // Create a dummy node (which is not added to allnodes), that adds a reference
1745 // to the root node, preventing it from being deleted, and tracking any
1746 // changes of the root.
1747 HandleSDNode Dummy(DAG.getRoot());
1748
1749 // While we have a valid worklist entry node, try to combine it.
1750 while (SDNode *N = getNextWorklistEntry()) {
1751 // If N has no uses, it is dead. Make sure to revisit all N's operands once
1752 // N is deleted from the DAG, since they too may now be dead or may have a
1753 // reduced number of uses, allowing other xforms.
1754 if (recursivelyDeleteUnusedNodes(N))
1755 continue;
1756
1757 WorklistRemover DeadNodes(*this);
1758
1759 // If this combine is running after legalizing the DAG, re-legalize any
1760 // nodes pulled off the worklist.
1761 if (LegalDAG) {
1762 SmallSetVector<SDNode *, 16> UpdatedNodes;
1763 bool NIsValid = DAG.LegalizeOp(N, UpdatedNodes);
1764
1765 for (SDNode *LN : UpdatedNodes)
1766 AddToWorklistWithUsers(N: LN);
1767
1768 if (!NIsValid)
1769 continue;
1770 }
1771
1772 LLVM_DEBUG(dbgs() << "\nCombining: "; N->dump(&DAG));
1773
1774 // Add any operands of the new node which have not yet been combined to the
1775 // worklist as well. getNextWorklistEntry flags nodes that have been
1776 // combined before. Because the worklist uniques things already, this won't
1777 // repeatedly process the same operand.
1778 for (const SDValue &ChildN : N->op_values())
1779 AddToWorklist(N: ChildN.getNode(), /*IsCandidateForPruning=*/true,
1780 /*SkipIfCombinedBefore=*/true);
1781
1782 SDValue RV = combine(N);
1783
1784 if (!RV.getNode())
1785 continue;
1786
1787 ++NodesCombined;
1788
1789 // Invalidate cached info.
1790 ChainsWithoutMergeableStores.clear();
1791
1792 // If we get back the same node we passed in, rather than a new node or
1793 // zero, we know that the node must have defined multiple values and
1794 // CombineTo was used. Since CombineTo takes care of the worklist
1795 // mechanics for us, we have no work to do in this case.
1796 if (RV.getNode() == N)
1797 continue;
1798
1799 assert(N->getOpcode() != ISD::DELETED_NODE &&
1800 RV.getOpcode() != ISD::DELETED_NODE &&
1801 "Node was deleted but visit returned new node!");
1802
1803 LLVM_DEBUG(dbgs() << " ... into: "; RV.dump(&DAG));
1804
1805 if (N->getNumValues() == RV->getNumValues())
1806 DAG.ReplaceAllUsesWith(From: N, To: RV.getNode());
1807 else {
1808 assert(N->getValueType(0) == RV.getValueType() &&
1809 N->getNumValues() == 1 && "Type mismatch");
1810 DAG.ReplaceAllUsesWith(From: N, To: &RV);
1811 }
1812
1813 // Push the new node and any users onto the worklist. Omit this if the
1814 // new node is the EntryToken (e.g. if a store managed to get optimized
1815 // out), because re-visiting the EntryToken and its users will not uncover
1816 // any additional opportunities, but there may be a large number of such
1817 // users, potentially causing compile time explosion.
1818 if (RV.getOpcode() != ISD::EntryToken)
1819 AddToWorklistWithUsers(N: RV.getNode());
1820
1821 // Finally, if the node is now dead, remove it from the graph. The node
1822 // may not be dead if the replacement process recursively simplified to
1823 // something else needing this node. This will also take care of adding any
1824 // operands which have lost a user to the worklist.
1825 recursivelyDeleteUnusedNodes(N);
1826 }
1827
1828 // If the root changed (e.g. it was a dead load, update the root).
1829 DAG.setRoot(Dummy.getValue());
1830 DAG.RemoveDeadNodes();
1831}
1832
1833SDValue DAGCombiner::visit(SDNode *N) {
1834 // clang-format off
1835 switch (N->getOpcode()) {
1836 default: break;
1837 case ISD::TokenFactor: return visitTokenFactor(N);
1838 case ISD::MERGE_VALUES: return visitMERGE_VALUES(N);
1839 case ISD::ADD: return visitADD(N);
1840 case ISD::SUB: return visitSUB(N);
1841 case ISD::SADDSAT:
1842 case ISD::UADDSAT: return visitADDSAT(N);
1843 case ISD::SSUBSAT:
1844 case ISD::USUBSAT: return visitSUBSAT(N);
1845 case ISD::ADDC: return visitADDC(N);
1846 case ISD::SADDO:
1847 case ISD::UADDO: return visitADDO(N);
1848 case ISD::SUBC: return visitSUBC(N);
1849 case ISD::SSUBO:
1850 case ISD::USUBO: return visitSUBO(N);
1851 case ISD::ADDE: return visitADDE(N);
1852 case ISD::UADDO_CARRY: return visitUADDO_CARRY(N);
1853 case ISD::SADDO_CARRY: return visitSADDO_CARRY(N);
1854 case ISD::SUBE: return visitSUBE(N);
1855 case ISD::USUBO_CARRY: return visitUSUBO_CARRY(N);
1856 case ISD::SSUBO_CARRY: return visitSSUBO_CARRY(N);
1857 case ISD::SMULFIX:
1858 case ISD::SMULFIXSAT:
1859 case ISD::UMULFIX:
1860 case ISD::UMULFIXSAT: return visitMULFIX(N);
1861 case ISD::MUL: return visitMUL<EmptyMatchContext>(N);
1862 case ISD::SDIV: return visitSDIV(N);
1863 case ISD::UDIV: return visitUDIV(N);
1864 case ISD::SREM:
1865 case ISD::UREM: return visitREM(N);
1866 case ISD::MULHU: return visitMULHU(N);
1867 case ISD::MULHS: return visitMULHS(N);
1868 case ISD::AVGFLOORS:
1869 case ISD::AVGFLOORU:
1870 case ISD::AVGCEILS:
1871 case ISD::AVGCEILU: return visitAVG(N);
1872 case ISD::ABDS:
1873 case ISD::ABDU: return visitABD(N);
1874 case ISD::SMUL_LOHI: return visitSMUL_LOHI(N);
1875 case ISD::UMUL_LOHI: return visitUMUL_LOHI(N);
1876 case ISD::SMULO:
1877 case ISD::UMULO: return visitMULO(N);
1878 case ISD::SMIN:
1879 case ISD::SMAX:
1880 case ISD::UMIN:
1881 case ISD::UMAX: return visitIMINMAX(N);
1882 case ISD::AND: return visitAND(N);
1883 case ISD::OR: return visitOR(N);
1884 case ISD::XOR: return visitXOR(N);
1885 case ISD::SHL: return visitSHL(N);
1886 case ISD::SRA: return visitSRA(N);
1887 case ISD::SRL: return visitSRL(N);
1888 case ISD::ROTR:
1889 case ISD::ROTL: return visitRotate(N);
1890 case ISD::FSHL:
1891 case ISD::FSHR: return visitFunnelShift(N);
1892 case ISD::SSHLSAT:
1893 case ISD::USHLSAT: return visitSHLSAT(N);
1894 case ISD::ABS: return visitABS(N);
1895 case ISD::BSWAP: return visitBSWAP(N);
1896 case ISD::BITREVERSE: return visitBITREVERSE(N);
1897 case ISD::CTLZ: return visitCTLZ(N);
1898 case ISD::CTLZ_ZERO_UNDEF: return visitCTLZ_ZERO_UNDEF(N);
1899 case ISD::CTTZ: return visitCTTZ(N);
1900 case ISD::CTTZ_ZERO_UNDEF: return visitCTTZ_ZERO_UNDEF(N);
1901 case ISD::CTPOP: return visitCTPOP(N);
1902 case ISD::SELECT: return visitSELECT(N);
1903 case ISD::VSELECT: return visitVSELECT(N);
1904 case ISD::SELECT_CC: return visitSELECT_CC(N);
1905 case ISD::SETCC: return visitSETCC(N);
1906 case ISD::SETCCCARRY: return visitSETCCCARRY(N);
1907 case ISD::SIGN_EXTEND: return visitSIGN_EXTEND(N);
1908 case ISD::ZERO_EXTEND: return visitZERO_EXTEND(N);
1909 case ISD::ANY_EXTEND: return visitANY_EXTEND(N);
1910 case ISD::AssertSext:
1911 case ISD::AssertZext: return visitAssertExt(N);
1912 case ISD::AssertAlign: return visitAssertAlign(N);
1913 case ISD::SIGN_EXTEND_INREG: return visitSIGN_EXTEND_INREG(N);
1914 case ISD::SIGN_EXTEND_VECTOR_INREG:
1915 case ISD::ZERO_EXTEND_VECTOR_INREG:
1916 case ISD::ANY_EXTEND_VECTOR_INREG: return visitEXTEND_VECTOR_INREG(N);
1917 case ISD::TRUNCATE: return visitTRUNCATE(N);
1918 case ISD::BITCAST: return visitBITCAST(N);
1919 case ISD::BUILD_PAIR: return visitBUILD_PAIR(N);
1920 case ISD::FADD: return visitFADD(N);
1921 case ISD::STRICT_FADD: return visitSTRICT_FADD(N);
1922 case ISD::FSUB: return visitFSUB(N);
1923 case ISD::FMUL: return visitFMUL(N);
1924 case ISD::FMA: return visitFMA<EmptyMatchContext>(N);
1925 case ISD::FMAD: return visitFMAD(N);
1926 case ISD::FDIV: return visitFDIV(N);
1927 case ISD::FREM: return visitFREM(N);
1928 case ISD::FSQRT: return visitFSQRT(N);
1929 case ISD::FCOPYSIGN: return visitFCOPYSIGN(N);
1930 case ISD::FPOW: return visitFPOW(N);
1931 case ISD::SINT_TO_FP: return visitSINT_TO_FP(N);
1932 case ISD::UINT_TO_FP: return visitUINT_TO_FP(N);
1933 case ISD::FP_TO_SINT: return visitFP_TO_SINT(N);
1934 case ISD::FP_TO_UINT: return visitFP_TO_UINT(N);
1935 case ISD::LRINT:
1936 case ISD::LLRINT: return visitXRINT(N);
1937 case ISD::FP_ROUND: return visitFP_ROUND(N);
1938 case ISD::FP_EXTEND: return visitFP_EXTEND(N);
1939 case ISD::FNEG: return visitFNEG(N);
1940 case ISD::FABS: return visitFABS(N);
1941 case ISD::FFLOOR: return visitFFLOOR(N);
1942 case ISD::FMINNUM:
1943 case ISD::FMAXNUM:
1944 case ISD::FMINIMUM:
1945 case ISD::FMAXIMUM: return visitFMinMax(N);
1946 case ISD::FCEIL: return visitFCEIL(N);
1947 case ISD::FTRUNC: return visitFTRUNC(N);
1948 case ISD::FFREXP: return visitFFREXP(N);
1949 case ISD::BRCOND: return visitBRCOND(N);
1950 case ISD::BR_CC: return visitBR_CC(N);
1951 case ISD::LOAD: return visitLOAD(N);
1952 case ISD::STORE: return visitSTORE(N);
1953 case ISD::ATOMIC_STORE: return visitATOMIC_STORE(N);
1954 case ISD::INSERT_VECTOR_ELT: return visitINSERT_VECTOR_ELT(N);
1955 case ISD::EXTRACT_VECTOR_ELT: return visitEXTRACT_VECTOR_ELT(N);
1956 case ISD::BUILD_VECTOR: return visitBUILD_VECTOR(N);
1957 case ISD::CONCAT_VECTORS: return visitCONCAT_VECTORS(N);
1958 case ISD::EXTRACT_SUBVECTOR: return visitEXTRACT_SUBVECTOR(N);
1959 case ISD::VECTOR_SHUFFLE: return visitVECTOR_SHUFFLE(N);
1960 case ISD::SCALAR_TO_VECTOR: return visitSCALAR_TO_VECTOR(N);
1961 case ISD::INSERT_SUBVECTOR: return visitINSERT_SUBVECTOR(N);
1962 case ISD::MGATHER: return visitMGATHER(N);
1963 case ISD::MLOAD: return visitMLOAD(N);
1964 case ISD::MSCATTER: return visitMSCATTER(N);
1965 case ISD::MSTORE: return visitMSTORE(N);
1966 case ISD::VECTOR_COMPRESS: return visitVECTOR_COMPRESS(N);
1967 case ISD::LIFETIME_END: return visitLIFETIME_END(N);
1968 case ISD::FP_TO_FP16: return visitFP_TO_FP16(N);
1969 case ISD::FP16_TO_FP: return visitFP16_TO_FP(N);
1970 case ISD::FP_TO_BF16: return visitFP_TO_BF16(N);
1971 case ISD::BF16_TO_FP: return visitBF16_TO_FP(N);
1972 case ISD::FREEZE: return visitFREEZE(N);
1973 case ISD::GET_FPENV_MEM: return visitGET_FPENV_MEM(N);
1974 case ISD::SET_FPENV_MEM: return visitSET_FPENV_MEM(N);
1975 case ISD::VECREDUCE_FADD:
1976 case ISD::VECREDUCE_FMUL:
1977 case ISD::VECREDUCE_ADD:
1978 case ISD::VECREDUCE_MUL:
1979 case ISD::VECREDUCE_AND:
1980 case ISD::VECREDUCE_OR:
1981 case ISD::VECREDUCE_XOR:
1982 case ISD::VECREDUCE_SMAX:
1983 case ISD::VECREDUCE_SMIN:
1984 case ISD::VECREDUCE_UMAX:
1985 case ISD::VECREDUCE_UMIN:
1986 case ISD::VECREDUCE_FMAX:
1987 case ISD::VECREDUCE_FMIN:
1988 case ISD::VECREDUCE_FMAXIMUM:
1989 case ISD::VECREDUCE_FMINIMUM: return visitVECREDUCE(N);
1990#define BEGIN_REGISTER_VP_SDNODE(SDOPC, ...) case ISD::SDOPC:
1991#include "llvm/IR/VPIntrinsics.def"
1992 return visitVPOp(N);
1993 }
1994 // clang-format on
1995 return SDValue();
1996}
1997
1998SDValue DAGCombiner::combine(SDNode *N) {
1999 if (!DebugCounter::shouldExecute(CounterName: DAGCombineCounter))
2000 return SDValue();
2001
2002 SDValue RV;
2003 if (!DisableGenericCombines)
2004 RV = visit(N);
2005
2006 // If nothing happened, try a target-specific DAG combine.
2007 if (!RV.getNode()) {
2008 assert(N->getOpcode() != ISD::DELETED_NODE &&
2009 "Node was deleted but visit returned NULL!");
2010
2011 if (N->getOpcode() >= ISD::BUILTIN_OP_END ||
2012 TLI.hasTargetDAGCombine(NT: (ISD::NodeType)N->getOpcode())) {
2013
2014 // Expose the DAG combiner to the target combiner impls.
2015 TargetLowering::DAGCombinerInfo
2016 DagCombineInfo(DAG, Level, false, this);
2017
2018 RV = TLI.PerformDAGCombine(N, DCI&: DagCombineInfo);
2019 }
2020 }
2021
2022 // If nothing happened still, try promoting the operation.
2023 if (!RV.getNode()) {
2024 switch (N->getOpcode()) {
2025 default: break;
2026 case ISD::ADD:
2027 case ISD::SUB:
2028 case ISD::MUL:
2029 case ISD::AND:
2030 case ISD::OR:
2031 case ISD::XOR:
2032 RV = PromoteIntBinOp(Op: SDValue(N, 0));
2033 break;
2034 case ISD::SHL:
2035 case ISD::SRA:
2036 case ISD::SRL:
2037 RV = PromoteIntShiftOp(Op: SDValue(N, 0));
2038 break;
2039 case ISD::SIGN_EXTEND:
2040 case ISD::ZERO_EXTEND:
2041 case ISD::ANY_EXTEND:
2042 RV = PromoteExtend(Op: SDValue(N, 0));
2043 break;
2044 case ISD::LOAD:
2045 if (PromoteLoad(Op: SDValue(N, 0)))
2046 RV = SDValue(N, 0);
2047 break;
2048 }
2049 }
2050
2051 // If N is a commutative binary node, try to eliminate it if the commuted
2052 // version is already present in the DAG.
2053 if (!RV.getNode() && TLI.isCommutativeBinOp(Opcode: N->getOpcode())) {
2054 SDValue N0 = N->getOperand(Num: 0);
2055 SDValue N1 = N->getOperand(Num: 1);
2056
2057 // Constant operands are canonicalized to RHS.
2058 if (N0 != N1 && (isa<ConstantSDNode>(Val: N0) || !isa<ConstantSDNode>(Val: N1))) {
2059 SDValue Ops[] = {N1, N0};
2060 SDNode *CSENode = DAG.getNodeIfExists(Opcode: N->getOpcode(), VTList: N->getVTList(), Ops,
2061 Flags: N->getFlags());
2062 if (CSENode)
2063 return SDValue(CSENode, 0);
2064 }
2065 }
2066
2067 return RV;
2068}
2069
2070/// Given a node, return its input chain if it has one, otherwise return a null
2071/// sd operand.
2072static SDValue getInputChainForNode(SDNode *N) {
2073 if (unsigned NumOps = N->getNumOperands()) {
2074 if (N->getOperand(Num: 0).getValueType() == MVT::Other)
2075 return N->getOperand(Num: 0);
2076 if (N->getOperand(Num: NumOps-1).getValueType() == MVT::Other)
2077 return N->getOperand(Num: NumOps-1);
2078 for (unsigned i = 1; i < NumOps-1; ++i)
2079 if (N->getOperand(Num: i).getValueType() == MVT::Other)
2080 return N->getOperand(Num: i);
2081 }
2082 return SDValue();
2083}
2084
2085SDValue DAGCombiner::visitTokenFactor(SDNode *N) {
2086 // If N has two operands, where one has an input chain equal to the other,
2087 // the 'other' chain is redundant.
2088 if (N->getNumOperands() == 2) {
2089 if (getInputChainForNode(N: N->getOperand(Num: 0).getNode()) == N->getOperand(Num: 1))
2090 return N->getOperand(Num: 0);
2091 if (getInputChainForNode(N: N->getOperand(Num: 1).getNode()) == N->getOperand(Num: 0))
2092 return N->getOperand(Num: 1);
2093 }
2094
2095 // Don't simplify token factors if optnone.
2096 if (OptLevel == CodeGenOptLevel::None)
2097 return SDValue();
2098
2099 // Don't simplify the token factor if the node itself has too many operands.
2100 if (N->getNumOperands() > TokenFactorInlineLimit)
2101 return SDValue();
2102
2103 // If the sole user is a token factor, we should make sure we have a
2104 // chance to merge them together. This prevents TF chains from inhibiting
2105 // optimizations.
2106 if (N->hasOneUse() && N->use_begin()->getOpcode() == ISD::TokenFactor)
2107 AddToWorklist(N: *(N->use_begin()));
2108
2109 SmallVector<SDNode *, 8> TFs; // List of token factors to visit.
2110 SmallVector<SDValue, 8> Ops; // Ops for replacing token factor.
2111 SmallPtrSet<SDNode*, 16> SeenOps;
2112 bool Changed = false; // If we should replace this token factor.
2113
2114 // Start out with this token factor.
2115 TFs.push_back(Elt: N);
2116
2117 // Iterate through token factors. The TFs grows when new token factors are
2118 // encountered.
2119 for (unsigned i = 0; i < TFs.size(); ++i) {
2120 // Limit number of nodes to inline, to avoid quadratic compile times.
2121 // We have to add the outstanding Token Factors to Ops, otherwise we might
2122 // drop Ops from the resulting Token Factors.
2123 if (Ops.size() > TokenFactorInlineLimit) {
2124 for (unsigned j = i; j < TFs.size(); j++)
2125 Ops.emplace_back(Args&: TFs[j], Args: 0);
2126 // Drop unprocessed Token Factors from TFs, so we do not add them to the
2127 // combiner worklist later.
2128 TFs.resize(N: i);
2129 break;
2130 }
2131
2132 SDNode *TF = TFs[i];
2133 // Check each of the operands.
2134 for (const SDValue &Op : TF->op_values()) {
2135 switch (Op.getOpcode()) {
2136 case ISD::EntryToken:
2137 // Entry tokens don't need to be added to the list. They are
2138 // redundant.
2139 Changed = true;
2140 break;
2141
2142 case ISD::TokenFactor:
2143 if (Op.hasOneUse() && !is_contained(Range&: TFs, Element: Op.getNode())) {
2144 // Queue up for processing.
2145 TFs.push_back(Elt: Op.getNode());
2146 Changed = true;
2147 break;
2148 }
2149 [[fallthrough]];
2150
2151 default:
2152 // Only add if it isn't already in the list.
2153 if (SeenOps.insert(Ptr: Op.getNode()).second)
2154 Ops.push_back(Elt: Op);
2155 else
2156 Changed = true;
2157 break;
2158 }
2159 }
2160 }
2161
2162 // Re-visit inlined Token Factors, to clean them up in case they have been
2163 // removed. Skip the first Token Factor, as this is the current node.
2164 for (unsigned i = 1, e = TFs.size(); i < e; i++)
2165 AddToWorklist(N: TFs[i]);
2166
2167 // Remove Nodes that are chained to another node in the list. Do so
2168 // by walking up chains breath-first stopping when we've seen
2169 // another operand. In general we must climb to the EntryNode, but we can exit
2170 // early if we find all remaining work is associated with just one operand as
2171 // no further pruning is possible.
2172
2173 // List of nodes to search through and original Ops from which they originate.
2174 SmallVector<std::pair<SDNode *, unsigned>, 8> Worklist;
2175 SmallVector<unsigned, 8> OpWorkCount; // Count of work for each Op.
2176 SmallPtrSet<SDNode *, 16> SeenChains;
2177 bool DidPruneOps = false;
2178
2179 unsigned NumLeftToConsider = 0;
2180 for (const SDValue &Op : Ops) {
2181 Worklist.push_back(Elt: std::make_pair(x: Op.getNode(), y: NumLeftToConsider++));
2182 OpWorkCount.push_back(Elt: 1);
2183 }
2184
2185 auto AddToWorklist = [&](unsigned CurIdx, SDNode *Op, unsigned OpNumber) {
2186 // If this is an Op, we can remove the op from the list. Remark any
2187 // search associated with it as from the current OpNumber.
2188 if (SeenOps.contains(Ptr: Op)) {
2189 Changed = true;
2190 DidPruneOps = true;
2191 unsigned OrigOpNumber = 0;
2192 while (OrigOpNumber < Ops.size() && Ops[OrigOpNumber].getNode() != Op)
2193 OrigOpNumber++;
2194 assert((OrigOpNumber != Ops.size()) &&
2195 "expected to find TokenFactor Operand");
2196 // Re-mark worklist from OrigOpNumber to OpNumber
2197 for (unsigned i = CurIdx + 1; i < Worklist.size(); ++i) {
2198 if (Worklist[i].second == OrigOpNumber) {
2199 Worklist[i].second = OpNumber;
2200 }
2201 }
2202 OpWorkCount[OpNumber] += OpWorkCount[OrigOpNumber];
2203 OpWorkCount[OrigOpNumber] = 0;
2204 NumLeftToConsider--;
2205 }
2206 // Add if it's a new chain
2207 if (SeenChains.insert(Ptr: Op).second) {
2208 OpWorkCount[OpNumber]++;
2209 Worklist.push_back(Elt: std::make_pair(x&: Op, y&: OpNumber));
2210 }
2211 };
2212
2213 for (unsigned i = 0; i < Worklist.size() && i < 1024; ++i) {
2214 // We need at least be consider at least 2 Ops to prune.
2215 if (NumLeftToConsider <= 1)
2216 break;
2217 auto CurNode = Worklist[i].first;
2218 auto CurOpNumber = Worklist[i].second;
2219 assert((OpWorkCount[CurOpNumber] > 0) &&
2220 "Node should not appear in worklist");
2221 switch (CurNode->getOpcode()) {
2222 case ISD::EntryToken:
2223 // Hitting EntryToken is the only way for the search to terminate without
2224 // hitting
2225 // another operand's search. Prevent us from marking this operand
2226 // considered.
2227 NumLeftToConsider++;
2228 break;
2229 case ISD::TokenFactor:
2230 for (const SDValue &Op : CurNode->op_values())
2231 AddToWorklist(i, Op.getNode(), CurOpNumber);
2232 break;
2233 case ISD::LIFETIME_START:
2234 case ISD::LIFETIME_END:
2235 case ISD::CopyFromReg:
2236 case ISD::CopyToReg:
2237 AddToWorklist(i, CurNode->getOperand(Num: 0).getNode(), CurOpNumber);
2238 break;
2239 default:
2240 if (auto *MemNode = dyn_cast<MemSDNode>(Val: CurNode))
2241 AddToWorklist(i, MemNode->getChain().getNode(), CurOpNumber);
2242 break;
2243 }
2244 OpWorkCount[CurOpNumber]--;
2245 if (OpWorkCount[CurOpNumber] == 0)
2246 NumLeftToConsider--;
2247 }
2248
2249 // If we've changed things around then replace token factor.
2250 if (Changed) {
2251 SDValue Result;
2252 if (Ops.empty()) {
2253 // The entry token is the only possible outcome.
2254 Result = DAG.getEntryNode();
2255 } else {
2256 if (DidPruneOps) {
2257 SmallVector<SDValue, 8> PrunedOps;
2258 //
2259 for (const SDValue &Op : Ops) {
2260 if (SeenChains.count(Ptr: Op.getNode()) == 0)
2261 PrunedOps.push_back(Elt: Op);
2262 }
2263 Result = DAG.getTokenFactor(DL: SDLoc(N), Vals&: PrunedOps);
2264 } else {
2265 Result = DAG.getTokenFactor(DL: SDLoc(N), Vals&: Ops);
2266 }
2267 }
2268 return Result;
2269 }
2270 return SDValue();
2271}
2272
2273/// MERGE_VALUES can always be eliminated.
2274SDValue DAGCombiner::visitMERGE_VALUES(SDNode *N) {
2275 WorklistRemover DeadNodes(*this);
2276 // Replacing results may cause a different MERGE_VALUES to suddenly
2277 // be CSE'd with N, and carry its uses with it. Iterate until no
2278 // uses remain, to ensure that the node can be safely deleted.
2279 // First add the users of this node to the work list so that they
2280 // can be tried again once they have new operands.
2281 AddUsersToWorklist(N);
2282 do {
2283 // Do as a single replacement to avoid rewalking use lists.
2284 SmallVector<SDValue, 8> Ops;
2285 for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
2286 Ops.push_back(Elt: N->getOperand(Num: i));
2287 DAG.ReplaceAllUsesWith(From: N, To: Ops.data());
2288 } while (!N->use_empty());
2289 deleteAndRecombine(N);
2290 return SDValue(N, 0); // Return N so it doesn't get rechecked!
2291}
2292
2293/// If \p N is a ConstantSDNode with isOpaque() == false return it casted to a
2294/// ConstantSDNode pointer else nullptr.
2295static ConstantSDNode *getAsNonOpaqueConstant(SDValue N) {
2296 ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Val&: N);
2297 return Const != nullptr && !Const->isOpaque() ? Const : nullptr;
2298}
2299
2300// isTruncateOf - If N is a truncate of some other value, return true, record
2301// the value being truncated in Op and which of Op's bits are zero/one in Known.
2302// This function computes KnownBits to avoid a duplicated call to
2303// computeKnownBits in the caller.
2304static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op,
2305 KnownBits &Known) {
2306 if (N->getOpcode() == ISD::TRUNCATE) {
2307 Op = N->getOperand(Num: 0);
2308 Known = DAG.computeKnownBits(Op);
2309 return true;
2310 }
2311
2312 if (N.getValueType().getScalarType() != MVT::i1 ||
2313 !sd_match(
2314 N, P: m_c_SetCC(LHS: m_Value(N&: Op), RHS: m_Zero(), CC: m_SpecificCondCode(CC: ISD::SETNE))))
2315 return false;
2316
2317 Known = DAG.computeKnownBits(Op);
2318 return (Known.Zero | 1).isAllOnes();
2319}
2320
2321/// Return true if 'Use' is a load or a store that uses N as its base pointer
2322/// and that N may be folded in the load / store addressing mode.
2323static bool canFoldInAddressingMode(SDNode *N, SDNode *Use, SelectionDAG &DAG,
2324 const TargetLowering &TLI) {
2325 EVT VT;
2326 unsigned AS;
2327
2328 if (LoadSDNode *LD = dyn_cast<LoadSDNode>(Val: Use)) {
2329 if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
2330 return false;
2331 VT = LD->getMemoryVT();
2332 AS = LD->getAddressSpace();
2333 } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(Val: Use)) {
2334 if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
2335 return false;
2336 VT = ST->getMemoryVT();
2337 AS = ST->getAddressSpace();
2338 } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(Val: Use)) {
2339 if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
2340 return false;
2341 VT = LD->getMemoryVT();
2342 AS = LD->getAddressSpace();
2343 } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(Val: Use)) {
2344 if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
2345 return false;
2346 VT = ST->getMemoryVT();
2347 AS = ST->getAddressSpace();
2348 } else {
2349 return false;
2350 }
2351
2352 TargetLowering::AddrMode AM;
2353 if (N->getOpcode() == ISD::ADD) {
2354 AM.HasBaseReg = true;
2355 ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
2356 if (Offset)
2357 // [reg +/- imm]
2358 AM.BaseOffs = Offset->getSExtValue();
2359 else
2360 // [reg +/- reg]
2361 AM.Scale = 1;
2362 } else if (N->getOpcode() == ISD::SUB) {
2363 AM.HasBaseReg = true;
2364 ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
2365 if (Offset)
2366 // [reg +/- imm]
2367 AM.BaseOffs = -Offset->getSExtValue();
2368 else
2369 // [reg +/- reg]
2370 AM.Scale = 1;
2371 } else {
2372 return false;
2373 }
2374
2375 return TLI.isLegalAddressingMode(DL: DAG.getDataLayout(), AM,
2376 Ty: VT.getTypeForEVT(Context&: *DAG.getContext()), AddrSpace: AS);
2377}
2378
2379/// This inverts a canonicalization in IR that replaces a variable select arm
2380/// with an identity constant. Codegen improves if we re-use the variable
2381/// operand rather than load a constant. This can also be converted into a
2382/// masked vector operation if the target supports it.
2383static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
2384 bool ShouldCommuteOperands) {
2385 // Match a select as operand 1. The identity constant that we are looking for
2386 // is only valid as operand 1 of a non-commutative binop.
2387 SDValue N0 = N->getOperand(Num: 0);
2388 SDValue N1 = N->getOperand(Num: 1);
2389 if (ShouldCommuteOperands)
2390 std::swap(a&: N0, b&: N1);
2391
2392 // TODO: Should this apply to scalar select too?
2393 if (N1.getOpcode() != ISD::VSELECT || !N1.hasOneUse())
2394 return SDValue();
2395
2396 // We can't hoist all instructions because of immediate UB (not speculatable).
2397 // For example div/rem by zero.
2398 if (!DAG.isSafeToSpeculativelyExecuteNode(N))
2399 return SDValue();
2400
2401 unsigned Opcode = N->getOpcode();
2402 EVT VT = N->getValueType(ResNo: 0);
2403 SDValue Cond = N1.getOperand(i: 0);
2404 SDValue TVal = N1.getOperand(i: 1);
2405 SDValue FVal = N1.getOperand(i: 2);
2406
2407 // This transform increases uses of N0, so freeze it to be safe.
2408 // binop N0, (vselect Cond, IDC, FVal) --> vselect Cond, N0, (binop N0, FVal)
2409 unsigned OpNo = ShouldCommuteOperands ? 0 : 1;
2410 if (isNeutralConstant(Opc: Opcode, Flags: N->getFlags(), V: TVal, OperandNo: OpNo)) {
2411 SDValue F0 = DAG.getFreeze(V: N0);
2412 SDValue NewBO = DAG.getNode(Opcode, DL: SDLoc(N), VT, N1: F0, N2: FVal, Flags: N->getFlags());
2413 return DAG.getSelect(DL: SDLoc(N), VT, Cond, LHS: F0, RHS: NewBO);
2414 }
2415 // binop N0, (vselect Cond, TVal, IDC) --> vselect Cond, (binop N0, TVal), N0
2416 if (isNeutralConstant(Opc: Opcode, Flags: N->getFlags(), V: FVal, OperandNo: OpNo)) {
2417 SDValue F0 = DAG.getFreeze(V: N0);
2418 SDValue NewBO = DAG.getNode(Opcode, DL: SDLoc(N), VT, N1: F0, N2: TVal, Flags: N->getFlags());
2419 return DAG.getSelect(DL: SDLoc(N), VT, Cond, LHS: NewBO, RHS: F0);
2420 }
2421
2422 return SDValue();
2423}
2424
2425SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) {
2426 assert(TLI.isBinOp(BO->getOpcode()) && BO->getNumValues() == 1 &&
2427 "Unexpected binary operator");
2428
2429 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2430 auto BinOpcode = BO->getOpcode();
2431 EVT VT = BO->getValueType(ResNo: 0);
2432 if (TLI.shouldFoldSelectWithIdentityConstant(BinOpcode, VT)) {
2433 if (SDValue Sel = foldSelectWithIdentityConstant(N: BO, DAG, ShouldCommuteOperands: false))
2434 return Sel;
2435
2436 if (TLI.isCommutativeBinOp(Opcode: BO->getOpcode()))
2437 if (SDValue Sel = foldSelectWithIdentityConstant(N: BO, DAG, ShouldCommuteOperands: true))
2438 return Sel;
2439 }
2440
2441 // Don't do this unless the old select is going away. We want to eliminate the
2442 // binary operator, not replace a binop with a select.
2443 // TODO: Handle ISD::SELECT_CC.
2444 unsigned SelOpNo = 0;
2445 SDValue Sel = BO->getOperand(Num: 0);
2446 if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse()) {
2447 SelOpNo = 1;
2448 Sel = BO->getOperand(Num: 1);
2449
2450 // Peek through trunc to shift amount type.
2451 if ((BinOpcode == ISD::SHL || BinOpcode == ISD::SRA ||
2452 BinOpcode == ISD::SRL) && Sel.hasOneUse()) {
2453 // This is valid when the truncated bits of x are already zero.
2454 SDValue Op;
2455 KnownBits Known;
2456 if (isTruncateOf(DAG, N: Sel, Op, Known) &&
2457 Known.countMaxActiveBits() < Sel.getScalarValueSizeInBits())
2458 Sel = Op;
2459 }
2460 }
2461
2462 if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse())
2463 return SDValue();
2464
2465 SDValue CT = Sel.getOperand(i: 1);
2466 if (!isConstantOrConstantVector(N: CT, NoOpaques: true) &&
2467 !DAG.isConstantFPBuildVectorOrConstantFP(N: CT))
2468 return SDValue();
2469
2470 SDValue CF = Sel.getOperand(i: 2);
2471 if (!isConstantOrConstantVector(N: CF, NoOpaques: true) &&
2472 !DAG.isConstantFPBuildVectorOrConstantFP(N: CF))
2473 return SDValue();
2474
2475 // Bail out if any constants are opaque because we can't constant fold those.
2476 // The exception is "and" and "or" with either 0 or -1 in which case we can
2477 // propagate non constant operands into select. I.e.:
2478 // and (select Cond, 0, -1), X --> select Cond, 0, X
2479 // or X, (select Cond, -1, 0) --> select Cond, -1, X
2480 bool CanFoldNonConst =
2481 (BinOpcode == ISD::AND || BinOpcode == ISD::OR) &&
2482 ((isNullOrNullSplat(V: CT) && isAllOnesOrAllOnesSplat(V: CF)) ||
2483 (isNullOrNullSplat(V: CF) && isAllOnesOrAllOnesSplat(V: CT)));
2484
2485 SDValue CBO = BO->getOperand(Num: SelOpNo ^ 1);
2486 if (!CanFoldNonConst &&
2487 !isConstantOrConstantVector(N: CBO, NoOpaques: true) &&
2488 !DAG.isConstantFPBuildVectorOrConstantFP(N: CBO))
2489 return SDValue();
2490
2491 SDLoc DL(Sel);
2492 SDValue NewCT, NewCF;
2493
2494 if (CanFoldNonConst) {
2495 // If CBO is an opaque constant, we can't rely on getNode to constant fold.
2496 if ((BinOpcode == ISD::AND && isNullOrNullSplat(V: CT)) ||
2497 (BinOpcode == ISD::OR && isAllOnesOrAllOnesSplat(V: CT)))
2498 NewCT = CT;
2499 else
2500 NewCT = CBO;
2501
2502 if ((BinOpcode == ISD::AND && isNullOrNullSplat(V: CF)) ||
2503 (BinOpcode == ISD::OR && isAllOnesOrAllOnesSplat(V: CF)))
2504 NewCF = CF;
2505 else
2506 NewCF = CBO;
2507 } else {
2508 // We have a select-of-constants followed by a binary operator with a
2509 // constant. Eliminate the binop by pulling the constant math into the
2510 // select. Example: add (select Cond, CT, CF), CBO --> select Cond, CT +
2511 // CBO, CF + CBO
2512 NewCT = SelOpNo ? DAG.FoldConstantArithmetic(Opcode: BinOpcode, DL, VT, Ops: {CBO, CT})
2513 : DAG.FoldConstantArithmetic(Opcode: BinOpcode, DL, VT, Ops: {CT, CBO});
2514 if (!NewCT)
2515 return SDValue();
2516
2517 NewCF = SelOpNo ? DAG.FoldConstantArithmetic(Opcode: BinOpcode, DL, VT, Ops: {CBO, CF})
2518 : DAG.FoldConstantArithmetic(Opcode: BinOpcode, DL, VT, Ops: {CF, CBO});
2519 if (!NewCF)
2520 return SDValue();
2521 }
2522
2523 SDValue SelectOp = DAG.getSelect(DL, VT, Cond: Sel.getOperand(i: 0), LHS: NewCT, RHS: NewCF);
2524 SelectOp->setFlags(BO->getFlags());
2525 return SelectOp;
2526}
2527
2528static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, const SDLoc &DL,
2529 SelectionDAG &DAG) {
2530 assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2531 "Expecting add or sub");
2532
2533 // Match a constant operand and a zext operand for the math instruction:
2534 // add Z, C
2535 // sub C, Z
2536 bool IsAdd = N->getOpcode() == ISD::ADD;
2537 SDValue C = IsAdd ? N->getOperand(Num: 1) : N->getOperand(Num: 0);
2538 SDValue Z = IsAdd ? N->getOperand(Num: 0) : N->getOperand(Num: 1);
2539 auto *CN = dyn_cast<ConstantSDNode>(Val&: C);
2540 if (!CN || Z.getOpcode() != ISD::ZERO_EXTEND)
2541 return SDValue();
2542
2543 // Match the zext operand as a setcc of a boolean.
2544 if (Z.getOperand(i: 0).getValueType() != MVT::i1)
2545 return SDValue();
2546
2547 // Match the compare as: setcc (X & 1), 0, eq.
2548 if (!sd_match(N: Z.getOperand(i: 0), P: m_SetCC(LHS: m_And(L: m_Value(), R: m_One()), RHS: m_Zero(),
2549 CC: m_SpecificCondCode(CC: ISD::SETEQ))))
2550 return SDValue();
2551
2552 // We are adding/subtracting a constant and an inverted low bit. Turn that
2553 // into a subtract/add of the low bit with incremented/decremented constant:
2554 // add (zext i1 (seteq (X & 1), 0)), C --> sub C+1, (zext (X & 1))
2555 // sub C, (zext i1 (seteq (X & 1), 0)) --> add C-1, (zext (X & 1))
2556 EVT VT = C.getValueType();
2557 SDValue LowBit = DAG.getZExtOrTrunc(Op: Z.getOperand(i: 0).getOperand(i: 0), DL, VT);
2558 SDValue C1 = IsAdd ? DAG.getConstant(Val: CN->getAPIntValue() + 1, DL, VT)
2559 : DAG.getConstant(Val: CN->getAPIntValue() - 1, DL, VT);
2560 return DAG.getNode(Opcode: IsAdd ? ISD::SUB : ISD::ADD, DL, VT, N1: C1, N2: LowBit);
2561}
2562
2563// Attempt to form avgceil(A, B) from (A | B) - ((A ^ B) >> 1)
2564SDValue DAGCombiner::foldSubToAvg(SDNode *N, const SDLoc &DL) {
2565 SDValue N0 = N->getOperand(Num: 0);
2566 EVT VT = N0.getValueType();
2567 SDValue A, B;
2568
2569 if ((!LegalOperations || hasOperation(Opcode: ISD::AVGCEILU, VT)) &&
2570 sd_match(N, P: m_Sub(L: m_Or(L: m_Value(N&: A), R: m_Value(N&: B)),
2571 R: m_Srl(L: m_Xor(L: m_Deferred(V&: A), R: m_Deferred(V&: B)),
2572 R: m_SpecificInt(V: 1))))) {
2573 return DAG.getNode(Opcode: ISD::AVGCEILU, DL, VT, N1: A, N2: B);
2574 }
2575 if ((!LegalOperations || hasOperation(Opcode: ISD::AVGCEILS, VT)) &&
2576 sd_match(N, P: m_Sub(L: m_Or(L: m_Value(N&: A), R: m_Value(N&: B)),
2577 R: m_Sra(L: m_Xor(L: m_Deferred(V&: A), R: m_Deferred(V&: B)),
2578 R: m_SpecificInt(V: 1))))) {
2579 return DAG.getNode(Opcode: ISD::AVGCEILS, DL, VT, N1: A, N2: B);
2580 }
2581 return SDValue();
2582}
2583
2584/// Try to fold a 'not' shifted sign-bit with add/sub with constant operand into
2585/// a shift and add with a different constant.
2586static SDValue foldAddSubOfSignBit(SDNode *N, const SDLoc &DL,
2587 SelectionDAG &DAG) {
2588 assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2589 "Expecting add or sub");
2590
2591 // We need a constant operand for the add/sub, and the other operand is a
2592 // logical shift right: add (srl), C or sub C, (srl).
2593 bool IsAdd = N->getOpcode() == ISD::ADD;
2594 SDValue ConstantOp = IsAdd ? N->getOperand(Num: 1) : N->getOperand(Num: 0);
2595 SDValue ShiftOp = IsAdd ? N->getOperand(Num: 0) : N->getOperand(Num: 1);
2596 if (!DAG.isConstantIntBuildVectorOrConstantInt(N: ConstantOp) ||
2597 ShiftOp.getOpcode() != ISD::SRL)
2598 return SDValue();
2599
2600 // The shift must be of a 'not' value.
2601 SDValue Not = ShiftOp.getOperand(i: 0);
2602 if (!Not.hasOneUse() || !isBitwiseNot(V: Not))
2603 return SDValue();
2604
2605 // The shift must be moving the sign bit to the least-significant-bit.
2606 EVT VT = ShiftOp.getValueType();
2607 SDValue ShAmt = ShiftOp.getOperand(i: 1);
2608 ConstantSDNode *ShAmtC = isConstOrConstSplat(N: ShAmt);
2609 if (!ShAmtC || ShAmtC->getAPIntValue() != (VT.getScalarSizeInBits() - 1))
2610 return SDValue();
2611
2612 // Eliminate the 'not' by adjusting the shift and add/sub constant:
2613 // add (srl (not X), 31), C --> add (sra X, 31), (C + 1)
2614 // sub C, (srl (not X), 31) --> add (srl X, 31), (C - 1)
2615 if (SDValue NewC = DAG.FoldConstantArithmetic(
2616 Opcode: IsAdd ? ISD::ADD : ISD::SUB, DL, VT,
2617 Ops: {ConstantOp, DAG.getConstant(Val: 1, DL, VT)})) {
2618 SDValue NewShift = DAG.getNode(Opcode: IsAdd ? ISD::SRA : ISD::SRL, DL, VT,
2619 N1: Not.getOperand(i: 0), N2: ShAmt);
2620 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: NewShift, N2: NewC);
2621 }
2622
2623 return SDValue();
2624}
2625
2626static bool
2627areBitwiseNotOfEachother(SDValue Op0, SDValue Op1) {
2628 return (isBitwiseNot(V: Op0) && Op0.getOperand(i: 0) == Op1) ||
2629 (isBitwiseNot(V: Op1) && Op1.getOperand(i: 0) == Op0);
2630}
2631
2632/// Try to fold a node that behaves like an ADD (note that N isn't necessarily
2633/// an ISD::ADD here, it could for example be an ISD::OR if we know that there
2634/// are no common bits set in the operands).
2635SDValue DAGCombiner::visitADDLike(SDNode *N) {
2636 SDValue N0 = N->getOperand(Num: 0);
2637 SDValue N1 = N->getOperand(Num: 1);
2638 EVT VT = N0.getValueType();
2639 SDLoc DL(N);
2640
2641 // fold (add x, undef) -> undef
2642 if (N0.isUndef())
2643 return N0;
2644 if (N1.isUndef())
2645 return N1;
2646
2647 // fold (add c1, c2) -> c1+c2
2648 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::ADD, DL, VT, Ops: {N0, N1}))
2649 return C;
2650
2651 // canonicalize constant to RHS
2652 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
2653 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
2654 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1, N2: N0);
2655
2656 if (areBitwiseNotOfEachother(Op0: N0, Op1: N1))
2657 return DAG.getConstant(Val: APInt::getAllOnes(numBits: VT.getScalarSizeInBits()), DL, VT);
2658
2659 // fold vector ops
2660 if (VT.isVector()) {
2661 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
2662 return FoldedVOp;
2663
2664 // fold (add x, 0) -> x, vector edition
2665 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
2666 return N0;
2667 }
2668
2669 // fold (add x, 0) -> x
2670 if (isNullConstant(V: N1))
2671 return N0;
2672
2673 if (N0.getOpcode() == ISD::SUB) {
2674 SDValue N00 = N0.getOperand(i: 0);
2675 SDValue N01 = N0.getOperand(i: 1);
2676
2677 // fold ((A-c1)+c2) -> (A+(c2-c1))
2678 if (SDValue Sub = DAG.FoldConstantArithmetic(Opcode: ISD::SUB, DL, VT, Ops: {N1, N01}))
2679 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 0), N2: Sub);
2680
2681 // fold ((c1-A)+c2) -> (c1+c2)-A
2682 if (SDValue Add = DAG.FoldConstantArithmetic(Opcode: ISD::ADD, DL, VT, Ops: {N1, N00}))
2683 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Add, N2: N0.getOperand(i: 1));
2684 }
2685
2686 // add (sext i1 X), 1 -> zext (not i1 X)
2687 // We don't transform this pattern:
2688 // add (zext i1 X), -1 -> sext (not i1 X)
2689 // because most (?) targets generate better code for the zext form.
2690 if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
2691 isOneOrOneSplat(V: N1)) {
2692 SDValue X = N0.getOperand(i: 0);
2693 if ((!LegalOperations ||
2694 (TLI.isOperationLegal(Op: ISD::XOR, VT: X.getValueType()) &&
2695 TLI.isOperationLegal(Op: ISD::ZERO_EXTEND, VT))) &&
2696 X.getScalarValueSizeInBits() == 1) {
2697 SDValue Not = DAG.getNOT(DL, Val: X, VT: X.getValueType());
2698 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: Not);
2699 }
2700 }
2701
2702 // Fold (add (or x, c0), c1) -> (add x, (c0 + c1))
2703 // iff (or x, c0) is equivalent to (add x, c0).
2704 // Fold (add (xor x, c0), c1) -> (add x, (c0 + c1))
2705 // iff (xor x, c0) is equivalent to (add x, c0).
2706 if (DAG.isADDLike(Op: N0)) {
2707 SDValue N01 = N0.getOperand(i: 1);
2708 if (SDValue Add = DAG.FoldConstantArithmetic(Opcode: ISD::ADD, DL, VT, Ops: {N1, N01}))
2709 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 0), N2: Add);
2710 }
2711
2712 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
2713 return NewSel;
2714
2715 // reassociate add
2716 if (!reassociationCanBreakAddressingModePattern(Opc: ISD::ADD, DL, N, N0, N1)) {
2717 if (SDValue RADD = reassociateOps(Opc: ISD::ADD, DL, N0, N1, Flags: N->getFlags()))
2718 return RADD;
2719
2720 // Reassociate (add (or x, c), y) -> (add add(x, y), c)) if (or x, c) is
2721 // equivalent to (add x, c).
2722 // Reassociate (add (xor x, c), y) -> (add add(x, y), c)) if (xor x, c) is
2723 // equivalent to (add x, c).
2724 // Do this optimization only when adding c does not introduce instructions
2725 // for adding carries.
2726 auto ReassociateAddOr = [&](SDValue N0, SDValue N1) {
2727 if (DAG.isADDLike(Op: N0) && N0.hasOneUse() &&
2728 isConstantOrConstantVector(N: N0.getOperand(i: 1), /* NoOpaque */ NoOpaques: true)) {
2729 // If N0's type does not split or is a sign mask, it does not introduce
2730 // add carry.
2731 auto TyActn = TLI.getTypeAction(Context&: *DAG.getContext(), VT: N0.getValueType());
2732 bool NoAddCarry = TyActn == TargetLoweringBase::TypeLegal ||
2733 TyActn == TargetLoweringBase::TypePromoteInteger ||
2734 isMinSignedConstant(V: N0.getOperand(i: 1));
2735 if (NoAddCarry)
2736 return DAG.getNode(
2737 Opcode: ISD::ADD, DL, VT,
2738 N1: DAG.getNode(Opcode: ISD::ADD, DL, VT, N1, N2: N0.getOperand(i: 0)),
2739 N2: N0.getOperand(i: 1));
2740 }
2741 return SDValue();
2742 };
2743 if (SDValue Add = ReassociateAddOr(N0, N1))
2744 return Add;
2745 if (SDValue Add = ReassociateAddOr(N1, N0))
2746 return Add;
2747
2748 // Fold add(vecreduce(x), vecreduce(y)) -> vecreduce(add(x, y))
2749 if (SDValue SD =
2750 reassociateReduction(RedOpc: ISD::VECREDUCE_ADD, Opc: ISD::ADD, DL, VT, N0, N1))
2751 return SD;
2752 }
2753
2754 SDValue A, B, C, D;
2755
2756 // fold ((0-A) + B) -> B-A
2757 if (sd_match(N: N0, P: m_Neg(V: m_Value(N&: A))))
2758 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1, N2: A);
2759
2760 // fold (A + (0-B)) -> A-B
2761 if (sd_match(N: N1, P: m_Neg(V: m_Value(N&: B))))
2762 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: B);
2763
2764 // fold (A+(B-A)) -> B
2765 if (sd_match(N: N1, P: m_Sub(L: m_Value(N&: B), R: m_Specific(N: N0))))
2766 return B;
2767
2768 // fold ((B-A)+A) -> B
2769 if (sd_match(N: N0, P: m_Sub(L: m_Value(N&: B), R: m_Specific(N: N1))))
2770 return B;
2771
2772 // fold ((A-B)+(C-A)) -> (C-B)
2773 if (sd_match(N: N0, P: m_Sub(L: m_Value(N&: A), R: m_Value(N&: B))) &&
2774 sd_match(N: N1, P: m_Sub(L: m_Value(N&: C), R: m_Specific(N: A))))
2775 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: C, N2: B);
2776
2777 // fold ((A-B)+(B-C)) -> (A-C)
2778 if (sd_match(N: N0, P: m_Sub(L: m_Value(N&: A), R: m_Value(N&: B))) &&
2779 sd_match(N: N1, P: m_Sub(L: m_Specific(N: B), R: m_Value(N&: C))))
2780 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: A, N2: C);
2781
2782 // fold (A+(B-(A+C))) to (B-C)
2783 // fold (A+(B-(C+A))) to (B-C)
2784 if (sd_match(N: N1, P: m_Sub(L: m_Value(N&: B), R: m_Add(L: m_Specific(N: N0), R: m_Value(N&: C)))))
2785 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: B, N2: C);
2786
2787 // fold (A+((B-A)+or-C)) to (B+or-C)
2788 if (sd_match(N: N1,
2789 P: m_AnyOf(preds: m_Add(L: m_Sub(L: m_Value(N&: B), R: m_Specific(N: N0)), R: m_Value(N&: C)),
2790 preds: m_Sub(L: m_Sub(L: m_Value(N&: B), R: m_Specific(N: N0)), R: m_Value(N&: C)))))
2791 return DAG.getNode(Opcode: N1.getOpcode(), DL, VT, N1: B, N2: C);
2792
2793 // fold (A-B)+(C-D) to (A+C)-(B+D) when A or C is constant
2794 if (sd_match(N: N0, P: m_OneUse(P: m_Sub(L: m_Value(N&: A), R: m_Value(N&: B)))) &&
2795 sd_match(N: N1, P: m_OneUse(P: m_Sub(L: m_Value(N&: C), R: m_Value(N&: D)))) &&
2796 (isConstantOrConstantVector(N: A) || isConstantOrConstantVector(N: C)))
2797 return DAG.getNode(Opcode: ISD::SUB, DL, VT,
2798 N1: DAG.getNode(Opcode: ISD::ADD, DL: SDLoc(N0), VT, N1: A, N2: C),
2799 N2: DAG.getNode(Opcode: ISD::ADD, DL: SDLoc(N1), VT, N1: B, N2: D));
2800
2801 // fold (add (umax X, C), -C) --> (usubsat X, C)
2802 if (N0.getOpcode() == ISD::UMAX && hasOperation(Opcode: ISD::USUBSAT, VT)) {
2803 auto MatchUSUBSAT = [](ConstantSDNode *Max, ConstantSDNode *Op) {
2804 return (!Max && !Op) ||
2805 (Max && Op && Max->getAPIntValue() == (-Op->getAPIntValue()));
2806 };
2807 if (ISD::matchBinaryPredicate(LHS: N0.getOperand(i: 1), RHS: N1, Match: MatchUSUBSAT,
2808 /*AllowUndefs*/ true))
2809 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT, N1: N0.getOperand(i: 0),
2810 N2: N0.getOperand(i: 1));
2811 }
2812
2813 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
2814 return SDValue(N, 0);
2815
2816 if (isOneOrOneSplat(V: N1)) {
2817 // fold (add (xor a, -1), 1) -> (sub 0, a)
2818 if (isBitwiseNot(V: N0))
2819 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: DAG.getConstant(Val: 0, DL, VT),
2820 N2: N0.getOperand(i: 0));
2821
2822 // fold (add (add (xor a, -1), b), 1) -> (sub b, a)
2823 if (N0.getOpcode() == ISD::ADD) {
2824 SDValue A, Xor;
2825
2826 if (isBitwiseNot(V: N0.getOperand(i: 0))) {
2827 A = N0.getOperand(i: 1);
2828 Xor = N0.getOperand(i: 0);
2829 } else if (isBitwiseNot(V: N0.getOperand(i: 1))) {
2830 A = N0.getOperand(i: 0);
2831 Xor = N0.getOperand(i: 1);
2832 }
2833
2834 if (Xor)
2835 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: A, N2: Xor.getOperand(i: 0));
2836 }
2837
2838 // Look for:
2839 // add (add x, y), 1
2840 // And if the target does not like this form then turn into:
2841 // sub y, (xor x, -1)
2842 if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.getOpcode() == ISD::ADD &&
2843 N0.hasOneUse() &&
2844 // Limit this to after legalization if the add has wrap flags
2845 (Level >= AfterLegalizeDAG || (!N->getFlags().hasNoUnsignedWrap() &&
2846 !N->getFlags().hasNoSignedWrap()))) {
2847 SDValue Not = DAG.getNOT(DL, Val: N0.getOperand(i: 0), VT);
2848 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0.getOperand(i: 1), N2: Not);
2849 }
2850 }
2851
2852 // (x - y) + -1 -> add (xor y, -1), x
2853 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
2854 isAllOnesOrAllOnesSplat(V: N1, /*AllowUndefs=*/true)) {
2855 SDValue Not = DAG.getNOT(DL, Val: N0.getOperand(i: 1), VT);
2856 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Not, N2: N0.getOperand(i: 0));
2857 }
2858
2859 // Fold add(mul(add(A, CA), CM), CB) -> add(mul(A, CM), CM*CA+CB).
2860 // This can help if the inner add has multiple uses.
2861 APInt CM, CA;
2862 if (ConstantSDNode *CB = dyn_cast<ConstantSDNode>(Val&: N1)) {
2863 if (VT.getScalarSizeInBits() <= 64) {
2864 if (sd_match(N: N0, P: m_OneUse(P: m_Mul(L: m_Add(L: m_Value(N&: A), R: m_ConstInt(V&: CA)),
2865 R: m_ConstInt(V&: CM)))) &&
2866 TLI.isLegalAddImmediate(
2867 (CA * CM + CB->getAPIntValue()).getSExtValue())) {
2868 SDNodeFlags Flags;
2869 // If all the inputs are nuw, the outputs can be nuw. If all the input
2870 // are _also_ nsw the outputs can be too.
2871 if (N->getFlags().hasNoUnsignedWrap() &&
2872 N0->getFlags().hasNoUnsignedWrap() &&
2873 N0.getOperand(i: 0)->getFlags().hasNoUnsignedWrap()) {
2874 Flags.setNoUnsignedWrap(true);
2875 if (N->getFlags().hasNoSignedWrap() &&
2876 N0->getFlags().hasNoSignedWrap() &&
2877 N0.getOperand(i: 0)->getFlags().hasNoSignedWrap())
2878 Flags.setNoSignedWrap(true);
2879 }
2880 SDValue Mul = DAG.getNode(Opcode: ISD::MUL, DL: SDLoc(N1), VT, N1: A,
2881 N2: DAG.getConstant(Val: CM, DL, VT), Flags);
2882 return DAG.getNode(
2883 Opcode: ISD::ADD, DL, VT, N1: Mul,
2884 N2: DAG.getConstant(Val: CA * CM + CB->getAPIntValue(), DL, VT), Flags);
2885 }
2886 // Also look in case there is an intermediate add.
2887 if (sd_match(N: N0, P: m_OneUse(P: m_Add(
2888 L: m_OneUse(P: m_Mul(L: m_Add(L: m_Value(N&: A), R: m_ConstInt(V&: CA)),
2889 R: m_ConstInt(V&: CM))),
2890 R: m_Value(N&: B)))) &&
2891 TLI.isLegalAddImmediate(
2892 (CA * CM + CB->getAPIntValue()).getSExtValue())) {
2893 SDNodeFlags Flags;
2894 // If all the inputs are nuw, the outputs can be nuw. If all the input
2895 // are _also_ nsw the outputs can be too.
2896 SDValue OMul =
2897 N0.getOperand(i: 0) == B ? N0.getOperand(i: 1) : N0.getOperand(i: 0);
2898 if (N->getFlags().hasNoUnsignedWrap() &&
2899 N0->getFlags().hasNoUnsignedWrap() &&
2900 OMul->getFlags().hasNoUnsignedWrap() &&
2901 OMul.getOperand(i: 0)->getFlags().hasNoUnsignedWrap()) {
2902 Flags.setNoUnsignedWrap(true);
2903 if (N->getFlags().hasNoSignedWrap() &&
2904 N0->getFlags().hasNoSignedWrap() &&
2905 OMul->getFlags().hasNoSignedWrap() &&
2906 OMul.getOperand(i: 0)->getFlags().hasNoSignedWrap())
2907 Flags.setNoSignedWrap(true);
2908 }
2909 SDValue Mul = DAG.getNode(Opcode: ISD::MUL, DL: SDLoc(N1), VT, N1: A,
2910 N2: DAG.getConstant(Val: CM, DL, VT), Flags);
2911 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL: SDLoc(N1), VT, N1: Mul, N2: B, Flags);
2912 return DAG.getNode(
2913 Opcode: ISD::ADD, DL, VT, N1: Add,
2914 N2: DAG.getConstant(Val: CA * CM + CB->getAPIntValue(), DL, VT), Flags);
2915 }
2916 }
2917 }
2918
2919 if (SDValue Combined = visitADDLikeCommutative(N0, N1, LocReference: N))
2920 return Combined;
2921
2922 if (SDValue Combined = visitADDLikeCommutative(N0: N1, N1: N0, LocReference: N))
2923 return Combined;
2924
2925 return SDValue();
2926}
2927
2928// Attempt to form avgfloor(A, B) from (A & B) + ((A ^ B) >> 1)
2929SDValue DAGCombiner::foldAddToAvg(SDNode *N, const SDLoc &DL) {
2930 SDValue N0 = N->getOperand(Num: 0);
2931 EVT VT = N0.getValueType();
2932 SDValue A, B;
2933
2934 if ((!LegalOperations || hasOperation(Opcode: ISD::AVGFLOORU, VT)) &&
2935 sd_match(N, P: m_Add(L: m_And(L: m_Value(N&: A), R: m_Value(N&: B)),
2936 R: m_Srl(L: m_Xor(L: m_Deferred(V&: A), R: m_Deferred(V&: B)),
2937 R: m_SpecificInt(V: 1))))) {
2938 return DAG.getNode(Opcode: ISD::AVGFLOORU, DL, VT, N1: A, N2: B);
2939 }
2940 if ((!LegalOperations || hasOperation(Opcode: ISD::AVGFLOORS, VT)) &&
2941 sd_match(N, P: m_Add(L: m_And(L: m_Value(N&: A), R: m_Value(N&: B)),
2942 R: m_Sra(L: m_Xor(L: m_Deferred(V&: A), R: m_Deferred(V&: B)),
2943 R: m_SpecificInt(V: 1))))) {
2944 return DAG.getNode(Opcode: ISD::AVGFLOORS, DL, VT, N1: A, N2: B);
2945 }
2946
2947 return SDValue();
2948}
2949
2950SDValue DAGCombiner::visitADD(SDNode *N) {
2951 SDValue N0 = N->getOperand(Num: 0);
2952 SDValue N1 = N->getOperand(Num: 1);
2953 EVT VT = N0.getValueType();
2954 SDLoc DL(N);
2955
2956 if (SDValue Combined = visitADDLike(N))
2957 return Combined;
2958
2959 if (SDValue V = foldAddSubBoolOfMaskedVal(N, DL, DAG))
2960 return V;
2961
2962 if (SDValue V = foldAddSubOfSignBit(N, DL, DAG))
2963 return V;
2964
2965 // Try to match AVGFLOOR fixedwidth pattern
2966 if (SDValue V = foldAddToAvg(N, DL))
2967 return V;
2968
2969 // fold (a+b) -> (a|b) iff a and b share no bits.
2970 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::OR, VT)) &&
2971 DAG.haveNoCommonBitsSet(A: N0, B: N1)) {
2972 SDNodeFlags Flags;
2973 Flags.setDisjoint(true);
2974 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: N0, N2: N1, Flags);
2975 }
2976
2977 // Fold (add (vscale * C0), (vscale * C1)) to (vscale * (C0 + C1)).
2978 if (N0.getOpcode() == ISD::VSCALE && N1.getOpcode() == ISD::VSCALE) {
2979 const APInt &C0 = N0->getConstantOperandAPInt(Num: 0);
2980 const APInt &C1 = N1->getConstantOperandAPInt(Num: 0);
2981 return DAG.getVScale(DL, VT, MulImm: C0 + C1);
2982 }
2983
2984 // fold a+vscale(c1)+vscale(c2) -> a+vscale(c1+c2)
2985 if (N0.getOpcode() == ISD::ADD &&
2986 N0.getOperand(i: 1).getOpcode() == ISD::VSCALE &&
2987 N1.getOpcode() == ISD::VSCALE) {
2988 const APInt &VS0 = N0.getOperand(i: 1)->getConstantOperandAPInt(Num: 0);
2989 const APInt &VS1 = N1->getConstantOperandAPInt(Num: 0);
2990 SDValue VS = DAG.getVScale(DL, VT, MulImm: VS0 + VS1);
2991 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 0), N2: VS);
2992 }
2993
2994 // Fold (add step_vector(c1), step_vector(c2) to step_vector(c1+c2))
2995 if (N0.getOpcode() == ISD::STEP_VECTOR &&
2996 N1.getOpcode() == ISD::STEP_VECTOR) {
2997 const APInt &C0 = N0->getConstantOperandAPInt(Num: 0);
2998 const APInt &C1 = N1->getConstantOperandAPInt(Num: 0);
2999 APInt NewStep = C0 + C1;
3000 return DAG.getStepVector(DL, ResVT: VT, StepVal: NewStep);
3001 }
3002
3003 // Fold a + step_vector(c1) + step_vector(c2) to a + step_vector(c1+c2)
3004 if (N0.getOpcode() == ISD::ADD &&
3005 N0.getOperand(i: 1).getOpcode() == ISD::STEP_VECTOR &&
3006 N1.getOpcode() == ISD::STEP_VECTOR) {
3007 const APInt &SV0 = N0.getOperand(i: 1)->getConstantOperandAPInt(Num: 0);
3008 const APInt &SV1 = N1->getConstantOperandAPInt(Num: 0);
3009 APInt NewStep = SV0 + SV1;
3010 SDValue SV = DAG.getStepVector(DL, ResVT: VT, StepVal: NewStep);
3011 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 0), N2: SV);
3012 }
3013
3014 return SDValue();
3015}
3016
3017SDValue DAGCombiner::visitADDSAT(SDNode *N) {
3018 unsigned Opcode = N->getOpcode();
3019 SDValue N0 = N->getOperand(Num: 0);
3020 SDValue N1 = N->getOperand(Num: 1);
3021 EVT VT = N0.getValueType();
3022 bool IsSigned = Opcode == ISD::SADDSAT;
3023 SDLoc DL(N);
3024
3025 // fold (add_sat x, undef) -> -1
3026 if (N0.isUndef() || N1.isUndef())
3027 return DAG.getAllOnesConstant(DL, VT);
3028
3029 // fold (add_sat c1, c2) -> c3
3030 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, Ops: {N0, N1}))
3031 return C;
3032
3033 // canonicalize constant to RHS
3034 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
3035 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
3036 return DAG.getNode(Opcode, DL, VT, N1, N2: N0);
3037
3038 // fold vector ops
3039 if (VT.isVector()) {
3040 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
3041 return FoldedVOp;
3042
3043 // fold (add_sat x, 0) -> x, vector edition
3044 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
3045 return N0;
3046 }
3047
3048 // fold (add_sat x, 0) -> x
3049 if (isNullConstant(V: N1))
3050 return N0;
3051
3052 // If it cannot overflow, transform into an add.
3053 if (DAG.willNotOverflowAdd(IsSigned, N0, N1))
3054 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1);
3055
3056 return SDValue();
3057}
3058
3059static SDValue getAsCarry(const TargetLowering &TLI, SDValue V,
3060 bool ForceCarryReconstruction = false) {
3061 bool Masked = false;
3062
3063 // First, peel away TRUNCATE/ZERO_EXTEND/AND nodes due to legalization.
3064 while (true) {
3065 if (V.getOpcode() == ISD::TRUNCATE || V.getOpcode() == ISD::ZERO_EXTEND) {
3066 V = V.getOperand(i: 0);
3067 continue;
3068 }
3069
3070 if (V.getOpcode() == ISD::AND && isOneConstant(V: V.getOperand(i: 1))) {
3071 if (ForceCarryReconstruction)
3072 return V;
3073
3074 Masked = true;
3075 V = V.getOperand(i: 0);
3076 continue;
3077 }
3078
3079 if (ForceCarryReconstruction && V.getValueType() == MVT::i1)
3080 return V;
3081
3082 break;
3083 }
3084
3085 // If this is not a carry, return.
3086 if (V.getResNo() != 1)
3087 return SDValue();
3088
3089 if (V.getOpcode() != ISD::UADDO_CARRY && V.getOpcode() != ISD::USUBO_CARRY &&
3090 V.getOpcode() != ISD::UADDO && V.getOpcode() != ISD::USUBO)
3091 return SDValue();
3092
3093 EVT VT = V->getValueType(ResNo: 0);
3094 if (!TLI.isOperationLegalOrCustom(Op: V.getOpcode(), VT))
3095 return SDValue();
3096
3097 // If the result is masked, then no matter what kind of bool it is we can
3098 // return. If it isn't, then we need to make sure the bool type is either 0 or
3099 // 1 and not other values.
3100 if (Masked ||
3101 TLI.getBooleanContents(Type: V.getValueType()) ==
3102 TargetLoweringBase::ZeroOrOneBooleanContent)
3103 return V;
3104
3105 return SDValue();
3106}
3107
3108/// Given the operands of an add/sub operation, see if the 2nd operand is a
3109/// masked 0/1 whose source operand is actually known to be 0/-1. If so, invert
3110/// the opcode and bypass the mask operation.
3111static SDValue foldAddSubMasked1(bool IsAdd, SDValue N0, SDValue N1,
3112 SelectionDAG &DAG, const SDLoc &DL) {
3113 if (N1.getOpcode() == ISD::ZERO_EXTEND)
3114 N1 = N1.getOperand(i: 0);
3115
3116 if (N1.getOpcode() != ISD::AND || !isOneOrOneSplat(V: N1->getOperand(Num: 1)))
3117 return SDValue();
3118
3119 EVT VT = N0.getValueType();
3120 SDValue N10 = N1.getOperand(i: 0);
3121 if (N10.getValueType() != VT && N10.getOpcode() == ISD::TRUNCATE)
3122 N10 = N10.getOperand(i: 0);
3123
3124 if (N10.getValueType() != VT)
3125 return SDValue();
3126
3127 if (DAG.ComputeNumSignBits(Op: N10) != VT.getScalarSizeInBits())
3128 return SDValue();
3129
3130 // add N0, (and (AssertSext X, i1), 1) --> sub N0, X
3131 // sub N0, (and (AssertSext X, i1), 1) --> add N0, X
3132 return DAG.getNode(Opcode: IsAdd ? ISD::SUB : ISD::ADD, DL, VT, N1: N0, N2: N10);
3133}
3134
3135/// Helper for doing combines based on N0 and N1 being added to each other.
3136SDValue DAGCombiner::visitADDLikeCommutative(SDValue N0, SDValue N1,
3137 SDNode *LocReference) {
3138 EVT VT = N0.getValueType();
3139 SDLoc DL(LocReference);
3140
3141 // fold (add x, shl(0 - y, n)) -> sub(x, shl(y, n))
3142 SDValue Y, N;
3143 if (sd_match(N: N1, P: m_Shl(L: m_Neg(V: m_Value(N&: Y)), R: m_Value(N))))
3144 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0,
3145 N2: DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Y, N2: N));
3146
3147 if (SDValue V = foldAddSubMasked1(IsAdd: true, N0, N1, DAG, DL))
3148 return V;
3149
3150 // Look for:
3151 // add (add x, 1), y
3152 // And if the target does not like this form then turn into:
3153 // sub y, (xor x, -1)
3154 if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.getOpcode() == ISD::ADD &&
3155 N0.hasOneUse() && isOneOrOneSplat(V: N0.getOperand(i: 1)) &&
3156 // Limit this to after legalization if the add has wrap flags
3157 (Level >= AfterLegalizeDAG || (!N0->getFlags().hasNoUnsignedWrap() &&
3158 !N0->getFlags().hasNoSignedWrap()))) {
3159 SDValue Not = DAG.getNOT(DL, Val: N0.getOperand(i: 0), VT);
3160 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1, N2: Not);
3161 }
3162
3163 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse()) {
3164 // Hoist one-use subtraction by non-opaque constant:
3165 // (x - C) + y -> (x + y) - C
3166 // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
3167 if (isConstantOrConstantVector(N: N0.getOperand(i: 1), /*NoOpaques=*/true)) {
3168 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 0), N2: N1);
3169 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Add, N2: N0.getOperand(i: 1));
3170 }
3171 // Hoist one-use subtraction from non-opaque constant:
3172 // (C - x) + y -> (y - x) + C
3173 if (isConstantOrConstantVector(N: N0.getOperand(i: 0), /*NoOpaques=*/true)) {
3174 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1, N2: N0.getOperand(i: 1));
3175 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Sub, N2: N0.getOperand(i: 0));
3176 }
3177 }
3178
3179 // add (mul x, C), x -> mul x, C+1
3180 if (N0.getOpcode() == ISD::MUL && N0.getOperand(i: 0) == N1 &&
3181 isConstantOrConstantVector(N: N0.getOperand(i: 1), /*NoOpaques=*/true) &&
3182 N0.hasOneUse()) {
3183 SDValue NewC = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 1),
3184 N2: DAG.getConstant(Val: 1, DL, VT));
3185 return DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: N0.getOperand(i: 0), N2: NewC);
3186 }
3187
3188 // If the target's bool is represented as 0/1, prefer to make this 'sub 0/1'
3189 // rather than 'add 0/-1' (the zext should get folded).
3190 // add (sext i1 Y), X --> sub X, (zext i1 Y)
3191 if (N0.getOpcode() == ISD::SIGN_EXTEND &&
3192 N0.getOperand(i: 0).getScalarValueSizeInBits() == 1 &&
3193 TLI.getBooleanContents(Type: VT) == TargetLowering::ZeroOrOneBooleanContent) {
3194 SDValue ZExt = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: N0.getOperand(i: 0));
3195 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1, N2: ZExt);
3196 }
3197
3198 // add X, (sextinreg Y i1) -> sub X, (and Y 1)
3199 if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
3200 VTSDNode *TN = cast<VTSDNode>(Val: N1.getOperand(i: 1));
3201 if (TN->getVT() == MVT::i1) {
3202 SDValue ZExt = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N1.getOperand(i: 0),
3203 N2: DAG.getConstant(Val: 1, DL, VT));
3204 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: ZExt);
3205 }
3206 }
3207
3208 // (add X, (uaddo_carry Y, 0, Carry)) -> (uaddo_carry X, Y, Carry)
3209 if (N1.getOpcode() == ISD::UADDO_CARRY && isNullConstant(V: N1.getOperand(i: 1)) &&
3210 N1.getResNo() == 0)
3211 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL, VTList: N1->getVTList(),
3212 N1: N0, N2: N1.getOperand(i: 0), N3: N1.getOperand(i: 2));
3213
3214 // (add X, Carry) -> (uaddo_carry X, 0, Carry)
3215 if (TLI.isOperationLegalOrCustom(Op: ISD::UADDO_CARRY, VT))
3216 if (SDValue Carry = getAsCarry(TLI, V: N1))
3217 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL,
3218 VTList: DAG.getVTList(VT1: VT, VT2: Carry.getValueType()), N1: N0,
3219 N2: DAG.getConstant(Val: 0, DL, VT), N3: Carry);
3220
3221 return SDValue();
3222}
3223
3224SDValue DAGCombiner::visitADDC(SDNode *N) {
3225 SDValue N0 = N->getOperand(Num: 0);
3226 SDValue N1 = N->getOperand(Num: 1);
3227 EVT VT = N0.getValueType();
3228 SDLoc DL(N);
3229
3230 // If the flag result is dead, turn this into an ADD.
3231 if (!N->hasAnyUseOfValue(Value: 1))
3232 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1),
3233 Res1: DAG.getNode(Opcode: ISD::CARRY_FALSE, DL, VT: MVT::Glue));
3234
3235 // canonicalize constant to RHS.
3236 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(Val&: N0);
3237 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(Val&: N1);
3238 if (N0C && !N1C)
3239 return DAG.getNode(Opcode: ISD::ADDC, DL, VTList: N->getVTList(), N1, N2: N0);
3240
3241 // fold (addc x, 0) -> x + no carry out
3242 if (isNullConstant(V: N1))
3243 return CombineTo(N, Res0: N0, Res1: DAG.getNode(Opcode: ISD::CARRY_FALSE,
3244 DL, VT: MVT::Glue));
3245
3246 // If it cannot overflow, transform into an add.
3247 if (DAG.computeOverflowForUnsignedAdd(N0, N1) == SelectionDAG::OFK_Never)
3248 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1),
3249 Res1: DAG.getNode(Opcode: ISD::CARRY_FALSE, DL, VT: MVT::Glue));
3250
3251 return SDValue();
3252}
3253
3254/**
3255 * Flips a boolean if it is cheaper to compute. If the Force parameters is set,
3256 * then the flip also occurs if computing the inverse is the same cost.
3257 * This function returns an empty SDValue in case it cannot flip the boolean
3258 * without increasing the cost of the computation. If you want to flip a boolean
3259 * no matter what, use DAG.getLogicalNOT.
3260 */
3261static SDValue extractBooleanFlip(SDValue V, SelectionDAG &DAG,
3262 const TargetLowering &TLI,
3263 bool Force) {
3264 if (Force && isa<ConstantSDNode>(Val: V))
3265 return DAG.getLogicalNOT(DL: SDLoc(V), Val: V, VT: V.getValueType());
3266
3267 if (V.getOpcode() != ISD::XOR)
3268 return SDValue();
3269
3270 ConstantSDNode *Const = isConstOrConstSplat(N: V.getOperand(i: 1), AllowUndefs: false);
3271 if (!Const)
3272 return SDValue();
3273
3274 EVT VT = V.getValueType();
3275
3276 bool IsFlip = false;
3277 switch(TLI.getBooleanContents(Type: VT)) {
3278 case TargetLowering::ZeroOrOneBooleanContent:
3279 IsFlip = Const->isOne();
3280 break;
3281 case TargetLowering::ZeroOrNegativeOneBooleanContent:
3282 IsFlip = Const->isAllOnes();
3283 break;
3284 case TargetLowering::UndefinedBooleanContent:
3285 IsFlip = (Const->getAPIntValue() & 0x01) == 1;
3286 break;
3287 }
3288
3289 if (IsFlip)
3290 return V.getOperand(i: 0);
3291 if (Force)
3292 return DAG.getLogicalNOT(DL: SDLoc(V), Val: V, VT: V.getValueType());
3293 return SDValue();
3294}
3295
3296SDValue DAGCombiner::visitADDO(SDNode *N) {
3297 SDValue N0 = N->getOperand(Num: 0);
3298 SDValue N1 = N->getOperand(Num: 1);
3299 EVT VT = N0.getValueType();
3300 bool IsSigned = (ISD::SADDO == N->getOpcode());
3301
3302 EVT CarryVT = N->getValueType(ResNo: 1);
3303 SDLoc DL(N);
3304
3305 // If the flag result is dead, turn this into an ADD.
3306 if (!N->hasAnyUseOfValue(Value: 1))
3307 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1),
3308 Res1: DAG.getUNDEF(VT: CarryVT));
3309
3310 // canonicalize constant to RHS.
3311 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
3312 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
3313 return DAG.getNode(Opcode: N->getOpcode(), DL, VTList: N->getVTList(), N1, N2: N0);
3314
3315 // fold (addo x, 0) -> x + no carry out
3316 if (isNullOrNullSplat(V: N1))
3317 return CombineTo(N, Res0: N0, Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
3318
3319 // If it cannot overflow, transform into an add.
3320 if (DAG.willNotOverflowAdd(IsSigned, N0, N1))
3321 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1),
3322 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
3323
3324 if (IsSigned) {
3325 // fold (saddo (xor a, -1), 1) -> (ssub 0, a).
3326 if (isBitwiseNot(V: N0) && isOneOrOneSplat(V: N1))
3327 return DAG.getNode(Opcode: ISD::SSUBO, DL, VTList: N->getVTList(),
3328 N1: DAG.getConstant(Val: 0, DL, VT), N2: N0.getOperand(i: 0));
3329 } else {
3330 // fold (uaddo (xor a, -1), 1) -> (usub 0, a) and flip carry.
3331 if (isBitwiseNot(V: N0) && isOneOrOneSplat(V: N1)) {
3332 SDValue Sub = DAG.getNode(Opcode: ISD::USUBO, DL, VTList: N->getVTList(),
3333 N1: DAG.getConstant(Val: 0, DL, VT), N2: N0.getOperand(i: 0));
3334 return CombineTo(
3335 N, Res0: Sub, Res1: DAG.getLogicalNOT(DL, Val: Sub.getValue(R: 1), VT: Sub->getValueType(ResNo: 1)));
3336 }
3337
3338 if (SDValue Combined = visitUADDOLike(N0, N1, N))
3339 return Combined;
3340
3341 if (SDValue Combined = visitUADDOLike(N0: N1, N1: N0, N))
3342 return Combined;
3343 }
3344
3345 return SDValue();
3346}
3347
3348SDValue DAGCombiner::visitUADDOLike(SDValue N0, SDValue N1, SDNode *N) {
3349 EVT VT = N0.getValueType();
3350 if (VT.isVector())
3351 return SDValue();
3352
3353 // (uaddo X, (uaddo_carry Y, 0, Carry)) -> (uaddo_carry X, Y, Carry)
3354 // If Y + 1 cannot overflow.
3355 if (N1.getOpcode() == ISD::UADDO_CARRY && isNullConstant(V: N1.getOperand(i: 1))) {
3356 SDValue Y = N1.getOperand(i: 0);
3357 SDValue One = DAG.getConstant(Val: 1, DL: SDLoc(N), VT: Y.getValueType());
3358 if (DAG.computeOverflowForUnsignedAdd(N0: Y, N1: One) == SelectionDAG::OFK_Never)
3359 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL: SDLoc(N), VTList: N->getVTList(), N1: N0, N2: Y,
3360 N3: N1.getOperand(i: 2));
3361 }
3362
3363 // (uaddo X, Carry) -> (uaddo_carry X, 0, Carry)
3364 if (TLI.isOperationLegalOrCustom(Op: ISD::UADDO_CARRY, VT))
3365 if (SDValue Carry = getAsCarry(TLI, V: N1))
3366 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL: SDLoc(N), VTList: N->getVTList(), N1: N0,
3367 N2: DAG.getConstant(Val: 0, DL: SDLoc(N), VT), N3: Carry);
3368
3369 return SDValue();
3370}
3371
3372SDValue DAGCombiner::visitADDE(SDNode *N) {
3373 SDValue N0 = N->getOperand(Num: 0);
3374 SDValue N1 = N->getOperand(Num: 1);
3375 SDValue CarryIn = N->getOperand(Num: 2);
3376
3377 // canonicalize constant to RHS
3378 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(Val&: N0);
3379 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(Val&: N1);
3380 if (N0C && !N1C)
3381 return DAG.getNode(Opcode: ISD::ADDE, DL: SDLoc(N), VTList: N->getVTList(),
3382 N1, N2: N0, N3: CarryIn);
3383
3384 // fold (adde x, y, false) -> (addc x, y)
3385 if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
3386 return DAG.getNode(Opcode: ISD::ADDC, DL: SDLoc(N), VTList: N->getVTList(), N1: N0, N2: N1);
3387
3388 return SDValue();
3389}
3390
3391SDValue DAGCombiner::visitUADDO_CARRY(SDNode *N) {
3392 SDValue N0 = N->getOperand(Num: 0);
3393 SDValue N1 = N->getOperand(Num: 1);
3394 SDValue CarryIn = N->getOperand(Num: 2);
3395 SDLoc DL(N);
3396
3397 // canonicalize constant to RHS
3398 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(Val&: N0);
3399 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(Val&: N1);
3400 if (N0C && !N1C)
3401 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL, VTList: N->getVTList(), N1, N2: N0, N3: CarryIn);
3402
3403 // fold (uaddo_carry x, y, false) -> (uaddo x, y)
3404 if (isNullConstant(V: CarryIn)) {
3405 if (!LegalOperations ||
3406 TLI.isOperationLegalOrCustom(Op: ISD::UADDO, VT: N->getValueType(ResNo: 0)))
3407 return DAG.getNode(Opcode: ISD::UADDO, DL, VTList: N->getVTList(), N1: N0, N2: N1);
3408 }
3409
3410 // fold (uaddo_carry 0, 0, X) -> (and (ext/trunc X), 1) and no carry.
3411 if (isNullConstant(V: N0) && isNullConstant(V: N1)) {
3412 EVT VT = N0.getValueType();
3413 EVT CarryVT = CarryIn.getValueType();
3414 SDValue CarryExt = DAG.getBoolExtOrTrunc(Op: CarryIn, SL: DL, VT, OpVT: CarryVT);
3415 AddToWorklist(N: CarryExt.getNode());
3416 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::AND, DL, VT, N1: CarryExt,
3417 N2: DAG.getConstant(Val: 1, DL, VT)),
3418 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
3419 }
3420
3421 if (SDValue Combined = visitUADDO_CARRYLike(N0, N1, CarryIn, N))
3422 return Combined;
3423
3424 if (SDValue Combined = visitUADDO_CARRYLike(N0: N1, N1: N0, CarryIn, N))
3425 return Combined;
3426
3427 // We want to avoid useless duplication.
3428 // TODO: This is done automatically for binary operations. As UADDO_CARRY is
3429 // not a binary operation, this is not really possible to leverage this
3430 // existing mechanism for it. However, if more operations require the same
3431 // deduplication logic, then it may be worth generalize.
3432 SDValue Ops[] = {N1, N0, CarryIn};
3433 SDNode *CSENode =
3434 DAG.getNodeIfExists(Opcode: ISD::UADDO_CARRY, VTList: N->getVTList(), Ops, Flags: N->getFlags());
3435 if (CSENode)
3436 return SDValue(CSENode, 0);
3437
3438 return SDValue();
3439}
3440
3441/**
3442 * If we are facing some sort of diamond carry propagation pattern try to
3443 * break it up to generate something like:
3444 * (uaddo_carry X, 0, (uaddo_carry A, B, Z):Carry)
3445 *
3446 * The end result is usually an increase in operation required, but because the
3447 * carry is now linearized, other transforms can kick in and optimize the DAG.
3448 *
3449 * Patterns typically look something like
3450 * (uaddo A, B)
3451 * / \
3452 * Carry Sum
3453 * | \
3454 * | (uaddo_carry *, 0, Z)
3455 * | /
3456 * \ Carry
3457 * | /
3458 * (uaddo_carry X, *, *)
3459 *
3460 * But numerous variation exist. Our goal is to identify A, B, X and Z and
3461 * produce a combine with a single path for carry propagation.
3462 */
3463static SDValue combineUADDO_CARRYDiamond(DAGCombiner &Combiner,
3464 SelectionDAG &DAG, SDValue X,
3465 SDValue Carry0, SDValue Carry1,
3466 SDNode *N) {
3467 if (Carry1.getResNo() != 1 || Carry0.getResNo() != 1)
3468 return SDValue();
3469 if (Carry1.getOpcode() != ISD::UADDO)
3470 return SDValue();
3471
3472 SDValue Z;
3473
3474 /**
3475 * First look for a suitable Z. It will present itself in the form of
3476 * (uaddo_carry Y, 0, Z) or its equivalent (uaddo Y, 1) for Z=true
3477 */
3478 if (Carry0.getOpcode() == ISD::UADDO_CARRY &&
3479 isNullConstant(V: Carry0.getOperand(i: 1))) {
3480 Z = Carry0.getOperand(i: 2);
3481 } else if (Carry0.getOpcode() == ISD::UADDO &&
3482 isOneConstant(V: Carry0.getOperand(i: 1))) {
3483 EVT VT = Carry0->getValueType(ResNo: 1);
3484 Z = DAG.getConstant(Val: 1, DL: SDLoc(Carry0.getOperand(i: 1)), VT);
3485 } else {
3486 // We couldn't find a suitable Z.
3487 return SDValue();
3488 }
3489
3490
3491 auto cancelDiamond = [&](SDValue A,SDValue B) {
3492 SDLoc DL(N);
3493 SDValue NewY =
3494 DAG.getNode(Opcode: ISD::UADDO_CARRY, DL, VTList: Carry0->getVTList(), N1: A, N2: B, N3: Z);
3495 Combiner.AddToWorklist(N: NewY.getNode());
3496 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL, VTList: N->getVTList(), N1: X,
3497 N2: DAG.getConstant(Val: 0, DL, VT: X.getValueType()),
3498 N3: NewY.getValue(R: 1));
3499 };
3500
3501 /**
3502 * (uaddo A, B)
3503 * |
3504 * Sum
3505 * |
3506 * (uaddo_carry *, 0, Z)
3507 */
3508 if (Carry0.getOperand(i: 0) == Carry1.getValue(R: 0)) {
3509 return cancelDiamond(Carry1.getOperand(i: 0), Carry1.getOperand(i: 1));
3510 }
3511
3512 /**
3513 * (uaddo_carry A, 0, Z)
3514 * |
3515 * Sum
3516 * |
3517 * (uaddo *, B)
3518 */
3519 if (Carry1.getOperand(i: 0) == Carry0.getValue(R: 0)) {
3520 return cancelDiamond(Carry0.getOperand(i: 0), Carry1.getOperand(i: 1));
3521 }
3522
3523 if (Carry1.getOperand(i: 1) == Carry0.getValue(R: 0)) {
3524 return cancelDiamond(Carry1.getOperand(i: 0), Carry0.getOperand(i: 0));
3525 }
3526
3527 return SDValue();
3528}
3529
3530// If we are facing some sort of diamond carry/borrow in/out pattern try to
3531// match patterns like:
3532//
3533// (uaddo A, B) CarryIn
3534// | \ |
3535// | \ |
3536// PartialSum PartialCarryOutX /
3537// | | /
3538// | ____|____________/
3539// | / |
3540// (uaddo *, *) \________
3541// | \ \
3542// | \ |
3543// | PartialCarryOutY |
3544// | \ |
3545// | \ /
3546// AddCarrySum | ______/
3547// | /
3548// CarryOut = (or *, *)
3549//
3550// And generate UADDO_CARRY (or USUBO_CARRY) with two result values:
3551//
3552// {AddCarrySum, CarryOut} = (uaddo_carry A, B, CarryIn)
3553//
3554// Our goal is to identify A, B, and CarryIn and produce UADDO_CARRY/USUBO_CARRY
3555// with a single path for carry/borrow out propagation.
3556static SDValue combineCarryDiamond(SelectionDAG &DAG, const TargetLowering &TLI,
3557 SDValue N0, SDValue N1, SDNode *N) {
3558 SDValue Carry0 = getAsCarry(TLI, V: N0);
3559 if (!Carry0)
3560 return SDValue();
3561 SDValue Carry1 = getAsCarry(TLI, V: N1);
3562 if (!Carry1)
3563 return SDValue();
3564
3565 unsigned Opcode = Carry0.getOpcode();
3566 if (Opcode != Carry1.getOpcode())
3567 return SDValue();
3568 if (Opcode != ISD::UADDO && Opcode != ISD::USUBO)
3569 return SDValue();
3570 // Guarantee identical type of CarryOut
3571 EVT CarryOutType = N->getValueType(ResNo: 0);
3572 if (CarryOutType != Carry0.getValue(R: 1).getValueType() ||
3573 CarryOutType != Carry1.getValue(R: 1).getValueType())
3574 return SDValue();
3575
3576 // Canonicalize the add/sub of A and B (the top node in the above ASCII art)
3577 // as Carry0 and the add/sub of the carry in as Carry1 (the middle node).
3578 if (Carry1.getNode()->isOperandOf(N: Carry0.getNode()))
3579 std::swap(a&: Carry0, b&: Carry1);
3580
3581 // Check if nodes are connected in expected way.
3582 if (Carry1.getOperand(i: 0) != Carry0.getValue(R: 0) &&
3583 Carry1.getOperand(i: 1) != Carry0.getValue(R: 0))
3584 return SDValue();
3585
3586 // The carry in value must be on the righthand side for subtraction.
3587 unsigned CarryInOperandNum =
3588 Carry1.getOperand(i: 0) == Carry0.getValue(R: 0) ? 1 : 0;
3589 if (Opcode == ISD::USUBO && CarryInOperandNum != 1)
3590 return SDValue();
3591 SDValue CarryIn = Carry1.getOperand(i: CarryInOperandNum);
3592
3593 unsigned NewOp = Opcode == ISD::UADDO ? ISD::UADDO_CARRY : ISD::USUBO_CARRY;
3594 if (!TLI.isOperationLegalOrCustom(Op: NewOp, VT: Carry0.getValue(R: 0).getValueType()))
3595 return SDValue();
3596
3597 // Verify that the carry/borrow in is plausibly a carry/borrow bit.
3598 CarryIn = getAsCarry(TLI, V: CarryIn, ForceCarryReconstruction: true);
3599 if (!CarryIn)
3600 return SDValue();
3601
3602 SDLoc DL(N);
3603 CarryIn = DAG.getBoolExtOrTrunc(Op: CarryIn, SL: DL, VT: Carry1->getValueType(ResNo: 1),
3604 OpVT: Carry1->getValueType(ResNo: 0));
3605 SDValue Merged =
3606 DAG.getNode(Opcode: NewOp, DL, VTList: Carry1->getVTList(), N1: Carry0.getOperand(i: 0),
3607 N2: Carry0.getOperand(i: 1), N3: CarryIn);
3608
3609 // Please note that because we have proven that the result of the UADDO/USUBO
3610 // of A and B feeds into the UADDO/USUBO that does the carry/borrow in, we can
3611 // therefore prove that if the first UADDO/USUBO overflows, the second
3612 // UADDO/USUBO cannot. For example consider 8-bit numbers where 0xFF is the
3613 // maximum value.
3614 //
3615 // 0xFF + 0xFF == 0xFE with carry but 0xFE + 1 does not carry
3616 // 0x00 - 0xFF == 1 with a carry/borrow but 1 - 1 == 0 (no carry/borrow)
3617 //
3618 // This is important because it means that OR and XOR can be used to merge
3619 // carry flags; and that AND can return a constant zero.
3620 //
3621 // TODO: match other operations that can merge flags (ADD, etc)
3622 DAG.ReplaceAllUsesOfValueWith(From: Carry1.getValue(R: 0), To: Merged.getValue(R: 0));
3623 if (N->getOpcode() == ISD::AND)
3624 return DAG.getConstant(Val: 0, DL, VT: CarryOutType);
3625 return Merged.getValue(R: 1);
3626}
3627
3628SDValue DAGCombiner::visitUADDO_CARRYLike(SDValue N0, SDValue N1,
3629 SDValue CarryIn, SDNode *N) {
3630 // fold (uaddo_carry (xor a, -1), b, c) -> (usubo_carry b, a, !c) and flip
3631 // carry.
3632 if (isBitwiseNot(V: N0))
3633 if (SDValue NotC = extractBooleanFlip(V: CarryIn, DAG, TLI, Force: true)) {
3634 SDLoc DL(N);
3635 SDValue Sub = DAG.getNode(Opcode: ISD::USUBO_CARRY, DL, VTList: N->getVTList(), N1,
3636 N2: N0.getOperand(i: 0), N3: NotC);
3637 return CombineTo(
3638 N, Res0: Sub, Res1: DAG.getLogicalNOT(DL, Val: Sub.getValue(R: 1), VT: Sub->getValueType(ResNo: 1)));
3639 }
3640
3641 // Iff the flag result is dead:
3642 // (uaddo_carry (add|uaddo X, Y), 0, Carry) -> (uaddo_carry X, Y, Carry)
3643 // Don't do this if the Carry comes from the uaddo. It won't remove the uaddo
3644 // or the dependency between the instructions.
3645 if ((N0.getOpcode() == ISD::ADD ||
3646 (N0.getOpcode() == ISD::UADDO && N0.getResNo() == 0 &&
3647 N0.getValue(R: 1) != CarryIn)) &&
3648 isNullConstant(V: N1) && !N->hasAnyUseOfValue(Value: 1))
3649 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL: SDLoc(N), VTList: N->getVTList(),
3650 N1: N0.getOperand(i: 0), N2: N0.getOperand(i: 1), N3: CarryIn);
3651
3652 /**
3653 * When one of the uaddo_carry argument is itself a carry, we may be facing
3654 * a diamond carry propagation. In which case we try to transform the DAG
3655 * to ensure linear carry propagation if that is possible.
3656 */
3657 if (auto Y = getAsCarry(TLI, V: N1)) {
3658 // Because both are carries, Y and Z can be swapped.
3659 if (auto R = combineUADDO_CARRYDiamond(Combiner&: *this, DAG, X: N0, Carry0: Y, Carry1: CarryIn, N))
3660 return R;
3661 if (auto R = combineUADDO_CARRYDiamond(Combiner&: *this, DAG, X: N0, Carry0: CarryIn, Carry1: Y, N))
3662 return R;
3663 }
3664
3665 return SDValue();
3666}
3667
3668SDValue DAGCombiner::visitSADDO_CARRYLike(SDValue N0, SDValue N1,
3669 SDValue CarryIn, SDNode *N) {
3670 // fold (saddo_carry (xor a, -1), b, c) -> (ssubo_carry b, a, !c)
3671 if (isBitwiseNot(V: N0)) {
3672 if (SDValue NotC = extractBooleanFlip(V: CarryIn, DAG, TLI, Force: true))
3673 return DAG.getNode(Opcode: ISD::SSUBO_CARRY, DL: SDLoc(N), VTList: N->getVTList(), N1,
3674 N2: N0.getOperand(i: 0), N3: NotC);
3675 }
3676
3677 return SDValue();
3678}
3679
3680SDValue DAGCombiner::visitSADDO_CARRY(SDNode *N) {
3681 SDValue N0 = N->getOperand(Num: 0);
3682 SDValue N1 = N->getOperand(Num: 1);
3683 SDValue CarryIn = N->getOperand(Num: 2);
3684 SDLoc DL(N);
3685
3686 // canonicalize constant to RHS
3687 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(Val&: N0);
3688 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(Val&: N1);
3689 if (N0C && !N1C)
3690 return DAG.getNode(Opcode: ISD::SADDO_CARRY, DL, VTList: N->getVTList(), N1, N2: N0, N3: CarryIn);
3691
3692 // fold (saddo_carry x, y, false) -> (saddo x, y)
3693 if (isNullConstant(V: CarryIn)) {
3694 if (!LegalOperations ||
3695 TLI.isOperationLegalOrCustom(Op: ISD::SADDO, VT: N->getValueType(ResNo: 0)))
3696 return DAG.getNode(Opcode: ISD::SADDO, DL, VTList: N->getVTList(), N1: N0, N2: N1);
3697 }
3698
3699 if (SDValue Combined = visitSADDO_CARRYLike(N0, N1, CarryIn, N))
3700 return Combined;
3701
3702 if (SDValue Combined = visitSADDO_CARRYLike(N0: N1, N1: N0, CarryIn, N))
3703 return Combined;
3704
3705 return SDValue();
3706}
3707
3708// Attempt to create a USUBSAT(LHS, RHS) node with DstVT, performing a
3709// clamp/truncation if necessary.
3710static SDValue getTruncatedUSUBSAT(EVT DstVT, EVT SrcVT, SDValue LHS,
3711 SDValue RHS, SelectionDAG &DAG,
3712 const SDLoc &DL) {
3713 assert(DstVT.getScalarSizeInBits() <= SrcVT.getScalarSizeInBits() &&
3714 "Illegal truncation");
3715
3716 if (DstVT == SrcVT)
3717 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT: DstVT, N1: LHS, N2: RHS);
3718
3719 // If the LHS is zero-extended then we can perform the USUBSAT as DstVT by
3720 // clamping RHS.
3721 APInt UpperBits = APInt::getBitsSetFrom(numBits: SrcVT.getScalarSizeInBits(),
3722 loBit: DstVT.getScalarSizeInBits());
3723 if (!DAG.MaskedValueIsZero(Op: LHS, Mask: UpperBits))
3724 return SDValue();
3725
3726 SDValue SatLimit =
3727 DAG.getConstant(Val: APInt::getLowBitsSet(numBits: SrcVT.getScalarSizeInBits(),
3728 loBitsSet: DstVT.getScalarSizeInBits()),
3729 DL, VT: SrcVT);
3730 RHS = DAG.getNode(Opcode: ISD::UMIN, DL, VT: SrcVT, N1: RHS, N2: SatLimit);
3731 RHS = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: DstVT, Operand: RHS);
3732 LHS = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: DstVT, Operand: LHS);
3733 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT: DstVT, N1: LHS, N2: RHS);
3734}
3735
3736// Try to find umax(a,b) - b or a - umin(a,b) patterns that may be converted to
3737// usubsat(a,b), optionally as a truncated type.
3738SDValue DAGCombiner::foldSubToUSubSat(EVT DstVT, SDNode *N, const SDLoc &DL) {
3739 if (N->getOpcode() != ISD::SUB ||
3740 !(!LegalOperations || hasOperation(Opcode: ISD::USUBSAT, VT: DstVT)))
3741 return SDValue();
3742
3743 EVT SubVT = N->getValueType(ResNo: 0);
3744 SDValue Op0 = N->getOperand(Num: 0);
3745 SDValue Op1 = N->getOperand(Num: 1);
3746
3747 // Try to find umax(a,b) - b or a - umin(a,b) patterns
3748 // they may be converted to usubsat(a,b).
3749 if (Op0.getOpcode() == ISD::UMAX && Op0.hasOneUse()) {
3750 SDValue MaxLHS = Op0.getOperand(i: 0);
3751 SDValue MaxRHS = Op0.getOperand(i: 1);
3752 if (MaxLHS == Op1)
3753 return getTruncatedUSUBSAT(DstVT, SrcVT: SubVT, LHS: MaxRHS, RHS: Op1, DAG, DL);
3754 if (MaxRHS == Op1)
3755 return getTruncatedUSUBSAT(DstVT, SrcVT: SubVT, LHS: MaxLHS, RHS: Op1, DAG, DL);
3756 }
3757
3758 if (Op1.getOpcode() == ISD::UMIN && Op1.hasOneUse()) {
3759 SDValue MinLHS = Op1.getOperand(i: 0);
3760 SDValue MinRHS = Op1.getOperand(i: 1);
3761 if (MinLHS == Op0)
3762 return getTruncatedUSUBSAT(DstVT, SrcVT: SubVT, LHS: Op0, RHS: MinRHS, DAG, DL);
3763 if (MinRHS == Op0)
3764 return getTruncatedUSUBSAT(DstVT, SrcVT: SubVT, LHS: Op0, RHS: MinLHS, DAG, DL);
3765 }
3766
3767 // sub(a,trunc(umin(zext(a),b))) -> usubsat(a,trunc(umin(b,SatLimit)))
3768 if (Op1.getOpcode() == ISD::TRUNCATE &&
3769 Op1.getOperand(i: 0).getOpcode() == ISD::UMIN &&
3770 Op1.getOperand(i: 0).hasOneUse()) {
3771 SDValue MinLHS = Op1.getOperand(i: 0).getOperand(i: 0);
3772 SDValue MinRHS = Op1.getOperand(i: 0).getOperand(i: 1);
3773 if (MinLHS.getOpcode() == ISD::ZERO_EXTEND && MinLHS.getOperand(i: 0) == Op0)
3774 return getTruncatedUSUBSAT(DstVT, SrcVT: MinLHS.getValueType(), LHS: MinLHS, RHS: MinRHS,
3775 DAG, DL);
3776 if (MinRHS.getOpcode() == ISD::ZERO_EXTEND && MinRHS.getOperand(i: 0) == Op0)
3777 return getTruncatedUSUBSAT(DstVT, SrcVT: MinLHS.getValueType(), LHS: MinRHS, RHS: MinLHS,
3778 DAG, DL);
3779 }
3780
3781 return SDValue();
3782}
3783
3784// Since it may not be valid to emit a fold to zero for vector initializers
3785// check if we can before folding.
3786static SDValue tryFoldToZero(const SDLoc &DL, const TargetLowering &TLI, EVT VT,
3787 SelectionDAG &DAG, bool LegalOperations) {
3788 if (!VT.isVector())
3789 return DAG.getConstant(Val: 0, DL, VT);
3790 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::BUILD_VECTOR, VT))
3791 return DAG.getConstant(Val: 0, DL, VT);
3792 return SDValue();
3793}
3794
3795SDValue DAGCombiner::visitSUB(SDNode *N) {
3796 SDValue N0 = N->getOperand(Num: 0);
3797 SDValue N1 = N->getOperand(Num: 1);
3798 EVT VT = N0.getValueType();
3799 unsigned BitWidth = VT.getScalarSizeInBits();
3800 SDLoc DL(N);
3801
3802 auto PeekThroughFreeze = [](SDValue N) {
3803 if (N->getOpcode() == ISD::FREEZE && N.hasOneUse())
3804 return N->getOperand(Num: 0);
3805 return N;
3806 };
3807
3808 // fold (sub x, x) -> 0
3809 // FIXME: Refactor this and xor and other similar operations together.
3810 if (PeekThroughFreeze(N0) == PeekThroughFreeze(N1))
3811 return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
3812
3813 // fold (sub c1, c2) -> c3
3814 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::SUB, DL, VT, Ops: {N0, N1}))
3815 return C;
3816
3817 // fold vector ops
3818 if (VT.isVector()) {
3819 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
3820 return FoldedVOp;
3821
3822 // fold (sub x, 0) -> x, vector edition
3823 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
3824 return N0;
3825 }
3826
3827 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
3828 return NewSel;
3829
3830 // fold (sub x, c) -> (add x, -c)
3831 if (ConstantSDNode *N1C = getAsNonOpaqueConstant(N: N1))
3832 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0,
3833 N2: DAG.getConstant(Val: -N1C->getAPIntValue(), DL, VT));
3834
3835 if (isNullOrNullSplat(V: N0)) {
3836 // Right-shifting everything out but the sign bit followed by negation is
3837 // the same as flipping arithmetic/logical shift type without the negation:
3838 // -(X >>u 31) -> (X >>s 31)
3839 // -(X >>s 31) -> (X >>u 31)
3840 if (N1->getOpcode() == ISD::SRA || N1->getOpcode() == ISD::SRL) {
3841 ConstantSDNode *ShiftAmt = isConstOrConstSplat(N: N1.getOperand(i: 1));
3842 if (ShiftAmt && ShiftAmt->getAPIntValue() == (BitWidth - 1)) {
3843 auto NewSh = N1->getOpcode() == ISD::SRA ? ISD::SRL : ISD::SRA;
3844 if (!LegalOperations || TLI.isOperationLegal(Op: NewSh, VT))
3845 return DAG.getNode(Opcode: NewSh, DL, VT, N1: N1.getOperand(i: 0), N2: N1.getOperand(i: 1));
3846 }
3847 }
3848
3849 // 0 - X --> 0 if the sub is NUW.
3850 if (N->getFlags().hasNoUnsignedWrap())
3851 return N0;
3852
3853 if (DAG.MaskedValueIsZero(Op: N1, Mask: ~APInt::getSignMask(BitWidth))) {
3854 // N1 is either 0 or the minimum signed value. If the sub is NSW, then
3855 // N1 must be 0 because negating the minimum signed value is undefined.
3856 if (N->getFlags().hasNoSignedWrap())
3857 return N0;
3858
3859 // 0 - X --> X if X is 0 or the minimum signed value.
3860 return N1;
3861 }
3862
3863 // Convert 0 - abs(x).
3864 if (N1.getOpcode() == ISD::ABS && N1.hasOneUse() &&
3865 !TLI.isOperationLegalOrCustom(Op: ISD::ABS, VT))
3866 if (SDValue Result = TLI.expandABS(N: N1.getNode(), DAG, IsNegative: true))
3867 return Result;
3868
3869 // Fold neg(splat(neg(x)) -> splat(x)
3870 if (VT.isVector()) {
3871 SDValue N1S = DAG.getSplatValue(V: N1, LegalTypes: true);
3872 if (N1S && N1S.getOpcode() == ISD::SUB &&
3873 isNullConstant(V: N1S.getOperand(i: 0)))
3874 return DAG.getSplat(VT, DL, Op: N1S.getOperand(i: 1));
3875 }
3876 }
3877
3878 // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1)
3879 if (isAllOnesOrAllOnesSplat(V: N0))
3880 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1, N2: N0);
3881
3882 // fold (A - (0-B)) -> A+B
3883 if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(V: N1.getOperand(i: 0)))
3884 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1.getOperand(i: 1));
3885
3886 // fold A-(A-B) -> B
3887 if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(i: 0))
3888 return N1.getOperand(i: 1);
3889
3890 // fold (A+B)-A -> B
3891 if (N0.getOpcode() == ISD::ADD && N0.getOperand(i: 0) == N1)
3892 return N0.getOperand(i: 1);
3893
3894 // fold (A+B)-B -> A
3895 if (N0.getOpcode() == ISD::ADD && N0.getOperand(i: 1) == N1)
3896 return N0.getOperand(i: 0);
3897
3898 // fold (A+C1)-C2 -> A+(C1-C2)
3899 if (N0.getOpcode() == ISD::ADD) {
3900 SDValue N01 = N0.getOperand(i: 1);
3901 if (SDValue NewC = DAG.FoldConstantArithmetic(Opcode: ISD::SUB, DL, VT, Ops: {N01, N1}))
3902 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 0), N2: NewC);
3903 }
3904
3905 // fold C2-(A+C1) -> (C2-C1)-A
3906 if (N1.getOpcode() == ISD::ADD) {
3907 SDValue N11 = N1.getOperand(i: 1);
3908 if (SDValue NewC = DAG.FoldConstantArithmetic(Opcode: ISD::SUB, DL, VT, Ops: {N0, N11}))
3909 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: NewC, N2: N1.getOperand(i: 0));
3910 }
3911
3912 // fold (A-C1)-C2 -> A-(C1+C2)
3913 if (N0.getOpcode() == ISD::SUB) {
3914 SDValue N01 = N0.getOperand(i: 1);
3915 if (SDValue NewC = DAG.FoldConstantArithmetic(Opcode: ISD::ADD, DL, VT, Ops: {N01, N1}))
3916 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0.getOperand(i: 0), N2: NewC);
3917 }
3918
3919 // fold (c1-A)-c2 -> (c1-c2)-A
3920 if (N0.getOpcode() == ISD::SUB) {
3921 SDValue N00 = N0.getOperand(i: 0);
3922 if (SDValue NewC = DAG.FoldConstantArithmetic(Opcode: ISD::SUB, DL, VT, Ops: {N00, N1}))
3923 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: NewC, N2: N0.getOperand(i: 1));
3924 }
3925
3926 SDValue A, B, C;
3927
3928 // fold ((A+(B+C))-B) -> A+C
3929 if (sd_match(N: N0, P: m_Add(L: m_Value(N&: A), R: m_Add(L: m_Specific(N: N1), R: m_Value(N&: C)))))
3930 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: A, N2: C);
3931
3932 // fold ((A+(B-C))-B) -> A-C
3933 if (sd_match(N: N0, P: m_Add(L: m_Value(N&: A), R: m_Sub(L: m_Specific(N: N1), R: m_Value(N&: C)))))
3934 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: A, N2: C);
3935
3936 // fold ((A-(B-C))-C) -> A-B
3937 if (sd_match(N: N0, P: m_Sub(L: m_Value(N&: A), R: m_Sub(L: m_Value(N&: B), R: m_Specific(N: N1)))))
3938 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: A, N2: B);
3939
3940 // fold (A-(B-C)) -> A+(C-B)
3941 if (sd_match(N: N1, P: m_OneUse(P: m_Sub(L: m_Value(N&: B), R: m_Value(N&: C)))))
3942 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0,
3943 N2: DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: C, N2: B));
3944
3945 // A - (A & B) -> A & (~B)
3946 if (sd_match(N: N1, P: m_And(L: m_Specific(N: N0), R: m_Value(N&: B))) &&
3947 (N1.hasOneUse() || isConstantOrConstantVector(N: B, /*NoOpaques=*/true)))
3948 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0, N2: DAG.getNOT(DL, Val: B, VT));
3949
3950 // fold (A - (-B * C)) -> (A + (B * C))
3951 if (sd_match(N: N1, P: m_OneUse(P: m_Mul(L: m_Neg(V: m_Value(N&: B)), R: m_Value(N&: C)))))
3952 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0,
3953 N2: DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: B, N2: C));
3954
3955 // If either operand of a sub is undef, the result is undef
3956 if (N0.isUndef())
3957 return N0;
3958 if (N1.isUndef())
3959 return N1;
3960
3961 if (SDValue V = foldAddSubBoolOfMaskedVal(N, DL, DAG))
3962 return V;
3963
3964 if (SDValue V = foldAddSubOfSignBit(N, DL, DAG))
3965 return V;
3966
3967 // Try to match AVGCEIL fixedwidth pattern
3968 if (SDValue V = foldSubToAvg(N, DL))
3969 return V;
3970
3971 if (SDValue V = foldAddSubMasked1(IsAdd: false, N0, N1, DAG, DL))
3972 return V;
3973
3974 if (SDValue V = foldSubToUSubSat(DstVT: VT, N, DL))
3975 return V;
3976
3977 // (A - B) - 1 -> add (xor B, -1), A
3978 if (sd_match(N, P: m_Sub(L: m_OneUse(P: m_Sub(L: m_Value(N&: A), R: m_Value(N&: B))), R: m_One())))
3979 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: A, N2: DAG.getNOT(DL, Val: B, VT));
3980
3981 // Look for:
3982 // sub y, (xor x, -1)
3983 // And if the target does not like this form then turn into:
3984 // add (add x, y), 1
3985 if (TLI.preferIncOfAddToSubOfNot(VT) && N1.hasOneUse() && isBitwiseNot(V: N1)) {
3986 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1.getOperand(i: 0));
3987 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Add, N2: DAG.getConstant(Val: 1, DL, VT));
3988 }
3989
3990 // Hoist one-use addition by non-opaque constant:
3991 // (x + C) - y -> (x - y) + C
3992 if (!reassociationCanBreakAddressingModePattern(Opc: ISD::SUB, DL, N, N0, N1) &&
3993 N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
3994 isConstantOrConstantVector(N: N0.getOperand(i: 1), /*NoOpaques=*/true)) {
3995 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0.getOperand(i: 0), N2: N1);
3996 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Sub, N2: N0.getOperand(i: 1));
3997 }
3998 // y - (x + C) -> (y - x) - C
3999 if (N1.getOpcode() == ISD::ADD && N1.hasOneUse() &&
4000 isConstantOrConstantVector(N: N1.getOperand(i: 1), /*NoOpaques=*/true)) {
4001 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: N1.getOperand(i: 0));
4002 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Sub, N2: N1.getOperand(i: 1));
4003 }
4004 // (x - C) - y -> (x - y) - C
4005 // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
4006 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
4007 isConstantOrConstantVector(N: N0.getOperand(i: 1), /*NoOpaques=*/true)) {
4008 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0.getOperand(i: 0), N2: N1);
4009 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Sub, N2: N0.getOperand(i: 1));
4010 }
4011 // (C - x) - y -> C - (x + y)
4012 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
4013 isConstantOrConstantVector(N: N0.getOperand(i: 0), /*NoOpaques=*/true)) {
4014 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 1), N2: N1);
4015 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0.getOperand(i: 0), N2: Add);
4016 }
4017
4018 // If the target's bool is represented as 0/-1, prefer to make this 'add 0/-1'
4019 // rather than 'sub 0/1' (the sext should get folded).
4020 // sub X, (zext i1 Y) --> add X, (sext i1 Y)
4021 if (N1.getOpcode() == ISD::ZERO_EXTEND &&
4022 N1.getOperand(i: 0).getScalarValueSizeInBits() == 1 &&
4023 TLI.getBooleanContents(Type: VT) ==
4024 TargetLowering::ZeroOrNegativeOneBooleanContent) {
4025 SDValue SExt = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: N1.getOperand(i: 0));
4026 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: SExt);
4027 }
4028
4029 // fold B = sra (A, size(A)-1); sub (xor (A, B), B) -> (abs A)
4030 if ((!LegalOperations || hasOperation(Opcode: ISD::ABS, VT)) &&
4031 sd_match(N: N1, P: m_Sra(L: m_Value(N&: A), R: m_SpecificInt(V: BitWidth - 1))) &&
4032 sd_match(N: N0, P: m_Xor(L: m_Specific(N: A), R: m_Specific(N: N1))))
4033 return DAG.getNode(Opcode: ISD::ABS, DL, VT, Operand: A);
4034
4035 // If the relocation model supports it, consider symbol offsets.
4036 if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(Val&: N0))
4037 if (!LegalOperations && TLI.isOffsetFoldingLegal(GA)) {
4038 // fold (sub Sym+c1, Sym+c2) -> c1-c2
4039 if (GlobalAddressSDNode *GB = dyn_cast<GlobalAddressSDNode>(Val&: N1))
4040 if (GA->getGlobal() == GB->getGlobal())
4041 return DAG.getConstant(Val: (uint64_t)GA->getOffset() - GB->getOffset(),
4042 DL, VT);
4043 }
4044
4045 // sub X, (sextinreg Y i1) -> add X, (and Y 1)
4046 if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
4047 VTSDNode *TN = cast<VTSDNode>(Val: N1.getOperand(i: 1));
4048 if (TN->getVT() == MVT::i1) {
4049 SDValue ZExt = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N1.getOperand(i: 0),
4050 N2: DAG.getConstant(Val: 1, DL, VT));
4051 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: ZExt);
4052 }
4053 }
4054
4055 // canonicalize (sub X, (vscale * C)) to (add X, (vscale * -C))
4056 if (N1.getOpcode() == ISD::VSCALE && N1.hasOneUse()) {
4057 const APInt &IntVal = N1.getConstantOperandAPInt(i: 0);
4058 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: DAG.getVScale(DL, VT, MulImm: -IntVal));
4059 }
4060
4061 // canonicalize (sub X, step_vector(C)) to (add X, step_vector(-C))
4062 if (N1.getOpcode() == ISD::STEP_VECTOR && N1.hasOneUse()) {
4063 APInt NewStep = -N1.getConstantOperandAPInt(i: 0);
4064 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0,
4065 N2: DAG.getStepVector(DL, ResVT: VT, StepVal: NewStep));
4066 }
4067
4068 // Prefer an add for more folding potential and possibly better codegen:
4069 // sub N0, (lshr N10, width-1) --> add N0, (ashr N10, width-1)
4070 if (!LegalOperations && N1.getOpcode() == ISD::SRL && N1.hasOneUse()) {
4071 SDValue ShAmt = N1.getOperand(i: 1);
4072 ConstantSDNode *ShAmtC = isConstOrConstSplat(N: ShAmt);
4073 if (ShAmtC && ShAmtC->getAPIntValue() == (BitWidth - 1)) {
4074 SDValue SRA = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: N1.getOperand(i: 0), N2: ShAmt);
4075 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: SRA);
4076 }
4077 }
4078
4079 // As with the previous fold, prefer add for more folding potential.
4080 // Subtracting SMIN/0 is the same as adding SMIN/0:
4081 // N0 - (X << BW-1) --> N0 + (X << BW-1)
4082 if (N1.getOpcode() == ISD::SHL) {
4083 ConstantSDNode *ShlC = isConstOrConstSplat(N: N1.getOperand(i: 1));
4084 if (ShlC && ShlC->getAPIntValue() == (BitWidth - 1))
4085 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1, N2: N0);
4086 }
4087
4088 // (sub (usubo_carry X, 0, Carry), Y) -> (usubo_carry X, Y, Carry)
4089 if (N0.getOpcode() == ISD::USUBO_CARRY && isNullConstant(V: N0.getOperand(i: 1)) &&
4090 N0.getResNo() == 0 && N0.hasOneUse())
4091 return DAG.getNode(Opcode: ISD::USUBO_CARRY, DL, VTList: N0->getVTList(),
4092 N1: N0.getOperand(i: 0), N2: N1, N3: N0.getOperand(i: 2));
4093
4094 if (TLI.isOperationLegalOrCustom(Op: ISD::UADDO_CARRY, VT)) {
4095 // (sub Carry, X) -> (uaddo_carry (sub 0, X), 0, Carry)
4096 if (SDValue Carry = getAsCarry(TLI, V: N0)) {
4097 SDValue X = N1;
4098 SDValue Zero = DAG.getConstant(Val: 0, DL, VT);
4099 SDValue NegX = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Zero, N2: X);
4100 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL,
4101 VTList: DAG.getVTList(VT1: VT, VT2: Carry.getValueType()), N1: NegX, N2: Zero,
4102 N3: Carry);
4103 }
4104 }
4105
4106 // If there's no chance of borrowing from adjacent bits, then sub is xor:
4107 // sub C0, X --> xor X, C0
4108 if (ConstantSDNode *C0 = isConstOrConstSplat(N: N0)) {
4109 if (!C0->isOpaque()) {
4110 const APInt &C0Val = C0->getAPIntValue();
4111 const APInt &MaybeOnes = ~DAG.computeKnownBits(Op: N1).Zero;
4112 if ((C0Val - MaybeOnes) == (C0Val ^ MaybeOnes))
4113 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1, N2: N0);
4114 }
4115 }
4116
4117 // smax(a,b) - smin(a,b) --> abds(a,b)
4118 if (hasOperation(Opcode: ISD::ABDS, VT) &&
4119 sd_match(N: N0, P: m_SMax(L: m_Value(N&: A), R: m_Value(N&: B))) &&
4120 sd_match(N: N1, P: m_SMin(L: m_Specific(N: A), R: m_Specific(N: B))))
4121 return DAG.getNode(Opcode: ISD::ABDS, DL, VT, N1: A, N2: B);
4122
4123 // umax(a,b) - umin(a,b) --> abdu(a,b)
4124 if (hasOperation(Opcode: ISD::ABDU, VT) &&
4125 sd_match(N: N0, P: m_UMax(L: m_Value(N&: A), R: m_Value(N&: B))) &&
4126 sd_match(N: N1, P: m_UMin(L: m_Specific(N: A), R: m_Specific(N: B))))
4127 return DAG.getNode(Opcode: ISD::ABDU, DL, VT, N1: A, N2: B);
4128
4129 return SDValue();
4130}
4131
4132SDValue DAGCombiner::visitSUBSAT(SDNode *N) {
4133 unsigned Opcode = N->getOpcode();
4134 SDValue N0 = N->getOperand(Num: 0);
4135 SDValue N1 = N->getOperand(Num: 1);
4136 EVT VT = N0.getValueType();
4137 bool IsSigned = Opcode == ISD::SSUBSAT;
4138 SDLoc DL(N);
4139
4140 // fold (sub_sat x, undef) -> 0
4141 if (N0.isUndef() || N1.isUndef())
4142 return DAG.getConstant(Val: 0, DL, VT);
4143
4144 // fold (sub_sat x, x) -> 0
4145 if (N0 == N1)
4146 return DAG.getConstant(Val: 0, DL, VT);
4147
4148 // fold (sub_sat c1, c2) -> c3
4149 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, Ops: {N0, N1}))
4150 return C;
4151
4152 // fold vector ops
4153 if (VT.isVector()) {
4154 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4155 return FoldedVOp;
4156
4157 // fold (sub_sat x, 0) -> x, vector edition
4158 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
4159 return N0;
4160 }
4161
4162 // fold (sub_sat x, 0) -> x
4163 if (isNullConstant(V: N1))
4164 return N0;
4165
4166 // If it cannot overflow, transform into an sub.
4167 if (DAG.willNotOverflowSub(IsSigned, N0, N1))
4168 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: N1);
4169
4170 return SDValue();
4171}
4172
4173SDValue DAGCombiner::visitSUBC(SDNode *N) {
4174 SDValue N0 = N->getOperand(Num: 0);
4175 SDValue N1 = N->getOperand(Num: 1);
4176 EVT VT = N0.getValueType();
4177 SDLoc DL(N);
4178
4179 // If the flag result is dead, turn this into an SUB.
4180 if (!N->hasAnyUseOfValue(Value: 1))
4181 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: N1),
4182 Res1: DAG.getNode(Opcode: ISD::CARRY_FALSE, DL, VT: MVT::Glue));
4183
4184 // fold (subc x, x) -> 0 + no borrow
4185 if (N0 == N1)
4186 return CombineTo(N, Res0: DAG.getConstant(Val: 0, DL, VT),
4187 Res1: DAG.getNode(Opcode: ISD::CARRY_FALSE, DL, VT: MVT::Glue));
4188
4189 // fold (subc x, 0) -> x + no borrow
4190 if (isNullConstant(V: N1))
4191 return CombineTo(N, Res0: N0, Res1: DAG.getNode(Opcode: ISD::CARRY_FALSE, DL, VT: MVT::Glue));
4192
4193 // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1) + no borrow
4194 if (isAllOnesConstant(V: N0))
4195 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::XOR, DL, VT, N1, N2: N0),
4196 Res1: DAG.getNode(Opcode: ISD::CARRY_FALSE, DL, VT: MVT::Glue));
4197
4198 return SDValue();
4199}
4200
4201SDValue DAGCombiner::visitSUBO(SDNode *N) {
4202 SDValue N0 = N->getOperand(Num: 0);
4203 SDValue N1 = N->getOperand(Num: 1);
4204 EVT VT = N0.getValueType();
4205 bool IsSigned = (ISD::SSUBO == N->getOpcode());
4206
4207 EVT CarryVT = N->getValueType(ResNo: 1);
4208 SDLoc DL(N);
4209
4210 // If the flag result is dead, turn this into an SUB.
4211 if (!N->hasAnyUseOfValue(Value: 1))
4212 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: N1),
4213 Res1: DAG.getUNDEF(VT: CarryVT));
4214
4215 // fold (subo x, x) -> 0 + no borrow
4216 if (N0 == N1)
4217 return CombineTo(N, Res0: DAG.getConstant(Val: 0, DL, VT),
4218 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
4219
4220 // fold (subox, c) -> (addo x, -c)
4221 if (ConstantSDNode *N1C = getAsNonOpaqueConstant(N: N1))
4222 if (IsSigned && !N1C->isMinSignedValue())
4223 return DAG.getNode(Opcode: ISD::SADDO, DL, VTList: N->getVTList(), N1: N0,
4224 N2: DAG.getConstant(Val: -N1C->getAPIntValue(), DL, VT));
4225
4226 // fold (subo x, 0) -> x + no borrow
4227 if (isNullOrNullSplat(V: N1))
4228 return CombineTo(N, Res0: N0, Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
4229
4230 // If it cannot overflow, transform into an sub.
4231 if (DAG.willNotOverflowSub(IsSigned, N0, N1))
4232 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: N1),
4233 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
4234
4235 // Canonicalize (usubo -1, x) -> ~x, i.e. (xor x, -1) + no borrow
4236 if (!IsSigned && isAllOnesOrAllOnesSplat(V: N0))
4237 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::XOR, DL, VT, N1, N2: N0),
4238 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
4239
4240 return SDValue();
4241}
4242
4243SDValue DAGCombiner::visitSUBE(SDNode *N) {
4244 SDValue N0 = N->getOperand(Num: 0);
4245 SDValue N1 = N->getOperand(Num: 1);
4246 SDValue CarryIn = N->getOperand(Num: 2);
4247
4248 // fold (sube x, y, false) -> (subc x, y)
4249 if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
4250 return DAG.getNode(Opcode: ISD::SUBC, DL: SDLoc(N), VTList: N->getVTList(), N1: N0, N2: N1);
4251
4252 return SDValue();
4253}
4254
4255SDValue DAGCombiner::visitUSUBO_CARRY(SDNode *N) {
4256 SDValue N0 = N->getOperand(Num: 0);
4257 SDValue N1 = N->getOperand(Num: 1);
4258 SDValue CarryIn = N->getOperand(Num: 2);
4259
4260 // fold (usubo_carry x, y, false) -> (usubo x, y)
4261 if (isNullConstant(V: CarryIn)) {
4262 if (!LegalOperations ||
4263 TLI.isOperationLegalOrCustom(Op: ISD::USUBO, VT: N->getValueType(ResNo: 0)))
4264 return DAG.getNode(Opcode: ISD::USUBO, DL: SDLoc(N), VTList: N->getVTList(), N1: N0, N2: N1);
4265 }
4266
4267 return SDValue();
4268}
4269
4270SDValue DAGCombiner::visitSSUBO_CARRY(SDNode *N) {
4271 SDValue N0 = N->getOperand(Num: 0);
4272 SDValue N1 = N->getOperand(Num: 1);
4273 SDValue CarryIn = N->getOperand(Num: 2);
4274
4275 // fold (ssubo_carry x, y, false) -> (ssubo x, y)
4276 if (isNullConstant(V: CarryIn)) {
4277 if (!LegalOperations ||
4278 TLI.isOperationLegalOrCustom(Op: ISD::SSUBO, VT: N->getValueType(ResNo: 0)))
4279 return DAG.getNode(Opcode: ISD::SSUBO, DL: SDLoc(N), VTList: N->getVTList(), N1: N0, N2: N1);
4280 }
4281
4282 return SDValue();
4283}
4284
4285// Notice that "mulfix" can be any of SMULFIX, SMULFIXSAT, UMULFIX and
4286// UMULFIXSAT here.
4287SDValue DAGCombiner::visitMULFIX(SDNode *N) {
4288 SDValue N0 = N->getOperand(Num: 0);
4289 SDValue N1 = N->getOperand(Num: 1);
4290 SDValue Scale = N->getOperand(Num: 2);
4291 EVT VT = N0.getValueType();
4292
4293 // fold (mulfix x, undef, scale) -> 0
4294 if (N0.isUndef() || N1.isUndef())
4295 return DAG.getConstant(Val: 0, DL: SDLoc(N), VT);
4296
4297 // Canonicalize constant to RHS (vector doesn't have to splat)
4298 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
4299 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
4300 return DAG.getNode(Opcode: N->getOpcode(), DL: SDLoc(N), VT, N1, N2: N0, N3: Scale);
4301
4302 // fold (mulfix x, 0, scale) -> 0
4303 if (isNullConstant(V: N1))
4304 return DAG.getConstant(Val: 0, DL: SDLoc(N), VT);
4305
4306 return SDValue();
4307}
4308
4309template <class MatchContextClass> SDValue DAGCombiner::visitMUL(SDNode *N) {
4310 SDValue N0 = N->getOperand(Num: 0);
4311 SDValue N1 = N->getOperand(Num: 1);
4312 EVT VT = N0.getValueType();
4313 unsigned BitWidth = VT.getScalarSizeInBits();
4314 SDLoc DL(N);
4315 bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
4316 MatchContextClass Matcher(DAG, TLI, N);
4317
4318 // fold (mul x, undef) -> 0
4319 if (N0.isUndef() || N1.isUndef())
4320 return DAG.getConstant(Val: 0, DL, VT);
4321
4322 // fold (mul c1, c2) -> c1*c2
4323 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::MUL, DL, VT, Ops: {N0, N1}))
4324 return C;
4325
4326 // canonicalize constant to RHS (vector doesn't have to splat)
4327 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
4328 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
4329 return Matcher.getNode(ISD::MUL, DL, VT, N1, N0);
4330
4331 bool N1IsConst = false;
4332 bool N1IsOpaqueConst = false;
4333 APInt ConstValue1;
4334
4335 // fold vector ops
4336 if (VT.isVector()) {
4337 // TODO: Change this to use SimplifyVBinOp when it supports VP op.
4338 if (!UseVP)
4339 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4340 return FoldedVOp;
4341
4342 N1IsConst = ISD::isConstantSplatVector(N: N1.getNode(), SplatValue&: ConstValue1);
4343 assert((!N1IsConst || ConstValue1.getBitWidth() == BitWidth) &&
4344 "Splat APInt should be element width");
4345 } else {
4346 N1IsConst = isa<ConstantSDNode>(Val: N1);
4347 if (N1IsConst) {
4348 ConstValue1 = N1->getAsAPIntVal();
4349 N1IsOpaqueConst = cast<ConstantSDNode>(Val&: N1)->isOpaque();
4350 }
4351 }
4352
4353 // fold (mul x, 0) -> 0
4354 if (N1IsConst && ConstValue1.isZero())
4355 return N1;
4356
4357 // fold (mul x, 1) -> x
4358 if (N1IsConst && ConstValue1.isOne())
4359 return N0;
4360
4361 if (!UseVP)
4362 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
4363 return NewSel;
4364
4365 // fold (mul x, -1) -> 0-x
4366 if (N1IsConst && ConstValue1.isAllOnes())
4367 return Matcher.getNode(ISD::SUB, DL, VT, DAG.getConstant(Val: 0, DL, VT), N0);
4368
4369 // fold (mul x, (1 << c)) -> x << c
4370 if (isConstantOrConstantVector(N: N1, /*NoOpaques*/ true) &&
4371 (!VT.isVector() || Level <= AfterLegalizeVectorOps)) {
4372 if (SDValue LogBase2 = BuildLogBase2(V: N1, DL)) {
4373 EVT ShiftVT = getShiftAmountTy(LHSTy: N0.getValueType());
4374 SDValue Trunc = DAG.getZExtOrTrunc(Op: LogBase2, DL, VT: ShiftVT);
4375 return Matcher.getNode(ISD::SHL, DL, VT, N0, Trunc);
4376 }
4377 }
4378
4379 // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c
4380 if (N1IsConst && !N1IsOpaqueConst && ConstValue1.isNegatedPowerOf2()) {
4381 unsigned Log2Val = (-ConstValue1).logBase2();
4382
4383 // FIXME: If the input is something that is easily negated (e.g. a
4384 // single-use add), we should put the negate there.
4385 return Matcher.getNode(
4386 ISD::SUB, DL, VT, DAG.getConstant(Val: 0, DL, VT),
4387 Matcher.getNode(ISD::SHL, DL, VT, N0,
4388 DAG.getShiftAmountConstant(Val: Log2Val, VT, DL)));
4389 }
4390
4391 // Attempt to reuse an existing umul_lohi/smul_lohi node, but only if the
4392 // hi result is in use in case we hit this mid-legalization.
4393 if (!UseVP) {
4394 for (unsigned LoHiOpc : {ISD::UMUL_LOHI, ISD::SMUL_LOHI}) {
4395 if (!LegalOperations || TLI.isOperationLegalOrCustom(Op: LoHiOpc, VT)) {
4396 SDVTList LoHiVT = DAG.getVTList(VT1: VT, VT2: VT);
4397 // TODO: Can we match commutable operands with getNodeIfExists?
4398 if (SDNode *LoHi = DAG.getNodeIfExists(Opcode: LoHiOpc, VTList: LoHiVT, Ops: {N0, N1}))
4399 if (LoHi->hasAnyUseOfValue(Value: 1))
4400 return SDValue(LoHi, 0);
4401 if (SDNode *LoHi = DAG.getNodeIfExists(Opcode: LoHiOpc, VTList: LoHiVT, Ops: {N1, N0}))
4402 if (LoHi->hasAnyUseOfValue(Value: 1))
4403 return SDValue(LoHi, 0);
4404 }
4405 }
4406 }
4407
4408 // Try to transform:
4409 // (1) multiply-by-(power-of-2 +/- 1) into shift and add/sub.
4410 // mul x, (2^N + 1) --> add (shl x, N), x
4411 // mul x, (2^N - 1) --> sub (shl x, N), x
4412 // Examples: x * 33 --> (x << 5) + x
4413 // x * 15 --> (x << 4) - x
4414 // x * -33 --> -((x << 5) + x)
4415 // x * -15 --> -((x << 4) - x) ; this reduces --> x - (x << 4)
4416 // (2) multiply-by-(power-of-2 +/- power-of-2) into shifts and add/sub.
4417 // mul x, (2^N + 2^M) --> (add (shl x, N), (shl x, M))
4418 // mul x, (2^N - 2^M) --> (sub (shl x, N), (shl x, M))
4419 // Examples: x * 0x8800 --> (x << 15) + (x << 11)
4420 // x * 0xf800 --> (x << 16) - (x << 11)
4421 // x * -0x8800 --> -((x << 15) + (x << 11))
4422 // x * -0xf800 --> -((x << 16) - (x << 11)) ; (x << 11) - (x << 16)
4423 if (!UseVP && N1IsConst &&
4424 TLI.decomposeMulByConstant(Context&: *DAG.getContext(), VT, C: N1)) {
4425 // TODO: We could handle more general decomposition of any constant by
4426 // having the target set a limit on number of ops and making a
4427 // callback to determine that sequence (similar to sqrt expansion).
4428 unsigned MathOp = ISD::DELETED_NODE;
4429 APInt MulC = ConstValue1.abs();
4430 // The constant `2` should be treated as (2^0 + 1).
4431 unsigned TZeros = MulC == 2 ? 0 : MulC.countr_zero();
4432 MulC.lshrInPlace(ShiftAmt: TZeros);
4433 if ((MulC - 1).isPowerOf2())
4434 MathOp = ISD::ADD;
4435 else if ((MulC + 1).isPowerOf2())
4436 MathOp = ISD::SUB;
4437
4438 if (MathOp != ISD::DELETED_NODE) {
4439 unsigned ShAmt =
4440 MathOp == ISD::ADD ? (MulC - 1).logBase2() : (MulC + 1).logBase2();
4441 ShAmt += TZeros;
4442 assert(ShAmt < BitWidth &&
4443 "multiply-by-constant generated out of bounds shift");
4444 SDValue Shl =
4445 DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0, N2: DAG.getConstant(Val: ShAmt, DL, VT));
4446 SDValue R =
4447 TZeros ? DAG.getNode(Opcode: MathOp, DL, VT, N1: Shl,
4448 N2: DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0,
4449 N2: DAG.getConstant(Val: TZeros, DL, VT)))
4450 : DAG.getNode(Opcode: MathOp, DL, VT, N1: Shl, N2: N0);
4451 if (ConstValue1.isNegative())
4452 R = DAG.getNegative(Val: R, DL, VT);
4453 return R;
4454 }
4455 }
4456
4457 // (mul (shl X, c1), c2) -> (mul X, c2 << c1)
4458 if (sd_context_match(N0, Matcher, m_Opc(Opcode: ISD::SHL))) {
4459 SDValue N01 = N0.getOperand(i: 1);
4460 if (SDValue C3 = DAG.FoldConstantArithmetic(Opcode: ISD::SHL, DL, VT, Ops: {N1, N01}))
4461 return DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: N0.getOperand(i: 0), N2: C3);
4462 }
4463
4464 // Change (mul (shl X, C), Y) -> (shl (mul X, Y), C) when the shift has one
4465 // use.
4466 {
4467 SDValue Sh, Y;
4468
4469 // Check for both (mul (shl X, C), Y) and (mul Y, (shl X, C)).
4470 if (sd_context_match(N0, Matcher, m_OneUse(P: m_Opc(Opcode: ISD::SHL))) &&
4471 isConstantOrConstantVector(N: N0.getOperand(i: 1))) {
4472 Sh = N0; Y = N1;
4473 } else if (sd_context_match(N1, Matcher, m_OneUse(P: m_Opc(Opcode: ISD::SHL))) &&
4474 isConstantOrConstantVector(N: N1.getOperand(i: 1))) {
4475 Sh = N1; Y = N0;
4476 }
4477
4478 if (Sh.getNode()) {
4479 SDValue Mul = Matcher.getNode(ISD::MUL, DL, VT, Sh.getOperand(i: 0), Y);
4480 return Matcher.getNode(ISD::SHL, DL, VT, Mul, Sh.getOperand(i: 1));
4481 }
4482 }
4483
4484 // fold (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2)
4485 if (sd_context_match(N0, Matcher, m_Opc(Opcode: ISD::ADD)) &&
4486 DAG.isConstantIntBuildVectorOrConstantInt(N: N1) &&
4487 DAG.isConstantIntBuildVectorOrConstantInt(N: N0.getOperand(i: 1)) &&
4488 isMulAddWithConstProfitable(MulNode: N, AddNode: N0, ConstNode: N1))
4489 return Matcher.getNode(
4490 ISD::ADD, DL, VT,
4491 Matcher.getNode(ISD::MUL, SDLoc(N0), VT, N0.getOperand(i: 0), N1),
4492 Matcher.getNode(ISD::MUL, SDLoc(N1), VT, N0.getOperand(i: 1), N1));
4493
4494 // Fold (mul (vscale * C0), C1) to (vscale * (C0 * C1)).
4495 ConstantSDNode *NC1 = isConstOrConstSplat(N: N1);
4496 if (!UseVP && N0.getOpcode() == ISD::VSCALE && NC1) {
4497 const APInt &C0 = N0.getConstantOperandAPInt(i: 0);
4498 const APInt &C1 = NC1->getAPIntValue();
4499 return DAG.getVScale(DL, VT, MulImm: C0 * C1);
4500 }
4501
4502 // Fold (mul step_vector(C0), C1) to (step_vector(C0 * C1)).
4503 APInt MulVal;
4504 if (!UseVP && N0.getOpcode() == ISD::STEP_VECTOR &&
4505 ISD::isConstantSplatVector(N: N1.getNode(), SplatValue&: MulVal)) {
4506 const APInt &C0 = N0.getConstantOperandAPInt(i: 0);
4507 APInt NewStep = C0 * MulVal;
4508 return DAG.getStepVector(DL, ResVT: VT, StepVal: NewStep);
4509 }
4510
4511 // Fold Y = sra (X, size(X)-1); mul (or (Y, 1), X) -> (abs X)
4512 SDValue X;
4513 if (!UseVP && (!LegalOperations || hasOperation(Opcode: ISD::ABS, VT)) &&
4514 sd_context_match(
4515 N, Matcher,
4516 m_Mul(L: m_Or(L: m_Sra(L: m_Value(N&: X), R: m_SpecificInt(V: BitWidth - 1)), R: m_One()),
4517 R: m_Deferred(V&: X)))) {
4518 return Matcher.getNode(ISD::ABS, DL, VT, X);
4519 }
4520
4521 // Fold ((mul x, 0/undef) -> 0,
4522 // (mul x, 1) -> x) -> x)
4523 // -> and(x, mask)
4524 // We can replace vectors with '0' and '1' factors with a clearing mask.
4525 if (VT.isFixedLengthVector()) {
4526 unsigned NumElts = VT.getVectorNumElements();
4527 SmallBitVector ClearMask;
4528 ClearMask.reserve(N: NumElts);
4529 auto IsClearMask = [&ClearMask](ConstantSDNode *V) {
4530 if (!V || V->isZero()) {
4531 ClearMask.push_back(Val: true);
4532 return true;
4533 }
4534 ClearMask.push_back(Val: false);
4535 return V->isOne();
4536 };
4537 if ((!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::AND, VT)) &&
4538 ISD::matchUnaryPredicate(Op: N1, Match: IsClearMask, /*AllowUndefs*/ true)) {
4539 assert(N1.getOpcode() == ISD::BUILD_VECTOR && "Unknown constant vector");
4540 EVT LegalSVT = N1.getOperand(i: 0).getValueType();
4541 SDValue Zero = DAG.getConstant(Val: 0, DL, VT: LegalSVT);
4542 SDValue AllOnes = DAG.getAllOnesConstant(DL, VT: LegalSVT);
4543 SmallVector<SDValue, 16> Mask(NumElts, AllOnes);
4544 for (unsigned I = 0; I != NumElts; ++I)
4545 if (ClearMask[I])
4546 Mask[I] = Zero;
4547 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0, N2: DAG.getBuildVector(VT, DL, Ops: Mask));
4548 }
4549 }
4550
4551 // reassociate mul
4552 // TODO: Change reassociateOps to support vp ops.
4553 if (!UseVP)
4554 if (SDValue RMUL = reassociateOps(Opc: ISD::MUL, DL, N0, N1, Flags: N->getFlags()))
4555 return RMUL;
4556
4557 // Fold mul(vecreduce(x), vecreduce(y)) -> vecreduce(mul(x, y))
4558 // TODO: Change reassociateReduction to support vp ops.
4559 if (!UseVP)
4560 if (SDValue SD =
4561 reassociateReduction(RedOpc: ISD::VECREDUCE_MUL, Opc: ISD::MUL, DL, VT, N0, N1))
4562 return SD;
4563
4564 // Simplify the operands using demanded-bits information.
4565 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
4566 return SDValue(N, 0);
4567
4568 return SDValue();
4569}
4570
4571/// Return true if divmod libcall is available.
4572static bool isDivRemLibcallAvailable(SDNode *Node, bool isSigned,
4573 const TargetLowering &TLI) {
4574 RTLIB::Libcall LC;
4575 EVT NodeType = Node->getValueType(ResNo: 0);
4576 if (!NodeType.isSimple())
4577 return false;
4578 switch (NodeType.getSimpleVT().SimpleTy) {
4579 default: return false; // No libcall for vector types.
4580 case MVT::i8: LC= isSigned ? RTLIB::SDIVREM_I8 : RTLIB::UDIVREM_I8; break;
4581 case MVT::i16: LC= isSigned ? RTLIB::SDIVREM_I16 : RTLIB::UDIVREM_I16; break;
4582 case MVT::i32: LC= isSigned ? RTLIB::SDIVREM_I32 : RTLIB::UDIVREM_I32; break;
4583 case MVT::i64: LC= isSigned ? RTLIB::SDIVREM_I64 : RTLIB::UDIVREM_I64; break;
4584 case MVT::i128: LC= isSigned ? RTLIB::SDIVREM_I128:RTLIB::UDIVREM_I128; break;
4585 }
4586
4587 return TLI.getLibcallName(Call: LC) != nullptr;
4588}
4589
4590/// Issue divrem if both quotient and remainder are needed.
4591SDValue DAGCombiner::useDivRem(SDNode *Node) {
4592 if (Node->use_empty())
4593 return SDValue(); // This is a dead node, leave it alone.
4594
4595 unsigned Opcode = Node->getOpcode();
4596 bool isSigned = (Opcode == ISD::SDIV) || (Opcode == ISD::SREM);
4597 unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM;
4598
4599 // DivMod lib calls can still work on non-legal types if using lib-calls.
4600 EVT VT = Node->getValueType(ResNo: 0);
4601 if (VT.isVector() || !VT.isInteger())
4602 return SDValue();
4603
4604 if (!TLI.isTypeLegal(VT) && !TLI.isOperationCustom(Op: DivRemOpc, VT))
4605 return SDValue();
4606
4607 // If DIVREM is going to get expanded into a libcall,
4608 // but there is no libcall available, then don't combine.
4609 if (!TLI.isOperationLegalOrCustom(Op: DivRemOpc, VT) &&
4610 !isDivRemLibcallAvailable(Node, isSigned, TLI))
4611 return SDValue();
4612
4613 // If div is legal, it's better to do the normal expansion
4614 unsigned OtherOpcode = 0;
4615 if ((Opcode == ISD::SDIV) || (Opcode == ISD::UDIV)) {
4616 OtherOpcode = isSigned ? ISD::SREM : ISD::UREM;
4617 if (TLI.isOperationLegalOrCustom(Op: Opcode, VT))
4618 return SDValue();
4619 } else {
4620 OtherOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
4621 if (TLI.isOperationLegalOrCustom(Op: OtherOpcode, VT))
4622 return SDValue();
4623 }
4624
4625 SDValue Op0 = Node->getOperand(Num: 0);
4626 SDValue Op1 = Node->getOperand(Num: 1);
4627 SDValue combined;
4628 for (SDNode *User : Op0->uses()) {
4629 if (User == Node || User->getOpcode() == ISD::DELETED_NODE ||
4630 User->use_empty())
4631 continue;
4632 // Convert the other matching node(s), too;
4633 // otherwise, the DIVREM may get target-legalized into something
4634 // target-specific that we won't be able to recognize.
4635 unsigned UserOpc = User->getOpcode();
4636 if ((UserOpc == Opcode || UserOpc == OtherOpcode || UserOpc == DivRemOpc) &&
4637 User->getOperand(Num: 0) == Op0 &&
4638 User->getOperand(Num: 1) == Op1) {
4639 if (!combined) {
4640 if (UserOpc == OtherOpcode) {
4641 SDVTList VTs = DAG.getVTList(VT1: VT, VT2: VT);
4642 combined = DAG.getNode(Opcode: DivRemOpc, DL: SDLoc(Node), VTList: VTs, N1: Op0, N2: Op1);
4643 } else if (UserOpc == DivRemOpc) {
4644 combined = SDValue(User, 0);
4645 } else {
4646 assert(UserOpc == Opcode);
4647 continue;
4648 }
4649 }
4650 if (UserOpc == ISD::SDIV || UserOpc == ISD::UDIV)
4651 CombineTo(N: User, Res: combined);
4652 else if (UserOpc == ISD::SREM || UserOpc == ISD::UREM)
4653 CombineTo(N: User, Res: combined.getValue(R: 1));
4654 }
4655 }
4656 return combined;
4657}
4658
4659static SDValue simplifyDivRem(SDNode *N, SelectionDAG &DAG) {
4660 SDValue N0 = N->getOperand(Num: 0);
4661 SDValue N1 = N->getOperand(Num: 1);
4662 EVT VT = N->getValueType(ResNo: 0);
4663 SDLoc DL(N);
4664
4665 unsigned Opc = N->getOpcode();
4666 bool IsDiv = (ISD::SDIV == Opc) || (ISD::UDIV == Opc);
4667 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
4668
4669 // X / undef -> undef
4670 // X % undef -> undef
4671 // X / 0 -> undef
4672 // X % 0 -> undef
4673 // NOTE: This includes vectors where any divisor element is zero/undef.
4674 if (DAG.isUndef(Opcode: Opc, Ops: {N0, N1}))
4675 return DAG.getUNDEF(VT);
4676
4677 // undef / X -> 0
4678 // undef % X -> 0
4679 if (N0.isUndef())
4680 return DAG.getConstant(Val: 0, DL, VT);
4681
4682 // 0 / X -> 0
4683 // 0 % X -> 0
4684 ConstantSDNode *N0C = isConstOrConstSplat(N: N0);
4685 if (N0C && N0C->isZero())
4686 return N0;
4687
4688 // X / X -> 1
4689 // X % X -> 0
4690 if (N0 == N1)
4691 return DAG.getConstant(Val: IsDiv ? 1 : 0, DL, VT);
4692
4693 // X / 1 -> X
4694 // X % 1 -> 0
4695 // If this is a boolean op (single-bit element type), we can't have
4696 // division-by-zero or remainder-by-zero, so assume the divisor is 1.
4697 // TODO: Similarly, if we're zero-extending a boolean divisor, then assume
4698 // it's a 1.
4699 if ((N1C && N1C->isOne()) || (VT.getScalarType() == MVT::i1))
4700 return IsDiv ? N0 : DAG.getConstant(Val: 0, DL, VT);
4701
4702 return SDValue();
4703}
4704
4705SDValue DAGCombiner::visitSDIV(SDNode *N) {
4706 SDValue N0 = N->getOperand(Num: 0);
4707 SDValue N1 = N->getOperand(Num: 1);
4708 EVT VT = N->getValueType(ResNo: 0);
4709 EVT CCVT = getSetCCResultType(VT);
4710 SDLoc DL(N);
4711
4712 // fold (sdiv c1, c2) -> c1/c2
4713 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::SDIV, DL, VT, Ops: {N0, N1}))
4714 return C;
4715
4716 // fold vector ops
4717 if (VT.isVector())
4718 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4719 return FoldedVOp;
4720
4721 // fold (sdiv X, -1) -> 0-X
4722 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
4723 if (N1C && N1C->isAllOnes())
4724 return DAG.getNegative(Val: N0, DL, VT);
4725
4726 // fold (sdiv X, MIN_SIGNED) -> select(X == MIN_SIGNED, 1, 0)
4727 if (N1C && N1C->isMinSignedValue())
4728 return DAG.getSelect(DL, VT, Cond: DAG.getSetCC(DL, VT: CCVT, LHS: N0, RHS: N1, Cond: ISD::SETEQ),
4729 LHS: DAG.getConstant(Val: 1, DL, VT),
4730 RHS: DAG.getConstant(Val: 0, DL, VT));
4731
4732 if (SDValue V = simplifyDivRem(N, DAG))
4733 return V;
4734
4735 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
4736 return NewSel;
4737
4738 // If we know the sign bits of both operands are zero, strength reduce to a
4739 // udiv instead. Handles (X&15) /s 4 -> X&15 >> 2
4740 if (DAG.SignBitIsZero(Op: N1) && DAG.SignBitIsZero(Op: N0))
4741 return DAG.getNode(Opcode: ISD::UDIV, DL, VT: N1.getValueType(), N1: N0, N2: N1);
4742
4743 if (SDValue V = visitSDIVLike(N0, N1, N)) {
4744 // If the corresponding remainder node exists, update its users with
4745 // (Dividend - (Quotient * Divisor).
4746 if (SDNode *RemNode = DAG.getNodeIfExists(Opcode: ISD::SREM, VTList: N->getVTList(),
4747 Ops: { N0, N1 })) {
4748 SDValue Mul = DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: V, N2: N1);
4749 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: Mul);
4750 AddToWorklist(N: Mul.getNode());
4751 AddToWorklist(N: Sub.getNode());
4752 CombineTo(N: RemNode, Res: Sub);
4753 }
4754 return V;
4755 }
4756
4757 // sdiv, srem -> sdivrem
4758 // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
4759 // true. Otherwise, we break the simplification logic in visitREM().
4760 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4761 if (!N1C || TLI.isIntDivCheap(VT: N->getValueType(ResNo: 0), Attr))
4762 if (SDValue DivRem = useDivRem(Node: N))
4763 return DivRem;
4764
4765 return SDValue();
4766}
4767
4768static bool isDivisorPowerOfTwo(SDValue Divisor) {
4769 // Helper for determining whether a value is a power-2 constant scalar or a
4770 // vector of such elements.
4771 auto IsPowerOfTwo = [](ConstantSDNode *C) {
4772 if (C->isZero() || C->isOpaque())
4773 return false;
4774 if (C->getAPIntValue().isPowerOf2())
4775 return true;
4776 if (C->getAPIntValue().isNegatedPowerOf2())
4777 return true;
4778 return false;
4779 };
4780
4781 return ISD::matchUnaryPredicate(Op: Divisor, Match: IsPowerOfTwo);
4782}
4783
4784SDValue DAGCombiner::visitSDIVLike(SDValue N0, SDValue N1, SDNode *N) {
4785 SDLoc DL(N);
4786 EVT VT = N->getValueType(ResNo: 0);
4787 EVT CCVT = getSetCCResultType(VT);
4788 unsigned BitWidth = VT.getScalarSizeInBits();
4789
4790 // fold (sdiv X, pow2) -> simple ops after legalize
4791 // FIXME: We check for the exact bit here because the generic lowering gives
4792 // better results in that case. The target-specific lowering should learn how
4793 // to handle exact sdivs efficiently.
4794 if (!N->getFlags().hasExact() && isDivisorPowerOfTwo(Divisor: N1)) {
4795 // Target-specific implementation of sdiv x, pow2.
4796 if (SDValue Res = BuildSDIVPow2(N))
4797 return Res;
4798
4799 // Create constants that are functions of the shift amount value.
4800 EVT ShiftAmtTy = getShiftAmountTy(LHSTy: N0.getValueType());
4801 SDValue Bits = DAG.getConstant(Val: BitWidth, DL, VT: ShiftAmtTy);
4802 SDValue C1 = DAG.getNode(Opcode: ISD::CTTZ, DL, VT, Operand: N1);
4803 C1 = DAG.getZExtOrTrunc(Op: C1, DL, VT: ShiftAmtTy);
4804 SDValue Inexact = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftAmtTy, N1: Bits, N2: C1);
4805 if (!isConstantOrConstantVector(N: Inexact))
4806 return SDValue();
4807
4808 // Splat the sign bit into the register
4809 SDValue Sign = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: N0,
4810 N2: DAG.getConstant(Val: BitWidth - 1, DL, VT: ShiftAmtTy));
4811 AddToWorklist(N: Sign.getNode());
4812
4813 // Add (N0 < 0) ? abs2 - 1 : 0;
4814 SDValue Srl = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Sign, N2: Inexact);
4815 AddToWorklist(N: Srl.getNode());
4816 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: Srl);
4817 AddToWorklist(N: Add.getNode());
4818 SDValue Sra = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: Add, N2: C1);
4819 AddToWorklist(N: Sra.getNode());
4820
4821 // Special case: (sdiv X, 1) -> X
4822 // Special Case: (sdiv X, -1) -> 0-X
4823 SDValue One = DAG.getConstant(Val: 1, DL, VT);
4824 SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
4825 SDValue IsOne = DAG.getSetCC(DL, VT: CCVT, LHS: N1, RHS: One, Cond: ISD::SETEQ);
4826 SDValue IsAllOnes = DAG.getSetCC(DL, VT: CCVT, LHS: N1, RHS: AllOnes, Cond: ISD::SETEQ);
4827 SDValue IsOneOrAllOnes = DAG.getNode(Opcode: ISD::OR, DL, VT: CCVT, N1: IsOne, N2: IsAllOnes);
4828 Sra = DAG.getSelect(DL, VT, Cond: IsOneOrAllOnes, LHS: N0, RHS: Sra);
4829
4830 // If dividing by a positive value, we're done. Otherwise, the result must
4831 // be negated.
4832 SDValue Zero = DAG.getConstant(Val: 0, DL, VT);
4833 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Zero, N2: Sra);
4834
4835 // FIXME: Use SELECT_CC once we improve SELECT_CC constant-folding.
4836 SDValue IsNeg = DAG.getSetCC(DL, VT: CCVT, LHS: N1, RHS: Zero, Cond: ISD::SETLT);
4837 SDValue Res = DAG.getSelect(DL, VT, Cond: IsNeg, LHS: Sub, RHS: Sra);
4838 return Res;
4839 }
4840
4841 // If integer divide is expensive and we satisfy the requirements, emit an
4842 // alternate sequence. Targets may check function attributes for size/speed
4843 // trade-offs.
4844 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4845 if (isConstantOrConstantVector(N: N1) &&
4846 !TLI.isIntDivCheap(VT: N->getValueType(ResNo: 0), Attr))
4847 if (SDValue Op = BuildSDIV(N))
4848 return Op;
4849
4850 return SDValue();
4851}
4852
4853SDValue DAGCombiner::visitUDIV(SDNode *N) {
4854 SDValue N0 = N->getOperand(Num: 0);
4855 SDValue N1 = N->getOperand(Num: 1);
4856 EVT VT = N->getValueType(ResNo: 0);
4857 EVT CCVT = getSetCCResultType(VT);
4858 SDLoc DL(N);
4859
4860 // fold (udiv c1, c2) -> c1/c2
4861 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::UDIV, DL, VT, Ops: {N0, N1}))
4862 return C;
4863
4864 // fold vector ops
4865 if (VT.isVector())
4866 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4867 return FoldedVOp;
4868
4869 // fold (udiv X, -1) -> select(X == -1, 1, 0)
4870 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
4871 if (N1C && N1C->isAllOnes() && CCVT.isVector() == VT.isVector()) {
4872 return DAG.getSelect(DL, VT, Cond: DAG.getSetCC(DL, VT: CCVT, LHS: N0, RHS: N1, Cond: ISD::SETEQ),
4873 LHS: DAG.getConstant(Val: 1, DL, VT),
4874 RHS: DAG.getConstant(Val: 0, DL, VT));
4875 }
4876
4877 if (SDValue V = simplifyDivRem(N, DAG))
4878 return V;
4879
4880 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
4881 return NewSel;
4882
4883 if (SDValue V = visitUDIVLike(N0, N1, N)) {
4884 // If the corresponding remainder node exists, update its users with
4885 // (Dividend - (Quotient * Divisor).
4886 if (SDNode *RemNode = DAG.getNodeIfExists(Opcode: ISD::UREM, VTList: N->getVTList(),
4887 Ops: { N0, N1 })) {
4888 SDValue Mul = DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: V, N2: N1);
4889 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: Mul);
4890 AddToWorklist(N: Mul.getNode());
4891 AddToWorklist(N: Sub.getNode());
4892 CombineTo(N: RemNode, Res: Sub);
4893 }
4894 return V;
4895 }
4896
4897 // sdiv, srem -> sdivrem
4898 // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
4899 // true. Otherwise, we break the simplification logic in visitREM().
4900 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4901 if (!N1C || TLI.isIntDivCheap(VT: N->getValueType(ResNo: 0), Attr))
4902 if (SDValue DivRem = useDivRem(Node: N))
4903 return DivRem;
4904
4905 return SDValue();
4906}
4907
4908SDValue DAGCombiner::visitUDIVLike(SDValue N0, SDValue N1, SDNode *N) {
4909 SDLoc DL(N);
4910 EVT VT = N->getValueType(ResNo: 0);
4911
4912 // fold (udiv x, (1 << c)) -> x >>u c
4913 if (isConstantOrConstantVector(N: N1, /*NoOpaques*/ true)) {
4914 if (SDValue LogBase2 = BuildLogBase2(V: N1, DL)) {
4915 AddToWorklist(N: LogBase2.getNode());
4916
4917 EVT ShiftVT = getShiftAmountTy(LHSTy: N0.getValueType());
4918 SDValue Trunc = DAG.getZExtOrTrunc(Op: LogBase2, DL, VT: ShiftVT);
4919 AddToWorklist(N: Trunc.getNode());
4920 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0, N2: Trunc);
4921 }
4922 }
4923
4924 // fold (udiv x, (shl c, y)) -> x >>u (log2(c)+y) iff c is power of 2
4925 if (N1.getOpcode() == ISD::SHL) {
4926 SDValue N10 = N1.getOperand(i: 0);
4927 if (isConstantOrConstantVector(N: N10, /*NoOpaques*/ true)) {
4928 if (SDValue LogBase2 = BuildLogBase2(V: N10, DL)) {
4929 AddToWorklist(N: LogBase2.getNode());
4930
4931 EVT ADDVT = N1.getOperand(i: 1).getValueType();
4932 SDValue Trunc = DAG.getZExtOrTrunc(Op: LogBase2, DL, VT: ADDVT);
4933 AddToWorklist(N: Trunc.getNode());
4934 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT: ADDVT, N1: N1.getOperand(i: 1), N2: Trunc);
4935 AddToWorklist(N: Add.getNode());
4936 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0, N2: Add);
4937 }
4938 }
4939 }
4940
4941 // fold (udiv x, c) -> alternate
4942 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4943 if (isConstantOrConstantVector(N: N1) &&
4944 !TLI.isIntDivCheap(VT: N->getValueType(ResNo: 0), Attr))
4945 if (SDValue Op = BuildUDIV(N))
4946 return Op;
4947
4948 return SDValue();
4949}
4950
4951SDValue DAGCombiner::buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N) {
4952 if (!N->getFlags().hasExact() && isDivisorPowerOfTwo(Divisor: N1) &&
4953 !DAG.doesNodeExist(Opcode: ISD::SDIV, VTList: N->getVTList(), Ops: {N0, N1})) {
4954 // Target-specific implementation of srem x, pow2.
4955 if (SDValue Res = BuildSREMPow2(N))
4956 return Res;
4957 }
4958 return SDValue();
4959}
4960
4961// handles ISD::SREM and ISD::UREM
4962SDValue DAGCombiner::visitREM(SDNode *N) {
4963 unsigned Opcode = N->getOpcode();
4964 SDValue N0 = N->getOperand(Num: 0);
4965 SDValue N1 = N->getOperand(Num: 1);
4966 EVT VT = N->getValueType(ResNo: 0);
4967 EVT CCVT = getSetCCResultType(VT);
4968
4969 bool isSigned = (Opcode == ISD::SREM);
4970 SDLoc DL(N);
4971
4972 // fold (rem c1, c2) -> c1%c2
4973 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, Ops: {N0, N1}))
4974 return C;
4975
4976 // fold (urem X, -1) -> select(FX == -1, 0, FX)
4977 // Freeze the numerator to avoid a miscompile with an undefined value.
4978 if (!isSigned && llvm::isAllOnesOrAllOnesSplat(V: N1, /*AllowUndefs*/ false) &&
4979 CCVT.isVector() == VT.isVector()) {
4980 SDValue F0 = DAG.getFreeze(V: N0);
4981 SDValue EqualsNeg1 = DAG.getSetCC(DL, VT: CCVT, LHS: F0, RHS: N1, Cond: ISD::SETEQ);
4982 return DAG.getSelect(DL, VT, Cond: EqualsNeg1, LHS: DAG.getConstant(Val: 0, DL, VT), RHS: F0);
4983 }
4984
4985 if (SDValue V = simplifyDivRem(N, DAG))
4986 return V;
4987
4988 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
4989 return NewSel;
4990
4991 if (isSigned) {
4992 // If we know the sign bits of both operands are zero, strength reduce to a
4993 // urem instead. Handles (X & 0x0FFFFFFF) %s 16 -> X&15
4994 if (DAG.SignBitIsZero(Op: N1) && DAG.SignBitIsZero(Op: N0))
4995 return DAG.getNode(Opcode: ISD::UREM, DL, VT, N1: N0, N2: N1);
4996 } else {
4997 if (DAG.isKnownToBeAPowerOfTwo(Val: N1)) {
4998 // fold (urem x, pow2) -> (and x, pow2-1)
4999 SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
5000 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1, N2: NegOne);
5001 AddToWorklist(N: Add.getNode());
5002 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0, N2: Add);
5003 }
5004 // fold (urem x, (shl pow2, y)) -> (and x, (add (shl pow2, y), -1))
5005 // fold (urem x, (lshr pow2, y)) -> (and x, (add (lshr pow2, y), -1))
5006 // TODO: We should sink the following into isKnownToBePowerOfTwo
5007 // using a OrZero parameter analogous to our handling in ValueTracking.
5008 if ((N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) &&
5009 DAG.isKnownToBeAPowerOfTwo(Val: N1.getOperand(i: 0))) {
5010 SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
5011 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1, N2: NegOne);
5012 AddToWorklist(N: Add.getNode());
5013 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0, N2: Add);
5014 }
5015 }
5016
5017 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
5018
5019 // If X/C can be simplified by the division-by-constant logic, lower
5020 // X%C to the equivalent of X-X/C*C.
5021 // Reuse the SDIVLike/UDIVLike combines - to avoid mangling nodes, the
5022 // speculative DIV must not cause a DIVREM conversion. We guard against this
5023 // by skipping the simplification if isIntDivCheap(). When div is not cheap,
5024 // combine will not return a DIVREM. Regardless, checking cheapness here
5025 // makes sense since the simplification results in fatter code.
5026 if (DAG.isKnownNeverZero(Op: N1) && !TLI.isIntDivCheap(VT, Attr)) {
5027 if (isSigned) {
5028 // check if we can build faster implementation for srem
5029 if (SDValue OptimizedRem = buildOptimizedSREM(N0, N1, N))
5030 return OptimizedRem;
5031 }
5032
5033 SDValue OptimizedDiv =
5034 isSigned ? visitSDIVLike(N0, N1, N) : visitUDIVLike(N0, N1, N);
5035 if (OptimizedDiv.getNode() && OptimizedDiv.getNode() != N) {
5036 // If the equivalent Div node also exists, update its users.
5037 unsigned DivOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
5038 if (SDNode *DivNode = DAG.getNodeIfExists(Opcode: DivOpcode, VTList: N->getVTList(),
5039 Ops: { N0, N1 }))
5040 CombineTo(N: DivNode, Res: OptimizedDiv);
5041 SDValue Mul = DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: OptimizedDiv, N2: N1);
5042 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: Mul);
5043 AddToWorklist(N: OptimizedDiv.getNode());
5044 AddToWorklist(N: Mul.getNode());
5045 return Sub;
5046 }
5047 }
5048
5049 // sdiv, srem -> sdivrem
5050 if (SDValue DivRem = useDivRem(Node: N))
5051 return DivRem.getValue(R: 1);
5052
5053 return SDValue();
5054}
5055
5056SDValue DAGCombiner::visitMULHS(SDNode *N) {
5057 SDValue N0 = N->getOperand(Num: 0);
5058 SDValue N1 = N->getOperand(Num: 1);
5059 EVT VT = N->getValueType(ResNo: 0);
5060 SDLoc DL(N);
5061
5062 // fold (mulhs c1, c2)
5063 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::MULHS, DL, VT, Ops: {N0, N1}))
5064 return C;
5065
5066 // canonicalize constant to RHS.
5067 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5068 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5069 return DAG.getNode(Opcode: ISD::MULHS, DL, VTList: N->getVTList(), N1, N2: N0);
5070
5071 if (VT.isVector()) {
5072 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5073 return FoldedVOp;
5074
5075 // fold (mulhs x, 0) -> 0
5076 // do not return N1, because undef node may exist.
5077 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
5078 return DAG.getConstant(Val: 0, DL, VT);
5079 }
5080
5081 // fold (mulhs x, 0) -> 0
5082 if (isNullConstant(V: N1))
5083 return N1;
5084
5085 // fold (mulhs x, 1) -> (sra x, size(x)-1)
5086 if (isOneConstant(V: N1))
5087 return DAG.getNode(
5088 Opcode: ISD::SRA, DL, VT, N1: N0,
5089 N2: DAG.getShiftAmountConstant(Val: N0.getScalarValueSizeInBits() - 1, VT, DL));
5090
5091 // fold (mulhs x, undef) -> 0
5092 if (N0.isUndef() || N1.isUndef())
5093 return DAG.getConstant(Val: 0, DL, VT);
5094
5095 // If the type twice as wide is legal, transform the mulhs to a wider multiply
5096 // plus a shift.
5097 if (!TLI.isOperationLegalOrCustom(Op: ISD::MULHS, VT) && VT.isSimple() &&
5098 !VT.isVector()) {
5099 MVT Simple = VT.getSimpleVT();
5100 unsigned SimpleSize = Simple.getSizeInBits();
5101 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: SimpleSize*2);
5102 if (TLI.isOperationLegal(Op: ISD::MUL, VT: NewVT)) {
5103 N0 = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT: NewVT, Operand: N0);
5104 N1 = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT: NewVT, Operand: N1);
5105 N1 = DAG.getNode(Opcode: ISD::MUL, DL, VT: NewVT, N1: N0, N2: N1);
5106 N1 = DAG.getNode(Opcode: ISD::SRL, DL, VT: NewVT, N1,
5107 N2: DAG.getShiftAmountConstant(Val: SimpleSize, VT: NewVT, DL));
5108 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N1);
5109 }
5110 }
5111
5112 return SDValue();
5113}
5114
5115SDValue DAGCombiner::visitMULHU(SDNode *N) {
5116 SDValue N0 = N->getOperand(Num: 0);
5117 SDValue N1 = N->getOperand(Num: 1);
5118 EVT VT = N->getValueType(ResNo: 0);
5119 SDLoc DL(N);
5120
5121 // fold (mulhu c1, c2)
5122 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::MULHU, DL, VT, Ops: {N0, N1}))
5123 return C;
5124
5125 // canonicalize constant to RHS.
5126 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5127 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5128 return DAG.getNode(Opcode: ISD::MULHU, DL, VTList: N->getVTList(), N1, N2: N0);
5129
5130 if (VT.isVector()) {
5131 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5132 return FoldedVOp;
5133
5134 // fold (mulhu x, 0) -> 0
5135 // do not return N1, because undef node may exist.
5136 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
5137 return DAG.getConstant(Val: 0, DL, VT);
5138 }
5139
5140 // fold (mulhu x, 0) -> 0
5141 if (isNullConstant(V: N1))
5142 return N1;
5143
5144 // fold (mulhu x, 1) -> 0
5145 if (isOneConstant(V: N1))
5146 return DAG.getConstant(Val: 0, DL, VT);
5147
5148 // fold (mulhu x, undef) -> 0
5149 if (N0.isUndef() || N1.isUndef())
5150 return DAG.getConstant(Val: 0, DL, VT);
5151
5152 // fold (mulhu x, (1 << c)) -> x >> (bitwidth - c)
5153 if (isConstantOrConstantVector(N: N1, /*NoOpaques*/ true) &&
5154 hasOperation(Opcode: ISD::SRL, VT)) {
5155 if (SDValue LogBase2 = BuildLogBase2(V: N1, DL)) {
5156 unsigned NumEltBits = VT.getScalarSizeInBits();
5157 SDValue SRLAmt = DAG.getNode(
5158 Opcode: ISD::SUB, DL, VT, N1: DAG.getConstant(Val: NumEltBits, DL, VT), N2: LogBase2);
5159 EVT ShiftVT = getShiftAmountTy(LHSTy: N0.getValueType());
5160 SDValue Trunc = DAG.getZExtOrTrunc(Op: SRLAmt, DL, VT: ShiftVT);
5161 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0, N2: Trunc);
5162 }
5163 }
5164
5165 // If the type twice as wide is legal, transform the mulhu to a wider multiply
5166 // plus a shift.
5167 if (!TLI.isOperationLegalOrCustom(Op: ISD::MULHU, VT) && VT.isSimple() &&
5168 !VT.isVector()) {
5169 MVT Simple = VT.getSimpleVT();
5170 unsigned SimpleSize = Simple.getSizeInBits();
5171 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: SimpleSize*2);
5172 if (TLI.isOperationLegal(Op: ISD::MUL, VT: NewVT)) {
5173 N0 = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: NewVT, Operand: N0);
5174 N1 = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: NewVT, Operand: N1);
5175 N1 = DAG.getNode(Opcode: ISD::MUL, DL, VT: NewVT, N1: N0, N2: N1);
5176 N1 = DAG.getNode(Opcode: ISD::SRL, DL, VT: NewVT, N1,
5177 N2: DAG.getShiftAmountConstant(Val: SimpleSize, VT: NewVT, DL));
5178 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N1);
5179 }
5180 }
5181
5182 // Simplify the operands using demanded-bits information.
5183 // We don't have demanded bits support for MULHU so this just enables constant
5184 // folding based on known bits.
5185 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
5186 return SDValue(N, 0);
5187
5188 return SDValue();
5189}
5190
5191SDValue DAGCombiner::visitAVG(SDNode *N) {
5192 unsigned Opcode = N->getOpcode();
5193 SDValue N0 = N->getOperand(Num: 0);
5194 SDValue N1 = N->getOperand(Num: 1);
5195 EVT VT = N->getValueType(ResNo: 0);
5196 SDLoc DL(N);
5197 bool IsSigned = Opcode == ISD::AVGCEILS || Opcode == ISD::AVGFLOORS;
5198
5199 // fold (avg c1, c2)
5200 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, Ops: {N0, N1}))
5201 return C;
5202
5203 // canonicalize constant to RHS.
5204 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5205 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5206 return DAG.getNode(Opcode, DL, VTList: N->getVTList(), N1, N2: N0);
5207
5208 if (VT.isVector())
5209 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5210 return FoldedVOp;
5211
5212 // fold (avg x, undef) -> x
5213 if (N0.isUndef())
5214 return N1;
5215 if (N1.isUndef())
5216 return N0;
5217
5218 // fold (avg x, x) --> x
5219 if (N0 == N1 && Level >= AfterLegalizeTypes)
5220 return N0;
5221
5222 // fold (avgfloor x, 0) -> x >> 1
5223 SDValue X, Y;
5224 if (sd_match(N, P: m_c_BinOp(Opc: ISD::AVGFLOORS, L: m_Value(N&: X), R: m_Zero())))
5225 return DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: X,
5226 N2: DAG.getShiftAmountConstant(Val: 1, VT, DL));
5227 if (sd_match(N, P: m_c_BinOp(Opc: ISD::AVGFLOORU, L: m_Value(N&: X), R: m_Zero())))
5228 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: X,
5229 N2: DAG.getShiftAmountConstant(Val: 1, VT, DL));
5230
5231 // fold avgu(zext(x), zext(y)) -> zext(avgu(x, y))
5232 // fold avgs(sext(x), sext(y)) -> sext(avgs(x, y))
5233 if (!IsSigned &&
5234 sd_match(N, P: m_BinOp(Opc: Opcode, L: m_ZExt(Op: m_Value(N&: X)), R: m_ZExt(Op: m_Value(N&: Y)))) &&
5235 X.getValueType() == Y.getValueType() &&
5236 hasOperation(Opcode, VT: X.getValueType())) {
5237 SDValue AvgU = DAG.getNode(Opcode, DL, VT: X.getValueType(), N1: X, N2: Y);
5238 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: AvgU);
5239 }
5240 if (IsSigned &&
5241 sd_match(N, P: m_BinOp(Opc: Opcode, L: m_SExt(Op: m_Value(N&: X)), R: m_SExt(Op: m_Value(N&: Y)))) &&
5242 X.getValueType() == Y.getValueType() &&
5243 hasOperation(Opcode, VT: X.getValueType())) {
5244 SDValue AvgS = DAG.getNode(Opcode, DL, VT: X.getValueType(), N1: X, N2: Y);
5245 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: AvgS);
5246 }
5247
5248 // Fold avgflooru(x,y) -> avgceilu(x,y-1) iff y != 0
5249 // Fold avgflooru(x,y) -> avgceilu(x-1,y) iff x != 0
5250 // Check if avgflooru isn't legal/custom but avgceilu is.
5251 if (Opcode == ISD::AVGFLOORU && !hasOperation(Opcode: ISD::AVGFLOORU, VT) &&
5252 (!LegalOperations || hasOperation(Opcode: ISD::AVGCEILU, VT))) {
5253 if (DAG.isKnownNeverZero(Op: N1))
5254 return DAG.getNode(
5255 Opcode: ISD::AVGCEILU, DL, VT, N1: N0,
5256 N2: DAG.getNode(Opcode: ISD::ADD, DL, VT, N1, N2: DAG.getAllOnesConstant(DL, VT)));
5257 if (DAG.isKnownNeverZero(Op: N0))
5258 return DAG.getNode(
5259 Opcode: ISD::AVGCEILU, DL, VT, N1,
5260 N2: DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: DAG.getAllOnesConstant(DL, VT)));
5261 }
5262
5263 return SDValue();
5264}
5265
5266SDValue DAGCombiner::visitABD(SDNode *N) {
5267 unsigned Opcode = N->getOpcode();
5268 SDValue N0 = N->getOperand(Num: 0);
5269 SDValue N1 = N->getOperand(Num: 1);
5270 EVT VT = N->getValueType(ResNo: 0);
5271 SDLoc DL(N);
5272
5273 // fold (abd c1, c2)
5274 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, Ops: {N0, N1}))
5275 return C;
5276
5277 // canonicalize constant to RHS.
5278 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5279 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5280 return DAG.getNode(Opcode, DL, VTList: N->getVTList(), N1, N2: N0);
5281
5282 if (VT.isVector())
5283 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5284 return FoldedVOp;
5285
5286 // fold (abd x, undef) -> 0
5287 if (N0.isUndef() || N1.isUndef())
5288 return DAG.getConstant(Val: 0, DL, VT);
5289
5290 SDValue X;
5291
5292 // fold (abds x, 0) -> abs x
5293 if (sd_match(N, P: m_c_BinOp(Opc: ISD::ABDS, L: m_Value(N&: X), R: m_Zero())) &&
5294 (!LegalOperations || hasOperation(Opcode: ISD::ABS, VT)))
5295 return DAG.getNode(Opcode: ISD::ABS, DL, VT, Operand: X);
5296
5297 // fold (abdu x, 0) -> x
5298 if (sd_match(N, P: m_c_BinOp(Opc: ISD::ABDU, L: m_Value(N&: X), R: m_Zero())))
5299 return X;
5300
5301 // fold (abds x, y) -> (abdu x, y) iff both args are known positive
5302 if (Opcode == ISD::ABDS && hasOperation(Opcode: ISD::ABDU, VT) &&
5303 DAG.SignBitIsZero(Op: N0) && DAG.SignBitIsZero(Op: N1))
5304 return DAG.getNode(Opcode: ISD::ABDU, DL, VT, N1, N2: N0);
5305
5306 return SDValue();
5307}
5308
5309/// Perform optimizations common to nodes that compute two values. LoOp and HiOp
5310/// give the opcodes for the two computations that are being performed. Return
5311/// true if a simplification was made.
5312SDValue DAGCombiner::SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
5313 unsigned HiOp) {
5314 // If the high half is not needed, just compute the low half.
5315 bool HiExists = N->hasAnyUseOfValue(Value: 1);
5316 if (!HiExists && (!LegalOperations ||
5317 TLI.isOperationLegalOrCustom(Op: LoOp, VT: N->getValueType(ResNo: 0)))) {
5318 SDValue Res = DAG.getNode(Opcode: LoOp, DL: SDLoc(N), VT: N->getValueType(ResNo: 0), Ops: N->ops());
5319 return CombineTo(N, Res0: Res, Res1: Res);
5320 }
5321
5322 // If the low half is not needed, just compute the high half.
5323 bool LoExists = N->hasAnyUseOfValue(Value: 0);
5324 if (!LoExists && (!LegalOperations ||
5325 TLI.isOperationLegalOrCustom(Op: HiOp, VT: N->getValueType(ResNo: 1)))) {
5326 SDValue Res = DAG.getNode(Opcode: HiOp, DL: SDLoc(N), VT: N->getValueType(ResNo: 1), Ops: N->ops());
5327 return CombineTo(N, Res0: Res, Res1: Res);
5328 }
5329
5330 // If both halves are used, return as it is.
5331 if (LoExists && HiExists)
5332 return SDValue();
5333
5334 // If the two computed results can be simplified separately, separate them.
5335 if (LoExists) {
5336 SDValue Lo = DAG.getNode(Opcode: LoOp, DL: SDLoc(N), VT: N->getValueType(ResNo: 0), Ops: N->ops());
5337 AddToWorklist(N: Lo.getNode());
5338 SDValue LoOpt = combine(N: Lo.getNode());
5339 if (LoOpt.getNode() && LoOpt.getNode() != Lo.getNode() &&
5340 (!LegalOperations ||
5341 TLI.isOperationLegalOrCustom(Op: LoOpt.getOpcode(), VT: LoOpt.getValueType())))
5342 return CombineTo(N, Res0: LoOpt, Res1: LoOpt);
5343 }
5344
5345 if (HiExists) {
5346 SDValue Hi = DAG.getNode(Opcode: HiOp, DL: SDLoc(N), VT: N->getValueType(ResNo: 1), Ops: N->ops());
5347 AddToWorklist(N: Hi.getNode());
5348 SDValue HiOpt = combine(N: Hi.getNode());
5349 if (HiOpt.getNode() && HiOpt != Hi &&
5350 (!LegalOperations ||
5351 TLI.isOperationLegalOrCustom(Op: HiOpt.getOpcode(), VT: HiOpt.getValueType())))
5352 return CombineTo(N, Res0: HiOpt, Res1: HiOpt);
5353 }
5354
5355 return SDValue();
5356}
5357
5358SDValue DAGCombiner::visitSMUL_LOHI(SDNode *N) {
5359 if (SDValue Res = SimplifyNodeWithTwoResults(N, LoOp: ISD::MUL, HiOp: ISD::MULHS))
5360 return Res;
5361
5362 SDValue N0 = N->getOperand(Num: 0);
5363 SDValue N1 = N->getOperand(Num: 1);
5364 EVT VT = N->getValueType(ResNo: 0);
5365 SDLoc DL(N);
5366
5367 // Constant fold.
5368 if (isa<ConstantSDNode>(Val: N0) && isa<ConstantSDNode>(Val: N1))
5369 return DAG.getNode(Opcode: ISD::SMUL_LOHI, DL, VTList: N->getVTList(), N1: N0, N2: N1);
5370
5371 // canonicalize constant to RHS (vector doesn't have to splat)
5372 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5373 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5374 return DAG.getNode(Opcode: ISD::SMUL_LOHI, DL, VTList: N->getVTList(), N1, N2: N0);
5375
5376 // If the type is twice as wide is legal, transform the mulhu to a wider
5377 // multiply plus a shift.
5378 if (VT.isSimple() && !VT.isVector()) {
5379 MVT Simple = VT.getSimpleVT();
5380 unsigned SimpleSize = Simple.getSizeInBits();
5381 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: SimpleSize*2);
5382 if (TLI.isOperationLegal(Op: ISD::MUL, VT: NewVT)) {
5383 SDValue Lo = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT: NewVT, Operand: N0);
5384 SDValue Hi = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT: NewVT, Operand: N1);
5385 Lo = DAG.getNode(Opcode: ISD::MUL, DL, VT: NewVT, N1: Lo, N2: Hi);
5386 // Compute the high part as N1.
5387 Hi = DAG.getNode(Opcode: ISD::SRL, DL, VT: NewVT, N1: Lo,
5388 N2: DAG.getShiftAmountConstant(Val: SimpleSize, VT: NewVT, DL));
5389 Hi = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: Hi);
5390 // Compute the low part as N0.
5391 Lo = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: Lo);
5392 return CombineTo(N, Res0: Lo, Res1: Hi);
5393 }
5394 }
5395
5396 return SDValue();
5397}
5398
5399SDValue DAGCombiner::visitUMUL_LOHI(SDNode *N) {
5400 if (SDValue Res = SimplifyNodeWithTwoResults(N, LoOp: ISD::MUL, HiOp: ISD::MULHU))
5401 return Res;
5402
5403 SDValue N0 = N->getOperand(Num: 0);
5404 SDValue N1 = N->getOperand(Num: 1);
5405 EVT VT = N->getValueType(ResNo: 0);
5406 SDLoc DL(N);
5407
5408 // Constant fold.
5409 if (isa<ConstantSDNode>(Val: N0) && isa<ConstantSDNode>(Val: N1))
5410 return DAG.getNode(Opcode: ISD::UMUL_LOHI, DL, VTList: N->getVTList(), N1: N0, N2: N1);
5411
5412 // canonicalize constant to RHS (vector doesn't have to splat)
5413 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5414 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5415 return DAG.getNode(Opcode: ISD::UMUL_LOHI, DL, VTList: N->getVTList(), N1, N2: N0);
5416
5417 // (umul_lohi N0, 0) -> (0, 0)
5418 if (isNullConstant(V: N1)) {
5419 SDValue Zero = DAG.getConstant(Val: 0, DL, VT);
5420 return CombineTo(N, Res0: Zero, Res1: Zero);
5421 }
5422
5423 // (umul_lohi N0, 1) -> (N0, 0)
5424 if (isOneConstant(V: N1)) {
5425 SDValue Zero = DAG.getConstant(Val: 0, DL, VT);
5426 return CombineTo(N, Res0: N0, Res1: Zero);
5427 }
5428
5429 // If the type is twice as wide is legal, transform the mulhu to a wider
5430 // multiply plus a shift.
5431 if (VT.isSimple() && !VT.isVector()) {
5432 MVT Simple = VT.getSimpleVT();
5433 unsigned SimpleSize = Simple.getSizeInBits();
5434 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: SimpleSize*2);
5435 if (TLI.isOperationLegal(Op: ISD::MUL, VT: NewVT)) {
5436 SDValue Lo = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: NewVT, Operand: N0);
5437 SDValue Hi = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: NewVT, Operand: N1);
5438 Lo = DAG.getNode(Opcode: ISD::MUL, DL, VT: NewVT, N1: Lo, N2: Hi);
5439 // Compute the high part as N1.
5440 Hi = DAG.getNode(Opcode: ISD::SRL, DL, VT: NewVT, N1: Lo,
5441 N2: DAG.getShiftAmountConstant(Val: SimpleSize, VT: NewVT, DL));
5442 Hi = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: Hi);
5443 // Compute the low part as N0.
5444 Lo = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: Lo);
5445 return CombineTo(N, Res0: Lo, Res1: Hi);
5446 }
5447 }
5448
5449 return SDValue();
5450}
5451
5452SDValue DAGCombiner::visitMULO(SDNode *N) {
5453 SDValue N0 = N->getOperand(Num: 0);
5454 SDValue N1 = N->getOperand(Num: 1);
5455 EVT VT = N0.getValueType();
5456 bool IsSigned = (ISD::SMULO == N->getOpcode());
5457
5458 EVT CarryVT = N->getValueType(ResNo: 1);
5459 SDLoc DL(N);
5460
5461 ConstantSDNode *N0C = isConstOrConstSplat(N: N0);
5462 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
5463
5464 // fold operation with constant operands.
5465 // TODO: Move this to FoldConstantArithmetic when it supports nodes with
5466 // multiple results.
5467 if (N0C && N1C) {
5468 bool Overflow;
5469 APInt Result =
5470 IsSigned ? N0C->getAPIntValue().smul_ov(RHS: N1C->getAPIntValue(), Overflow)
5471 : N0C->getAPIntValue().umul_ov(RHS: N1C->getAPIntValue(), Overflow);
5472 return CombineTo(N, Res0: DAG.getConstant(Val: Result, DL, VT),
5473 Res1: DAG.getBoolConstant(V: Overflow, DL, VT: CarryVT, OpVT: CarryVT));
5474 }
5475
5476 // canonicalize constant to RHS.
5477 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5478 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5479 return DAG.getNode(Opcode: N->getOpcode(), DL, VTList: N->getVTList(), N1, N2: N0);
5480
5481 // fold (mulo x, 0) -> 0 + no carry out
5482 if (isNullOrNullSplat(V: N1))
5483 return CombineTo(N, Res0: DAG.getConstant(Val: 0, DL, VT),
5484 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
5485
5486 // (mulo x, 2) -> (addo x, x)
5487 // FIXME: This needs a freeze.
5488 if (N1C && N1C->getAPIntValue() == 2 &&
5489 (!IsSigned || VT.getScalarSizeInBits() > 2))
5490 return DAG.getNode(Opcode: IsSigned ? ISD::SADDO : ISD::UADDO, DL,
5491 VTList: N->getVTList(), N1: N0, N2: N0);
5492
5493 // A 1 bit SMULO overflows if both inputs are 1.
5494 if (IsSigned && VT.getScalarSizeInBits() == 1) {
5495 SDValue And = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0, N2: N1);
5496 SDValue Cmp = DAG.getSetCC(DL, VT: CarryVT, LHS: And,
5497 RHS: DAG.getConstant(Val: 0, DL, VT), Cond: ISD::SETNE);
5498 return CombineTo(N, Res0: And, Res1: Cmp);
5499 }
5500
5501 // If it cannot overflow, transform into a mul.
5502 if (DAG.willNotOverflowMul(IsSigned, N0, N1))
5503 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: N0, N2: N1),
5504 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
5505 return SDValue();
5506}
5507
5508// Function to calculate whether the Min/Max pair of SDNodes (potentially
5509// swapped around) make a signed saturate pattern, clamping to between a signed
5510// saturate of -2^(BW-1) and 2^(BW-1)-1, or an unsigned saturate of 0 and 2^BW.
5511// Returns the node being clamped and the bitwidth of the clamp in BW. Should
5512// work with both SMIN/SMAX nodes and setcc/select combo. The operands are the
5513// same as SimplifySelectCC. N0<N1 ? N2 : N3.
5514static SDValue isSaturatingMinMax(SDValue N0, SDValue N1, SDValue N2,
5515 SDValue N3, ISD::CondCode CC, unsigned &BW,
5516 bool &Unsigned, SelectionDAG &DAG) {
5517 auto isSignedMinMax = [&](SDValue N0, SDValue N1, SDValue N2, SDValue N3,
5518 ISD::CondCode CC) {
5519 // The compare and select operand should be the same or the select operands
5520 // should be truncated versions of the comparison.
5521 if (N0 != N2 && (N2.getOpcode() != ISD::TRUNCATE || N0 != N2.getOperand(i: 0)))
5522 return 0;
5523 // The constants need to be the same or a truncated version of each other.
5524 ConstantSDNode *N1C = isConstOrConstSplat(N: peekThroughTruncates(V: N1));
5525 ConstantSDNode *N3C = isConstOrConstSplat(N: peekThroughTruncates(V: N3));
5526 if (!N1C || !N3C)
5527 return 0;
5528 const APInt &C1 = N1C->getAPIntValue().trunc(width: N1.getScalarValueSizeInBits());
5529 const APInt &C2 = N3C->getAPIntValue().trunc(width: N3.getScalarValueSizeInBits());
5530 if (C1.getBitWidth() < C2.getBitWidth() || C1 != C2.sext(width: C1.getBitWidth()))
5531 return 0;
5532 return CC == ISD::SETLT ? ISD::SMIN : (CC == ISD::SETGT ? ISD::SMAX : 0);
5533 };
5534
5535 // Check the initial value is a SMIN/SMAX equivalent.
5536 unsigned Opcode0 = isSignedMinMax(N0, N1, N2, N3, CC);
5537 if (!Opcode0)
5538 return SDValue();
5539
5540 // We could only need one range check, if the fptosi could never produce
5541 // the upper value.
5542 if (N0.getOpcode() == ISD::FP_TO_SINT && Opcode0 == ISD::SMAX) {
5543 if (isNullOrNullSplat(V: N3)) {
5544 EVT IntVT = N0.getValueType().getScalarType();
5545 EVT FPVT = N0.getOperand(i: 0).getValueType().getScalarType();
5546 if (FPVT.isSimple()) {
5547 Type *InputTy = FPVT.getTypeForEVT(Context&: *DAG.getContext());
5548 const fltSemantics &Semantics = InputTy->getFltSemantics();
5549 uint32_t MinBitWidth =
5550 APFloatBase::semanticsIntSizeInBits(Semantics, /*isSigned*/ true);
5551 if (IntVT.getSizeInBits() >= MinBitWidth) {
5552 Unsigned = true;
5553 BW = PowerOf2Ceil(A: MinBitWidth);
5554 return N0;
5555 }
5556 }
5557 }
5558 }
5559
5560 SDValue N00, N01, N02, N03;
5561 ISD::CondCode N0CC;
5562 switch (N0.getOpcode()) {
5563 case ISD::SMIN:
5564 case ISD::SMAX:
5565 N00 = N02 = N0.getOperand(i: 0);
5566 N01 = N03 = N0.getOperand(i: 1);
5567 N0CC = N0.getOpcode() == ISD::SMIN ? ISD::SETLT : ISD::SETGT;
5568 break;
5569 case ISD::SELECT_CC:
5570 N00 = N0.getOperand(i: 0);
5571 N01 = N0.getOperand(i: 1);
5572 N02 = N0.getOperand(i: 2);
5573 N03 = N0.getOperand(i: 3);
5574 N0CC = cast<CondCodeSDNode>(Val: N0.getOperand(i: 4))->get();
5575 break;
5576 case ISD::SELECT:
5577 case ISD::VSELECT:
5578 if (N0.getOperand(i: 0).getOpcode() != ISD::SETCC)
5579 return SDValue();
5580 N00 = N0.getOperand(i: 0).getOperand(i: 0);
5581 N01 = N0.getOperand(i: 0).getOperand(i: 1);
5582 N02 = N0.getOperand(i: 1);
5583 N03 = N0.getOperand(i: 2);
5584 N0CC = cast<CondCodeSDNode>(Val: N0.getOperand(i: 0).getOperand(i: 2))->get();
5585 break;
5586 default:
5587 return SDValue();
5588 }
5589
5590 unsigned Opcode1 = isSignedMinMax(N00, N01, N02, N03, N0CC);
5591 if (!Opcode1 || Opcode0 == Opcode1)
5592 return SDValue();
5593
5594 ConstantSDNode *MinCOp = isConstOrConstSplat(N: Opcode0 == ISD::SMIN ? N1 : N01);
5595 ConstantSDNode *MaxCOp = isConstOrConstSplat(N: Opcode0 == ISD::SMIN ? N01 : N1);
5596 if (!MinCOp || !MaxCOp || MinCOp->getValueType(ResNo: 0) != MaxCOp->getValueType(ResNo: 0))
5597 return SDValue();
5598
5599 const APInt &MinC = MinCOp->getAPIntValue();
5600 const APInt &MaxC = MaxCOp->getAPIntValue();
5601 APInt MinCPlus1 = MinC + 1;
5602 if (-MaxC == MinCPlus1 && MinCPlus1.isPowerOf2()) {
5603 BW = MinCPlus1.exactLogBase2() + 1;
5604 Unsigned = false;
5605 return N02;
5606 }
5607
5608 if (MaxC == 0 && MinCPlus1.isPowerOf2()) {
5609 BW = MinCPlus1.exactLogBase2();
5610 Unsigned = true;
5611 return N02;
5612 }
5613
5614 return SDValue();
5615}
5616
5617static SDValue PerformMinMaxFpToSatCombine(SDValue N0, SDValue N1, SDValue N2,
5618 SDValue N3, ISD::CondCode CC,
5619 SelectionDAG &DAG) {
5620 unsigned BW;
5621 bool Unsigned;
5622 SDValue Fp = isSaturatingMinMax(N0, N1, N2, N3, CC, BW, Unsigned, DAG);
5623 if (!Fp || Fp.getOpcode() != ISD::FP_TO_SINT)
5624 return SDValue();
5625 EVT FPVT = Fp.getOperand(i: 0).getValueType();
5626 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: BW);
5627 if (FPVT.isVector())
5628 NewVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: NewVT,
5629 EC: FPVT.getVectorElementCount());
5630 unsigned NewOpc = Unsigned ? ISD::FP_TO_UINT_SAT : ISD::FP_TO_SINT_SAT;
5631 if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(Op: NewOpc, FPVT, VT: NewVT))
5632 return SDValue();
5633 SDLoc DL(Fp);
5634 SDValue Sat = DAG.getNode(Opcode: NewOpc, DL, VT: NewVT, N1: Fp.getOperand(i: 0),
5635 N2: DAG.getValueType(NewVT.getScalarType()));
5636 return DAG.getExtOrTrunc(IsSigned: !Unsigned, Op: Sat, DL, VT: N2->getValueType(ResNo: 0));
5637}
5638
5639static SDValue PerformUMinFpToSatCombine(SDValue N0, SDValue N1, SDValue N2,
5640 SDValue N3, ISD::CondCode CC,
5641 SelectionDAG &DAG) {
5642 // We are looking for UMIN(FPTOUI(X), (2^n)-1), which may have come via a
5643 // select/vselect/select_cc. The two operands pairs for the select (N2/N3) may
5644 // be truncated versions of the setcc (N0/N1).
5645 if ((N0 != N2 &&
5646 (N2.getOpcode() != ISD::TRUNCATE || N0 != N2.getOperand(i: 0))) ||
5647 N0.getOpcode() != ISD::FP_TO_UINT || CC != ISD::SETULT)
5648 return SDValue();
5649 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
5650 ConstantSDNode *N3C = isConstOrConstSplat(N: N3);
5651 if (!N1C || !N3C)
5652 return SDValue();
5653 const APInt &C1 = N1C->getAPIntValue();
5654 const APInt &C3 = N3C->getAPIntValue();
5655 if (!(C1 + 1).isPowerOf2() || C1.getBitWidth() < C3.getBitWidth() ||
5656 C1 != C3.zext(width: C1.getBitWidth()))
5657 return SDValue();
5658
5659 unsigned BW = (C1 + 1).exactLogBase2();
5660 EVT FPVT = N0.getOperand(i: 0).getValueType();
5661 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: BW);
5662 if (FPVT.isVector())
5663 NewVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: NewVT,
5664 EC: FPVT.getVectorElementCount());
5665 if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(Op: ISD::FP_TO_UINT_SAT,
5666 FPVT, VT: NewVT))
5667 return SDValue();
5668
5669 SDValue Sat =
5670 DAG.getNode(Opcode: ISD::FP_TO_UINT_SAT, DL: SDLoc(N0), VT: NewVT, N1: N0.getOperand(i: 0),
5671 N2: DAG.getValueType(NewVT.getScalarType()));
5672 return DAG.getZExtOrTrunc(Op: Sat, DL: SDLoc(N0), VT: N3.getValueType());
5673}
5674
5675SDValue DAGCombiner::visitIMINMAX(SDNode *N) {
5676 SDValue N0 = N->getOperand(Num: 0);
5677 SDValue N1 = N->getOperand(Num: 1);
5678 EVT VT = N0.getValueType();
5679 unsigned Opcode = N->getOpcode();
5680 SDLoc DL(N);
5681
5682 // fold operation with constant operands.
5683 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, Ops: {N0, N1}))
5684 return C;
5685
5686 // If the operands are the same, this is a no-op.
5687 if (N0 == N1)
5688 return N0;
5689
5690 // canonicalize constant to RHS
5691 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5692 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5693 return DAG.getNode(Opcode, DL, VT, N1, N2: N0);
5694
5695 // fold vector ops
5696 if (VT.isVector())
5697 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5698 return FoldedVOp;
5699
5700 // reassociate minmax
5701 if (SDValue RMINMAX = reassociateOps(Opc: Opcode, DL, N0, N1, Flags: N->getFlags()))
5702 return RMINMAX;
5703
5704 // Is sign bits are zero, flip between UMIN/UMAX and SMIN/SMAX.
5705 // Only do this if:
5706 // 1. The current op isn't legal and the flipped is.
5707 // 2. The saturation pattern is broken by canonicalization in InstCombine.
5708 bool IsOpIllegal = !TLI.isOperationLegal(Op: Opcode, VT);
5709 bool IsSatBroken = Opcode == ISD::UMIN && N0.getOpcode() == ISD::SMAX;
5710 if ((IsSatBroken || IsOpIllegal) && (N0.isUndef() || DAG.SignBitIsZero(Op: N0)) &&
5711 (N1.isUndef() || DAG.SignBitIsZero(Op: N1))) {
5712 unsigned AltOpcode;
5713 switch (Opcode) {
5714 case ISD::SMIN: AltOpcode = ISD::UMIN; break;
5715 case ISD::SMAX: AltOpcode = ISD::UMAX; break;
5716 case ISD::UMIN: AltOpcode = ISD::SMIN; break;
5717 case ISD::UMAX: AltOpcode = ISD::SMAX; break;
5718 default: llvm_unreachable("Unknown MINMAX opcode");
5719 }
5720 if ((IsSatBroken && IsOpIllegal) || TLI.isOperationLegal(Op: AltOpcode, VT))
5721 return DAG.getNode(Opcode: AltOpcode, DL, VT, N1: N0, N2: N1);
5722 }
5723
5724 if (Opcode == ISD::SMIN || Opcode == ISD::SMAX)
5725 if (SDValue S = PerformMinMaxFpToSatCombine(
5726 N0, N1, N2: N0, N3: N1, CC: Opcode == ISD::SMIN ? ISD::SETLT : ISD::SETGT, DAG))
5727 return S;
5728 if (Opcode == ISD::UMIN)
5729 if (SDValue S = PerformUMinFpToSatCombine(N0, N1, N2: N0, N3: N1, CC: ISD::SETULT, DAG))
5730 return S;
5731
5732 // Fold min/max(vecreduce(x), vecreduce(y)) -> vecreduce(min/max(x, y))
5733 auto ReductionOpcode = [](unsigned Opcode) {
5734 switch (Opcode) {
5735 case ISD::SMIN:
5736 return ISD::VECREDUCE_SMIN;
5737 case ISD::SMAX:
5738 return ISD::VECREDUCE_SMAX;
5739 case ISD::UMIN:
5740 return ISD::VECREDUCE_UMIN;
5741 case ISD::UMAX:
5742 return ISD::VECREDUCE_UMAX;
5743 default:
5744 llvm_unreachable("Unexpected opcode");
5745 }
5746 };
5747 if (SDValue SD = reassociateReduction(RedOpc: ReductionOpcode(Opcode), Opc: Opcode,
5748 DL: SDLoc(N), VT, N0, N1))
5749 return SD;
5750
5751 // Simplify the operands using demanded-bits information.
5752 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
5753 return SDValue(N, 0);
5754
5755 return SDValue();
5756}
5757
5758/// If this is a bitwise logic instruction and both operands have the same
5759/// opcode, try to sink the other opcode after the logic instruction.
5760SDValue DAGCombiner::hoistLogicOpWithSameOpcodeHands(SDNode *N) {
5761 SDValue N0 = N->getOperand(Num: 0), N1 = N->getOperand(Num: 1);
5762 EVT VT = N0.getValueType();
5763 unsigned LogicOpcode = N->getOpcode();
5764 unsigned HandOpcode = N0.getOpcode();
5765 assert(ISD::isBitwiseLogicOp(LogicOpcode) && "Expected logic opcode");
5766 assert(HandOpcode == N1.getOpcode() && "Bad input!");
5767
5768 // Bail early if none of these transforms apply.
5769 if (N0.getNumOperands() == 0)
5770 return SDValue();
5771
5772 // FIXME: We should check number of uses of the operands to not increase
5773 // the instruction count for all transforms.
5774
5775 // Handle size-changing casts (or sign_extend_inreg).
5776 SDValue X = N0.getOperand(i: 0);
5777 SDValue Y = N1.getOperand(i: 0);
5778 EVT XVT = X.getValueType();
5779 SDLoc DL(N);
5780 if (ISD::isExtOpcode(Opcode: HandOpcode) || ISD::isExtVecInRegOpcode(Opcode: HandOpcode) ||
5781 (HandOpcode == ISD::SIGN_EXTEND_INREG &&
5782 N0.getOperand(i: 1) == N1.getOperand(i: 1))) {
5783 // If both operands have other uses, this transform would create extra
5784 // instructions without eliminating anything.
5785 if (!N0.hasOneUse() && !N1.hasOneUse())
5786 return SDValue();
5787 // We need matching integer source types.
5788 if (XVT != Y.getValueType())
5789 return SDValue();
5790 // Don't create an illegal op during or after legalization. Don't ever
5791 // create an unsupported vector op.
5792 if ((VT.isVector() || LegalOperations) &&
5793 !TLI.isOperationLegalOrCustom(Op: LogicOpcode, VT: XVT))
5794 return SDValue();
5795 // Avoid infinite looping with PromoteIntBinOp.
5796 // TODO: Should we apply desirable/legal constraints to all opcodes?
5797 if ((HandOpcode == ISD::ANY_EXTEND ||
5798 HandOpcode == ISD::ANY_EXTEND_VECTOR_INREG) &&
5799 LegalTypes && !TLI.isTypeDesirableForOp(LogicOpcode, VT: XVT))
5800 return SDValue();
5801 // logic_op (hand_op X), (hand_op Y) --> hand_op (logic_op X, Y)
5802 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT: XVT, N1: X, N2: Y);
5803 if (HandOpcode == ISD::SIGN_EXTEND_INREG)
5804 return DAG.getNode(Opcode: HandOpcode, DL, VT, N1: Logic, N2: N0.getOperand(i: 1));
5805 return DAG.getNode(Opcode: HandOpcode, DL, VT, Operand: Logic);
5806 }
5807
5808 // logic_op (truncate x), (truncate y) --> truncate (logic_op x, y)
5809 if (HandOpcode == ISD::TRUNCATE) {
5810 // If both operands have other uses, this transform would create extra
5811 // instructions without eliminating anything.
5812 if (!N0.hasOneUse() && !N1.hasOneUse())
5813 return SDValue();
5814 // We need matching source types.
5815 if (XVT != Y.getValueType())
5816 return SDValue();
5817 // Don't create an illegal op during or after legalization.
5818 if (LegalOperations && !TLI.isOperationLegal(Op: LogicOpcode, VT: XVT))
5819 return SDValue();
5820 // Be extra careful sinking truncate. If it's free, there's no benefit in
5821 // widening a binop. Also, don't create a logic op on an illegal type.
5822 if (TLI.isZExtFree(FromTy: VT, ToTy: XVT) && TLI.isTruncateFree(FromVT: XVT, ToVT: VT))
5823 return SDValue();
5824 if (!TLI.isTypeLegal(VT: XVT))
5825 return SDValue();
5826 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT: XVT, N1: X, N2: Y);
5827 return DAG.getNode(Opcode: HandOpcode, DL, VT, Operand: Logic);
5828 }
5829
5830 // For binops SHL/SRL/SRA/AND:
5831 // logic_op (OP x, z), (OP y, z) --> OP (logic_op x, y), z
5832 if ((HandOpcode == ISD::SHL || HandOpcode == ISD::SRL ||
5833 HandOpcode == ISD::SRA || HandOpcode == ISD::AND) &&
5834 N0.getOperand(i: 1) == N1.getOperand(i: 1)) {
5835 // If either operand has other uses, this transform is not an improvement.
5836 if (!N0.hasOneUse() || !N1.hasOneUse())
5837 return SDValue();
5838 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT: XVT, N1: X, N2: Y);
5839 return DAG.getNode(Opcode: HandOpcode, DL, VT, N1: Logic, N2: N0.getOperand(i: 1));
5840 }
5841
5842 // Unary ops: logic_op (bswap x), (bswap y) --> bswap (logic_op x, y)
5843 if (HandOpcode == ISD::BSWAP) {
5844 // If either operand has other uses, this transform is not an improvement.
5845 if (!N0.hasOneUse() || !N1.hasOneUse())
5846 return SDValue();
5847 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT: XVT, N1: X, N2: Y);
5848 return DAG.getNode(Opcode: HandOpcode, DL, VT, Operand: Logic);
5849 }
5850
5851 // For funnel shifts FSHL/FSHR:
5852 // logic_op (OP x, x1, s), (OP y, y1, s) -->
5853 // --> OP (logic_op x, y), (logic_op, x1, y1), s
5854 if ((HandOpcode == ISD::FSHL || HandOpcode == ISD::FSHR) &&
5855 N0.getOperand(i: 2) == N1.getOperand(i: 2)) {
5856 if (!N0.hasOneUse() || !N1.hasOneUse())
5857 return SDValue();
5858 SDValue X1 = N0.getOperand(i: 1);
5859 SDValue Y1 = N1.getOperand(i: 1);
5860 SDValue S = N0.getOperand(i: 2);
5861 SDValue Logic0 = DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: X, N2: Y);
5862 SDValue Logic1 = DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: X1, N2: Y1);
5863 return DAG.getNode(Opcode: HandOpcode, DL, VT, N1: Logic0, N2: Logic1, N3: S);
5864 }
5865
5866 // Simplify xor/and/or (bitcast(A), bitcast(B)) -> bitcast(op (A,B))
5867 // Only perform this optimization up until type legalization, before
5868 // LegalizeVectorOprs. LegalizeVectorOprs promotes vector operations by
5869 // adding bitcasts. For example (xor v4i32) is promoted to (v2i64), and
5870 // we don't want to undo this promotion.
5871 // We also handle SCALAR_TO_VECTOR because xor/or/and operations are cheaper
5872 // on scalars.
5873 if ((HandOpcode == ISD::BITCAST || HandOpcode == ISD::SCALAR_TO_VECTOR) &&
5874 Level <= AfterLegalizeTypes) {
5875 // Input types must be integer and the same.
5876 if (XVT.isInteger() && XVT == Y.getValueType() &&
5877 !(VT.isVector() && TLI.isTypeLegal(VT) &&
5878 !XVT.isVector() && !TLI.isTypeLegal(VT: XVT))) {
5879 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT: XVT, N1: X, N2: Y);
5880 return DAG.getNode(Opcode: HandOpcode, DL, VT, Operand: Logic);
5881 }
5882 }
5883
5884 // Xor/and/or are indifferent to the swizzle operation (shuffle of one value).
5885 // Simplify xor/and/or (shuff(A), shuff(B)) -> shuff(op (A,B))
5886 // If both shuffles use the same mask, and both shuffle within a single
5887 // vector, then it is worthwhile to move the swizzle after the operation.
5888 // The type-legalizer generates this pattern when loading illegal
5889 // vector types from memory. In many cases this allows additional shuffle
5890 // optimizations.
5891 // There are other cases where moving the shuffle after the xor/and/or
5892 // is profitable even if shuffles don't perform a swizzle.
5893 // If both shuffles use the same mask, and both shuffles have the same first
5894 // or second operand, then it might still be profitable to move the shuffle
5895 // after the xor/and/or operation.
5896 if (HandOpcode == ISD::VECTOR_SHUFFLE && Level < AfterLegalizeDAG) {
5897 auto *SVN0 = cast<ShuffleVectorSDNode>(Val&: N0);
5898 auto *SVN1 = cast<ShuffleVectorSDNode>(Val&: N1);
5899 assert(X.getValueType() == Y.getValueType() &&
5900 "Inputs to shuffles are not the same type");
5901
5902 // Check that both shuffles use the same mask. The masks are known to be of
5903 // the same length because the result vector type is the same.
5904 // Check also that shuffles have only one use to avoid introducing extra
5905 // instructions.
5906 if (!SVN0->hasOneUse() || !SVN1->hasOneUse() ||
5907 !SVN0->getMask().equals(RHS: SVN1->getMask()))
5908 return SDValue();
5909
5910 // Don't try to fold this node if it requires introducing a
5911 // build vector of all zeros that might be illegal at this stage.
5912 SDValue ShOp = N0.getOperand(i: 1);
5913 if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
5914 ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
5915
5916 // (logic_op (shuf (A, C), shuf (B, C))) --> shuf (logic_op (A, B), C)
5917 if (N0.getOperand(i: 1) == N1.getOperand(i: 1) && ShOp.getNode()) {
5918 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT,
5919 N1: N0.getOperand(i: 0), N2: N1.getOperand(i: 0));
5920 return DAG.getVectorShuffle(VT, dl: DL, N1: Logic, N2: ShOp, Mask: SVN0->getMask());
5921 }
5922
5923 // Don't try to fold this node if it requires introducing a
5924 // build vector of all zeros that might be illegal at this stage.
5925 ShOp = N0.getOperand(i: 0);
5926 if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
5927 ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
5928
5929 // (logic_op (shuf (C, A), shuf (C, B))) --> shuf (C, logic_op (A, B))
5930 if (N0.getOperand(i: 0) == N1.getOperand(i: 0) && ShOp.getNode()) {
5931 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: N0.getOperand(i: 1),
5932 N2: N1.getOperand(i: 1));
5933 return DAG.getVectorShuffle(VT, dl: DL, N1: ShOp, N2: Logic, Mask: SVN0->getMask());
5934 }
5935 }
5936
5937 return SDValue();
5938}
5939
5940/// Try to make (and/or setcc (LL, LR), setcc (RL, RR)) more efficient.
5941SDValue DAGCombiner::foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
5942 const SDLoc &DL) {
5943 SDValue LL, LR, RL, RR, N0CC, N1CC;
5944 if (!isSetCCEquivalent(N: N0, LHS&: LL, RHS&: LR, CC&: N0CC) ||
5945 !isSetCCEquivalent(N: N1, LHS&: RL, RHS&: RR, CC&: N1CC))
5946 return SDValue();
5947
5948 assert(N0.getValueType() == N1.getValueType() &&
5949 "Unexpected operand types for bitwise logic op");
5950 assert(LL.getValueType() == LR.getValueType() &&
5951 RL.getValueType() == RR.getValueType() &&
5952 "Unexpected operand types for setcc");
5953
5954 // If we're here post-legalization or the logic op type is not i1, the logic
5955 // op type must match a setcc result type. Also, all folds require new
5956 // operations on the left and right operands, so those types must match.
5957 EVT VT = N0.getValueType();
5958 EVT OpVT = LL.getValueType();
5959 if (LegalOperations || VT.getScalarType() != MVT::i1)
5960 if (VT != getSetCCResultType(VT: OpVT))
5961 return SDValue();
5962 if (OpVT != RL.getValueType())
5963 return SDValue();
5964
5965 ISD::CondCode CC0 = cast<CondCodeSDNode>(Val&: N0CC)->get();
5966 ISD::CondCode CC1 = cast<CondCodeSDNode>(Val&: N1CC)->get();
5967 bool IsInteger = OpVT.isInteger();
5968 if (LR == RR && CC0 == CC1 && IsInteger) {
5969 bool IsZero = isNullOrNullSplat(V: LR);
5970 bool IsNeg1 = isAllOnesOrAllOnesSplat(V: LR);
5971
5972 // All bits clear?
5973 bool AndEqZero = IsAnd && CC1 == ISD::SETEQ && IsZero;
5974 // All sign bits clear?
5975 bool AndGtNeg1 = IsAnd && CC1 == ISD::SETGT && IsNeg1;
5976 // Any bits set?
5977 bool OrNeZero = !IsAnd && CC1 == ISD::SETNE && IsZero;
5978 // Any sign bits set?
5979 bool OrLtZero = !IsAnd && CC1 == ISD::SETLT && IsZero;
5980
5981 // (and (seteq X, 0), (seteq Y, 0)) --> (seteq (or X, Y), 0)
5982 // (and (setgt X, -1), (setgt Y, -1)) --> (setgt (or X, Y), -1)
5983 // (or (setne X, 0), (setne Y, 0)) --> (setne (or X, Y), 0)
5984 // (or (setlt X, 0), (setlt Y, 0)) --> (setlt (or X, Y), 0)
5985 if (AndEqZero || AndGtNeg1 || OrNeZero || OrLtZero) {
5986 SDValue Or = DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N0), VT: OpVT, N1: LL, N2: RL);
5987 AddToWorklist(N: Or.getNode());
5988 return DAG.getSetCC(DL, VT, LHS: Or, RHS: LR, Cond: CC1);
5989 }
5990
5991 // All bits set?
5992 bool AndEqNeg1 = IsAnd && CC1 == ISD::SETEQ && IsNeg1;
5993 // All sign bits set?
5994 bool AndLtZero = IsAnd && CC1 == ISD::SETLT && IsZero;
5995 // Any bits clear?
5996 bool OrNeNeg1 = !IsAnd && CC1 == ISD::SETNE && IsNeg1;
5997 // Any sign bits clear?
5998 bool OrGtNeg1 = !IsAnd && CC1 == ISD::SETGT && IsNeg1;
5999
6000 // (and (seteq X, -1), (seteq Y, -1)) --> (seteq (and X, Y), -1)
6001 // (and (setlt X, 0), (setlt Y, 0)) --> (setlt (and X, Y), 0)
6002 // (or (setne X, -1), (setne Y, -1)) --> (setne (and X, Y), -1)
6003 // (or (setgt X, -1), (setgt Y -1)) --> (setgt (and X, Y), -1)
6004 if (AndEqNeg1 || AndLtZero || OrNeNeg1 || OrGtNeg1) {
6005 SDValue And = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(N0), VT: OpVT, N1: LL, N2: RL);
6006 AddToWorklist(N: And.getNode());
6007 return DAG.getSetCC(DL, VT, LHS: And, RHS: LR, Cond: CC1);
6008 }
6009 }
6010
6011 // TODO: What is the 'or' equivalent of this fold?
6012 // (and (setne X, 0), (setne X, -1)) --> (setuge (add X, 1), 2)
6013 if (IsAnd && LL == RL && CC0 == CC1 && OpVT.getScalarSizeInBits() > 1 &&
6014 IsInteger && CC0 == ISD::SETNE &&
6015 ((isNullConstant(V: LR) && isAllOnesConstant(V: RR)) ||
6016 (isAllOnesConstant(V: LR) && isNullConstant(V: RR)))) {
6017 SDValue One = DAG.getConstant(Val: 1, DL, VT: OpVT);
6018 SDValue Two = DAG.getConstant(Val: 2, DL, VT: OpVT);
6019 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL: SDLoc(N0), VT: OpVT, N1: LL, N2: One);
6020 AddToWorklist(N: Add.getNode());
6021 return DAG.getSetCC(DL, VT, LHS: Add, RHS: Two, Cond: ISD::SETUGE);
6022 }
6023
6024 // Try more general transforms if the predicates match and the only user of
6025 // the compares is the 'and' or 'or'.
6026 if (IsInteger && TLI.convertSetCCLogicToBitwiseLogic(VT: OpVT) && CC0 == CC1 &&
6027 N0.hasOneUse() && N1.hasOneUse()) {
6028 // and (seteq A, B), (seteq C, D) --> seteq (or (xor A, B), (xor C, D)), 0
6029 // or (setne A, B), (setne C, D) --> setne (or (xor A, B), (xor C, D)), 0
6030 if ((IsAnd && CC1 == ISD::SETEQ) || (!IsAnd && CC1 == ISD::SETNE)) {
6031 SDValue XorL = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N0), VT: OpVT, N1: LL, N2: LR);
6032 SDValue XorR = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N1), VT: OpVT, N1: RL, N2: RR);
6033 SDValue Or = DAG.getNode(Opcode: ISD::OR, DL, VT: OpVT, N1: XorL, N2: XorR);
6034 SDValue Zero = DAG.getConstant(Val: 0, DL, VT: OpVT);
6035 return DAG.getSetCC(DL, VT, LHS: Or, RHS: Zero, Cond: CC1);
6036 }
6037
6038 // Turn compare of constants whose difference is 1 bit into add+and+setcc.
6039 if ((IsAnd && CC1 == ISD::SETNE) || (!IsAnd && CC1 == ISD::SETEQ)) {
6040 // Match a shared variable operand and 2 non-opaque constant operands.
6041 auto MatchDiffPow2 = [&](ConstantSDNode *C0, ConstantSDNode *C1) {
6042 // The difference of the constants must be a single bit.
6043 const APInt &CMax =
6044 APIntOps::umax(A: C0->getAPIntValue(), B: C1->getAPIntValue());
6045 const APInt &CMin =
6046 APIntOps::umin(A: C0->getAPIntValue(), B: C1->getAPIntValue());
6047 return !C0->isOpaque() && !C1->isOpaque() && (CMax - CMin).isPowerOf2();
6048 };
6049 if (LL == RL && ISD::matchBinaryPredicate(LHS: LR, RHS: RR, Match: MatchDiffPow2)) {
6050 // and/or (setcc X, CMax, ne), (setcc X, CMin, ne/eq) -->
6051 // setcc ((sub X, CMin), ~(CMax - CMin)), 0, ne/eq
6052 SDValue Max = DAG.getNode(Opcode: ISD::UMAX, DL, VT: OpVT, N1: LR, N2: RR);
6053 SDValue Min = DAG.getNode(Opcode: ISD::UMIN, DL, VT: OpVT, N1: LR, N2: RR);
6054 SDValue Offset = DAG.getNode(Opcode: ISD::SUB, DL, VT: OpVT, N1: LL, N2: Min);
6055 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: OpVT, N1: Max, N2: Min);
6056 SDValue Mask = DAG.getNOT(DL, Val: Diff, VT: OpVT);
6057 SDValue And = DAG.getNode(Opcode: ISD::AND, DL, VT: OpVT, N1: Offset, N2: Mask);
6058 SDValue Zero = DAG.getConstant(Val: 0, DL, VT: OpVT);
6059 return DAG.getSetCC(DL, VT, LHS: And, RHS: Zero, Cond: CC0);
6060 }
6061 }
6062 }
6063
6064 // Canonicalize equivalent operands to LL == RL.
6065 if (LL == RR && LR == RL) {
6066 CC1 = ISD::getSetCCSwappedOperands(Operation: CC1);
6067 std::swap(a&: RL, b&: RR);
6068 }
6069
6070 // (and (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
6071 // (or (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
6072 if (LL == RL && LR == RR) {
6073 ISD::CondCode NewCC = IsAnd ? ISD::getSetCCAndOperation(Op1: CC0, Op2: CC1, Type: OpVT)
6074 : ISD::getSetCCOrOperation(Op1: CC0, Op2: CC1, Type: OpVT);
6075 if (NewCC != ISD::SETCC_INVALID &&
6076 (!LegalOperations ||
6077 (TLI.isCondCodeLegal(CC: NewCC, VT: LL.getSimpleValueType()) &&
6078 TLI.isOperationLegal(Op: ISD::SETCC, VT: OpVT))))
6079 return DAG.getSetCC(DL, VT, LHS: LL, RHS: LR, Cond: NewCC);
6080 }
6081
6082 return SDValue();
6083}
6084
6085static bool arebothOperandsNotSNan(SDValue Operand1, SDValue Operand2,
6086 SelectionDAG &DAG) {
6087 return DAG.isKnownNeverSNaN(Op: Operand2) && DAG.isKnownNeverSNaN(Op: Operand1);
6088}
6089
6090static bool arebothOperandsNotNan(SDValue Operand1, SDValue Operand2,
6091 SelectionDAG &DAG) {
6092 return DAG.isKnownNeverNaN(Op: Operand2) && DAG.isKnownNeverNaN(Op: Operand1);
6093}
6094
6095static unsigned getMinMaxOpcodeForFP(SDValue Operand1, SDValue Operand2,
6096 ISD::CondCode CC, unsigned OrAndOpcode,
6097 SelectionDAG &DAG,
6098 bool isFMAXNUMFMINNUM_IEEE,
6099 bool isFMAXNUMFMINNUM) {
6100 // The optimization cannot be applied for all the predicates because
6101 // of the way FMINNUM/FMAXNUM and FMINNUM_IEEE/FMAXNUM_IEEE handle
6102 // NaNs. For FMINNUM_IEEE/FMAXNUM_IEEE, the optimization cannot be
6103 // applied at all if one of the operands is a signaling NaN.
6104
6105 // It is safe to use FMINNUM_IEEE/FMAXNUM_IEEE if all the operands
6106 // are non NaN values.
6107 if (((CC == ISD::SETLT || CC == ISD::SETLE) && (OrAndOpcode == ISD::OR)) ||
6108 ((CC == ISD::SETGT || CC == ISD::SETGE) && (OrAndOpcode == ISD::AND)))
6109 return arebothOperandsNotNan(Operand1, Operand2, DAG) &&
6110 isFMAXNUMFMINNUM_IEEE
6111 ? ISD::FMINNUM_IEEE
6112 : ISD::DELETED_NODE;
6113 else if (((CC == ISD::SETGT || CC == ISD::SETGE) &&
6114 (OrAndOpcode == ISD::OR)) ||
6115 ((CC == ISD::SETLT || CC == ISD::SETLE) &&
6116 (OrAndOpcode == ISD::AND)))
6117 return arebothOperandsNotNan(Operand1, Operand2, DAG) &&
6118 isFMAXNUMFMINNUM_IEEE
6119 ? ISD::FMAXNUM_IEEE
6120 : ISD::DELETED_NODE;
6121 // Both FMINNUM/FMAXNUM and FMINNUM_IEEE/FMAXNUM_IEEE handle quiet
6122 // NaNs in the same way. But, FMINNUM/FMAXNUM and FMINNUM_IEEE/
6123 // FMAXNUM_IEEE handle signaling NaNs differently. If we cannot prove
6124 // that there are not any sNaNs, then the optimization is not valid
6125 // for FMINNUM_IEEE/FMAXNUM_IEEE. In the presence of sNaNs, we apply
6126 // the optimization using FMINNUM/FMAXNUM for the following cases. If
6127 // we can prove that we do not have any sNaNs, then we can do the
6128 // optimization using FMINNUM_IEEE/FMAXNUM_IEEE for the following
6129 // cases.
6130 else if (((CC == ISD::SETOLT || CC == ISD::SETOLE) &&
6131 (OrAndOpcode == ISD::OR)) ||
6132 ((CC == ISD::SETUGT || CC == ISD::SETUGE) &&
6133 (OrAndOpcode == ISD::AND)))
6134 return isFMAXNUMFMINNUM ? ISD::FMINNUM
6135 : arebothOperandsNotSNan(Operand1, Operand2, DAG) &&
6136 isFMAXNUMFMINNUM_IEEE
6137 ? ISD::FMINNUM_IEEE
6138 : ISD::DELETED_NODE;
6139 else if (((CC == ISD::SETOGT || CC == ISD::SETOGE) &&
6140 (OrAndOpcode == ISD::OR)) ||
6141 ((CC == ISD::SETULT || CC == ISD::SETULE) &&
6142 (OrAndOpcode == ISD::AND)))
6143 return isFMAXNUMFMINNUM ? ISD::FMAXNUM
6144 : arebothOperandsNotSNan(Operand1, Operand2, DAG) &&
6145 isFMAXNUMFMINNUM_IEEE
6146 ? ISD::FMAXNUM_IEEE
6147 : ISD::DELETED_NODE;
6148 return ISD::DELETED_NODE;
6149}
6150
6151static SDValue foldAndOrOfSETCC(SDNode *LogicOp, SelectionDAG &DAG) {
6152 using AndOrSETCCFoldKind = TargetLowering::AndOrSETCCFoldKind;
6153 assert(
6154 (LogicOp->getOpcode() == ISD::AND || LogicOp->getOpcode() == ISD::OR) &&
6155 "Invalid Op to combine SETCC with");
6156
6157 // TODO: Search past casts/truncates.
6158 SDValue LHS = LogicOp->getOperand(Num: 0);
6159 SDValue RHS = LogicOp->getOperand(Num: 1);
6160 if (LHS->getOpcode() != ISD::SETCC || RHS->getOpcode() != ISD::SETCC ||
6161 !LHS->hasOneUse() || !RHS->hasOneUse())
6162 return SDValue();
6163
6164 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6165 AndOrSETCCFoldKind TargetPreference = TLI.isDesirableToCombineLogicOpOfSETCC(
6166 LogicOp, SETCC0: LHS.getNode(), SETCC1: RHS.getNode());
6167
6168 SDValue LHS0 = LHS->getOperand(Num: 0);
6169 SDValue RHS0 = RHS->getOperand(Num: 0);
6170 SDValue LHS1 = LHS->getOperand(Num: 1);
6171 SDValue RHS1 = RHS->getOperand(Num: 1);
6172 // TODO: We don't actually need a splat here, for vectors we just need the
6173 // invariants to hold for each element.
6174 auto *LHS1C = isConstOrConstSplat(N: LHS1);
6175 auto *RHS1C = isConstOrConstSplat(N: RHS1);
6176 ISD::CondCode CCL = cast<CondCodeSDNode>(Val: LHS.getOperand(i: 2))->get();
6177 ISD::CondCode CCR = cast<CondCodeSDNode>(Val: RHS.getOperand(i: 2))->get();
6178 EVT VT = LogicOp->getValueType(ResNo: 0);
6179 EVT OpVT = LHS0.getValueType();
6180 SDLoc DL(LogicOp);
6181
6182 // Check if the operands of an and/or operation are comparisons and if they
6183 // compare against the same value. Replace the and/or-cmp-cmp sequence with
6184 // min/max cmp sequence. If LHS1 is equal to RHS1, then the or-cmp-cmp
6185 // sequence will be replaced with min-cmp sequence:
6186 // (LHS0 < LHS1) | (RHS0 < RHS1) -> min(LHS0, RHS0) < LHS1
6187 // and and-cmp-cmp will be replaced with max-cmp sequence:
6188 // (LHS0 < LHS1) & (RHS0 < RHS1) -> max(LHS0, RHS0) < LHS1
6189 // The optimization does not work for `==` or `!=` .
6190 // The two comparisons should have either the same predicate or the
6191 // predicate of one of the comparisons is the opposite of the other one.
6192 bool isFMAXNUMFMINNUM_IEEE = TLI.isOperationLegal(Op: ISD::FMAXNUM_IEEE, VT: OpVT) &&
6193 TLI.isOperationLegal(Op: ISD::FMINNUM_IEEE, VT: OpVT);
6194 bool isFMAXNUMFMINNUM = TLI.isOperationLegalOrCustom(Op: ISD::FMAXNUM, VT: OpVT) &&
6195 TLI.isOperationLegalOrCustom(Op: ISD::FMINNUM, VT: OpVT);
6196 if (((OpVT.isInteger() && TLI.isOperationLegal(Op: ISD::UMAX, VT: OpVT) &&
6197 TLI.isOperationLegal(Op: ISD::SMAX, VT: OpVT) &&
6198 TLI.isOperationLegal(Op: ISD::UMIN, VT: OpVT) &&
6199 TLI.isOperationLegal(Op: ISD::SMIN, VT: OpVT)) ||
6200 (OpVT.isFloatingPoint() &&
6201 (isFMAXNUMFMINNUM_IEEE || isFMAXNUMFMINNUM))) &&
6202 !ISD::isIntEqualitySetCC(Code: CCL) && !ISD::isFPEqualitySetCC(Code: CCL) &&
6203 CCL != ISD::SETFALSE && CCL != ISD::SETO && CCL != ISD::SETUO &&
6204 CCL != ISD::SETTRUE &&
6205 (CCL == CCR || CCL == ISD::getSetCCSwappedOperands(Operation: CCR))) {
6206
6207 SDValue CommonValue, Operand1, Operand2;
6208 ISD::CondCode CC = ISD::SETCC_INVALID;
6209 if (CCL == CCR) {
6210 if (LHS0 == RHS0) {
6211 CommonValue = LHS0;
6212 Operand1 = LHS1;
6213 Operand2 = RHS1;
6214 CC = ISD::getSetCCSwappedOperands(Operation: CCL);
6215 } else if (LHS1 == RHS1) {
6216 CommonValue = LHS1;
6217 Operand1 = LHS0;
6218 Operand2 = RHS0;
6219 CC = CCL;
6220 }
6221 } else {
6222 assert(CCL == ISD::getSetCCSwappedOperands(CCR) && "Unexpected CC");
6223 if (LHS0 == RHS1) {
6224 CommonValue = LHS0;
6225 Operand1 = LHS1;
6226 Operand2 = RHS0;
6227 CC = CCR;
6228 } else if (RHS0 == LHS1) {
6229 CommonValue = LHS1;
6230 Operand1 = LHS0;
6231 Operand2 = RHS1;
6232 CC = CCL;
6233 }
6234 }
6235
6236 // Don't do this transform for sign bit tests. Let foldLogicOfSetCCs
6237 // handle it using OR/AND.
6238 if (CC == ISD::SETLT && isNullOrNullSplat(V: CommonValue))
6239 CC = ISD::SETCC_INVALID;
6240 else if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(V: CommonValue))
6241 CC = ISD::SETCC_INVALID;
6242
6243 if (CC != ISD::SETCC_INVALID) {
6244 unsigned NewOpcode = ISD::DELETED_NODE;
6245 bool IsSigned = isSignedIntSetCC(Code: CC);
6246 if (OpVT.isInteger()) {
6247 bool IsLess = (CC == ISD::SETLE || CC == ISD::SETULE ||
6248 CC == ISD::SETLT || CC == ISD::SETULT);
6249 bool IsOr = (LogicOp->getOpcode() == ISD::OR);
6250 if (IsLess == IsOr)
6251 NewOpcode = IsSigned ? ISD::SMIN : ISD::UMIN;
6252 else
6253 NewOpcode = IsSigned ? ISD::SMAX : ISD::UMAX;
6254 } else if (OpVT.isFloatingPoint())
6255 NewOpcode =
6256 getMinMaxOpcodeForFP(Operand1, Operand2, CC, OrAndOpcode: LogicOp->getOpcode(),
6257 DAG, isFMAXNUMFMINNUM_IEEE, isFMAXNUMFMINNUM);
6258
6259 if (NewOpcode != ISD::DELETED_NODE) {
6260 SDValue MinMaxValue =
6261 DAG.getNode(Opcode: NewOpcode, DL, VT: OpVT, N1: Operand1, N2: Operand2);
6262 return DAG.getSetCC(DL, VT, LHS: MinMaxValue, RHS: CommonValue, Cond: CC);
6263 }
6264 }
6265 }
6266
6267 if (TargetPreference == AndOrSETCCFoldKind::None)
6268 return SDValue();
6269
6270 if (CCL == CCR &&
6271 CCL == (LogicOp->getOpcode() == ISD::AND ? ISD::SETNE : ISD::SETEQ) &&
6272 LHS0 == RHS0 && LHS1C && RHS1C && OpVT.isInteger()) {
6273 const APInt &APLhs = LHS1C->getAPIntValue();
6274 const APInt &APRhs = RHS1C->getAPIntValue();
6275
6276 // Preference is to use ISD::ABS or we already have an ISD::ABS (in which
6277 // case this is just a compare).
6278 if (APLhs == (-APRhs) &&
6279 ((TargetPreference & AndOrSETCCFoldKind::ABS) ||
6280 DAG.doesNodeExist(Opcode: ISD::ABS, VTList: DAG.getVTList(VT: OpVT), Ops: {LHS0}))) {
6281 const APInt &C = APLhs.isNegative() ? APRhs : APLhs;
6282 // (icmp eq A, C) | (icmp eq A, -C)
6283 // -> (icmp eq Abs(A), C)
6284 // (icmp ne A, C) & (icmp ne A, -C)
6285 // -> (icmp ne Abs(A), C)
6286 SDValue AbsOp = DAG.getNode(Opcode: ISD::ABS, DL, VT: OpVT, Operand: LHS0);
6287 return DAG.getNode(Opcode: ISD::SETCC, DL, VT, N1: AbsOp,
6288 N2: DAG.getConstant(Val: C, DL, VT: OpVT), N3: LHS.getOperand(i: 2));
6289 } else if (TargetPreference &
6290 (AndOrSETCCFoldKind::AddAnd | AndOrSETCCFoldKind::NotAnd)) {
6291
6292 // AndOrSETCCFoldKind::AddAnd:
6293 // A == C0 | A == C1
6294 // IF IsPow2(smax(C0, C1)-smin(C0, C1))
6295 // -> ((A - smin(C0, C1)) & ~(smax(C0, C1)-smin(C0, C1))) == 0
6296 // A != C0 & A != C1
6297 // IF IsPow2(smax(C0, C1)-smin(C0, C1))
6298 // -> ((A - smin(C0, C1)) & ~(smax(C0, C1)-smin(C0, C1))) != 0
6299
6300 // AndOrSETCCFoldKind::NotAnd:
6301 // A == C0 | A == C1
6302 // IF smax(C0, C1) == -1 AND IsPow2(smax(C0, C1) - smin(C0, C1))
6303 // -> ~A & smin(C0, C1) == 0
6304 // A != C0 & A != C1
6305 // IF smax(C0, C1) == -1 AND IsPow2(smax(C0, C1) - smin(C0, C1))
6306 // -> ~A & smin(C0, C1) != 0
6307
6308 const APInt &MaxC = APIntOps::smax(A: APRhs, B: APLhs);
6309 const APInt &MinC = APIntOps::smin(A: APRhs, B: APLhs);
6310 APInt Dif = MaxC - MinC;
6311 if (!Dif.isZero() && Dif.isPowerOf2()) {
6312 if (MaxC.isAllOnes() &&
6313 (TargetPreference & AndOrSETCCFoldKind::NotAnd)) {
6314 SDValue NotOp = DAG.getNOT(DL, Val: LHS0, VT: OpVT);
6315 SDValue AndOp = DAG.getNode(Opcode: ISD::AND, DL, VT: OpVT, N1: NotOp,
6316 N2: DAG.getConstant(Val: MinC, DL, VT: OpVT));
6317 return DAG.getNode(Opcode: ISD::SETCC, DL, VT, N1: AndOp,
6318 N2: DAG.getConstant(Val: 0, DL, VT: OpVT), N3: LHS.getOperand(i: 2));
6319 } else if (TargetPreference & AndOrSETCCFoldKind::AddAnd) {
6320
6321 SDValue AddOp = DAG.getNode(Opcode: ISD::ADD, DL, VT: OpVT, N1: LHS0,
6322 N2: DAG.getConstant(Val: -MinC, DL, VT: OpVT));
6323 SDValue AndOp = DAG.getNode(Opcode: ISD::AND, DL, VT: OpVT, N1: AddOp,
6324 N2: DAG.getConstant(Val: ~Dif, DL, VT: OpVT));
6325 return DAG.getNode(Opcode: ISD::SETCC, DL, VT, N1: AndOp,
6326 N2: DAG.getConstant(Val: 0, DL, VT: OpVT), N3: LHS.getOperand(i: 2));
6327 }
6328 }
6329 }
6330 }
6331
6332 return SDValue();
6333}
6334
6335// Combine `(select c, (X & 1), 0)` -> `(and (zext c), X)`.
6336// We canonicalize to the `select` form in the middle end, but the `and` form
6337// gets better codegen and all tested targets (arm, x86, riscv)
6338static SDValue combineSelectAsExtAnd(SDValue Cond, SDValue T, SDValue F,
6339 const SDLoc &DL, SelectionDAG &DAG) {
6340 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6341 if (!isNullConstant(V: F))
6342 return SDValue();
6343
6344 EVT CondVT = Cond.getValueType();
6345 if (TLI.getBooleanContents(Type: CondVT) !=
6346 TargetLoweringBase::ZeroOrOneBooleanContent)
6347 return SDValue();
6348
6349 if (T.getOpcode() != ISD::AND)
6350 return SDValue();
6351
6352 if (!isOneConstant(V: T.getOperand(i: 1)))
6353 return SDValue();
6354
6355 EVT OpVT = T.getValueType();
6356
6357 SDValue CondMask =
6358 OpVT == CondVT ? Cond : DAG.getBoolExtOrTrunc(Op: Cond, SL: DL, VT: OpVT, OpVT: CondVT);
6359 return DAG.getNode(Opcode: ISD::AND, DL, VT: OpVT, N1: CondMask, N2: T.getOperand(i: 0));
6360}
6361
6362/// This contains all DAGCombine rules which reduce two values combined by
6363/// an And operation to a single value. This makes them reusable in the context
6364/// of visitSELECT(). Rules involving constants are not included as
6365/// visitSELECT() already handles those cases.
6366SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1, SDNode *N) {
6367 EVT VT = N1.getValueType();
6368 SDLoc DL(N);
6369
6370 // fold (and x, undef) -> 0
6371 if (N0.isUndef() || N1.isUndef())
6372 return DAG.getConstant(Val: 0, DL, VT);
6373
6374 if (SDValue V = foldLogicOfSetCCs(IsAnd: true, N0, N1, DL))
6375 return V;
6376
6377 // Canonicalize:
6378 // and(x, add) -> and(add, x)
6379 if (N1.getOpcode() == ISD::ADD)
6380 std::swap(a&: N0, b&: N1);
6381
6382 // TODO: Rewrite this to return a new 'AND' instead of using CombineTo.
6383 if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::SRL &&
6384 VT.isScalarInteger() && VT.getSizeInBits() <= 64 && N0->hasOneUse()) {
6385 if (ConstantSDNode *ADDI = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1))) {
6386 if (ConstantSDNode *SRLI = dyn_cast<ConstantSDNode>(Val: N1.getOperand(i: 1))) {
6387 // Look for (and (add x, c1), (lshr y, c2)). If C1 wasn't a legal
6388 // immediate for an add, but it is legal if its top c2 bits are set,
6389 // transform the ADD so the immediate doesn't need to be materialized
6390 // in a register.
6391 APInt ADDC = ADDI->getAPIntValue();
6392 APInt SRLC = SRLI->getAPIntValue();
6393 if (ADDC.getSignificantBits() <= 64 && SRLC.ult(RHS: VT.getSizeInBits()) &&
6394 !TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
6395 APInt Mask = APInt::getHighBitsSet(numBits: VT.getSizeInBits(),
6396 hiBitsSet: SRLC.getZExtValue());
6397 if (DAG.MaskedValueIsZero(Op: N0.getOperand(i: 1), Mask)) {
6398 ADDC |= Mask;
6399 if (TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
6400 SDLoc DL0(N0);
6401 SDValue NewAdd =
6402 DAG.getNode(Opcode: ISD::ADD, DL: DL0, VT,
6403 N1: N0.getOperand(i: 0), N2: DAG.getConstant(Val: ADDC, DL, VT));
6404 CombineTo(N: N0.getNode(), Res: NewAdd);
6405 // Return N so it doesn't get rechecked!
6406 return SDValue(N, 0);
6407 }
6408 }
6409 }
6410 }
6411 }
6412 }
6413
6414 return SDValue();
6415}
6416
6417bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
6418 EVT LoadResultTy, EVT &ExtVT) {
6419 if (!AndC->getAPIntValue().isMask())
6420 return false;
6421
6422 unsigned ActiveBits = AndC->getAPIntValue().countr_one();
6423
6424 ExtVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: ActiveBits);
6425 EVT LoadedVT = LoadN->getMemoryVT();
6426
6427 if (ExtVT == LoadedVT &&
6428 (!LegalOperations ||
6429 TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: LoadResultTy, MemVT: ExtVT))) {
6430 // ZEXTLOAD will match without needing to change the size of the value being
6431 // loaded.
6432 return true;
6433 }
6434
6435 // Do not change the width of a volatile or atomic loads.
6436 if (!LoadN->isSimple())
6437 return false;
6438
6439 // Do not generate loads of non-round integer types since these can
6440 // be expensive (and would be wrong if the type is not byte sized).
6441 if (!LoadedVT.bitsGT(VT: ExtVT) || !ExtVT.isRound())
6442 return false;
6443
6444 if (LegalOperations &&
6445 !TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: LoadResultTy, MemVT: ExtVT))
6446 return false;
6447
6448 if (!TLI.shouldReduceLoadWidth(Load: LoadN, ExtTy: ISD::ZEXTLOAD, NewVT: ExtVT))
6449 return false;
6450
6451 return true;
6452}
6453
6454bool DAGCombiner::isLegalNarrowLdSt(LSBaseSDNode *LDST,
6455 ISD::LoadExtType ExtType, EVT &MemVT,
6456 unsigned ShAmt) {
6457 if (!LDST)
6458 return false;
6459 // Only allow byte offsets.
6460 if (ShAmt % 8)
6461 return false;
6462
6463 // Do not generate loads of non-round integer types since these can
6464 // be expensive (and would be wrong if the type is not byte sized).
6465 if (!MemVT.isRound())
6466 return false;
6467
6468 // Don't change the width of a volatile or atomic loads.
6469 if (!LDST->isSimple())
6470 return false;
6471
6472 EVT LdStMemVT = LDST->getMemoryVT();
6473
6474 // Bail out when changing the scalable property, since we can't be sure that
6475 // we're actually narrowing here.
6476 if (LdStMemVT.isScalableVector() != MemVT.isScalableVector())
6477 return false;
6478
6479 // Verify that we are actually reducing a load width here.
6480 if (LdStMemVT.bitsLT(VT: MemVT))
6481 return false;
6482
6483 // Ensure that this isn't going to produce an unsupported memory access.
6484 if (ShAmt) {
6485 assert(ShAmt % 8 == 0 && "ShAmt is byte offset");
6486 const unsigned ByteShAmt = ShAmt / 8;
6487 const Align LDSTAlign = LDST->getAlign();
6488 const Align NarrowAlign = commonAlignment(A: LDSTAlign, Offset: ByteShAmt);
6489 if (!TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT: MemVT,
6490 AddrSpace: LDST->getAddressSpace(), Alignment: NarrowAlign,
6491 Flags: LDST->getMemOperand()->getFlags()))
6492 return false;
6493 }
6494
6495 // It's not possible to generate a constant of extended or untyped type.
6496 EVT PtrType = LDST->getBasePtr().getValueType();
6497 if (PtrType == MVT::Untyped || PtrType.isExtended())
6498 return false;
6499
6500 if (isa<LoadSDNode>(Val: LDST)) {
6501 LoadSDNode *Load = cast<LoadSDNode>(Val: LDST);
6502 // Don't transform one with multiple uses, this would require adding a new
6503 // load.
6504 if (!SDValue(Load, 0).hasOneUse())
6505 return false;
6506
6507 if (LegalOperations &&
6508 !TLI.isLoadExtLegal(ExtType, ValVT: Load->getValueType(ResNo: 0), MemVT))
6509 return false;
6510
6511 // For the transform to be legal, the load must produce only two values
6512 // (the value loaded and the chain). Don't transform a pre-increment
6513 // load, for example, which produces an extra value. Otherwise the
6514 // transformation is not equivalent, and the downstream logic to replace
6515 // uses gets things wrong.
6516 if (Load->getNumValues() > 2)
6517 return false;
6518
6519 // If the load that we're shrinking is an extload and we're not just
6520 // discarding the extension we can't simply shrink the load. Bail.
6521 // TODO: It would be possible to merge the extensions in some cases.
6522 if (Load->getExtensionType() != ISD::NON_EXTLOAD &&
6523 Load->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
6524 return false;
6525
6526 if (!TLI.shouldReduceLoadWidth(Load, ExtTy: ExtType, NewVT: MemVT))
6527 return false;
6528 } else {
6529 assert(isa<StoreSDNode>(LDST) && "It is not a Load nor a Store SDNode");
6530 StoreSDNode *Store = cast<StoreSDNode>(Val: LDST);
6531 // Can't write outside the original store
6532 if (Store->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
6533 return false;
6534
6535 if (LegalOperations &&
6536 !TLI.isTruncStoreLegal(ValVT: Store->getValue().getValueType(), MemVT))
6537 return false;
6538 }
6539 return true;
6540}
6541
6542bool DAGCombiner::SearchForAndLoads(SDNode *N,
6543 SmallVectorImpl<LoadSDNode*> &Loads,
6544 SmallPtrSetImpl<SDNode*> &NodesWithConsts,
6545 ConstantSDNode *Mask,
6546 SDNode *&NodeToMask) {
6547 // Recursively search for the operands, looking for loads which can be
6548 // narrowed.
6549 for (SDValue Op : N->op_values()) {
6550 if (Op.getValueType().isVector())
6551 return false;
6552
6553 // Some constants may need fixing up later if they are too large.
6554 if (auto *C = dyn_cast<ConstantSDNode>(Val&: Op)) {
6555 if ((N->getOpcode() == ISD::OR || N->getOpcode() == ISD::XOR) &&
6556 (Mask->getAPIntValue() & C->getAPIntValue()) != C->getAPIntValue())
6557 NodesWithConsts.insert(Ptr: N);
6558 continue;
6559 }
6560
6561 if (!Op.hasOneUse())
6562 return false;
6563
6564 switch(Op.getOpcode()) {
6565 case ISD::LOAD: {
6566 auto *Load = cast<LoadSDNode>(Val&: Op);
6567 EVT ExtVT;
6568 if (isAndLoadExtLoad(AndC: Mask, LoadN: Load, LoadResultTy: Load->getValueType(ResNo: 0), ExtVT) &&
6569 isLegalNarrowLdSt(LDST: Load, ExtType: ISD::ZEXTLOAD, MemVT&: ExtVT)) {
6570
6571 // ZEXTLOAD is already small enough.
6572 if (Load->getExtensionType() == ISD::ZEXTLOAD &&
6573 ExtVT.bitsGE(VT: Load->getMemoryVT()))
6574 continue;
6575
6576 // Use LE to convert equal sized loads to zext.
6577 if (ExtVT.bitsLE(VT: Load->getMemoryVT()))
6578 Loads.push_back(Elt: Load);
6579
6580 continue;
6581 }
6582 return false;
6583 }
6584 case ISD::ZERO_EXTEND:
6585 case ISD::AssertZext: {
6586 unsigned ActiveBits = Mask->getAPIntValue().countr_one();
6587 EVT ExtVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: ActiveBits);
6588 EVT VT = Op.getOpcode() == ISD::AssertZext ?
6589 cast<VTSDNode>(Val: Op.getOperand(i: 1))->getVT() :
6590 Op.getOperand(i: 0).getValueType();
6591
6592 // We can accept extending nodes if the mask is wider or an equal
6593 // width to the original type.
6594 if (ExtVT.bitsGE(VT))
6595 continue;
6596 break;
6597 }
6598 case ISD::OR:
6599 case ISD::XOR:
6600 case ISD::AND:
6601 if (!SearchForAndLoads(N: Op.getNode(), Loads, NodesWithConsts, Mask,
6602 NodeToMask))
6603 return false;
6604 continue;
6605 }
6606
6607 // Allow one node which will masked along with any loads found.
6608 if (NodeToMask)
6609 return false;
6610
6611 // Also ensure that the node to be masked only produces one data result.
6612 NodeToMask = Op.getNode();
6613 if (NodeToMask->getNumValues() > 1) {
6614 bool HasValue = false;
6615 for (unsigned i = 0, e = NodeToMask->getNumValues(); i < e; ++i) {
6616 MVT VT = SDValue(NodeToMask, i).getSimpleValueType();
6617 if (VT != MVT::Glue && VT != MVT::Other) {
6618 if (HasValue) {
6619 NodeToMask = nullptr;
6620 return false;
6621 }
6622 HasValue = true;
6623 }
6624 }
6625 assert(HasValue && "Node to be masked has no data result?");
6626 }
6627 }
6628 return true;
6629}
6630
6631bool DAGCombiner::BackwardsPropagateMask(SDNode *N) {
6632 auto *Mask = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
6633 if (!Mask)
6634 return false;
6635
6636 if (!Mask->getAPIntValue().isMask())
6637 return false;
6638
6639 // No need to do anything if the and directly uses a load.
6640 if (isa<LoadSDNode>(Val: N->getOperand(Num: 0)))
6641 return false;
6642
6643 SmallVector<LoadSDNode*, 8> Loads;
6644 SmallPtrSet<SDNode*, 2> NodesWithConsts;
6645 SDNode *FixupNode = nullptr;
6646 if (SearchForAndLoads(N, Loads, NodesWithConsts, Mask, NodeToMask&: FixupNode)) {
6647 if (Loads.empty())
6648 return false;
6649
6650 LLVM_DEBUG(dbgs() << "Backwards propagate AND: "; N->dump());
6651 SDValue MaskOp = N->getOperand(Num: 1);
6652
6653 // If it exists, fixup the single node we allow in the tree that needs
6654 // masking.
6655 if (FixupNode) {
6656 LLVM_DEBUG(dbgs() << "First, need to fix up: "; FixupNode->dump());
6657 SDValue And = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(FixupNode),
6658 VT: FixupNode->getValueType(ResNo: 0),
6659 N1: SDValue(FixupNode, 0), N2: MaskOp);
6660 DAG.ReplaceAllUsesOfValueWith(From: SDValue(FixupNode, 0), To: And);
6661 if (And.getOpcode() == ISD ::AND)
6662 DAG.UpdateNodeOperands(N: And.getNode(), Op1: SDValue(FixupNode, 0), Op2: MaskOp);
6663 }
6664
6665 // Narrow any constants that need it.
6666 for (auto *LogicN : NodesWithConsts) {
6667 SDValue Op0 = LogicN->getOperand(Num: 0);
6668 SDValue Op1 = LogicN->getOperand(Num: 1);
6669
6670 if (isa<ConstantSDNode>(Val: Op0))
6671 Op0 =
6672 DAG.getNode(Opcode: ISD::AND, DL: SDLoc(Op0), VT: Op0.getValueType(), N1: Op0, N2: MaskOp);
6673
6674 if (isa<ConstantSDNode>(Val: Op1))
6675 Op1 =
6676 DAG.getNode(Opcode: ISD::AND, DL: SDLoc(Op1), VT: Op1.getValueType(), N1: Op1, N2: MaskOp);
6677
6678 if (isa<ConstantSDNode>(Val: Op0) && !isa<ConstantSDNode>(Val: Op1))
6679 std::swap(a&: Op0, b&: Op1);
6680
6681 DAG.UpdateNodeOperands(N: LogicN, Op1: Op0, Op2: Op1);
6682 }
6683
6684 // Create narrow loads.
6685 for (auto *Load : Loads) {
6686 LLVM_DEBUG(dbgs() << "Propagate AND back to: "; Load->dump());
6687 SDValue And = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(Load), VT: Load->getValueType(ResNo: 0),
6688 N1: SDValue(Load, 0), N2: MaskOp);
6689 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Load, 0), To: And);
6690 if (And.getOpcode() == ISD ::AND)
6691 And = SDValue(
6692 DAG.UpdateNodeOperands(N: And.getNode(), Op1: SDValue(Load, 0), Op2: MaskOp), 0);
6693 SDValue NewLoad = reduceLoadWidth(N: And.getNode());
6694 assert(NewLoad &&
6695 "Shouldn't be masking the load if it can't be narrowed");
6696 CombineTo(N: Load, Res0: NewLoad, Res1: NewLoad.getValue(R: 1));
6697 }
6698 DAG.ReplaceAllUsesWith(From: N, To: N->getOperand(Num: 0).getNode());
6699 return true;
6700 }
6701 return false;
6702}
6703
6704// Unfold
6705// x & (-1 'logical shift' y)
6706// To
6707// (x 'opposite logical shift' y) 'logical shift' y
6708// if it is better for performance.
6709SDValue DAGCombiner::unfoldExtremeBitClearingToShifts(SDNode *N) {
6710 assert(N->getOpcode() == ISD::AND);
6711
6712 SDValue N0 = N->getOperand(Num: 0);
6713 SDValue N1 = N->getOperand(Num: 1);
6714
6715 // Do we actually prefer shifts over mask?
6716 if (!TLI.shouldFoldMaskToVariableShiftPair(X: N0))
6717 return SDValue();
6718
6719 // Try to match (-1 '[outer] logical shift' y)
6720 unsigned OuterShift;
6721 unsigned InnerShift; // The opposite direction to the OuterShift.
6722 SDValue Y; // Shift amount.
6723 auto matchMask = [&OuterShift, &InnerShift, &Y](SDValue M) -> bool {
6724 if (!M.hasOneUse())
6725 return false;
6726 OuterShift = M->getOpcode();
6727 if (OuterShift == ISD::SHL)
6728 InnerShift = ISD::SRL;
6729 else if (OuterShift == ISD::SRL)
6730 InnerShift = ISD::SHL;
6731 else
6732 return false;
6733 if (!isAllOnesConstant(V: M->getOperand(Num: 0)))
6734 return false;
6735 Y = M->getOperand(Num: 1);
6736 return true;
6737 };
6738
6739 SDValue X;
6740 if (matchMask(N1))
6741 X = N0;
6742 else if (matchMask(N0))
6743 X = N1;
6744 else
6745 return SDValue();
6746
6747 SDLoc DL(N);
6748 EVT VT = N->getValueType(ResNo: 0);
6749
6750 // tmp = x 'opposite logical shift' y
6751 SDValue T0 = DAG.getNode(Opcode: InnerShift, DL, VT, N1: X, N2: Y);
6752 // ret = tmp 'logical shift' y
6753 SDValue T1 = DAG.getNode(Opcode: OuterShift, DL, VT, N1: T0, N2: Y);
6754
6755 return T1;
6756}
6757
6758/// Try to replace shift/logic that tests if a bit is clear with mask + setcc.
6759/// For a target with a bit test, this is expected to become test + set and save
6760/// at least 1 instruction.
6761static SDValue combineShiftAnd1ToBitTest(SDNode *And, SelectionDAG &DAG) {
6762 assert(And->getOpcode() == ISD::AND && "Expected an 'and' op");
6763
6764 // Look through an optional extension.
6765 SDValue And0 = And->getOperand(Num: 0), And1 = And->getOperand(Num: 1);
6766 if (And0.getOpcode() == ISD::ANY_EXTEND && And0.hasOneUse())
6767 And0 = And0.getOperand(i: 0);
6768 if (!isOneConstant(V: And1) || !And0.hasOneUse())
6769 return SDValue();
6770
6771 SDValue Src = And0;
6772
6773 // Attempt to find a 'not' op.
6774 // TODO: Should we favor test+set even without the 'not' op?
6775 bool FoundNot = false;
6776 if (isBitwiseNot(V: Src)) {
6777 FoundNot = true;
6778 Src = Src.getOperand(i: 0);
6779
6780 // Look though an optional truncation. The source operand may not be the
6781 // same type as the original 'and', but that is ok because we are masking
6782 // off everything but the low bit.
6783 if (Src.getOpcode() == ISD::TRUNCATE && Src.hasOneUse())
6784 Src = Src.getOperand(i: 0);
6785 }
6786
6787 // Match a shift-right by constant.
6788 if (Src.getOpcode() != ISD::SRL || !Src.hasOneUse())
6789 return SDValue();
6790
6791 // This is probably not worthwhile without a supported type.
6792 EVT SrcVT = Src.getValueType();
6793 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6794 if (!TLI.isTypeLegal(VT: SrcVT))
6795 return SDValue();
6796
6797 // We might have looked through casts that make this transform invalid.
6798 unsigned BitWidth = SrcVT.getScalarSizeInBits();
6799 SDValue ShiftAmt = Src.getOperand(i: 1);
6800 auto *ShiftAmtC = dyn_cast<ConstantSDNode>(Val&: ShiftAmt);
6801 if (!ShiftAmtC || !ShiftAmtC->getAPIntValue().ult(RHS: BitWidth))
6802 return SDValue();
6803
6804 // Set source to shift source.
6805 Src = Src.getOperand(i: 0);
6806
6807 // Try again to find a 'not' op.
6808 // TODO: Should we favor test+set even with two 'not' ops?
6809 if (!FoundNot) {
6810 if (!isBitwiseNot(V: Src))
6811 return SDValue();
6812 Src = Src.getOperand(i: 0);
6813 }
6814
6815 if (!TLI.hasBitTest(X: Src, Y: ShiftAmt))
6816 return SDValue();
6817
6818 // Turn this into a bit-test pattern using mask op + setcc:
6819 // and (not (srl X, C)), 1 --> (and X, 1<<C) == 0
6820 // and (srl (not X), C)), 1 --> (and X, 1<<C) == 0
6821 SDLoc DL(And);
6822 SDValue X = DAG.getZExtOrTrunc(Op: Src, DL, VT: SrcVT);
6823 EVT CCVT =
6824 TLI.getSetCCResultType(DL: DAG.getDataLayout(), Context&: *DAG.getContext(), VT: SrcVT);
6825 SDValue Mask = DAG.getConstant(
6826 Val: APInt::getOneBitSet(numBits: BitWidth, BitNo: ShiftAmtC->getZExtValue()), DL, VT: SrcVT);
6827 SDValue NewAnd = DAG.getNode(Opcode: ISD::AND, DL, VT: SrcVT, N1: X, N2: Mask);
6828 SDValue Zero = DAG.getConstant(Val: 0, DL, VT: SrcVT);
6829 SDValue Setcc = DAG.getSetCC(DL, VT: CCVT, LHS: NewAnd, RHS: Zero, Cond: ISD::SETEQ);
6830 return DAG.getZExtOrTrunc(Op: Setcc, DL, VT: And->getValueType(ResNo: 0));
6831}
6832
6833/// For targets that support usubsat, match a bit-hack form of that operation
6834/// that ends in 'and' and convert it.
6835static SDValue foldAndToUsubsat(SDNode *N, SelectionDAG &DAG, const SDLoc &DL) {
6836 EVT VT = N->getValueType(ResNo: 0);
6837 unsigned BitWidth = VT.getScalarSizeInBits();
6838 APInt SignMask = APInt::getSignMask(BitWidth);
6839
6840 // (i8 X ^ 128) & (i8 X s>> 7) --> usubsat X, 128
6841 // (i8 X + 128) & (i8 X s>> 7) --> usubsat X, 128
6842 // xor/add with SMIN (signmask) are logically equivalent.
6843 SDValue X;
6844 if (!sd_match(N, P: m_And(L: m_OneUse(P: m_Xor(L: m_Value(N&: X), R: m_SpecificInt(V: SignMask))),
6845 R: m_OneUse(P: m_Sra(L: m_Deferred(V&: X),
6846 R: m_SpecificInt(V: BitWidth - 1))))) &&
6847 !sd_match(N, P: m_And(L: m_OneUse(P: m_Add(L: m_Value(N&: X), R: m_SpecificInt(V: SignMask))),
6848 R: m_OneUse(P: m_Sra(L: m_Deferred(V&: X),
6849 R: m_SpecificInt(V: BitWidth - 1))))))
6850 return SDValue();
6851
6852 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT, N1: X,
6853 N2: DAG.getConstant(Val: SignMask, DL, VT));
6854}
6855
6856/// Given a bitwise logic operation N with a matching bitwise logic operand,
6857/// fold a pattern where 2 of the source operands are identically shifted
6858/// values. For example:
6859/// ((X0 << Y) | Z) | (X1 << Y) --> ((X0 | X1) << Y) | Z
6860static SDValue foldLogicOfShifts(SDNode *N, SDValue LogicOp, SDValue ShiftOp,
6861 SelectionDAG &DAG) {
6862 unsigned LogicOpcode = N->getOpcode();
6863 assert(ISD::isBitwiseLogicOp(LogicOpcode) &&
6864 "Expected bitwise logic operation");
6865
6866 if (!LogicOp.hasOneUse() || !ShiftOp.hasOneUse())
6867 return SDValue();
6868
6869 // Match another bitwise logic op and a shift.
6870 unsigned ShiftOpcode = ShiftOp.getOpcode();
6871 if (LogicOp.getOpcode() != LogicOpcode ||
6872 !(ShiftOpcode == ISD::SHL || ShiftOpcode == ISD::SRL ||
6873 ShiftOpcode == ISD::SRA))
6874 return SDValue();
6875
6876 // Match another shift op inside the first logic operand. Handle both commuted
6877 // possibilities.
6878 // LOGIC (LOGIC (SH X0, Y), Z), (SH X1, Y) --> LOGIC (SH (LOGIC X0, X1), Y), Z
6879 // LOGIC (LOGIC Z, (SH X0, Y)), (SH X1, Y) --> LOGIC (SH (LOGIC X0, X1), Y), Z
6880 SDValue X1 = ShiftOp.getOperand(i: 0);
6881 SDValue Y = ShiftOp.getOperand(i: 1);
6882 SDValue X0, Z;
6883 if (LogicOp.getOperand(i: 0).getOpcode() == ShiftOpcode &&
6884 LogicOp.getOperand(i: 0).getOperand(i: 1) == Y) {
6885 X0 = LogicOp.getOperand(i: 0).getOperand(i: 0);
6886 Z = LogicOp.getOperand(i: 1);
6887 } else if (LogicOp.getOperand(i: 1).getOpcode() == ShiftOpcode &&
6888 LogicOp.getOperand(i: 1).getOperand(i: 1) == Y) {
6889 X0 = LogicOp.getOperand(i: 1).getOperand(i: 0);
6890 Z = LogicOp.getOperand(i: 0);
6891 } else {
6892 return SDValue();
6893 }
6894
6895 EVT VT = N->getValueType(ResNo: 0);
6896 SDLoc DL(N);
6897 SDValue LogicX = DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: X0, N2: X1);
6898 SDValue NewShift = DAG.getNode(Opcode: ShiftOpcode, DL, VT, N1: LogicX, N2: Y);
6899 return DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: NewShift, N2: Z);
6900}
6901
6902/// Given a tree of logic operations with shape like
6903/// (LOGIC (LOGIC (X, Y), LOGIC (Z, Y)))
6904/// try to match and fold shift operations with the same shift amount.
6905/// For example:
6906/// LOGIC (LOGIC (SH X0, Y), Z), (LOGIC (SH X1, Y), W) -->
6907/// --> LOGIC (SH (LOGIC X0, X1), Y), (LOGIC Z, W)
6908static SDValue foldLogicTreeOfShifts(SDNode *N, SDValue LeftHand,
6909 SDValue RightHand, SelectionDAG &DAG) {
6910 unsigned LogicOpcode = N->getOpcode();
6911 assert(ISD::isBitwiseLogicOp(LogicOpcode) &&
6912 "Expected bitwise logic operation");
6913 if (LeftHand.getOpcode() != LogicOpcode ||
6914 RightHand.getOpcode() != LogicOpcode)
6915 return SDValue();
6916 if (!LeftHand.hasOneUse() || !RightHand.hasOneUse())
6917 return SDValue();
6918
6919 // Try to match one of following patterns:
6920 // LOGIC (LOGIC (SH X0, Y), Z), (LOGIC (SH X1, Y), W)
6921 // LOGIC (LOGIC (SH X0, Y), Z), (LOGIC W, (SH X1, Y))
6922 // Note that foldLogicOfShifts will handle commuted versions of the left hand
6923 // itself.
6924 SDValue CombinedShifts, W;
6925 SDValue R0 = RightHand.getOperand(i: 0);
6926 SDValue R1 = RightHand.getOperand(i: 1);
6927 if ((CombinedShifts = foldLogicOfShifts(N, LogicOp: LeftHand, ShiftOp: R0, DAG)))
6928 W = R1;
6929 else if ((CombinedShifts = foldLogicOfShifts(N, LogicOp: LeftHand, ShiftOp: R1, DAG)))
6930 W = R0;
6931 else
6932 return SDValue();
6933
6934 EVT VT = N->getValueType(ResNo: 0);
6935 SDLoc DL(N);
6936 return DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: CombinedShifts, N2: W);
6937}
6938
6939SDValue DAGCombiner::visitAND(SDNode *N) {
6940 SDValue N0 = N->getOperand(Num: 0);
6941 SDValue N1 = N->getOperand(Num: 1);
6942 EVT VT = N1.getValueType();
6943 SDLoc DL(N);
6944
6945 // x & x --> x
6946 if (N0 == N1)
6947 return N0;
6948
6949 // fold (and c1, c2) -> c1&c2
6950 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::AND, DL, VT, Ops: {N0, N1}))
6951 return C;
6952
6953 // canonicalize constant to RHS
6954 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
6955 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
6956 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1, N2: N0);
6957
6958 if (areBitwiseNotOfEachother(Op0: N0, Op1: N1))
6959 return DAG.getConstant(Val: APInt::getZero(numBits: VT.getScalarSizeInBits()), DL, VT);
6960
6961 // fold vector ops
6962 if (VT.isVector()) {
6963 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
6964 return FoldedVOp;
6965
6966 // fold (and x, 0) -> 0, vector edition
6967 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
6968 // do not return N1, because undef node may exist in N1
6969 return DAG.getConstant(Val: APInt::getZero(numBits: N1.getScalarValueSizeInBits()), DL,
6970 VT: N1.getValueType());
6971
6972 // fold (and x, -1) -> x, vector edition
6973 if (ISD::isConstantSplatVectorAllOnes(N: N1.getNode()))
6974 return N0;
6975
6976 // fold (and (masked_load) (splat_vec (x, ...))) to zext_masked_load
6977 auto *MLoad = dyn_cast<MaskedLoadSDNode>(Val&: N0);
6978 ConstantSDNode *Splat = isConstOrConstSplat(N: N1, AllowUndefs: true, AllowTruncation: true);
6979 if (MLoad && MLoad->getExtensionType() == ISD::EXTLOAD && Splat &&
6980 N1.hasOneUse()) {
6981 EVT LoadVT = MLoad->getMemoryVT();
6982 EVT ExtVT = VT;
6983 if (TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: ExtVT, MemVT: LoadVT)) {
6984 // For this AND to be a zero extension of the masked load the elements
6985 // of the BuildVec must mask the bottom bits of the extended element
6986 // type
6987 uint64_t ElementSize =
6988 LoadVT.getVectorElementType().getScalarSizeInBits();
6989 if (Splat->getAPIntValue().isMask(numBits: ElementSize)) {
6990 SDValue NewLoad = DAG.getMaskedLoad(
6991 VT: ExtVT, dl: DL, Chain: MLoad->getChain(), Base: MLoad->getBasePtr(),
6992 Offset: MLoad->getOffset(), Mask: MLoad->getMask(), Src0: MLoad->getPassThru(),
6993 MemVT: LoadVT, MMO: MLoad->getMemOperand(), AM: MLoad->getAddressingMode(),
6994 ISD::ZEXTLOAD, IsExpanding: MLoad->isExpandingLoad());
6995 bool LoadHasOtherUsers = !N0.hasOneUse();
6996 CombineTo(N, Res: NewLoad);
6997 if (LoadHasOtherUsers)
6998 CombineTo(N: MLoad, Res0: NewLoad.getValue(R: 0), Res1: NewLoad.getValue(R: 1));
6999 return SDValue(N, 0);
7000 }
7001 }
7002 }
7003 }
7004
7005 // fold (and x, -1) -> x
7006 if (isAllOnesConstant(V: N1))
7007 return N0;
7008
7009 // if (and x, c) is known to be zero, return 0
7010 unsigned BitWidth = VT.getScalarSizeInBits();
7011 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
7012 if (N1C && DAG.MaskedValueIsZero(Op: SDValue(N, 0), Mask: APInt::getAllOnes(numBits: BitWidth)))
7013 return DAG.getConstant(Val: 0, DL, VT);
7014
7015 if (SDValue R = foldAndOrOfSETCC(LogicOp: N, DAG))
7016 return R;
7017
7018 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
7019 return NewSel;
7020
7021 // reassociate and
7022 if (SDValue RAND = reassociateOps(Opc: ISD::AND, DL, N0, N1, Flags: N->getFlags()))
7023 return RAND;
7024
7025 // Fold and(vecreduce(x), vecreduce(y)) -> vecreduce(and(x, y))
7026 if (SDValue SD =
7027 reassociateReduction(RedOpc: ISD::VECREDUCE_AND, Opc: ISD::AND, DL, VT, N0, N1))
7028 return SD;
7029
7030 // fold (and (or x, C), D) -> D if (C & D) == D
7031 auto MatchSubset = [](ConstantSDNode *LHS, ConstantSDNode *RHS) {
7032 return RHS->getAPIntValue().isSubsetOf(RHS: LHS->getAPIntValue());
7033 };
7034 if (N0.getOpcode() == ISD::OR &&
7035 ISD::matchBinaryPredicate(LHS: N0.getOperand(i: 1), RHS: N1, Match: MatchSubset))
7036 return N1;
7037
7038 if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
7039 SDValue N0Op0 = N0.getOperand(i: 0);
7040 EVT SrcVT = N0Op0.getValueType();
7041 unsigned SrcBitWidth = SrcVT.getScalarSizeInBits();
7042 APInt Mask = ~N1C->getAPIntValue();
7043 Mask = Mask.trunc(width: SrcBitWidth);
7044
7045 // fold (and (any_ext V), c) -> (zero_ext V) if 'and' only clears top bits.
7046 if (DAG.MaskedValueIsZero(Op: N0Op0, Mask))
7047 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: N0Op0);
7048
7049 // fold (and (any_ext V), c) -> (zero_ext (and (trunc V), c)) if profitable.
7050 if (N1C->getAPIntValue().countLeadingZeros() >= (BitWidth - SrcBitWidth) &&
7051 TLI.isTruncateFree(FromVT: VT, ToVT: SrcVT) && TLI.isZExtFree(FromTy: SrcVT, ToTy: VT) &&
7052 TLI.isTypeDesirableForOp(ISD::AND, VT: SrcVT) &&
7053 TLI.isNarrowingProfitable(SrcVT: VT, DestVT: SrcVT))
7054 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT,
7055 Operand: DAG.getNode(Opcode: ISD::AND, DL, VT: SrcVT, N1: N0Op0,
7056 N2: DAG.getZExtOrTrunc(Op: N1, DL, VT: SrcVT)));
7057 }
7058
7059 // fold (and (ext (and V, c1)), c2) -> (and (ext V), (and c1, (ext c2)))
7060 if (ISD::isExtOpcode(Opcode: N0.getOpcode())) {
7061 unsigned ExtOpc = N0.getOpcode();
7062 SDValue N0Op0 = N0.getOperand(i: 0);
7063 if (N0Op0.getOpcode() == ISD::AND &&
7064 (ExtOpc != ISD::ZERO_EXTEND || !TLI.isZExtFree(Val: N0Op0, VT2: VT)) &&
7065 DAG.isConstantIntBuildVectorOrConstantInt(N: N1) &&
7066 DAG.isConstantIntBuildVectorOrConstantInt(N: N0Op0.getOperand(i: 1)) &&
7067 N0->hasOneUse() && N0Op0->hasOneUse()) {
7068 SDValue NewMask =
7069 DAG.getNode(Opcode: ISD::AND, DL, VT, N1,
7070 N2: DAG.getNode(Opcode: ExtOpc, DL, VT, Operand: N0Op0.getOperand(i: 1)));
7071 return DAG.getNode(Opcode: ISD::AND, DL, VT,
7072 N1: DAG.getNode(Opcode: ExtOpc, DL, VT, Operand: N0Op0.getOperand(i: 0)),
7073 N2: NewMask);
7074 }
7075 }
7076
7077 // similarly fold (and (X (load ([non_ext|any_ext|zero_ext] V))), c) ->
7078 // (X (load ([non_ext|zero_ext] V))) if 'and' only clears top bits which must
7079 // already be zero by virtue of the width of the base type of the load.
7080 //
7081 // the 'X' node here can either be nothing or an extract_vector_elt to catch
7082 // more cases.
7083 if ((N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
7084 N0.getValueSizeInBits() == N0.getOperand(i: 0).getScalarValueSizeInBits() &&
7085 N0.getOperand(i: 0).getOpcode() == ISD::LOAD &&
7086 N0.getOperand(i: 0).getResNo() == 0) ||
7087 (N0.getOpcode() == ISD::LOAD && N0.getResNo() == 0)) {
7088 auto *Load =
7089 cast<LoadSDNode>(Val: (N0.getOpcode() == ISD::LOAD) ? N0 : N0.getOperand(i: 0));
7090
7091 // Get the constant (if applicable) the zero'th operand is being ANDed with.
7092 // This can be a pure constant or a vector splat, in which case we treat the
7093 // vector as a scalar and use the splat value.
7094 APInt Constant = APInt::getZero(numBits: 1);
7095 if (const ConstantSDNode *C = isConstOrConstSplat(
7096 N: N1, /*AllowUndef=*/AllowUndefs: false, /*AllowTruncation=*/true)) {
7097 Constant = C->getAPIntValue();
7098 } else if (BuildVectorSDNode *Vector = dyn_cast<BuildVectorSDNode>(Val&: N1)) {
7099 unsigned EltBitWidth = Vector->getValueType(ResNo: 0).getScalarSizeInBits();
7100 APInt SplatValue, SplatUndef;
7101 unsigned SplatBitSize;
7102 bool HasAnyUndefs;
7103 // Endianness should not matter here. Code below makes sure that we only
7104 // use the result if the SplatBitSize is a multiple of the vector element
7105 // size. And after that we AND all element sized parts of the splat
7106 // together. So the end result should be the same regardless of in which
7107 // order we do those operations.
7108 const bool IsBigEndian = false;
7109 bool IsSplat =
7110 Vector->isConstantSplat(SplatValue, SplatUndef, SplatBitSize,
7111 HasAnyUndefs, MinSplatBits: EltBitWidth, isBigEndian: IsBigEndian);
7112
7113 // Make sure that variable 'Constant' is only set if 'SplatBitSize' is a
7114 // multiple of 'BitWidth'. Otherwise, we could propagate a wrong value.
7115 if (IsSplat && (SplatBitSize % EltBitWidth) == 0) {
7116 // Undef bits can contribute to a possible optimisation if set, so
7117 // set them.
7118 SplatValue |= SplatUndef;
7119
7120 // The splat value may be something like "0x00FFFFFF", which means 0 for
7121 // the first vector value and FF for the rest, repeating. We need a mask
7122 // that will apply equally to all members of the vector, so AND all the
7123 // lanes of the constant together.
7124 Constant = APInt::getAllOnes(numBits: EltBitWidth);
7125 for (unsigned i = 0, n = (SplatBitSize / EltBitWidth); i < n; ++i)
7126 Constant &= SplatValue.extractBits(numBits: EltBitWidth, bitPosition: i * EltBitWidth);
7127 }
7128 }
7129
7130 // If we want to change an EXTLOAD to a ZEXTLOAD, ensure a ZEXTLOAD is
7131 // actually legal and isn't going to get expanded, else this is a false
7132 // optimisation.
7133 bool CanZextLoadProfitably = TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD,
7134 ValVT: Load->getValueType(ResNo: 0),
7135 MemVT: Load->getMemoryVT());
7136
7137 // Resize the constant to the same size as the original memory access before
7138 // extension. If it is still the AllOnesValue then this AND is completely
7139 // unneeded.
7140 Constant = Constant.zextOrTrunc(width: Load->getMemoryVT().getScalarSizeInBits());
7141
7142 bool B;
7143 switch (Load->getExtensionType()) {
7144 default: B = false; break;
7145 case ISD::EXTLOAD: B = CanZextLoadProfitably; break;
7146 case ISD::ZEXTLOAD:
7147 case ISD::NON_EXTLOAD: B = true; break;
7148 }
7149
7150 if (B && Constant.isAllOnes()) {
7151 // If the load type was an EXTLOAD, convert to ZEXTLOAD in order to
7152 // preserve semantics once we get rid of the AND.
7153 SDValue NewLoad(Load, 0);
7154
7155 // Fold the AND away. NewLoad may get replaced immediately.
7156 CombineTo(N, Res: (N0.getNode() == Load) ? NewLoad : N0);
7157
7158 if (Load->getExtensionType() == ISD::EXTLOAD) {
7159 NewLoad = DAG.getLoad(AM: Load->getAddressingMode(), ExtType: ISD::ZEXTLOAD,
7160 VT: Load->getValueType(ResNo: 0), dl: SDLoc(Load),
7161 Chain: Load->getChain(), Ptr: Load->getBasePtr(),
7162 Offset: Load->getOffset(), MemVT: Load->getMemoryVT(),
7163 MMO: Load->getMemOperand());
7164 // Replace uses of the EXTLOAD with the new ZEXTLOAD.
7165 if (Load->getNumValues() == 3) {
7166 // PRE/POST_INC loads have 3 values.
7167 SDValue To[] = { NewLoad.getValue(R: 0), NewLoad.getValue(R: 1),
7168 NewLoad.getValue(R: 2) };
7169 CombineTo(N: Load, To, NumTo: 3, AddTo: true);
7170 } else {
7171 CombineTo(N: Load, Res0: NewLoad.getValue(R: 0), Res1: NewLoad.getValue(R: 1));
7172 }
7173 }
7174
7175 return SDValue(N, 0); // Return N so it doesn't get rechecked!
7176 }
7177 }
7178
7179 // Try to convert a constant mask AND into a shuffle clear mask.
7180 if (VT.isVector())
7181 if (SDValue Shuffle = XformToShuffleWithZero(N))
7182 return Shuffle;
7183
7184 if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
7185 return Combined;
7186
7187 if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR && N0.hasOneUse() && N1C &&
7188 ISD::isExtOpcode(Opcode: N0.getOperand(i: 0).getOpcode())) {
7189 SDValue Ext = N0.getOperand(i: 0);
7190 EVT ExtVT = Ext->getValueType(ResNo: 0);
7191 SDValue Extendee = Ext->getOperand(Num: 0);
7192
7193 unsigned ScalarWidth = Extendee.getValueType().getScalarSizeInBits();
7194 if (N1C->getAPIntValue().isMask(numBits: ScalarWidth) &&
7195 (!LegalOperations || TLI.isOperationLegal(Op: ISD::ZERO_EXTEND, VT: ExtVT))) {
7196 // (and (extract_subvector (zext|anyext|sext v) _) iN_mask)
7197 // => (extract_subvector (iN_zeroext v))
7198 SDValue ZeroExtExtendee =
7199 DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: ExtVT, Operand: Extendee);
7200
7201 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT, N1: ZeroExtExtendee,
7202 N2: N0.getOperand(i: 1));
7203 }
7204 }
7205
7206 // fold (and (masked_gather x)) -> (zext_masked_gather x)
7207 if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(Val&: N0)) {
7208 EVT MemVT = GN0->getMemoryVT();
7209 EVT ScalarVT = MemVT.getScalarType();
7210
7211 if (SDValue(GN0, 0).hasOneUse() &&
7212 isConstantSplatVectorMaskForType(N: N1.getNode(), ScalarTy: ScalarVT) &&
7213 TLI.isVectorLoadExtDesirable(ExtVal: SDValue(SDValue(GN0, 0)))) {
7214 SDValue Ops[] = {GN0->getChain(), GN0->getPassThru(), GN0->getMask(),
7215 GN0->getBasePtr(), GN0->getIndex(), GN0->getScale()};
7216
7217 SDValue ZExtLoad = DAG.getMaskedGather(
7218 VTs: DAG.getVTList(VT1: VT, VT2: MVT::Other), MemVT, dl: DL, Ops, MMO: GN0->getMemOperand(),
7219 IndexType: GN0->getIndexType(), ExtTy: ISD::ZEXTLOAD);
7220
7221 CombineTo(N, Res: ZExtLoad);
7222 AddToWorklist(N: ZExtLoad.getNode());
7223 // Avoid recheck of N.
7224 return SDValue(N, 0);
7225 }
7226 }
7227
7228 // fold (and (load x), 255) -> (zextload x, i8)
7229 // fold (and (extload x, i16), 255) -> (zextload x, i8)
7230 if (N1C && N0.getOpcode() == ISD::LOAD && !VT.isVector())
7231 if (SDValue Res = reduceLoadWidth(N))
7232 return Res;
7233
7234 if (LegalTypes) {
7235 // Attempt to propagate the AND back up to the leaves which, if they're
7236 // loads, can be combined to narrow loads and the AND node can be removed.
7237 // Perform after legalization so that extend nodes will already be
7238 // combined into the loads.
7239 if (BackwardsPropagateMask(N))
7240 return SDValue(N, 0);
7241 }
7242
7243 if (SDValue Combined = visitANDLike(N0, N1, N))
7244 return Combined;
7245
7246 // Simplify: (and (op x...), (op y...)) -> (op (and x, y))
7247 if (N0.getOpcode() == N1.getOpcode())
7248 if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
7249 return V;
7250
7251 if (SDValue R = foldLogicOfShifts(N, LogicOp: N0, ShiftOp: N1, DAG))
7252 return R;
7253 if (SDValue R = foldLogicOfShifts(N, LogicOp: N1, ShiftOp: N0, DAG))
7254 return R;
7255
7256 // Masking the negated extension of a boolean is just the zero-extended
7257 // boolean:
7258 // and (sub 0, zext(bool X)), 1 --> zext(bool X)
7259 // and (sub 0, sext(bool X)), 1 --> zext(bool X)
7260 //
7261 // Note: the SimplifyDemandedBits fold below can make an information-losing
7262 // transform, and then we have no way to find this better fold.
7263 if (N1C && N1C->isOne() && N0.getOpcode() == ISD::SUB) {
7264 if (isNullOrNullSplat(V: N0.getOperand(i: 0))) {
7265 SDValue SubRHS = N0.getOperand(i: 1);
7266 if (SubRHS.getOpcode() == ISD::ZERO_EXTEND &&
7267 SubRHS.getOperand(i: 0).getScalarValueSizeInBits() == 1)
7268 return SubRHS;
7269 if (SubRHS.getOpcode() == ISD::SIGN_EXTEND &&
7270 SubRHS.getOperand(i: 0).getScalarValueSizeInBits() == 1)
7271 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: SubRHS.getOperand(i: 0));
7272 }
7273 }
7274
7275 // fold (and (sign_extend_inreg x, i16 to i32), 1) -> (and x, 1)
7276 // fold (and (sra)) -> (and (srl)) when possible.
7277 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
7278 return SDValue(N, 0);
7279
7280 // fold (zext_inreg (extload x)) -> (zextload x)
7281 // fold (zext_inreg (sextload x)) -> (zextload x) iff load has one use
7282 if (ISD::isUNINDEXEDLoad(N: N0.getNode()) &&
7283 (ISD::isEXTLoad(N: N0.getNode()) ||
7284 (ISD::isSEXTLoad(N: N0.getNode()) && N0.hasOneUse()))) {
7285 auto *LN0 = cast<LoadSDNode>(Val&: N0);
7286 EVT MemVT = LN0->getMemoryVT();
7287 // If we zero all the possible extended bits, then we can turn this into
7288 // a zextload if we are running before legalize or the operation is legal.
7289 unsigned ExtBitSize = N1.getScalarValueSizeInBits();
7290 unsigned MemBitSize = MemVT.getScalarSizeInBits();
7291 APInt ExtBits = APInt::getHighBitsSet(numBits: ExtBitSize, hiBitsSet: ExtBitSize - MemBitSize);
7292 if (DAG.MaskedValueIsZero(Op: N1, Mask: ExtBits) &&
7293 ((!LegalOperations && LN0->isSimple()) ||
7294 TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: VT, MemVT))) {
7295 SDValue ExtLoad =
7296 DAG.getExtLoad(ExtType: ISD::ZEXTLOAD, dl: SDLoc(N0), VT, Chain: LN0->getChain(),
7297 Ptr: LN0->getBasePtr(), MemVT, MMO: LN0->getMemOperand());
7298 AddToWorklist(N);
7299 CombineTo(N: N0.getNode(), Res0: ExtLoad, Res1: ExtLoad.getValue(R: 1));
7300 return SDValue(N, 0); // Return N so it doesn't get rechecked!
7301 }
7302 }
7303
7304 // fold (and (or (srl N, 8), (shl N, 8)), 0xffff) -> (srl (bswap N), const)
7305 if (N1C && N1C->getAPIntValue() == 0xffff && N0.getOpcode() == ISD::OR) {
7306 if (SDValue BSwap = MatchBSwapHWordLow(N: N0.getNode(), N0: N0.getOperand(i: 0),
7307 N1: N0.getOperand(i: 1), DemandHighBits: false))
7308 return BSwap;
7309 }
7310
7311 if (SDValue Shifts = unfoldExtremeBitClearingToShifts(N))
7312 return Shifts;
7313
7314 if (SDValue V = combineShiftAnd1ToBitTest(And: N, DAG))
7315 return V;
7316
7317 // Recognize the following pattern:
7318 //
7319 // AndVT = (and (sign_extend NarrowVT to AndVT) #bitmask)
7320 //
7321 // where bitmask is a mask that clears the upper bits of AndVT. The
7322 // number of bits in bitmask must be a power of two.
7323 auto IsAndZeroExtMask = [](SDValue LHS, SDValue RHS) {
7324 if (LHS->getOpcode() != ISD::SIGN_EXTEND)
7325 return false;
7326
7327 auto *C = dyn_cast<ConstantSDNode>(Val&: RHS);
7328 if (!C)
7329 return false;
7330
7331 if (!C->getAPIntValue().isMask(
7332 numBits: LHS.getOperand(i: 0).getValueType().getFixedSizeInBits()))
7333 return false;
7334
7335 return true;
7336 };
7337
7338 // Replace (and (sign_extend ...) #bitmask) with (zero_extend ...).
7339 if (IsAndZeroExtMask(N0, N1))
7340 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: N0.getOperand(i: 0));
7341
7342 if (hasOperation(Opcode: ISD::USUBSAT, VT))
7343 if (SDValue V = foldAndToUsubsat(N, DAG, DL))
7344 return V;
7345
7346 // Postpone until legalization completed to avoid interference with bswap
7347 // folding
7348 if (LegalOperations || VT.isVector())
7349 if (SDValue R = foldLogicTreeOfShifts(N, LeftHand: N0, RightHand: N1, DAG))
7350 return R;
7351
7352 return SDValue();
7353}
7354
7355/// Match (a >> 8) | (a << 8) as (bswap a) >> 16.
7356SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
7357 bool DemandHighBits) {
7358 if (!LegalOperations)
7359 return SDValue();
7360
7361 EVT VT = N->getValueType(ResNo: 0);
7362 if (VT != MVT::i64 && VT != MVT::i32 && VT != MVT::i16)
7363 return SDValue();
7364 if (!TLI.isOperationLegalOrCustom(Op: ISD::BSWAP, VT))
7365 return SDValue();
7366
7367 // Recognize (and (shl a, 8), 0xff00), (and (srl a, 8), 0xff)
7368 bool LookPassAnd0 = false;
7369 bool LookPassAnd1 = false;
7370 if (N0.getOpcode() == ISD::AND && N0.getOperand(i: 0).getOpcode() == ISD::SRL)
7371 std::swap(a&: N0, b&: N1);
7372 if (N1.getOpcode() == ISD::AND && N1.getOperand(i: 0).getOpcode() == ISD::SHL)
7373 std::swap(a&: N0, b&: N1);
7374 if (N0.getOpcode() == ISD::AND) {
7375 if (!N0->hasOneUse())
7376 return SDValue();
7377 ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
7378 // Also handle 0xffff since the LHS is guaranteed to have zeros there.
7379 // This is needed for X86.
7380 if (!N01C || (N01C->getZExtValue() != 0xFF00 &&
7381 N01C->getZExtValue() != 0xFFFF))
7382 return SDValue();
7383 N0 = N0.getOperand(i: 0);
7384 LookPassAnd0 = true;
7385 }
7386
7387 if (N1.getOpcode() == ISD::AND) {
7388 if (!N1->hasOneUse())
7389 return SDValue();
7390 ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(Val: N1.getOperand(i: 1));
7391 if (!N11C || N11C->getZExtValue() != 0xFF)
7392 return SDValue();
7393 N1 = N1.getOperand(i: 0);
7394 LookPassAnd1 = true;
7395 }
7396
7397 if (N0.getOpcode() == ISD::SRL && N1.getOpcode() == ISD::SHL)
7398 std::swap(a&: N0, b&: N1);
7399 if (N0.getOpcode() != ISD::SHL || N1.getOpcode() != ISD::SRL)
7400 return SDValue();
7401 if (!N0->hasOneUse() || !N1->hasOneUse())
7402 return SDValue();
7403
7404 ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
7405 ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(Val: N1.getOperand(i: 1));
7406 if (!N01C || !N11C)
7407 return SDValue();
7408 if (N01C->getZExtValue() != 8 || N11C->getZExtValue() != 8)
7409 return SDValue();
7410
7411 // Look for (shl (and a, 0xff), 8), (srl (and a, 0xff00), 8)
7412 SDValue N00 = N0->getOperand(Num: 0);
7413 if (!LookPassAnd0 && N00.getOpcode() == ISD::AND) {
7414 if (!N00->hasOneUse())
7415 return SDValue();
7416 ConstantSDNode *N001C = dyn_cast<ConstantSDNode>(Val: N00.getOperand(i: 1));
7417 if (!N001C || N001C->getZExtValue() != 0xFF)
7418 return SDValue();
7419 N00 = N00.getOperand(i: 0);
7420 LookPassAnd0 = true;
7421 }
7422
7423 SDValue N10 = N1->getOperand(Num: 0);
7424 if (!LookPassAnd1 && N10.getOpcode() == ISD::AND) {
7425 if (!N10->hasOneUse())
7426 return SDValue();
7427 ConstantSDNode *N101C = dyn_cast<ConstantSDNode>(Val: N10.getOperand(i: 1));
7428 // Also allow 0xFFFF since the bits will be shifted out. This is needed
7429 // for X86.
7430 if (!N101C || (N101C->getZExtValue() != 0xFF00 &&
7431 N101C->getZExtValue() != 0xFFFF))
7432 return SDValue();
7433 N10 = N10.getOperand(i: 0);
7434 LookPassAnd1 = true;
7435 }
7436
7437 if (N00 != N10)
7438 return SDValue();
7439
7440 // Make sure everything beyond the low halfword gets set to zero since the SRL
7441 // 16 will clear the top bits.
7442 unsigned OpSizeInBits = VT.getSizeInBits();
7443 if (OpSizeInBits > 16) {
7444 // If the left-shift isn't masked out then the only way this is a bswap is
7445 // if all bits beyond the low 8 are 0. In that case the entire pattern
7446 // reduces to a left shift anyway: leave it for other parts of the combiner.
7447 if (DemandHighBits && !LookPassAnd0)
7448 return SDValue();
7449
7450 // However, if the right shift isn't masked out then it might be because
7451 // it's not needed. See if we can spot that too. If the high bits aren't
7452 // demanded, we only need bits 23:16 to be zero. Otherwise, we need all
7453 // upper bits to be zero.
7454 if (!LookPassAnd1) {
7455 unsigned HighBit = DemandHighBits ? OpSizeInBits : 24;
7456 if (!DAG.MaskedValueIsZero(Op: N10,
7457 Mask: APInt::getBitsSet(numBits: OpSizeInBits, loBit: 16, hiBit: HighBit)))
7458 return SDValue();
7459 }
7460 }
7461
7462 SDValue Res = DAG.getNode(Opcode: ISD::BSWAP, DL: SDLoc(N), VT, Operand: N00);
7463 if (OpSizeInBits > 16) {
7464 SDLoc DL(N);
7465 Res = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Res,
7466 N2: DAG.getShiftAmountConstant(Val: OpSizeInBits - 16, VT, DL));
7467 }
7468 return Res;
7469}
7470
7471/// Return true if the specified node is an element that makes up a 32-bit
7472/// packed halfword byteswap.
7473/// ((x & 0x000000ff) << 8) |
7474/// ((x & 0x0000ff00) >> 8) |
7475/// ((x & 0x00ff0000) << 8) |
7476/// ((x & 0xff000000) >> 8)
7477static bool isBSwapHWordElement(SDValue N, MutableArrayRef<SDNode *> Parts) {
7478 if (!N->hasOneUse())
7479 return false;
7480
7481 unsigned Opc = N.getOpcode();
7482 if (Opc != ISD::AND && Opc != ISD::SHL && Opc != ISD::SRL)
7483 return false;
7484
7485 SDValue N0 = N.getOperand(i: 0);
7486 unsigned Opc0 = N0.getOpcode();
7487 if (Opc0 != ISD::AND && Opc0 != ISD::SHL && Opc0 != ISD::SRL)
7488 return false;
7489
7490 ConstantSDNode *N1C = nullptr;
7491 // SHL or SRL: look upstream for AND mask operand
7492 if (Opc == ISD::AND)
7493 N1C = dyn_cast<ConstantSDNode>(Val: N.getOperand(i: 1));
7494 else if (Opc0 == ISD::AND)
7495 N1C = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
7496 if (!N1C)
7497 return false;
7498
7499 unsigned MaskByteOffset;
7500 switch (N1C->getZExtValue()) {
7501 default:
7502 return false;
7503 case 0xFF: MaskByteOffset = 0; break;
7504 case 0xFF00: MaskByteOffset = 1; break;
7505 case 0xFFFF:
7506 // In case demanded bits didn't clear the bits that will be shifted out.
7507 // This is needed for X86.
7508 if (Opc == ISD::SRL || (Opc == ISD::AND && Opc0 == ISD::SHL)) {
7509 MaskByteOffset = 1;
7510 break;
7511 }
7512 return false;
7513 case 0xFF0000: MaskByteOffset = 2; break;
7514 case 0xFF000000: MaskByteOffset = 3; break;
7515 }
7516
7517 // Look for (x & 0xff) << 8 as well as ((x << 8) & 0xff00).
7518 if (Opc == ISD::AND) {
7519 if (MaskByteOffset == 0 || MaskByteOffset == 2) {
7520 // (x >> 8) & 0xff
7521 // (x >> 8) & 0xff0000
7522 if (Opc0 != ISD::SRL)
7523 return false;
7524 ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
7525 if (!C || C->getZExtValue() != 8)
7526 return false;
7527 } else {
7528 // (x << 8) & 0xff00
7529 // (x << 8) & 0xff000000
7530 if (Opc0 != ISD::SHL)
7531 return false;
7532 ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
7533 if (!C || C->getZExtValue() != 8)
7534 return false;
7535 }
7536 } else if (Opc == ISD::SHL) {
7537 // (x & 0xff) << 8
7538 // (x & 0xff0000) << 8
7539 if (MaskByteOffset != 0 && MaskByteOffset != 2)
7540 return false;
7541 ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val: N.getOperand(i: 1));
7542 if (!C || C->getZExtValue() != 8)
7543 return false;
7544 } else { // Opc == ISD::SRL
7545 // (x & 0xff00) >> 8
7546 // (x & 0xff000000) >> 8
7547 if (MaskByteOffset != 1 && MaskByteOffset != 3)
7548 return false;
7549 ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val: N.getOperand(i: 1));
7550 if (!C || C->getZExtValue() != 8)
7551 return false;
7552 }
7553
7554 if (Parts[MaskByteOffset])
7555 return false;
7556
7557 Parts[MaskByteOffset] = N0.getOperand(i: 0).getNode();
7558 return true;
7559}
7560
7561// Match 2 elements of a packed halfword bswap.
7562static bool isBSwapHWordPair(SDValue N, MutableArrayRef<SDNode *> Parts) {
7563 if (N.getOpcode() == ISD::OR)
7564 return isBSwapHWordElement(N: N.getOperand(i: 0), Parts) &&
7565 isBSwapHWordElement(N: N.getOperand(i: 1), Parts);
7566
7567 if (N.getOpcode() == ISD::SRL && N.getOperand(i: 0).getOpcode() == ISD::BSWAP) {
7568 ConstantSDNode *C = isConstOrConstSplat(N: N.getOperand(i: 1));
7569 if (!C || C->getAPIntValue() != 16)
7570 return false;
7571 Parts[0] = Parts[1] = N.getOperand(i: 0).getOperand(i: 0).getNode();
7572 return true;
7573 }
7574
7575 return false;
7576}
7577
7578// Match this pattern:
7579// (or (and (shl (A, 8)), 0xff00ff00), (and (srl (A, 8)), 0x00ff00ff))
7580// And rewrite this to:
7581// (rotr (bswap A), 16)
7582static SDValue matchBSwapHWordOrAndAnd(const TargetLowering &TLI,
7583 SelectionDAG &DAG, SDNode *N, SDValue N0,
7584 SDValue N1, EVT VT) {
7585 assert(N->getOpcode() == ISD::OR && VT == MVT::i32 &&
7586 "MatchBSwapHWordOrAndAnd: expecting i32");
7587 if (!TLI.isOperationLegalOrCustom(Op: ISD::ROTR, VT))
7588 return SDValue();
7589 if (N0.getOpcode() != ISD::AND || N1.getOpcode() != ISD::AND)
7590 return SDValue();
7591 // TODO: this is too restrictive; lifting this restriction requires more tests
7592 if (!N0->hasOneUse() || !N1->hasOneUse())
7593 return SDValue();
7594 ConstantSDNode *Mask0 = isConstOrConstSplat(N: N0.getOperand(i: 1));
7595 ConstantSDNode *Mask1 = isConstOrConstSplat(N: N1.getOperand(i: 1));
7596 if (!Mask0 || !Mask1)
7597 return SDValue();
7598 if (Mask0->getAPIntValue() != 0xff00ff00 ||
7599 Mask1->getAPIntValue() != 0x00ff00ff)
7600 return SDValue();
7601 SDValue Shift0 = N0.getOperand(i: 0);
7602 SDValue Shift1 = N1.getOperand(i: 0);
7603 if (Shift0.getOpcode() != ISD::SHL || Shift1.getOpcode() != ISD::SRL)
7604 return SDValue();
7605 ConstantSDNode *ShiftAmt0 = isConstOrConstSplat(N: Shift0.getOperand(i: 1));
7606 ConstantSDNode *ShiftAmt1 = isConstOrConstSplat(N: Shift1.getOperand(i: 1));
7607 if (!ShiftAmt0 || !ShiftAmt1)
7608 return SDValue();
7609 if (ShiftAmt0->getAPIntValue() != 8 || ShiftAmt1->getAPIntValue() != 8)
7610 return SDValue();
7611 if (Shift0.getOperand(i: 0) != Shift1.getOperand(i: 0))
7612 return SDValue();
7613
7614 SDLoc DL(N);
7615 SDValue BSwap = DAG.getNode(Opcode: ISD::BSWAP, DL, VT, Operand: Shift0.getOperand(i: 0));
7616 SDValue ShAmt = DAG.getShiftAmountConstant(Val: 16, VT, DL);
7617 return DAG.getNode(Opcode: ISD::ROTR, DL, VT, N1: BSwap, N2: ShAmt);
7618}
7619
7620/// Match a 32-bit packed halfword bswap. That is
7621/// ((x & 0x000000ff) << 8) |
7622/// ((x & 0x0000ff00) >> 8) |
7623/// ((x & 0x00ff0000) << 8) |
7624/// ((x & 0xff000000) >> 8)
7625/// => (rotl (bswap x), 16)
7626SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
7627 if (!LegalOperations)
7628 return SDValue();
7629
7630 EVT VT = N->getValueType(ResNo: 0);
7631 if (VT != MVT::i32)
7632 return SDValue();
7633 if (!TLI.isOperationLegalOrCustom(Op: ISD::BSWAP, VT))
7634 return SDValue();
7635
7636 if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N0, N1, VT))
7637 return BSwap;
7638
7639 // Try again with commuted operands.
7640 if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N0: N1, N1: N0, VT))
7641 return BSwap;
7642
7643
7644 // Look for either
7645 // (or (bswaphpair), (bswaphpair))
7646 // (or (or (bswaphpair), (and)), (and))
7647 // (or (or (and), (bswaphpair)), (and))
7648 SDNode *Parts[4] = {};
7649
7650 if (isBSwapHWordPair(N: N0, Parts)) {
7651 // (or (or (and), (and)), (or (and), (and)))
7652 if (!isBSwapHWordPair(N: N1, Parts))
7653 return SDValue();
7654 } else if (N0.getOpcode() == ISD::OR) {
7655 // (or (or (or (and), (and)), (and)), (and))
7656 if (!isBSwapHWordElement(N: N1, Parts))
7657 return SDValue();
7658 SDValue N00 = N0.getOperand(i: 0);
7659 SDValue N01 = N0.getOperand(i: 1);
7660 if (!(isBSwapHWordElement(N: N01, Parts) && isBSwapHWordPair(N: N00, Parts)) &&
7661 !(isBSwapHWordElement(N: N00, Parts) && isBSwapHWordPair(N: N01, Parts)))
7662 return SDValue();
7663 } else {
7664 return SDValue();
7665 }
7666
7667 // Make sure the parts are all coming from the same node.
7668 if (Parts[0] != Parts[1] || Parts[0] != Parts[2] || Parts[0] != Parts[3])
7669 return SDValue();
7670
7671 SDLoc DL(N);
7672 SDValue BSwap = DAG.getNode(Opcode: ISD::BSWAP, DL, VT,
7673 Operand: SDValue(Parts[0], 0));
7674
7675 // Result of the bswap should be rotated by 16. If it's not legal, then
7676 // do (x << 16) | (x >> 16).
7677 SDValue ShAmt = DAG.getShiftAmountConstant(Val: 16, VT, DL);
7678 if (TLI.isOperationLegalOrCustom(Op: ISD::ROTL, VT))
7679 return DAG.getNode(Opcode: ISD::ROTL, DL, VT, N1: BSwap, N2: ShAmt);
7680 if (TLI.isOperationLegalOrCustom(Op: ISD::ROTR, VT))
7681 return DAG.getNode(Opcode: ISD::ROTR, DL, VT, N1: BSwap, N2: ShAmt);
7682 return DAG.getNode(Opcode: ISD::OR, DL, VT,
7683 N1: DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: BSwap, N2: ShAmt),
7684 N2: DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: BSwap, N2: ShAmt));
7685}
7686
7687/// This contains all DAGCombine rules which reduce two values combined by
7688/// an Or operation to a single value \see visitANDLike().
7689SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, const SDLoc &DL) {
7690 EVT VT = N1.getValueType();
7691
7692 // fold (or x, undef) -> -1
7693 if (!LegalOperations && (N0.isUndef() || N1.isUndef()))
7694 return DAG.getAllOnesConstant(DL, VT);
7695
7696 if (SDValue V = foldLogicOfSetCCs(IsAnd: false, N0, N1, DL))
7697 return V;
7698
7699 // (or (and X, C1), (and Y, C2)) -> (and (or X, Y), C3) if possible.
7700 if (N0.getOpcode() == ISD::AND && N1.getOpcode() == ISD::AND &&
7701 // Don't increase # computations.
7702 (N0->hasOneUse() || N1->hasOneUse())) {
7703 // We can only do this xform if we know that bits from X that are set in C2
7704 // but not in C1 are already zero. Likewise for Y.
7705 if (const ConstantSDNode *N0O1C =
7706 getAsNonOpaqueConstant(N: N0.getOperand(i: 1))) {
7707 if (const ConstantSDNode *N1O1C =
7708 getAsNonOpaqueConstant(N: N1.getOperand(i: 1))) {
7709 // We can only do this xform if we know that bits from X that are set in
7710 // C2 but not in C1 are already zero. Likewise for Y.
7711 const APInt &LHSMask = N0O1C->getAPIntValue();
7712 const APInt &RHSMask = N1O1C->getAPIntValue();
7713
7714 if (DAG.MaskedValueIsZero(Op: N0.getOperand(i: 0), Mask: RHSMask&~LHSMask) &&
7715 DAG.MaskedValueIsZero(Op: N1.getOperand(i: 0), Mask: LHSMask&~RHSMask)) {
7716 SDValue X = DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N0), VT,
7717 N1: N0.getOperand(i: 0), N2: N1.getOperand(i: 0));
7718 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: X,
7719 N2: DAG.getConstant(Val: LHSMask | RHSMask, DL, VT));
7720 }
7721 }
7722 }
7723 }
7724
7725 // (or (and X, M), (and X, N)) -> (and X, (or M, N))
7726 if (N0.getOpcode() == ISD::AND &&
7727 N1.getOpcode() == ISD::AND &&
7728 N0.getOperand(i: 0) == N1.getOperand(i: 0) &&
7729 // Don't increase # computations.
7730 (N0->hasOneUse() || N1->hasOneUse())) {
7731 SDValue X = DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N0), VT,
7732 N1: N0.getOperand(i: 1), N2: N1.getOperand(i: 1));
7733 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0.getOperand(i: 0), N2: X);
7734 }
7735
7736 return SDValue();
7737}
7738
7739/// OR combines for which the commuted variant will be tried as well.
7740static SDValue visitORCommutative(SelectionDAG &DAG, SDValue N0, SDValue N1,
7741 SDNode *N) {
7742 EVT VT = N0.getValueType();
7743 unsigned BW = VT.getScalarSizeInBits();
7744 SDLoc DL(N);
7745
7746 auto peekThroughResize = [](SDValue V) {
7747 if (V->getOpcode() == ISD::ZERO_EXTEND || V->getOpcode() == ISD::TRUNCATE)
7748 return V->getOperand(Num: 0);
7749 return V;
7750 };
7751
7752 SDValue N0Resized = peekThroughResize(N0);
7753 if (N0Resized.getOpcode() == ISD::AND) {
7754 SDValue N1Resized = peekThroughResize(N1);
7755 SDValue N00 = N0Resized.getOperand(i: 0);
7756 SDValue N01 = N0Resized.getOperand(i: 1);
7757
7758 // fold or (and x, y), x --> x
7759 if (N00 == N1Resized || N01 == N1Resized)
7760 return N1;
7761
7762 // fold (or (and X, (xor Y, -1)), Y) -> (or X, Y)
7763 // TODO: Set AllowUndefs = true.
7764 if (SDValue NotOperand = getBitwiseNotOperand(V: N01, Mask: N00,
7765 /* AllowUndefs */ false)) {
7766 if (peekThroughResize(NotOperand) == N1Resized)
7767 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: DAG.getZExtOrTrunc(Op: N00, DL, VT),
7768 N2: N1);
7769 }
7770
7771 // fold (or (and (xor Y, -1), X), Y) -> (or X, Y)
7772 if (SDValue NotOperand = getBitwiseNotOperand(V: N00, Mask: N01,
7773 /* AllowUndefs */ false)) {
7774 if (peekThroughResize(NotOperand) == N1Resized)
7775 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: DAG.getZExtOrTrunc(Op: N01, DL, VT),
7776 N2: N1);
7777 }
7778 }
7779
7780 SDValue X, Y;
7781
7782 // fold or (xor X, N1), N1 --> or X, N1
7783 if (sd_match(N: N0, P: m_Xor(L: m_Value(N&: X), R: m_Specific(N: N1))))
7784 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: X, N2: N1);
7785
7786 // fold or (xor x, y), (x and/or y) --> or x, y
7787 if (sd_match(N: N0, P: m_Xor(L: m_Value(N&: X), R: m_Value(N&: Y))) &&
7788 (sd_match(N: N1, P: m_And(L: m_Specific(N: X), R: m_Specific(N: Y))) ||
7789 sd_match(N: N1, P: m_Or(L: m_Specific(N: X), R: m_Specific(N: Y)))))
7790 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: X, N2: Y);
7791
7792 if (SDValue R = foldLogicOfShifts(N, LogicOp: N0, ShiftOp: N1, DAG))
7793 return R;
7794
7795 auto peekThroughZext = [](SDValue V) {
7796 if (V->getOpcode() == ISD::ZERO_EXTEND)
7797 return V->getOperand(Num: 0);
7798 return V;
7799 };
7800
7801 // (fshl X, ?, Y) | (shl X, Y) --> fshl X, ?, Y
7802 if (N0.getOpcode() == ISD::FSHL && N1.getOpcode() == ISD::SHL &&
7803 N0.getOperand(i: 0) == N1.getOperand(i: 0) &&
7804 peekThroughZext(N0.getOperand(i: 2)) == peekThroughZext(N1.getOperand(i: 1)))
7805 return N0;
7806
7807 // (fshr ?, X, Y) | (srl X, Y) --> fshr ?, X, Y
7808 if (N0.getOpcode() == ISD::FSHR && N1.getOpcode() == ISD::SRL &&
7809 N0.getOperand(i: 1) == N1.getOperand(i: 0) &&
7810 peekThroughZext(N0.getOperand(i: 2)) == peekThroughZext(N1.getOperand(i: 1)))
7811 return N0;
7812
7813 // Attempt to match a legalized build_pair-esque pattern:
7814 // or(shl(aext(Hi),BW/2),zext(Lo))
7815 SDValue Lo, Hi;
7816 if (sd_match(N: N0,
7817 P: m_OneUse(P: m_Shl(L: m_AnyExt(Op: m_Value(N&: Hi)), R: m_SpecificInt(V: BW / 2)))) &&
7818 sd_match(N: N1, P: m_ZExt(Op: m_Value(N&: Lo))) &&
7819 Lo.getScalarValueSizeInBits() == (BW / 2) &&
7820 Lo.getValueType() == Hi.getValueType()) {
7821 // Fold build_pair(not(Lo),not(Hi)) -> not(build_pair(Lo,Hi)).
7822 SDValue NotLo, NotHi;
7823 if (sd_match(N: Lo, P: m_OneUse(P: m_Not(V: m_Value(N&: NotLo)))) &&
7824 sd_match(N: Hi, P: m_OneUse(P: m_Not(V: m_Value(N&: NotHi))))) {
7825 Lo = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: NotLo);
7826 Hi = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT, Operand: NotHi);
7827 Hi = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Hi,
7828 N2: DAG.getShiftAmountConstant(Val: BW / 2, VT, DL));
7829 return DAG.getNOT(DL, Val: DAG.getNode(Opcode: ISD::OR, DL, VT, N1: Lo, N2: Hi), VT);
7830 }
7831 }
7832
7833 return SDValue();
7834}
7835
7836SDValue DAGCombiner::visitOR(SDNode *N) {
7837 SDValue N0 = N->getOperand(Num: 0);
7838 SDValue N1 = N->getOperand(Num: 1);
7839 EVT VT = N1.getValueType();
7840 SDLoc DL(N);
7841
7842 // x | x --> x
7843 if (N0 == N1)
7844 return N0;
7845
7846 // fold (or c1, c2) -> c1|c2
7847 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::OR, DL, VT, Ops: {N0, N1}))
7848 return C;
7849
7850 // canonicalize constant to RHS
7851 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
7852 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
7853 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1, N2: N0);
7854
7855 // fold vector ops
7856 if (VT.isVector()) {
7857 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
7858 return FoldedVOp;
7859
7860 // fold (or x, 0) -> x, vector edition
7861 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
7862 return N0;
7863
7864 // fold (or x, -1) -> -1, vector edition
7865 if (ISD::isConstantSplatVectorAllOnes(N: N1.getNode()))
7866 // do not return N1, because undef node may exist in N1
7867 return DAG.getAllOnesConstant(DL, VT: N1.getValueType());
7868
7869 // fold (or (shuf A, V_0, MA), (shuf B, V_0, MB)) -> (shuf A, B, Mask)
7870 // Do this only if the resulting type / shuffle is legal.
7871 auto *SV0 = dyn_cast<ShuffleVectorSDNode>(Val&: N0);
7872 auto *SV1 = dyn_cast<ShuffleVectorSDNode>(Val&: N1);
7873 if (SV0 && SV1 && TLI.isTypeLegal(VT)) {
7874 bool ZeroN00 = ISD::isBuildVectorAllZeros(N: N0.getOperand(i: 0).getNode());
7875 bool ZeroN01 = ISD::isBuildVectorAllZeros(N: N0.getOperand(i: 1).getNode());
7876 bool ZeroN10 = ISD::isBuildVectorAllZeros(N: N1.getOperand(i: 0).getNode());
7877 bool ZeroN11 = ISD::isBuildVectorAllZeros(N: N1.getOperand(i: 1).getNode());
7878 // Ensure both shuffles have a zero input.
7879 if ((ZeroN00 != ZeroN01) && (ZeroN10 != ZeroN11)) {
7880 assert((!ZeroN00 || !ZeroN01) && "Both inputs zero!");
7881 assert((!ZeroN10 || !ZeroN11) && "Both inputs zero!");
7882 bool CanFold = true;
7883 int NumElts = VT.getVectorNumElements();
7884 SmallVector<int, 4> Mask(NumElts, -1);
7885
7886 for (int i = 0; i != NumElts; ++i) {
7887 int M0 = SV0->getMaskElt(Idx: i);
7888 int M1 = SV1->getMaskElt(Idx: i);
7889
7890 // Determine if either index is pointing to a zero vector.
7891 bool M0Zero = M0 < 0 || (ZeroN00 == (M0 < NumElts));
7892 bool M1Zero = M1 < 0 || (ZeroN10 == (M1 < NumElts));
7893
7894 // If one element is zero and the otherside is undef, keep undef.
7895 // This also handles the case that both are undef.
7896 if ((M0Zero && M1 < 0) || (M1Zero && M0 < 0))
7897 continue;
7898
7899 // Make sure only one of the elements is zero.
7900 if (M0Zero == M1Zero) {
7901 CanFold = false;
7902 break;
7903 }
7904
7905 assert((M0 >= 0 || M1 >= 0) && "Undef index!");
7906
7907 // We have a zero and non-zero element. If the non-zero came from
7908 // SV0 make the index a LHS index. If it came from SV1, make it
7909 // a RHS index. We need to mod by NumElts because we don't care
7910 // which operand it came from in the original shuffles.
7911 Mask[i] = M1Zero ? M0 % NumElts : (M1 % NumElts) + NumElts;
7912 }
7913
7914 if (CanFold) {
7915 SDValue NewLHS = ZeroN00 ? N0.getOperand(i: 1) : N0.getOperand(i: 0);
7916 SDValue NewRHS = ZeroN10 ? N1.getOperand(i: 1) : N1.getOperand(i: 0);
7917 SDValue LegalShuffle =
7918 TLI.buildLegalVectorShuffle(VT, DL, N0: NewLHS, N1: NewRHS, Mask, DAG);
7919 if (LegalShuffle)
7920 return LegalShuffle;
7921 }
7922 }
7923 }
7924 }
7925
7926 // fold (or x, 0) -> x
7927 if (isNullConstant(V: N1))
7928 return N0;
7929
7930 // fold (or x, -1) -> -1
7931 if (isAllOnesConstant(V: N1))
7932 return N1;
7933
7934 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
7935 return NewSel;
7936
7937 // fold (or x, c) -> c iff (x & ~c) == 0
7938 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(Val&: N1);
7939 if (N1C && DAG.MaskedValueIsZero(Op: N0, Mask: ~N1C->getAPIntValue()))
7940 return N1;
7941
7942 if (SDValue R = foldAndOrOfSETCC(LogicOp: N, DAG))
7943 return R;
7944
7945 if (SDValue Combined = visitORLike(N0, N1, DL))
7946 return Combined;
7947
7948 if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
7949 return Combined;
7950
7951 // Recognize halfword bswaps as (bswap + rotl 16) or (bswap + shl 16)
7952 if (SDValue BSwap = MatchBSwapHWord(N, N0, N1))
7953 return BSwap;
7954 if (SDValue BSwap = MatchBSwapHWordLow(N, N0, N1))
7955 return BSwap;
7956
7957 // reassociate or
7958 if (SDValue ROR = reassociateOps(Opc: ISD::OR, DL, N0, N1, Flags: N->getFlags()))
7959 return ROR;
7960
7961 // Fold or(vecreduce(x), vecreduce(y)) -> vecreduce(or(x, y))
7962 if (SDValue SD =
7963 reassociateReduction(RedOpc: ISD::VECREDUCE_OR, Opc: ISD::OR, DL, VT, N0, N1))
7964 return SD;
7965
7966 // Canonicalize (or (and X, c1), c2) -> (and (or X, c2), c1|c2)
7967 // iff (c1 & c2) != 0 or c1/c2 are undef.
7968 auto MatchIntersect = [](ConstantSDNode *C1, ConstantSDNode *C2) {
7969 return !C1 || !C2 || C1->getAPIntValue().intersects(RHS: C2->getAPIntValue());
7970 };
7971 if (N0.getOpcode() == ISD::AND && N0->hasOneUse() &&
7972 ISD::matchBinaryPredicate(LHS: N0.getOperand(i: 1), RHS: N1, Match: MatchIntersect, AllowUndefs: true)) {
7973 if (SDValue COR = DAG.FoldConstantArithmetic(Opcode: ISD::OR, DL: SDLoc(N1), VT,
7974 Ops: {N1, N0.getOperand(i: 1)})) {
7975 SDValue IOR = DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N0), VT, N1: N0.getOperand(i: 0), N2: N1);
7976 AddToWorklist(N: IOR.getNode());
7977 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: COR, N2: IOR);
7978 }
7979 }
7980
7981 if (SDValue Combined = visitORCommutative(DAG, N0, N1, N))
7982 return Combined;
7983 if (SDValue Combined = visitORCommutative(DAG, N0: N1, N1: N0, N))
7984 return Combined;
7985
7986 // Simplify: (or (op x...), (op y...)) -> (op (or x, y))
7987 if (N0.getOpcode() == N1.getOpcode())
7988 if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
7989 return V;
7990
7991 // See if this is some rotate idiom.
7992 if (SDValue Rot = MatchRotate(LHS: N0, RHS: N1, DL))
7993 return Rot;
7994
7995 if (SDValue Load = MatchLoadCombine(N))
7996 return Load;
7997
7998 // Simplify the operands using demanded-bits information.
7999 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
8000 return SDValue(N, 0);
8001
8002 // If OR can be rewritten into ADD, try combines based on ADD.
8003 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::ADD, VT)) &&
8004 DAG.isADDLike(Op: SDValue(N, 0)))
8005 if (SDValue Combined = visitADDLike(N))
8006 return Combined;
8007
8008 // Postpone until legalization completed to avoid interference with bswap
8009 // folding
8010 if (LegalOperations || VT.isVector())
8011 if (SDValue R = foldLogicTreeOfShifts(N, LeftHand: N0, RightHand: N1, DAG))
8012 return R;
8013
8014 return SDValue();
8015}
8016
8017static SDValue stripConstantMask(const SelectionDAG &DAG, SDValue Op,
8018 SDValue &Mask) {
8019 if (Op.getOpcode() == ISD::AND &&
8020 DAG.isConstantIntBuildVectorOrConstantInt(N: Op.getOperand(i: 1))) {
8021 Mask = Op.getOperand(i: 1);
8022 return Op.getOperand(i: 0);
8023 }
8024 return Op;
8025}
8026
8027/// Match "(X shl/srl V1) & V2" where V2 may not be present.
8028static bool matchRotateHalf(const SelectionDAG &DAG, SDValue Op, SDValue &Shift,
8029 SDValue &Mask) {
8030 Op = stripConstantMask(DAG, Op, Mask);
8031 if (Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SHL) {
8032 Shift = Op;
8033 return true;
8034 }
8035 return false;
8036}
8037
8038/// Helper function for visitOR to extract the needed side of a rotate idiom
8039/// from a shl/srl/mul/udiv. This is meant to handle cases where
8040/// InstCombine merged some outside op with one of the shifts from
8041/// the rotate pattern.
8042/// \returns An empty \c SDValue if the needed shift couldn't be extracted.
8043/// Otherwise, returns an expansion of \p ExtractFrom based on the following
8044/// patterns:
8045///
8046/// (or (add v v) (shrl v bitwidth-1)):
8047/// expands (add v v) -> (shl v 1)
8048///
8049/// (or (mul v c0) (shrl (mul v c1) c2)):
8050/// expands (mul v c0) -> (shl (mul v c1) c3)
8051///
8052/// (or (udiv v c0) (shl (udiv v c1) c2)):
8053/// expands (udiv v c0) -> (shrl (udiv v c1) c3)
8054///
8055/// (or (shl v c0) (shrl (shl v c1) c2)):
8056/// expands (shl v c0) -> (shl (shl v c1) c3)
8057///
8058/// (or (shrl v c0) (shl (shrl v c1) c2)):
8059/// expands (shrl v c0) -> (shrl (shrl v c1) c3)
8060///
8061/// Such that in all cases, c3+c2==bitwidth(op v c1).
8062static SDValue extractShiftForRotate(SelectionDAG &DAG, SDValue OppShift,
8063 SDValue ExtractFrom, SDValue &Mask,
8064 const SDLoc &DL) {
8065 assert(OppShift && ExtractFrom && "Empty SDValue");
8066 if (OppShift.getOpcode() != ISD::SHL && OppShift.getOpcode() != ISD::SRL)
8067 return SDValue();
8068
8069 ExtractFrom = stripConstantMask(DAG, Op: ExtractFrom, Mask);
8070
8071 // Value and Type of the shift.
8072 SDValue OppShiftLHS = OppShift.getOperand(i: 0);
8073 EVT ShiftedVT = OppShiftLHS.getValueType();
8074
8075 // Amount of the existing shift.
8076 ConstantSDNode *OppShiftCst = isConstOrConstSplat(N: OppShift.getOperand(i: 1));
8077
8078 // (add v v) -> (shl v 1)
8079 // TODO: Should this be a general DAG canonicalization?
8080 if (OppShift.getOpcode() == ISD::SRL && OppShiftCst &&
8081 ExtractFrom.getOpcode() == ISD::ADD &&
8082 ExtractFrom.getOperand(i: 0) == ExtractFrom.getOperand(i: 1) &&
8083 ExtractFrom.getOperand(i: 0) == OppShiftLHS &&
8084 OppShiftCst->getAPIntValue() == ShiftedVT.getScalarSizeInBits() - 1)
8085 return DAG.getNode(Opcode: ISD::SHL, DL, VT: ShiftedVT, N1: OppShiftLHS,
8086 N2: DAG.getShiftAmountConstant(Val: 1, VT: ShiftedVT, DL));
8087
8088 // Preconditions:
8089 // (or (op0 v c0) (shiftl/r (op0 v c1) c2))
8090 //
8091 // Find opcode of the needed shift to be extracted from (op0 v c0).
8092 unsigned Opcode = ISD::DELETED_NODE;
8093 bool IsMulOrDiv = false;
8094 // Set Opcode and IsMulOrDiv if the extract opcode matches the needed shift
8095 // opcode or its arithmetic (mul or udiv) variant.
8096 auto SelectOpcode = [&](unsigned NeededShift, unsigned MulOrDivVariant) {
8097 IsMulOrDiv = ExtractFrom.getOpcode() == MulOrDivVariant;
8098 if (!IsMulOrDiv && ExtractFrom.getOpcode() != NeededShift)
8099 return false;
8100 Opcode = NeededShift;
8101 return true;
8102 };
8103 // op0 must be either the needed shift opcode or the mul/udiv equivalent
8104 // that the needed shift can be extracted from.
8105 if ((OppShift.getOpcode() != ISD::SRL || !SelectOpcode(ISD::SHL, ISD::MUL)) &&
8106 (OppShift.getOpcode() != ISD::SHL || !SelectOpcode(ISD::SRL, ISD::UDIV)))
8107 return SDValue();
8108
8109 // op0 must be the same opcode on both sides, have the same LHS argument,
8110 // and produce the same value type.
8111 if (OppShiftLHS.getOpcode() != ExtractFrom.getOpcode() ||
8112 OppShiftLHS.getOperand(i: 0) != ExtractFrom.getOperand(i: 0) ||
8113 ShiftedVT != ExtractFrom.getValueType())
8114 return SDValue();
8115
8116 // Constant mul/udiv/shift amount from the RHS of the shift's LHS op.
8117 ConstantSDNode *OppLHSCst = isConstOrConstSplat(N: OppShiftLHS.getOperand(i: 1));
8118 // Constant mul/udiv/shift amount from the RHS of the ExtractFrom op.
8119 ConstantSDNode *ExtractFromCst =
8120 isConstOrConstSplat(N: ExtractFrom.getOperand(i: 1));
8121 // TODO: We should be able to handle non-uniform constant vectors for these values
8122 // Check that we have constant values.
8123 if (!OppShiftCst || !OppShiftCst->getAPIntValue() ||
8124 !OppLHSCst || !OppLHSCst->getAPIntValue() ||
8125 !ExtractFromCst || !ExtractFromCst->getAPIntValue())
8126 return SDValue();
8127
8128 // Compute the shift amount we need to extract to complete the rotate.
8129 const unsigned VTWidth = ShiftedVT.getScalarSizeInBits();
8130 if (OppShiftCst->getAPIntValue().ugt(RHS: VTWidth))
8131 return SDValue();
8132 APInt NeededShiftAmt = VTWidth - OppShiftCst->getAPIntValue();
8133 // Normalize the bitwidth of the two mul/udiv/shift constant operands.
8134 APInt ExtractFromAmt = ExtractFromCst->getAPIntValue();
8135 APInt OppLHSAmt = OppLHSCst->getAPIntValue();
8136 zeroExtendToMatch(LHS&: ExtractFromAmt, RHS&: OppLHSAmt);
8137
8138 // Now try extract the needed shift from the ExtractFrom op and see if the
8139 // result matches up with the existing shift's LHS op.
8140 if (IsMulOrDiv) {
8141 // Op to extract from is a mul or udiv by a constant.
8142 // Check:
8143 // c2 / (1 << (bitwidth(op0 v c0) - c1)) == c0
8144 // c2 % (1 << (bitwidth(op0 v c0) - c1)) == 0
8145 const APInt ExtractDiv = APInt::getOneBitSet(numBits: ExtractFromAmt.getBitWidth(),
8146 BitNo: NeededShiftAmt.getZExtValue());
8147 APInt ResultAmt;
8148 APInt Rem;
8149 APInt::udivrem(LHS: ExtractFromAmt, RHS: ExtractDiv, Quotient&: ResultAmt, Remainder&: Rem);
8150 if (Rem != 0 || ResultAmt != OppLHSAmt)
8151 return SDValue();
8152 } else {
8153 // Op to extract from is a shift by a constant.
8154 // Check:
8155 // c2 - (bitwidth(op0 v c0) - c1) == c0
8156 if (OppLHSAmt != ExtractFromAmt - NeededShiftAmt.zextOrTrunc(
8157 width: ExtractFromAmt.getBitWidth()))
8158 return SDValue();
8159 }
8160
8161 // Return the expanded shift op that should allow a rotate to be formed.
8162 EVT ShiftVT = OppShift.getOperand(i: 1).getValueType();
8163 EVT ResVT = ExtractFrom.getValueType();
8164 SDValue NewShiftNode = DAG.getConstant(Val: NeededShiftAmt, DL, VT: ShiftVT);
8165 return DAG.getNode(Opcode, DL, VT: ResVT, N1: OppShiftLHS, N2: NewShiftNode);
8166}
8167
8168// Return true if we can prove that, whenever Neg and Pos are both in the
8169// range [0, EltSize), Neg == (Pos == 0 ? 0 : EltSize - Pos). This means that
8170// for two opposing shifts shift1 and shift2 and a value X with OpBits bits:
8171//
8172// (or (shift1 X, Neg), (shift2 X, Pos))
8173//
8174// reduces to a rotate in direction shift2 by Pos or (equivalently) a rotate
8175// in direction shift1 by Neg. The range [0, EltSize) means that we only need
8176// to consider shift amounts with defined behavior.
8177//
8178// The IsRotate flag should be set when the LHS of both shifts is the same.
8179// Otherwise if matching a general funnel shift, it should be clear.
8180static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
8181 SelectionDAG &DAG, bool IsRotate) {
8182 const auto &TLI = DAG.getTargetLoweringInfo();
8183 // If EltSize is a power of 2 then:
8184 //
8185 // (a) (Pos == 0 ? 0 : EltSize - Pos) == (EltSize - Pos) & (EltSize - 1)
8186 // (b) Neg == Neg & (EltSize - 1) whenever Neg is in [0, EltSize).
8187 //
8188 // So if EltSize is a power of 2 and Neg is (and Neg', EltSize-1), we check
8189 // for the stronger condition:
8190 //
8191 // Neg & (EltSize - 1) == (EltSize - Pos) & (EltSize - 1) [A]
8192 //
8193 // for all Neg and Pos. Since Neg & (EltSize - 1) == Neg' & (EltSize - 1)
8194 // we can just replace Neg with Neg' for the rest of the function.
8195 //
8196 // In other cases we check for the even stronger condition:
8197 //
8198 // Neg == EltSize - Pos [B]
8199 //
8200 // for all Neg and Pos. Note that the (or ...) then invokes undefined
8201 // behavior if Pos == 0 (and consequently Neg == EltSize).
8202 //
8203 // We could actually use [A] whenever EltSize is a power of 2, but the
8204 // only extra cases that it would match are those uninteresting ones
8205 // where Neg and Pos are never in range at the same time. E.g. for
8206 // EltSize == 32, using [A] would allow a Neg of the form (sub 64, Pos)
8207 // as well as (sub 32, Pos), but:
8208 //
8209 // (or (shift1 X, (sub 64, Pos)), (shift2 X, Pos))
8210 //
8211 // always invokes undefined behavior for 32-bit X.
8212 //
8213 // Below, Mask == EltSize - 1 when using [A] and is all-ones otherwise.
8214 // This allows us to peek through any operations that only affect Mask's
8215 // un-demanded bits.
8216 //
8217 // NOTE: We can only do this when matching operations which won't modify the
8218 // least Log2(EltSize) significant bits and not a general funnel shift.
8219 unsigned MaskLoBits = 0;
8220 if (IsRotate && isPowerOf2_64(Value: EltSize)) {
8221 unsigned Bits = Log2_64(Value: EltSize);
8222 unsigned NegBits = Neg.getScalarValueSizeInBits();
8223 if (NegBits >= Bits) {
8224 APInt DemandedBits = APInt::getLowBitsSet(numBits: NegBits, loBitsSet: Bits);
8225 if (SDValue Inner =
8226 TLI.SimplifyMultipleUseDemandedBits(Op: Neg, DemandedBits, DAG)) {
8227 Neg = Inner;
8228 MaskLoBits = Bits;
8229 }
8230 }
8231 }
8232
8233 // Check whether Neg has the form (sub NegC, NegOp1) for some NegC and NegOp1.
8234 if (Neg.getOpcode() != ISD::SUB)
8235 return false;
8236 ConstantSDNode *NegC = isConstOrConstSplat(N: Neg.getOperand(i: 0));
8237 if (!NegC)
8238 return false;
8239 SDValue NegOp1 = Neg.getOperand(i: 1);
8240
8241 // On the RHS of [A], if Pos is the result of operation on Pos' that won't
8242 // affect Mask's demanded bits, just replace Pos with Pos'. These operations
8243 // are redundant for the purpose of the equality.
8244 if (MaskLoBits) {
8245 unsigned PosBits = Pos.getScalarValueSizeInBits();
8246 if (PosBits >= MaskLoBits) {
8247 APInt DemandedBits = APInt::getLowBitsSet(numBits: PosBits, loBitsSet: MaskLoBits);
8248 if (SDValue Inner =
8249 TLI.SimplifyMultipleUseDemandedBits(Op: Pos, DemandedBits, DAG)) {
8250 Pos = Inner;
8251 }
8252 }
8253 }
8254
8255 // The condition we need is now:
8256 //
8257 // (NegC - NegOp1) & Mask == (EltSize - Pos) & Mask
8258 //
8259 // If NegOp1 == Pos then we need:
8260 //
8261 // EltSize & Mask == NegC & Mask
8262 //
8263 // (because "x & Mask" is a truncation and distributes through subtraction).
8264 //
8265 // We also need to account for a potential truncation of NegOp1 if the amount
8266 // has already been legalized to a shift amount type.
8267 APInt Width;
8268 if ((Pos == NegOp1) ||
8269 (NegOp1.getOpcode() == ISD::TRUNCATE && Pos == NegOp1.getOperand(i: 0)))
8270 Width = NegC->getAPIntValue();
8271
8272 // Check for cases where Pos has the form (add NegOp1, PosC) for some PosC.
8273 // Then the condition we want to prove becomes:
8274 //
8275 // (NegC - NegOp1) & Mask == (EltSize - (NegOp1 + PosC)) & Mask
8276 //
8277 // which, again because "x & Mask" is a truncation, becomes:
8278 //
8279 // NegC & Mask == (EltSize - PosC) & Mask
8280 // EltSize & Mask == (NegC + PosC) & Mask
8281 else if (Pos.getOpcode() == ISD::ADD && Pos.getOperand(i: 0) == NegOp1) {
8282 if (ConstantSDNode *PosC = isConstOrConstSplat(N: Pos.getOperand(i: 1)))
8283 Width = PosC->getAPIntValue() + NegC->getAPIntValue();
8284 else
8285 return false;
8286 } else
8287 return false;
8288
8289 // Now we just need to check that EltSize & Mask == Width & Mask.
8290 if (MaskLoBits)
8291 // EltSize & Mask is 0 since Mask is EltSize - 1.
8292 return Width.getLoBits(numBits: MaskLoBits) == 0;
8293 return Width == EltSize;
8294}
8295
8296// A subroutine of MatchRotate used once we have found an OR of two opposite
8297// shifts of Shifted. If Neg == <operand size> - Pos then the OR reduces
8298// to both (PosOpcode Shifted, Pos) and (NegOpcode Shifted, Neg), with the
8299// former being preferred if supported. InnerPos and InnerNeg are Pos and
8300// Neg with outer conversions stripped away.
8301SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
8302 SDValue Neg, SDValue InnerPos,
8303 SDValue InnerNeg, bool HasPos,
8304 unsigned PosOpcode, unsigned NegOpcode,
8305 const SDLoc &DL) {
8306 // fold (or (shl x, (*ext y)),
8307 // (srl x, (*ext (sub 32, y)))) ->
8308 // (rotl x, y) or (rotr x, (sub 32, y))
8309 //
8310 // fold (or (shl x, (*ext (sub 32, y))),
8311 // (srl x, (*ext y))) ->
8312 // (rotr x, y) or (rotl x, (sub 32, y))
8313 EVT VT = Shifted.getValueType();
8314 if (matchRotateSub(Pos: InnerPos, Neg: InnerNeg, EltSize: VT.getScalarSizeInBits(), DAG,
8315 /*IsRotate*/ true)) {
8316 return DAG.getNode(Opcode: HasPos ? PosOpcode : NegOpcode, DL, VT, N1: Shifted,
8317 N2: HasPos ? Pos : Neg);
8318 }
8319
8320 return SDValue();
8321}
8322
8323// A subroutine of MatchRotate used once we have found an OR of two opposite
8324// shifts of N0 + N1. If Neg == <operand size> - Pos then the OR reduces
8325// to both (PosOpcode N0, N1, Pos) and (NegOpcode N0, N1, Neg), with the
8326// former being preferred if supported. InnerPos and InnerNeg are Pos and
8327// Neg with outer conversions stripped away.
8328// TODO: Merge with MatchRotatePosNeg.
8329SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
8330 SDValue Neg, SDValue InnerPos,
8331 SDValue InnerNeg, bool HasPos,
8332 unsigned PosOpcode, unsigned NegOpcode,
8333 const SDLoc &DL) {
8334 EVT VT = N0.getValueType();
8335 unsigned EltBits = VT.getScalarSizeInBits();
8336
8337 // fold (or (shl x0, (*ext y)),
8338 // (srl x1, (*ext (sub 32, y)))) ->
8339 // (fshl x0, x1, y) or (fshr x0, x1, (sub 32, y))
8340 //
8341 // fold (or (shl x0, (*ext (sub 32, y))),
8342 // (srl x1, (*ext y))) ->
8343 // (fshr x0, x1, y) or (fshl x0, x1, (sub 32, y))
8344 if (matchRotateSub(Pos: InnerPos, Neg: InnerNeg, EltSize: EltBits, DAG, /*IsRotate*/ N0 == N1)) {
8345 return DAG.getNode(Opcode: HasPos ? PosOpcode : NegOpcode, DL, VT, N1: N0, N2: N1,
8346 N3: HasPos ? Pos : Neg);
8347 }
8348
8349 // Matching the shift+xor cases, we can't easily use the xor'd shift amount
8350 // so for now just use the PosOpcode case if its legal.
8351 // TODO: When can we use the NegOpcode case?
8352 if (PosOpcode == ISD::FSHL && isPowerOf2_32(Value: EltBits)) {
8353 auto IsBinOpImm = [](SDValue Op, unsigned BinOpc, unsigned Imm) {
8354 if (Op.getOpcode() != BinOpc)
8355 return false;
8356 ConstantSDNode *Cst = isConstOrConstSplat(N: Op.getOperand(i: 1));
8357 return Cst && (Cst->getAPIntValue() == Imm);
8358 };
8359
8360 // fold (or (shl x0, y), (srl (srl x1, 1), (xor y, 31)))
8361 // -> (fshl x0, x1, y)
8362 if (IsBinOpImm(N1, ISD::SRL, 1) &&
8363 IsBinOpImm(InnerNeg, ISD::XOR, EltBits - 1) &&
8364 InnerPos == InnerNeg.getOperand(i: 0) &&
8365 TLI.isOperationLegalOrCustom(Op: ISD::FSHL, VT)) {
8366 return DAG.getNode(Opcode: ISD::FSHL, DL, VT, N1: N0, N2: N1.getOperand(i: 0), N3: Pos);
8367 }
8368
8369 // fold (or (shl (shl x0, 1), (xor y, 31)), (srl x1, y))
8370 // -> (fshr x0, x1, y)
8371 if (IsBinOpImm(N0, ISD::SHL, 1) &&
8372 IsBinOpImm(InnerPos, ISD::XOR, EltBits - 1) &&
8373 InnerNeg == InnerPos.getOperand(i: 0) &&
8374 TLI.isOperationLegalOrCustom(Op: ISD::FSHR, VT)) {
8375 return DAG.getNode(Opcode: ISD::FSHR, DL, VT, N1: N0.getOperand(i: 0), N2: N1, N3: Neg);
8376 }
8377
8378 // fold (or (shl (add x0, x0), (xor y, 31)), (srl x1, y))
8379 // -> (fshr x0, x1, y)
8380 // TODO: Should add(x,x) -> shl(x,1) be a general DAG canonicalization?
8381 if (N0.getOpcode() == ISD::ADD && N0.getOperand(i: 0) == N0.getOperand(i: 1) &&
8382 IsBinOpImm(InnerPos, ISD::XOR, EltBits - 1) &&
8383 InnerNeg == InnerPos.getOperand(i: 0) &&
8384 TLI.isOperationLegalOrCustom(Op: ISD::FSHR, VT)) {
8385 return DAG.getNode(Opcode: ISD::FSHR, DL, VT, N1: N0.getOperand(i: 0), N2: N1, N3: Neg);
8386 }
8387 }
8388
8389 return SDValue();
8390}
8391
8392// MatchRotate - Handle an 'or' of two operands. If this is one of the many
8393// idioms for rotate, and if the target supports rotation instructions, generate
8394// a rot[lr]. This also matches funnel shift patterns, similar to rotation but
8395// with different shifted sources.
8396SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) {
8397 EVT VT = LHS.getValueType();
8398
8399 // The target must have at least one rotate/funnel flavor.
8400 // We still try to match rotate by constant pre-legalization.
8401 // TODO: Support pre-legalization funnel-shift by constant.
8402 bool HasROTL = hasOperation(Opcode: ISD::ROTL, VT);
8403 bool HasROTR = hasOperation(Opcode: ISD::ROTR, VT);
8404 bool HasFSHL = hasOperation(Opcode: ISD::FSHL, VT);
8405 bool HasFSHR = hasOperation(Opcode: ISD::FSHR, VT);
8406
8407 // If the type is going to be promoted and the target has enabled custom
8408 // lowering for rotate, allow matching rotate by non-constants. Only allow
8409 // this for scalar types.
8410 if (VT.isScalarInteger() && TLI.getTypeAction(Context&: *DAG.getContext(), VT) ==
8411 TargetLowering::TypePromoteInteger) {
8412 HasROTL |= TLI.getOperationAction(Op: ISD::ROTL, VT) == TargetLowering::Custom;
8413 HasROTR |= TLI.getOperationAction(Op: ISD::ROTR, VT) == TargetLowering::Custom;
8414 }
8415
8416 if (LegalOperations && !HasROTL && !HasROTR && !HasFSHL && !HasFSHR)
8417 return SDValue();
8418
8419 // Check for truncated rotate.
8420 if (LHS.getOpcode() == ISD::TRUNCATE && RHS.getOpcode() == ISD::TRUNCATE &&
8421 LHS.getOperand(i: 0).getValueType() == RHS.getOperand(i: 0).getValueType()) {
8422 assert(LHS.getValueType() == RHS.getValueType());
8423 if (SDValue Rot = MatchRotate(LHS: LHS.getOperand(i: 0), RHS: RHS.getOperand(i: 0), DL)) {
8424 return DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(LHS), VT: LHS.getValueType(), Operand: Rot);
8425 }
8426 }
8427
8428 // Match "(X shl/srl V1) & V2" where V2 may not be present.
8429 SDValue LHSShift; // The shift.
8430 SDValue LHSMask; // AND value if any.
8431 matchRotateHalf(DAG, Op: LHS, Shift&: LHSShift, Mask&: LHSMask);
8432
8433 SDValue RHSShift; // The shift.
8434 SDValue RHSMask; // AND value if any.
8435 matchRotateHalf(DAG, Op: RHS, Shift&: RHSShift, Mask&: RHSMask);
8436
8437 // If neither side matched a rotate half, bail
8438 if (!LHSShift && !RHSShift)
8439 return SDValue();
8440
8441 // InstCombine may have combined a constant shl, srl, mul, or udiv with one
8442 // side of the rotate, so try to handle that here. In all cases we need to
8443 // pass the matched shift from the opposite side to compute the opcode and
8444 // needed shift amount to extract. We still want to do this if both sides
8445 // matched a rotate half because one half may be a potential overshift that
8446 // can be broken down (ie if InstCombine merged two shl or srl ops into a
8447 // single one).
8448
8449 // Have LHS side of the rotate, try to extract the needed shift from the RHS.
8450 if (LHSShift)
8451 if (SDValue NewRHSShift =
8452 extractShiftForRotate(DAG, OppShift: LHSShift, ExtractFrom: RHS, Mask&: RHSMask, DL))
8453 RHSShift = NewRHSShift;
8454 // Have RHS side of the rotate, try to extract the needed shift from the LHS.
8455 if (RHSShift)
8456 if (SDValue NewLHSShift =
8457 extractShiftForRotate(DAG, OppShift: RHSShift, ExtractFrom: LHS, Mask&: LHSMask, DL))
8458 LHSShift = NewLHSShift;
8459
8460 // If a side is still missing, nothing else we can do.
8461 if (!RHSShift || !LHSShift)
8462 return SDValue();
8463
8464 // At this point we've matched or extracted a shift op on each side.
8465
8466 if (LHSShift.getOpcode() == RHSShift.getOpcode())
8467 return SDValue(); // Shifts must disagree.
8468
8469 // Canonicalize shl to left side in a shl/srl pair.
8470 if (RHSShift.getOpcode() == ISD::SHL) {
8471 std::swap(a&: LHS, b&: RHS);
8472 std::swap(a&: LHSShift, b&: RHSShift);
8473 std::swap(a&: LHSMask, b&: RHSMask);
8474 }
8475
8476 // Something has gone wrong - we've lost the shl/srl pair - bail.
8477 if (LHSShift.getOpcode() != ISD::SHL || RHSShift.getOpcode() != ISD::SRL)
8478 return SDValue();
8479
8480 unsigned EltSizeInBits = VT.getScalarSizeInBits();
8481 SDValue LHSShiftArg = LHSShift.getOperand(i: 0);
8482 SDValue LHSShiftAmt = LHSShift.getOperand(i: 1);
8483 SDValue RHSShiftArg = RHSShift.getOperand(i: 0);
8484 SDValue RHSShiftAmt = RHSShift.getOperand(i: 1);
8485
8486 auto MatchRotateSum = [EltSizeInBits](ConstantSDNode *LHS,
8487 ConstantSDNode *RHS) {
8488 return (LHS->getAPIntValue() + RHS->getAPIntValue()) == EltSizeInBits;
8489 };
8490
8491 auto ApplyMasks = [&](SDValue Res) {
8492 // If there is an AND of either shifted operand, apply it to the result.
8493 if (LHSMask.getNode() || RHSMask.getNode()) {
8494 SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
8495 SDValue Mask = AllOnes;
8496
8497 if (LHSMask.getNode()) {
8498 SDValue RHSBits = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: AllOnes, N2: RHSShiftAmt);
8499 Mask = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Mask,
8500 N2: DAG.getNode(Opcode: ISD::OR, DL, VT, N1: LHSMask, N2: RHSBits));
8501 }
8502 if (RHSMask.getNode()) {
8503 SDValue LHSBits = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: AllOnes, N2: LHSShiftAmt);
8504 Mask = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Mask,
8505 N2: DAG.getNode(Opcode: ISD::OR, DL, VT, N1: RHSMask, N2: LHSBits));
8506 }
8507
8508 Res = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Res, N2: Mask);
8509 }
8510
8511 return Res;
8512 };
8513
8514 // TODO: Support pre-legalization funnel-shift by constant.
8515 bool IsRotate = LHSShiftArg == RHSShiftArg;
8516 if (!IsRotate && !(HasFSHL || HasFSHR)) {
8517 if (TLI.isTypeLegal(VT) && LHS.hasOneUse() && RHS.hasOneUse() &&
8518 ISD::matchBinaryPredicate(LHS: LHSShiftAmt, RHS: RHSShiftAmt, Match: MatchRotateSum)) {
8519 // Look for a disguised rotate by constant.
8520 // The common shifted operand X may be hidden inside another 'or'.
8521 SDValue X, Y;
8522 auto matchOr = [&X, &Y](SDValue Or, SDValue CommonOp) {
8523 if (!Or.hasOneUse() || Or.getOpcode() != ISD::OR)
8524 return false;
8525 if (CommonOp == Or.getOperand(i: 0)) {
8526 X = CommonOp;
8527 Y = Or.getOperand(i: 1);
8528 return true;
8529 }
8530 if (CommonOp == Or.getOperand(i: 1)) {
8531 X = CommonOp;
8532 Y = Or.getOperand(i: 0);
8533 return true;
8534 }
8535 return false;
8536 };
8537
8538 SDValue Res;
8539 if (matchOr(LHSShiftArg, RHSShiftArg)) {
8540 // (shl (X | Y), C1) | (srl X, C2) --> (rotl X, C1) | (shl Y, C1)
8541 SDValue RotX = DAG.getNode(Opcode: ISD::ROTL, DL, VT, N1: X, N2: LHSShiftAmt);
8542 SDValue ShlY = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Y, N2: LHSShiftAmt);
8543 Res = DAG.getNode(Opcode: ISD::OR, DL, VT, N1: RotX, N2: ShlY);
8544 } else if (matchOr(RHSShiftArg, LHSShiftArg)) {
8545 // (shl X, C1) | (srl (X | Y), C2) --> (rotl X, C1) | (srl Y, C2)
8546 SDValue RotX = DAG.getNode(Opcode: ISD::ROTL, DL, VT, N1: X, N2: LHSShiftAmt);
8547 SDValue SrlY = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Y, N2: RHSShiftAmt);
8548 Res = DAG.getNode(Opcode: ISD::OR, DL, VT, N1: RotX, N2: SrlY);
8549 } else {
8550 return SDValue();
8551 }
8552
8553 return ApplyMasks(Res);
8554 }
8555
8556 return SDValue(); // Requires funnel shift support.
8557 }
8558
8559 // fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1)
8560 // fold (or (shl x, C1), (srl x, C2)) -> (rotr x, C2)
8561 // fold (or (shl x, C1), (srl y, C2)) -> (fshl x, y, C1)
8562 // fold (or (shl x, C1), (srl y, C2)) -> (fshr x, y, C2)
8563 // iff C1+C2 == EltSizeInBits
8564 if (ISD::matchBinaryPredicate(LHS: LHSShiftAmt, RHS: RHSShiftAmt, Match: MatchRotateSum)) {
8565 SDValue Res;
8566 if (IsRotate && (HasROTL || HasROTR || !(HasFSHL || HasFSHR))) {
8567 bool UseROTL = !LegalOperations || HasROTL;
8568 Res = DAG.getNode(Opcode: UseROTL ? ISD::ROTL : ISD::ROTR, DL, VT, N1: LHSShiftArg,
8569 N2: UseROTL ? LHSShiftAmt : RHSShiftAmt);
8570 } else {
8571 bool UseFSHL = !LegalOperations || HasFSHL;
8572 Res = DAG.getNode(Opcode: UseFSHL ? ISD::FSHL : ISD::FSHR, DL, VT, N1: LHSShiftArg,
8573 N2: RHSShiftArg, N3: UseFSHL ? LHSShiftAmt : RHSShiftAmt);
8574 }
8575
8576 return ApplyMasks(Res);
8577 }
8578
8579 // Even pre-legalization, we can't easily rotate/funnel-shift by a variable
8580 // shift.
8581 if (!HasROTL && !HasROTR && !HasFSHL && !HasFSHR)
8582 return SDValue();
8583
8584 // If there is a mask here, and we have a variable shift, we can't be sure
8585 // that we're masking out the right stuff.
8586 if (LHSMask.getNode() || RHSMask.getNode())
8587 return SDValue();
8588
8589 // If the shift amount is sign/zext/any-extended just peel it off.
8590 SDValue LExtOp0 = LHSShiftAmt;
8591 SDValue RExtOp0 = RHSShiftAmt;
8592 if ((LHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
8593 LHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
8594 LHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
8595 LHSShiftAmt.getOpcode() == ISD::TRUNCATE) &&
8596 (RHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
8597 RHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
8598 RHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
8599 RHSShiftAmt.getOpcode() == ISD::TRUNCATE)) {
8600 LExtOp0 = LHSShiftAmt.getOperand(i: 0);
8601 RExtOp0 = RHSShiftAmt.getOperand(i: 0);
8602 }
8603
8604 if (IsRotate && (HasROTL || HasROTR)) {
8605 SDValue TryL =
8606 MatchRotatePosNeg(Shifted: LHSShiftArg, Pos: LHSShiftAmt, Neg: RHSShiftAmt, InnerPos: LExtOp0,
8607 InnerNeg: RExtOp0, HasPos: HasROTL, PosOpcode: ISD::ROTL, NegOpcode: ISD::ROTR, DL);
8608 if (TryL)
8609 return TryL;
8610
8611 SDValue TryR =
8612 MatchRotatePosNeg(Shifted: RHSShiftArg, Pos: RHSShiftAmt, Neg: LHSShiftAmt, InnerPos: RExtOp0,
8613 InnerNeg: LExtOp0, HasPos: HasROTR, PosOpcode: ISD::ROTR, NegOpcode: ISD::ROTL, DL);
8614 if (TryR)
8615 return TryR;
8616 }
8617
8618 SDValue TryL =
8619 MatchFunnelPosNeg(N0: LHSShiftArg, N1: RHSShiftArg, Pos: LHSShiftAmt, Neg: RHSShiftAmt,
8620 InnerPos: LExtOp0, InnerNeg: RExtOp0, HasPos: HasFSHL, PosOpcode: ISD::FSHL, NegOpcode: ISD::FSHR, DL);
8621 if (TryL)
8622 return TryL;
8623
8624 SDValue TryR =
8625 MatchFunnelPosNeg(N0: LHSShiftArg, N1: RHSShiftArg, Pos: RHSShiftAmt, Neg: LHSShiftAmt,
8626 InnerPos: RExtOp0, InnerNeg: LExtOp0, HasPos: HasFSHR, PosOpcode: ISD::FSHR, NegOpcode: ISD::FSHL, DL);
8627 if (TryR)
8628 return TryR;
8629
8630 return SDValue();
8631}
8632
8633/// Recursively traverses the expression calculating the origin of the requested
8634/// byte of the given value. Returns std::nullopt if the provider can't be
8635/// calculated.
8636///
8637/// For all the values except the root of the expression, we verify that the
8638/// value has exactly one use and if not then return std::nullopt. This way if
8639/// the origin of the byte is returned it's guaranteed that the values which
8640/// contribute to the byte are not used outside of this expression.
8641
8642/// However, there is a special case when dealing with vector loads -- we allow
8643/// more than one use if the load is a vector type. Since the values that
8644/// contribute to the byte ultimately come from the ExtractVectorElements of the
8645/// Load, we don't care if the Load has uses other than ExtractVectorElements,
8646/// because those operations are independent from the pattern to be combined.
8647/// For vector loads, we simply care that the ByteProviders are adjacent
8648/// positions of the same vector, and their index matches the byte that is being
8649/// provided. This is captured by the \p VectorIndex algorithm. \p VectorIndex
8650/// is the index used in an ExtractVectorElement, and \p StartingIndex is the
8651/// byte position we are trying to provide for the LoadCombine. If these do
8652/// not match, then we can not combine the vector loads. \p Index uses the
8653/// byte position we are trying to provide for and is matched against the
8654/// shl and load size. The \p Index algorithm ensures the requested byte is
8655/// provided for by the pattern, and the pattern does not over provide bytes.
8656///
8657///
8658/// The supported LoadCombine pattern for vector loads is as follows
8659/// or
8660/// / \
8661/// or shl
8662/// / \ |
8663/// or shl zext
8664/// / \ | |
8665/// shl zext zext EVE*
8666/// | | | |
8667/// zext EVE* EVE* LOAD
8668/// | | |
8669/// EVE* LOAD LOAD
8670/// |
8671/// LOAD
8672///
8673/// *ExtractVectorElement
8674using SDByteProvider = ByteProvider<SDNode *>;
8675
8676static std::optional<SDByteProvider>
8677calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
8678 std::optional<uint64_t> VectorIndex,
8679 unsigned StartingIndex = 0) {
8680
8681 // Typical i64 by i8 pattern requires recursion up to 8 calls depth
8682 if (Depth == 10)
8683 return std::nullopt;
8684
8685 // Only allow multiple uses if the instruction is a vector load (in which
8686 // case we will use the load for every ExtractVectorElement)
8687 if (Depth && !Op.hasOneUse() &&
8688 (Op.getOpcode() != ISD::LOAD || !Op.getValueType().isVector()))
8689 return std::nullopt;
8690
8691 // Fail to combine if we have encountered anything but a LOAD after handling
8692 // an ExtractVectorElement.
8693 if (Op.getOpcode() != ISD::LOAD && VectorIndex.has_value())
8694 return std::nullopt;
8695
8696 unsigned BitWidth = Op.getValueSizeInBits();
8697 if (BitWidth % 8 != 0)
8698 return std::nullopt;
8699 unsigned ByteWidth = BitWidth / 8;
8700 assert(Index < ByteWidth && "invalid index requested");
8701 (void) ByteWidth;
8702
8703 switch (Op.getOpcode()) {
8704 case ISD::OR: {
8705 auto LHS =
8706 calculateByteProvider(Op: Op->getOperand(Num: 0), Index, Depth: Depth + 1, VectorIndex);
8707 if (!LHS)
8708 return std::nullopt;
8709 auto RHS =
8710 calculateByteProvider(Op: Op->getOperand(Num: 1), Index, Depth: Depth + 1, VectorIndex);
8711 if (!RHS)
8712 return std::nullopt;
8713
8714 if (LHS->isConstantZero())
8715 return RHS;
8716 if (RHS->isConstantZero())
8717 return LHS;
8718 return std::nullopt;
8719 }
8720 case ISD::SHL: {
8721 auto ShiftOp = dyn_cast<ConstantSDNode>(Val: Op->getOperand(Num: 1));
8722 if (!ShiftOp)
8723 return std::nullopt;
8724
8725 uint64_t BitShift = ShiftOp->getZExtValue();
8726
8727 if (BitShift % 8 != 0)
8728 return std::nullopt;
8729 uint64_t ByteShift = BitShift / 8;
8730
8731 // If we are shifting by an amount greater than the index we are trying to
8732 // provide, then do not provide anything. Otherwise, subtract the index by
8733 // the amount we shifted by.
8734 return Index < ByteShift
8735 ? SDByteProvider::getConstantZero()
8736 : calculateByteProvider(Op: Op->getOperand(Num: 0), Index: Index - ByteShift,
8737 Depth: Depth + 1, VectorIndex, StartingIndex: Index);
8738 }
8739 case ISD::ANY_EXTEND:
8740 case ISD::SIGN_EXTEND:
8741 case ISD::ZERO_EXTEND: {
8742 SDValue NarrowOp = Op->getOperand(Num: 0);
8743 unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
8744 if (NarrowBitWidth % 8 != 0)
8745 return std::nullopt;
8746 uint64_t NarrowByteWidth = NarrowBitWidth / 8;
8747
8748 if (Index >= NarrowByteWidth)
8749 return Op.getOpcode() == ISD::ZERO_EXTEND
8750 ? std::optional<SDByteProvider>(
8751 SDByteProvider::getConstantZero())
8752 : std::nullopt;
8753 return calculateByteProvider(Op: NarrowOp, Index, Depth: Depth + 1, VectorIndex,
8754 StartingIndex);
8755 }
8756 case ISD::BSWAP:
8757 return calculateByteProvider(Op: Op->getOperand(Num: 0), Index: ByteWidth - Index - 1,
8758 Depth: Depth + 1, VectorIndex, StartingIndex);
8759 case ISD::EXTRACT_VECTOR_ELT: {
8760 auto OffsetOp = dyn_cast<ConstantSDNode>(Val: Op->getOperand(Num: 1));
8761 if (!OffsetOp)
8762 return std::nullopt;
8763
8764 VectorIndex = OffsetOp->getZExtValue();
8765
8766 SDValue NarrowOp = Op->getOperand(Num: 0);
8767 unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
8768 if (NarrowBitWidth % 8 != 0)
8769 return std::nullopt;
8770 uint64_t NarrowByteWidth = NarrowBitWidth / 8;
8771 // EXTRACT_VECTOR_ELT can extend the element type to the width of the return
8772 // type, leaving the high bits undefined.
8773 if (Index >= NarrowByteWidth)
8774 return std::nullopt;
8775
8776 // Check to see if the position of the element in the vector corresponds
8777 // with the byte we are trying to provide for. In the case of a vector of
8778 // i8, this simply means the VectorIndex == StartingIndex. For non i8 cases,
8779 // the element will provide a range of bytes. For example, if we have a
8780 // vector of i16s, each element provides two bytes (V[1] provides byte 2 and
8781 // 3).
8782 if (*VectorIndex * NarrowByteWidth > StartingIndex)
8783 return std::nullopt;
8784 if ((*VectorIndex + 1) * NarrowByteWidth <= StartingIndex)
8785 return std::nullopt;
8786
8787 return calculateByteProvider(Op: Op->getOperand(Num: 0), Index, Depth: Depth + 1,
8788 VectorIndex, StartingIndex);
8789 }
8790 case ISD::LOAD: {
8791 auto L = cast<LoadSDNode>(Val: Op.getNode());
8792 if (!L->isSimple() || L->isIndexed())
8793 return std::nullopt;
8794
8795 unsigned NarrowBitWidth = L->getMemoryVT().getSizeInBits();
8796 if (NarrowBitWidth % 8 != 0)
8797 return std::nullopt;
8798 uint64_t NarrowByteWidth = NarrowBitWidth / 8;
8799
8800 // If the width of the load does not reach byte we are trying to provide for
8801 // and it is not a ZEXTLOAD, then the load does not provide for the byte in
8802 // question
8803 if (Index >= NarrowByteWidth)
8804 return L->getExtensionType() == ISD::ZEXTLOAD
8805 ? std::optional<SDByteProvider>(
8806 SDByteProvider::getConstantZero())
8807 : std::nullopt;
8808
8809 unsigned BPVectorIndex = VectorIndex.value_or(u: 0U);
8810 return SDByteProvider::getSrc(Val: L, ByteOffset: Index, VectorOffset: BPVectorIndex);
8811 }
8812 }
8813
8814 return std::nullopt;
8815}
8816
8817static unsigned littleEndianByteAt(unsigned BW, unsigned i) {
8818 return i;
8819}
8820
8821static unsigned bigEndianByteAt(unsigned BW, unsigned i) {
8822 return BW - i - 1;
8823}
8824
8825// Check if the bytes offsets we are looking at match with either big or
8826// little endian value loaded. Return true for big endian, false for little
8827// endian, and std::nullopt if match failed.
8828static std::optional<bool> isBigEndian(const ArrayRef<int64_t> ByteOffsets,
8829 int64_t FirstOffset) {
8830 // The endian can be decided only when it is 2 bytes at least.
8831 unsigned Width = ByteOffsets.size();
8832 if (Width < 2)
8833 return std::nullopt;
8834
8835 bool BigEndian = true, LittleEndian = true;
8836 for (unsigned i = 0; i < Width; i++) {
8837 int64_t CurrentByteOffset = ByteOffsets[i] - FirstOffset;
8838 LittleEndian &= CurrentByteOffset == littleEndianByteAt(BW: Width, i);
8839 BigEndian &= CurrentByteOffset == bigEndianByteAt(BW: Width, i);
8840 if (!BigEndian && !LittleEndian)
8841 return std::nullopt;
8842 }
8843
8844 assert((BigEndian != LittleEndian) && "It should be either big endian or"
8845 "little endian");
8846 return BigEndian;
8847}
8848
8849// Look through one layer of truncate or extend.
8850static SDValue stripTruncAndExt(SDValue Value) {
8851 switch (Value.getOpcode()) {
8852 case ISD::TRUNCATE:
8853 case ISD::ZERO_EXTEND:
8854 case ISD::SIGN_EXTEND:
8855 case ISD::ANY_EXTEND:
8856 return Value.getOperand(i: 0);
8857 }
8858 return SDValue();
8859}
8860
8861/// Match a pattern where a wide type scalar value is stored by several narrow
8862/// stores. Fold it into a single store or a BSWAP and a store if the targets
8863/// supports it.
8864///
8865/// Assuming little endian target:
8866/// i8 *p = ...
8867/// i32 val = ...
8868/// p[0] = (val >> 0) & 0xFF;
8869/// p[1] = (val >> 8) & 0xFF;
8870/// p[2] = (val >> 16) & 0xFF;
8871/// p[3] = (val >> 24) & 0xFF;
8872/// =>
8873/// *((i32)p) = val;
8874///
8875/// i8 *p = ...
8876/// i32 val = ...
8877/// p[0] = (val >> 24) & 0xFF;
8878/// p[1] = (val >> 16) & 0xFF;
8879/// p[2] = (val >> 8) & 0xFF;
8880/// p[3] = (val >> 0) & 0xFF;
8881/// =>
8882/// *((i32)p) = BSWAP(val);
8883SDValue DAGCombiner::mergeTruncStores(StoreSDNode *N) {
8884 // The matching looks for "store (trunc x)" patterns that appear early but are
8885 // likely to be replaced by truncating store nodes during combining.
8886 // TODO: If there is evidence that running this later would help, this
8887 // limitation could be removed. Legality checks may need to be added
8888 // for the created store and optional bswap/rotate.
8889 if (LegalOperations || OptLevel == CodeGenOptLevel::None)
8890 return SDValue();
8891
8892 // We only handle merging simple stores of 1-4 bytes.
8893 // TODO: Allow unordered atomics when wider type is legal (see D66309)
8894 EVT MemVT = N->getMemoryVT();
8895 if (!(MemVT == MVT::i8 || MemVT == MVT::i16 || MemVT == MVT::i32) ||
8896 !N->isSimple() || N->isIndexed())
8897 return SDValue();
8898
8899 // Collect all of the stores in the chain, upto the maximum store width (i64).
8900 SDValue Chain = N->getChain();
8901 SmallVector<StoreSDNode *, 8> Stores = {N};
8902 unsigned NarrowNumBits = MemVT.getScalarSizeInBits();
8903 unsigned MaxWideNumBits = 64;
8904 unsigned MaxStores = MaxWideNumBits / NarrowNumBits;
8905 while (auto *Store = dyn_cast<StoreSDNode>(Val&: Chain)) {
8906 // All stores must be the same size to ensure that we are writing all of the
8907 // bytes in the wide value.
8908 // This store should have exactly one use as a chain operand for another
8909 // store in the merging set. If there are other chain uses, then the
8910 // transform may not be safe because order of loads/stores outside of this
8911 // set may not be preserved.
8912 // TODO: We could allow multiple sizes by tracking each stored byte.
8913 if (Store->getMemoryVT() != MemVT || !Store->isSimple() ||
8914 Store->isIndexed() || !Store->hasOneUse())
8915 return SDValue();
8916 Stores.push_back(Elt: Store);
8917 Chain = Store->getChain();
8918 if (MaxStores < Stores.size())
8919 return SDValue();
8920 }
8921 // There is no reason to continue if we do not have at least a pair of stores.
8922 if (Stores.size() < 2)
8923 return SDValue();
8924
8925 // Handle simple types only.
8926 LLVMContext &Context = *DAG.getContext();
8927 unsigned NumStores = Stores.size();
8928 unsigned WideNumBits = NumStores * NarrowNumBits;
8929 EVT WideVT = EVT::getIntegerVT(Context, BitWidth: WideNumBits);
8930 if (WideVT != MVT::i16 && WideVT != MVT::i32 && WideVT != MVT::i64)
8931 return SDValue();
8932
8933 // Check if all bytes of the source value that we are looking at are stored
8934 // to the same base address. Collect offsets from Base address into OffsetMap.
8935 SDValue SourceValue;
8936 SmallVector<int64_t, 8> OffsetMap(NumStores, INT64_MAX);
8937 int64_t FirstOffset = INT64_MAX;
8938 StoreSDNode *FirstStore = nullptr;
8939 std::optional<BaseIndexOffset> Base;
8940 for (auto *Store : Stores) {
8941 // All the stores store different parts of the CombinedValue. A truncate is
8942 // required to get the partial value.
8943 SDValue Trunc = Store->getValue();
8944 if (Trunc.getOpcode() != ISD::TRUNCATE)
8945 return SDValue();
8946 // Other than the first/last part, a shift operation is required to get the
8947 // offset.
8948 int64_t Offset = 0;
8949 SDValue WideVal = Trunc.getOperand(i: 0);
8950 if ((WideVal.getOpcode() == ISD::SRL || WideVal.getOpcode() == ISD::SRA) &&
8951 isa<ConstantSDNode>(Val: WideVal.getOperand(i: 1))) {
8952 // The shift amount must be a constant multiple of the narrow type.
8953 // It is translated to the offset address in the wide source value "y".
8954 //
8955 // x = srl y, ShiftAmtC
8956 // i8 z = trunc x
8957 // store z, ...
8958 uint64_t ShiftAmtC = WideVal.getConstantOperandVal(i: 1);
8959 if (ShiftAmtC % NarrowNumBits != 0)
8960 return SDValue();
8961
8962 // Make sure we aren't reading bits that are shifted in.
8963 if (ShiftAmtC > WideVal.getScalarValueSizeInBits() - NarrowNumBits)
8964 return SDValue();
8965
8966 Offset = ShiftAmtC / NarrowNumBits;
8967 WideVal = WideVal.getOperand(i: 0);
8968 }
8969
8970 // Stores must share the same source value with different offsets.
8971 if (!SourceValue)
8972 SourceValue = WideVal;
8973 else if (SourceValue != WideVal) {
8974 // Truncate and extends can be stripped to see if the values are related.
8975 if (stripTruncAndExt(Value: SourceValue) != WideVal &&
8976 stripTruncAndExt(Value: WideVal) != SourceValue)
8977 return SDValue();
8978
8979 if (WideVal.getScalarValueSizeInBits() >
8980 SourceValue.getScalarValueSizeInBits())
8981 SourceValue = WideVal;
8982
8983 // Give up if the source value type is smaller than the store size.
8984 if (SourceValue.getScalarValueSizeInBits() < WideVT.getScalarSizeInBits())
8985 return SDValue();
8986 }
8987
8988 // Stores must share the same base address.
8989 BaseIndexOffset Ptr = BaseIndexOffset::match(N: Store, DAG);
8990 int64_t ByteOffsetFromBase = 0;
8991 if (!Base)
8992 Base = Ptr;
8993 else if (!Base->equalBaseIndex(Other: Ptr, DAG, Off&: ByteOffsetFromBase))
8994 return SDValue();
8995
8996 // Remember the first store.
8997 if (ByteOffsetFromBase < FirstOffset) {
8998 FirstStore = Store;
8999 FirstOffset = ByteOffsetFromBase;
9000 }
9001 // Map the offset in the store and the offset in the combined value, and
9002 // early return if it has been set before.
9003 if (Offset < 0 || Offset >= NumStores || OffsetMap[Offset] != INT64_MAX)
9004 return SDValue();
9005 OffsetMap[Offset] = ByteOffsetFromBase;
9006 }
9007
9008 assert(FirstOffset != INT64_MAX && "First byte offset must be set");
9009 assert(FirstStore && "First store must be set");
9010
9011 // Check that a store of the wide type is both allowed and fast on the target
9012 const DataLayout &Layout = DAG.getDataLayout();
9013 unsigned Fast = 0;
9014 bool Allowed = TLI.allowsMemoryAccess(Context, DL: Layout, VT: WideVT,
9015 MMO: *FirstStore->getMemOperand(), Fast: &Fast);
9016 if (!Allowed || !Fast)
9017 return SDValue();
9018
9019 // Check if the pieces of the value are going to the expected places in memory
9020 // to merge the stores.
9021 auto checkOffsets = [&](bool MatchLittleEndian) {
9022 if (MatchLittleEndian) {
9023 for (unsigned i = 0; i != NumStores; ++i)
9024 if (OffsetMap[i] != i * (NarrowNumBits / 8) + FirstOffset)
9025 return false;
9026 } else { // MatchBigEndian by reversing loop counter.
9027 for (unsigned i = 0, j = NumStores - 1; i != NumStores; ++i, --j)
9028 if (OffsetMap[j] != i * (NarrowNumBits / 8) + FirstOffset)
9029 return false;
9030 }
9031 return true;
9032 };
9033
9034 // Check if the offsets line up for the native data layout of this target.
9035 bool NeedBswap = false;
9036 bool NeedRotate = false;
9037 if (!checkOffsets(Layout.isLittleEndian())) {
9038 // Special-case: check if byte offsets line up for the opposite endian.
9039 if (NarrowNumBits == 8 && checkOffsets(Layout.isBigEndian()))
9040 NeedBswap = true;
9041 else if (NumStores == 2 && checkOffsets(Layout.isBigEndian()))
9042 NeedRotate = true;
9043 else
9044 return SDValue();
9045 }
9046
9047 SDLoc DL(N);
9048 if (WideVT != SourceValue.getValueType()) {
9049 assert(SourceValue.getValueType().getScalarSizeInBits() > WideNumBits &&
9050 "Unexpected store value to merge");
9051 SourceValue = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: WideVT, Operand: SourceValue);
9052 }
9053
9054 // Before legalize we can introduce illegal bswaps/rotates which will be later
9055 // converted to an explicit bswap sequence. This way we end up with a single
9056 // store and byte shuffling instead of several stores and byte shuffling.
9057 if (NeedBswap) {
9058 SourceValue = DAG.getNode(Opcode: ISD::BSWAP, DL, VT: WideVT, Operand: SourceValue);
9059 } else if (NeedRotate) {
9060 assert(WideNumBits % 2 == 0 && "Unexpected type for rotate");
9061 SDValue RotAmt = DAG.getConstant(Val: WideNumBits / 2, DL, VT: WideVT);
9062 SourceValue = DAG.getNode(Opcode: ISD::ROTR, DL, VT: WideVT, N1: SourceValue, N2: RotAmt);
9063 }
9064
9065 SDValue NewStore =
9066 DAG.getStore(Chain, dl: DL, Val: SourceValue, Ptr: FirstStore->getBasePtr(),
9067 PtrInfo: FirstStore->getPointerInfo(), Alignment: FirstStore->getAlign());
9068
9069 // Rely on other DAG combine rules to remove the other individual stores.
9070 DAG.ReplaceAllUsesWith(From: N, To: NewStore.getNode());
9071 return NewStore;
9072}
9073
9074/// Match a pattern where a wide type scalar value is loaded by several narrow
9075/// loads and combined by shifts and ors. Fold it into a single load or a load
9076/// and a BSWAP if the targets supports it.
9077///
9078/// Assuming little endian target:
9079/// i8 *a = ...
9080/// i32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24)
9081/// =>
9082/// i32 val = *((i32)a)
9083///
9084/// i8 *a = ...
9085/// i32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]
9086/// =>
9087/// i32 val = BSWAP(*((i32)a))
9088///
9089/// TODO: This rule matches complex patterns with OR node roots and doesn't
9090/// interact well with the worklist mechanism. When a part of the pattern is
9091/// updated (e.g. one of the loads) its direct users are put into the worklist,
9092/// but the root node of the pattern which triggers the load combine is not
9093/// necessarily a direct user of the changed node. For example, once the address
9094/// of t28 load is reassociated load combine won't be triggered:
9095/// t25: i32 = add t4, Constant:i32<2>
9096/// t26: i64 = sign_extend t25
9097/// t27: i64 = add t2, t26
9098/// t28: i8,ch = load<LD1[%tmp9]> t0, t27, undef:i64
9099/// t29: i32 = zero_extend t28
9100/// t32: i32 = shl t29, Constant:i8<8>
9101/// t33: i32 = or t23, t32
9102/// As a possible fix visitLoad can check if the load can be a part of a load
9103/// combine pattern and add corresponding OR roots to the worklist.
9104SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
9105 assert(N->getOpcode() == ISD::OR &&
9106 "Can only match load combining against OR nodes");
9107
9108 // Handles simple types only
9109 EVT VT = N->getValueType(ResNo: 0);
9110 if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
9111 return SDValue();
9112 unsigned ByteWidth = VT.getSizeInBits() / 8;
9113
9114 bool IsBigEndianTarget = DAG.getDataLayout().isBigEndian();
9115 auto MemoryByteOffset = [&](SDByteProvider P) {
9116 assert(P.hasSrc() && "Must be a memory byte provider");
9117 auto *Load = cast<LoadSDNode>(Val: P.Src.value());
9118
9119 unsigned LoadBitWidth = Load->getMemoryVT().getScalarSizeInBits();
9120
9121 assert(LoadBitWidth % 8 == 0 &&
9122 "can only analyze providers for individual bytes not bit");
9123 unsigned LoadByteWidth = LoadBitWidth / 8;
9124 return IsBigEndianTarget ? bigEndianByteAt(BW: LoadByteWidth, i: P.DestOffset)
9125 : littleEndianByteAt(BW: LoadByteWidth, i: P.DestOffset);
9126 };
9127
9128 std::optional<BaseIndexOffset> Base;
9129 SDValue Chain;
9130
9131 SmallPtrSet<LoadSDNode *, 8> Loads;
9132 std::optional<SDByteProvider> FirstByteProvider;
9133 int64_t FirstOffset = INT64_MAX;
9134
9135 // Check if all the bytes of the OR we are looking at are loaded from the same
9136 // base address. Collect bytes offsets from Base address in ByteOffsets.
9137 SmallVector<int64_t, 8> ByteOffsets(ByteWidth);
9138 unsigned ZeroExtendedBytes = 0;
9139 for (int i = ByteWidth - 1; i >= 0; --i) {
9140 auto P =
9141 calculateByteProvider(Op: SDValue(N, 0), Index: i, Depth: 0, /*VectorIndex*/ std::nullopt,
9142 /*StartingIndex*/ i);
9143 if (!P)
9144 return SDValue();
9145
9146 if (P->isConstantZero()) {
9147 // It's OK for the N most significant bytes to be 0, we can just
9148 // zero-extend the load.
9149 if (++ZeroExtendedBytes != (ByteWidth - static_cast<unsigned>(i)))
9150 return SDValue();
9151 continue;
9152 }
9153 assert(P->hasSrc() && "provenance should either be memory or zero");
9154 auto *L = cast<LoadSDNode>(Val: P->Src.value());
9155
9156 // All loads must share the same chain
9157 SDValue LChain = L->getChain();
9158 if (!Chain)
9159 Chain = LChain;
9160 else if (Chain != LChain)
9161 return SDValue();
9162
9163 // Loads must share the same base address
9164 BaseIndexOffset Ptr = BaseIndexOffset::match(N: L, DAG);
9165 int64_t ByteOffsetFromBase = 0;
9166
9167 // For vector loads, the expected load combine pattern will have an
9168 // ExtractElement for each index in the vector. While each of these
9169 // ExtractElements will be accessing the same base address as determined
9170 // by the load instruction, the actual bytes they interact with will differ
9171 // due to different ExtractElement indices. To accurately determine the
9172 // byte position of an ExtractElement, we offset the base load ptr with
9173 // the index multiplied by the byte size of each element in the vector.
9174 if (L->getMemoryVT().isVector()) {
9175 unsigned LoadWidthInBit = L->getMemoryVT().getScalarSizeInBits();
9176 if (LoadWidthInBit % 8 != 0)
9177 return SDValue();
9178 unsigned ByteOffsetFromVector = P->SrcOffset * LoadWidthInBit / 8;
9179 Ptr.addToOffset(VectorOff: ByteOffsetFromVector);
9180 }
9181
9182 if (!Base)
9183 Base = Ptr;
9184
9185 else if (!Base->equalBaseIndex(Other: Ptr, DAG, Off&: ByteOffsetFromBase))
9186 return SDValue();
9187
9188 // Calculate the offset of the current byte from the base address
9189 ByteOffsetFromBase += MemoryByteOffset(*P);
9190 ByteOffsets[i] = ByteOffsetFromBase;
9191
9192 // Remember the first byte load
9193 if (ByteOffsetFromBase < FirstOffset) {
9194 FirstByteProvider = P;
9195 FirstOffset = ByteOffsetFromBase;
9196 }
9197
9198 Loads.insert(Ptr: L);
9199 }
9200
9201 assert(!Loads.empty() && "All the bytes of the value must be loaded from "
9202 "memory, so there must be at least one load which produces the value");
9203 assert(Base && "Base address of the accessed memory location must be set");
9204 assert(FirstOffset != INT64_MAX && "First byte offset must be set");
9205
9206 bool NeedsZext = ZeroExtendedBytes > 0;
9207
9208 EVT MemVT =
9209 EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: (ByteWidth - ZeroExtendedBytes) * 8);
9210
9211 if (!MemVT.isSimple())
9212 return SDValue();
9213
9214 // Before legalize we can introduce too wide illegal loads which will be later
9215 // split into legal sized loads. This enables us to combine i64 load by i8
9216 // patterns to a couple of i32 loads on 32 bit targets.
9217 if (LegalOperations &&
9218 !TLI.isOperationLegal(Op: NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD,
9219 VT: MemVT))
9220 return SDValue();
9221
9222 // Check if the bytes of the OR we are looking at match with either big or
9223 // little endian value load
9224 std::optional<bool> IsBigEndian = isBigEndian(
9225 ByteOffsets: ArrayRef(ByteOffsets).drop_back(N: ZeroExtendedBytes), FirstOffset);
9226 if (!IsBigEndian)
9227 return SDValue();
9228
9229 assert(FirstByteProvider && "must be set");
9230
9231 // Ensure that the first byte is loaded from zero offset of the first load.
9232 // So the combined value can be loaded from the first load address.
9233 if (MemoryByteOffset(*FirstByteProvider) != 0)
9234 return SDValue();
9235 auto *FirstLoad = cast<LoadSDNode>(Val: FirstByteProvider->Src.value());
9236
9237 // The node we are looking at matches with the pattern, check if we can
9238 // replace it with a single (possibly zero-extended) load and bswap + shift if
9239 // needed.
9240
9241 // If the load needs byte swap check if the target supports it
9242 bool NeedsBswap = IsBigEndianTarget != *IsBigEndian;
9243
9244 // Before legalize we can introduce illegal bswaps which will be later
9245 // converted to an explicit bswap sequence. This way we end up with a single
9246 // load and byte shuffling instead of several loads and byte shuffling.
9247 // We do not introduce illegal bswaps when zero-extending as this tends to
9248 // introduce too many arithmetic instructions.
9249 if (NeedsBswap && (LegalOperations || NeedsZext) &&
9250 !TLI.isOperationLegal(Op: ISD::BSWAP, VT))
9251 return SDValue();
9252
9253 // If we need to bswap and zero extend, we have to insert a shift. Check that
9254 // it is legal.
9255 if (NeedsBswap && NeedsZext && LegalOperations &&
9256 !TLI.isOperationLegal(Op: ISD::SHL, VT))
9257 return SDValue();
9258
9259 // Check that a load of the wide type is both allowed and fast on the target
9260 unsigned Fast = 0;
9261 bool Allowed =
9262 TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT: MemVT,
9263 MMO: *FirstLoad->getMemOperand(), Fast: &Fast);
9264 if (!Allowed || !Fast)
9265 return SDValue();
9266
9267 SDValue NewLoad =
9268 DAG.getExtLoad(ExtType: NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, dl: SDLoc(N), VT,
9269 Chain, Ptr: FirstLoad->getBasePtr(),
9270 PtrInfo: FirstLoad->getPointerInfo(), MemVT, Alignment: FirstLoad->getAlign());
9271
9272 // Transfer chain users from old loads to the new load.
9273 for (LoadSDNode *L : Loads)
9274 DAG.makeEquivalentMemoryOrdering(OldLoad: L, NewMemOp: NewLoad);
9275
9276 if (!NeedsBswap)
9277 return NewLoad;
9278
9279 SDValue ShiftedLoad =
9280 NeedsZext ? DAG.getNode(Opcode: ISD::SHL, DL: SDLoc(N), VT, N1: NewLoad,
9281 N2: DAG.getShiftAmountConstant(Val: ZeroExtendedBytes * 8,
9282 VT, DL: SDLoc(N)))
9283 : NewLoad;
9284 return DAG.getNode(Opcode: ISD::BSWAP, DL: SDLoc(N), VT, Operand: ShiftedLoad);
9285}
9286
9287// If the target has andn, bsl, or a similar bit-select instruction,
9288// we want to unfold masked merge, with canonical pattern of:
9289// | A | |B|
9290// ((x ^ y) & m) ^ y
9291// | D |
9292// Into:
9293// (x & m) | (y & ~m)
9294// If y is a constant, m is not a 'not', and the 'andn' does not work with
9295// immediates, we unfold into a different pattern:
9296// ~(~x & m) & (m | y)
9297// If x is a constant, m is a 'not', and the 'andn' does not work with
9298// immediates, we unfold into a different pattern:
9299// (x | ~m) & ~(~m & ~y)
9300// NOTE: we don't unfold the pattern if 'xor' is actually a 'not', because at
9301// the very least that breaks andnpd / andnps patterns, and because those
9302// patterns are simplified in IR and shouldn't be created in the DAG
9303SDValue DAGCombiner::unfoldMaskedMerge(SDNode *N) {
9304 assert(N->getOpcode() == ISD::XOR);
9305
9306 // Don't touch 'not' (i.e. where y = -1).
9307 if (isAllOnesOrAllOnesSplat(V: N->getOperand(Num: 1)))
9308 return SDValue();
9309
9310 EVT VT = N->getValueType(ResNo: 0);
9311
9312 // There are 3 commutable operators in the pattern,
9313 // so we have to deal with 8 possible variants of the basic pattern.
9314 SDValue X, Y, M;
9315 auto matchAndXor = [&X, &Y, &M](SDValue And, unsigned XorIdx, SDValue Other) {
9316 if (And.getOpcode() != ISD::AND || !And.hasOneUse())
9317 return false;
9318 SDValue Xor = And.getOperand(i: XorIdx);
9319 if (Xor.getOpcode() != ISD::XOR || !Xor.hasOneUse())
9320 return false;
9321 SDValue Xor0 = Xor.getOperand(i: 0);
9322 SDValue Xor1 = Xor.getOperand(i: 1);
9323 // Don't touch 'not' (i.e. where y = -1).
9324 if (isAllOnesOrAllOnesSplat(V: Xor1))
9325 return false;
9326 if (Other == Xor0)
9327 std::swap(a&: Xor0, b&: Xor1);
9328 if (Other != Xor1)
9329 return false;
9330 X = Xor0;
9331 Y = Xor1;
9332 M = And.getOperand(i: XorIdx ? 0 : 1);
9333 return true;
9334 };
9335
9336 SDValue N0 = N->getOperand(Num: 0);
9337 SDValue N1 = N->getOperand(Num: 1);
9338 if (!matchAndXor(N0, 0, N1) && !matchAndXor(N0, 1, N1) &&
9339 !matchAndXor(N1, 0, N0) && !matchAndXor(N1, 1, N0))
9340 return SDValue();
9341
9342 // Don't do anything if the mask is constant. This should not be reachable.
9343 // InstCombine should have already unfolded this pattern, and DAGCombiner
9344 // probably shouldn't produce it, too.
9345 if (isa<ConstantSDNode>(Val: M.getNode()))
9346 return SDValue();
9347
9348 // We can transform if the target has AndNot
9349 if (!TLI.hasAndNot(X: M))
9350 return SDValue();
9351
9352 SDLoc DL(N);
9353
9354 // If Y is a constant, check that 'andn' works with immediates. Unless M is
9355 // a bitwise not that would already allow ANDN to be used.
9356 if (!TLI.hasAndNot(X: Y) && !isBitwiseNot(V: M)) {
9357 assert(TLI.hasAndNot(X) && "Only mask is a variable? Unreachable.");
9358 // If not, we need to do a bit more work to make sure andn is still used.
9359 SDValue NotX = DAG.getNOT(DL, Val: X, VT);
9360 SDValue LHS = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: NotX, N2: M);
9361 SDValue NotLHS = DAG.getNOT(DL, Val: LHS, VT);
9362 SDValue RHS = DAG.getNode(Opcode: ISD::OR, DL, VT, N1: M, N2: Y);
9363 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: NotLHS, N2: RHS);
9364 }
9365
9366 // If X is a constant and M is a bitwise not, check that 'andn' works with
9367 // immediates.
9368 if (!TLI.hasAndNot(X) && isBitwiseNot(V: M)) {
9369 assert(TLI.hasAndNot(Y) && "Only mask is a variable? Unreachable.");
9370 // If not, we need to do a bit more work to make sure andn is still used.
9371 SDValue NotM = M.getOperand(i: 0);
9372 SDValue LHS = DAG.getNode(Opcode: ISD::OR, DL, VT, N1: X, N2: NotM);
9373 SDValue NotY = DAG.getNOT(DL, Val: Y, VT);
9374 SDValue RHS = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: NotM, N2: NotY);
9375 SDValue NotRHS = DAG.getNOT(DL, Val: RHS, VT);
9376 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: LHS, N2: NotRHS);
9377 }
9378
9379 SDValue LHS = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: X, N2: M);
9380 SDValue NotM = DAG.getNOT(DL, Val: M, VT);
9381 SDValue RHS = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Y, N2: NotM);
9382
9383 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: LHS, N2: RHS);
9384}
9385
9386SDValue DAGCombiner::visitXOR(SDNode *N) {
9387 SDValue N0 = N->getOperand(Num: 0);
9388 SDValue N1 = N->getOperand(Num: 1);
9389 EVT VT = N0.getValueType();
9390 SDLoc DL(N);
9391
9392 // fold (xor undef, undef) -> 0. This is a common idiom (misuse).
9393 if (N0.isUndef() && N1.isUndef())
9394 return DAG.getConstant(Val: 0, DL, VT);
9395
9396 // fold (xor x, undef) -> undef
9397 if (N0.isUndef())
9398 return N0;
9399 if (N1.isUndef())
9400 return N1;
9401
9402 // fold (xor c1, c2) -> c1^c2
9403 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::XOR, DL, VT, Ops: {N0, N1}))
9404 return C;
9405
9406 // canonicalize constant to RHS
9407 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
9408 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
9409 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1, N2: N0);
9410
9411 // fold vector ops
9412 if (VT.isVector()) {
9413 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
9414 return FoldedVOp;
9415
9416 // fold (xor x, 0) -> x, vector edition
9417 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
9418 return N0;
9419 }
9420
9421 // fold (xor x, 0) -> x
9422 if (isNullConstant(V: N1))
9423 return N0;
9424
9425 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
9426 return NewSel;
9427
9428 // reassociate xor
9429 if (SDValue RXOR = reassociateOps(Opc: ISD::XOR, DL, N0, N1, Flags: N->getFlags()))
9430 return RXOR;
9431
9432 // Fold xor(vecreduce(x), vecreduce(y)) -> vecreduce(xor(x, y))
9433 if (SDValue SD =
9434 reassociateReduction(RedOpc: ISD::VECREDUCE_XOR, Opc: ISD::XOR, DL, VT, N0, N1))
9435 return SD;
9436
9437 // fold (a^b) -> (a|b) iff a and b share no bits.
9438 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::OR, VT)) &&
9439 DAG.haveNoCommonBitsSet(A: N0, B: N1)) {
9440 SDNodeFlags Flags;
9441 Flags.setDisjoint(true);
9442 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: N0, N2: N1, Flags);
9443 }
9444
9445 // look for 'add-like' folds:
9446 // XOR(N0,MIN_SIGNED_VALUE) == ADD(N0,MIN_SIGNED_VALUE)
9447 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::ADD, VT)) &&
9448 isMinSignedConstant(V: N1))
9449 if (SDValue Combined = visitADDLike(N))
9450 return Combined;
9451
9452 // fold !(x cc y) -> (x !cc y)
9453 unsigned N0Opcode = N0.getOpcode();
9454 SDValue LHS, RHS, CC;
9455 if (TLI.isConstTrueVal(N: N1) &&
9456 isSetCCEquivalent(N: N0, LHS, RHS, CC, /*MatchStrict*/ true)) {
9457 ISD::CondCode NotCC = ISD::getSetCCInverse(Operation: cast<CondCodeSDNode>(Val&: CC)->get(),
9458 Type: LHS.getValueType());
9459 if (!LegalOperations ||
9460 TLI.isCondCodeLegal(CC: NotCC, VT: LHS.getSimpleValueType())) {
9461 switch (N0Opcode) {
9462 default:
9463 llvm_unreachable("Unhandled SetCC Equivalent!");
9464 case ISD::SETCC:
9465 return DAG.getSetCC(DL: SDLoc(N0), VT, LHS, RHS, Cond: NotCC);
9466 case ISD::SELECT_CC:
9467 return DAG.getSelectCC(DL: SDLoc(N0), LHS, RHS, True: N0.getOperand(i: 2),
9468 False: N0.getOperand(i: 3), Cond: NotCC);
9469 case ISD::STRICT_FSETCC:
9470 case ISD::STRICT_FSETCCS: {
9471 if (N0.hasOneUse()) {
9472 // FIXME Can we handle multiple uses? Could we token factor the chain
9473 // results from the new/old setcc?
9474 SDValue SetCC =
9475 DAG.getSetCC(DL: SDLoc(N0), VT, LHS, RHS, Cond: NotCC,
9476 Chain: N0.getOperand(i: 0), IsSignaling: N0Opcode == ISD::STRICT_FSETCCS);
9477 CombineTo(N, Res: SetCC);
9478 DAG.ReplaceAllUsesOfValueWith(From: N0.getValue(R: 1), To: SetCC.getValue(R: 1));
9479 recursivelyDeleteUnusedNodes(N: N0.getNode());
9480 return SDValue(N, 0); // Return N so it doesn't get rechecked!
9481 }
9482 break;
9483 }
9484 }
9485 }
9486 }
9487
9488 // fold (not (zext (setcc x, y))) -> (zext (not (setcc x, y)))
9489 if (isOneConstant(V: N1) && N0Opcode == ISD::ZERO_EXTEND && N0.hasOneUse() &&
9490 isSetCCEquivalent(N: N0.getOperand(i: 0), LHS, RHS, CC)){
9491 SDValue V = N0.getOperand(i: 0);
9492 SDLoc DL0(N0);
9493 V = DAG.getNode(Opcode: ISD::XOR, DL: DL0, VT: V.getValueType(), N1: V,
9494 N2: DAG.getConstant(Val: 1, DL: DL0, VT: V.getValueType()));
9495 AddToWorklist(N: V.getNode());
9496 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: V);
9497 }
9498
9499 // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are setcc
9500 if (isOneConstant(V: N1) && VT == MVT::i1 && N0.hasOneUse() &&
9501 (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
9502 SDValue N00 = N0.getOperand(i: 0), N01 = N0.getOperand(i: 1);
9503 if (isOneUseSetCC(N: N01) || isOneUseSetCC(N: N00)) {
9504 unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
9505 N00 = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N00), VT, N1: N00, N2: N1); // N00 = ~N00
9506 N01 = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N01), VT, N1: N01, N2: N1); // N01 = ~N01
9507 AddToWorklist(N: N00.getNode()); AddToWorklist(N: N01.getNode());
9508 return DAG.getNode(Opcode: NewOpcode, DL, VT, N1: N00, N2: N01);
9509 }
9510 }
9511 // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are constants
9512 if (isAllOnesConstant(V: N1) && N0.hasOneUse() &&
9513 (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
9514 SDValue N00 = N0.getOperand(i: 0), N01 = N0.getOperand(i: 1);
9515 if (isa<ConstantSDNode>(Val: N01) || isa<ConstantSDNode>(Val: N00)) {
9516 unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
9517 N00 = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N00), VT, N1: N00, N2: N1); // N00 = ~N00
9518 N01 = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N01), VT, N1: N01, N2: N1); // N01 = ~N01
9519 AddToWorklist(N: N00.getNode()); AddToWorklist(N: N01.getNode());
9520 return DAG.getNode(Opcode: NewOpcode, DL, VT, N1: N00, N2: N01);
9521 }
9522 }
9523
9524 // fold (not (neg x)) -> (add X, -1)
9525 // FIXME: This can be generalized to (not (sub Y, X)) -> (add X, ~Y) if
9526 // Y is a constant or the subtract has a single use.
9527 if (isAllOnesConstant(V: N1) && N0.getOpcode() == ISD::SUB &&
9528 isNullConstant(V: N0.getOperand(i: 0))) {
9529 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 1),
9530 N2: DAG.getAllOnesConstant(DL, VT));
9531 }
9532
9533 // fold (not (add X, -1)) -> (neg X)
9534 if (isAllOnesConstant(V: N1) && N0.getOpcode() == ISD::ADD &&
9535 isAllOnesOrAllOnesSplat(V: N0.getOperand(i: 1))) {
9536 return DAG.getNegative(Val: N0.getOperand(i: 0), DL, VT);
9537 }
9538
9539 // fold (xor (and x, y), y) -> (and (not x), y)
9540 if (N0Opcode == ISD::AND && N0.hasOneUse() && N0->getOperand(Num: 1) == N1) {
9541 SDValue X = N0.getOperand(i: 0);
9542 SDValue NotX = DAG.getNOT(DL: SDLoc(X), Val: X, VT);
9543 AddToWorklist(N: NotX.getNode());
9544 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: NotX, N2: N1);
9545 }
9546
9547 // fold Y = sra (X, size(X)-1); xor (add (X, Y), Y) -> (abs X)
9548 if (!LegalOperations || hasOperation(Opcode: ISD::ABS, VT)) {
9549 SDValue A = N0Opcode == ISD::ADD ? N0 : N1;
9550 SDValue S = N0Opcode == ISD::SRA ? N0 : N1;
9551 if (A.getOpcode() == ISD::ADD && S.getOpcode() == ISD::SRA) {
9552 SDValue A0 = A.getOperand(i: 0), A1 = A.getOperand(i: 1);
9553 SDValue S0 = S.getOperand(i: 0);
9554 if ((A0 == S && A1 == S0) || (A1 == S && A0 == S0))
9555 if (ConstantSDNode *C = isConstOrConstSplat(N: S.getOperand(i: 1)))
9556 if (C->getAPIntValue() == (VT.getScalarSizeInBits() - 1))
9557 return DAG.getNode(Opcode: ISD::ABS, DL, VT, Operand: S0);
9558 }
9559 }
9560
9561 // fold (xor x, x) -> 0
9562 if (N0 == N1)
9563 return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
9564
9565 // fold (xor (shl 1, x), -1) -> (rotl ~1, x)
9566 // Here is a concrete example of this equivalence:
9567 // i16 x == 14
9568 // i16 shl == 1 << 14 == 16384 == 0b0100000000000000
9569 // i16 xor == ~(1 << 14) == 49151 == 0b1011111111111111
9570 //
9571 // =>
9572 //
9573 // i16 ~1 == 0b1111111111111110
9574 // i16 rol(~1, 14) == 0b1011111111111111
9575 //
9576 // Some additional tips to help conceptualize this transform:
9577 // - Try to see the operation as placing a single zero in a value of all ones.
9578 // - There exists no value for x which would allow the result to contain zero.
9579 // - Values of x larger than the bitwidth are undefined and do not require a
9580 // consistent result.
9581 // - Pushing the zero left requires shifting one bits in from the right.
9582 // A rotate left of ~1 is a nice way of achieving the desired result.
9583 if (TLI.isOperationLegalOrCustom(Op: ISD::ROTL, VT) && N0Opcode == ISD::SHL &&
9584 isAllOnesConstant(V: N1) && isOneConstant(V: N0.getOperand(i: 0))) {
9585 return DAG.getNode(Opcode: ISD::ROTL, DL, VT, N1: DAG.getConstant(Val: ~1, DL, VT),
9586 N2: N0.getOperand(i: 1));
9587 }
9588
9589 // Simplify: xor (op x...), (op y...) -> (op (xor x, y))
9590 if (N0Opcode == N1.getOpcode())
9591 if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
9592 return V;
9593
9594 if (SDValue R = foldLogicOfShifts(N, LogicOp: N0, ShiftOp: N1, DAG))
9595 return R;
9596 if (SDValue R = foldLogicOfShifts(N, LogicOp: N1, ShiftOp: N0, DAG))
9597 return R;
9598 if (SDValue R = foldLogicTreeOfShifts(N, LeftHand: N0, RightHand: N1, DAG))
9599 return R;
9600
9601 // Unfold ((x ^ y) & m) ^ y into (x & m) | (y & ~m) if profitable
9602 if (SDValue MM = unfoldMaskedMerge(N))
9603 return MM;
9604
9605 // Simplify the expression using non-local knowledge.
9606 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
9607 return SDValue(N, 0);
9608
9609 if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
9610 return Combined;
9611
9612 return SDValue();
9613}
9614
9615/// If we have a shift-by-constant of a bitwise logic op that itself has a
9616/// shift-by-constant operand with identical opcode, we may be able to convert
9617/// that into 2 independent shifts followed by the logic op. This is a
9618/// throughput improvement.
9619static SDValue combineShiftOfShiftedLogic(SDNode *Shift, SelectionDAG &DAG) {
9620 // Match a one-use bitwise logic op.
9621 SDValue LogicOp = Shift->getOperand(Num: 0);
9622 if (!LogicOp.hasOneUse())
9623 return SDValue();
9624
9625 unsigned LogicOpcode = LogicOp.getOpcode();
9626 if (LogicOpcode != ISD::AND && LogicOpcode != ISD::OR &&
9627 LogicOpcode != ISD::XOR)
9628 return SDValue();
9629
9630 // Find a matching one-use shift by constant.
9631 unsigned ShiftOpcode = Shift->getOpcode();
9632 SDValue C1 = Shift->getOperand(Num: 1);
9633 ConstantSDNode *C1Node = isConstOrConstSplat(N: C1);
9634 assert(C1Node && "Expected a shift with constant operand");
9635 const APInt &C1Val = C1Node->getAPIntValue();
9636 auto matchFirstShift = [&](SDValue V, SDValue &ShiftOp,
9637 const APInt *&ShiftAmtVal) {
9638 if (V.getOpcode() != ShiftOpcode || !V.hasOneUse())
9639 return false;
9640
9641 ConstantSDNode *ShiftCNode = isConstOrConstSplat(N: V.getOperand(i: 1));
9642 if (!ShiftCNode)
9643 return false;
9644
9645 // Capture the shifted operand and shift amount value.
9646 ShiftOp = V.getOperand(i: 0);
9647 ShiftAmtVal = &ShiftCNode->getAPIntValue();
9648
9649 // Shift amount types do not have to match their operand type, so check that
9650 // the constants are the same width.
9651 if (ShiftAmtVal->getBitWidth() != C1Val.getBitWidth())
9652 return false;
9653
9654 // The fold is not valid if the sum of the shift values doesn't fit in the
9655 // given shift amount type.
9656 bool Overflow = false;
9657 APInt NewShiftAmt = C1Val.uadd_ov(RHS: *ShiftAmtVal, Overflow);
9658 if (Overflow)
9659 return false;
9660
9661 // The fold is not valid if the sum of the shift values exceeds bitwidth.
9662 if (NewShiftAmt.uge(RHS: V.getScalarValueSizeInBits()))
9663 return false;
9664
9665 return true;
9666 };
9667
9668 // Logic ops are commutative, so check each operand for a match.
9669 SDValue X, Y;
9670 const APInt *C0Val;
9671 if (matchFirstShift(LogicOp.getOperand(i: 0), X, C0Val))
9672 Y = LogicOp.getOperand(i: 1);
9673 else if (matchFirstShift(LogicOp.getOperand(i: 1), X, C0Val))
9674 Y = LogicOp.getOperand(i: 0);
9675 else
9676 return SDValue();
9677
9678 // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1)
9679 SDLoc DL(Shift);
9680 EVT VT = Shift->getValueType(ResNo: 0);
9681 EVT ShiftAmtVT = Shift->getOperand(Num: 1).getValueType();
9682 SDValue ShiftSumC = DAG.getConstant(Val: *C0Val + C1Val, DL, VT: ShiftAmtVT);
9683 SDValue NewShift1 = DAG.getNode(Opcode: ShiftOpcode, DL, VT, N1: X, N2: ShiftSumC);
9684 SDValue NewShift2 = DAG.getNode(Opcode: ShiftOpcode, DL, VT, N1: Y, N2: C1);
9685 return DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: NewShift1, N2: NewShift2,
9686 Flags: LogicOp->getFlags());
9687}
9688
9689/// Handle transforms common to the three shifts, when the shift amount is a
9690/// constant.
9691/// We are looking for: (shift being one of shl/sra/srl)
9692/// shift (binop X, C0), C1
9693/// And want to transform into:
9694/// binop (shift X, C1), (shift C0, C1)
9695SDValue DAGCombiner::visitShiftByConstant(SDNode *N) {
9696 assert(isConstOrConstSplat(N->getOperand(1)) && "Expected constant operand");
9697
9698 // Do not turn a 'not' into a regular xor.
9699 if (isBitwiseNot(V: N->getOperand(Num: 0)))
9700 return SDValue();
9701
9702 // The inner binop must be one-use, since we want to replace it.
9703 SDValue LHS = N->getOperand(Num: 0);
9704 if (!LHS.hasOneUse() || !TLI.isDesirableToCommuteWithShift(N, Level))
9705 return SDValue();
9706
9707 // Fold shift(bitop(shift(x,c1),y), c2) -> bitop(shift(x,c1+c2),shift(y,c2)).
9708 if (SDValue R = combineShiftOfShiftedLogic(Shift: N, DAG))
9709 return R;
9710
9711 // We want to pull some binops through shifts, so that we have (and (shift))
9712 // instead of (shift (and)), likewise for add, or, xor, etc. This sort of
9713 // thing happens with address calculations, so it's important to canonicalize
9714 // it.
9715 switch (LHS.getOpcode()) {
9716 default:
9717 return SDValue();
9718 case ISD::OR:
9719 case ISD::XOR:
9720 case ISD::AND:
9721 break;
9722 case ISD::ADD:
9723 if (N->getOpcode() != ISD::SHL)
9724 return SDValue(); // only shl(add) not sr[al](add).
9725 break;
9726 }
9727
9728 // FIXME: disable this unless the input to the binop is a shift by a constant
9729 // or is copy/select. Enable this in other cases when figure out it's exactly
9730 // profitable.
9731 SDValue BinOpLHSVal = LHS.getOperand(i: 0);
9732 bool IsShiftByConstant = (BinOpLHSVal.getOpcode() == ISD::SHL ||
9733 BinOpLHSVal.getOpcode() == ISD::SRA ||
9734 BinOpLHSVal.getOpcode() == ISD::SRL) &&
9735 isa<ConstantSDNode>(Val: BinOpLHSVal.getOperand(i: 1));
9736 bool IsCopyOrSelect = BinOpLHSVal.getOpcode() == ISD::CopyFromReg ||
9737 BinOpLHSVal.getOpcode() == ISD::SELECT;
9738
9739 if (!IsShiftByConstant && !IsCopyOrSelect)
9740 return SDValue();
9741
9742 if (IsCopyOrSelect && N->hasOneUse())
9743 return SDValue();
9744
9745 // Attempt to fold the constants, shifting the binop RHS by the shift amount.
9746 SDLoc DL(N);
9747 EVT VT = N->getValueType(ResNo: 0);
9748 if (SDValue NewRHS = DAG.FoldConstantArithmetic(
9749 Opcode: N->getOpcode(), DL, VT, Ops: {LHS.getOperand(i: 1), N->getOperand(Num: 1)})) {
9750 SDValue NewShift = DAG.getNode(Opcode: N->getOpcode(), DL, VT, N1: LHS.getOperand(i: 0),
9751 N2: N->getOperand(Num: 1));
9752 return DAG.getNode(Opcode: LHS.getOpcode(), DL, VT, N1: NewShift, N2: NewRHS);
9753 }
9754
9755 return SDValue();
9756}
9757
9758SDValue DAGCombiner::distributeTruncateThroughAnd(SDNode *N) {
9759 assert(N->getOpcode() == ISD::TRUNCATE);
9760 assert(N->getOperand(0).getOpcode() == ISD::AND);
9761
9762 // (truncate:TruncVT (and N00, N01C)) -> (and (truncate:TruncVT N00), TruncC)
9763 EVT TruncVT = N->getValueType(ResNo: 0);
9764 if (N->hasOneUse() && N->getOperand(Num: 0).hasOneUse() &&
9765 TLI.isTypeDesirableForOp(ISD::AND, VT: TruncVT)) {
9766 SDValue N01 = N->getOperand(Num: 0).getOperand(i: 1);
9767 if (isConstantOrConstantVector(N: N01, /* NoOpaques */ true)) {
9768 SDLoc DL(N);
9769 SDValue N00 = N->getOperand(Num: 0).getOperand(i: 0);
9770 SDValue Trunc00 = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: TruncVT, Operand: N00);
9771 SDValue Trunc01 = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: TruncVT, Operand: N01);
9772 AddToWorklist(N: Trunc00.getNode());
9773 AddToWorklist(N: Trunc01.getNode());
9774 return DAG.getNode(Opcode: ISD::AND, DL, VT: TruncVT, N1: Trunc00, N2: Trunc01);
9775 }
9776 }
9777
9778 return SDValue();
9779}
9780
9781SDValue DAGCombiner::visitRotate(SDNode *N) {
9782 SDLoc dl(N);
9783 SDValue N0 = N->getOperand(Num: 0);
9784 SDValue N1 = N->getOperand(Num: 1);
9785 EVT VT = N->getValueType(ResNo: 0);
9786 unsigned Bitsize = VT.getScalarSizeInBits();
9787
9788 // fold (rot x, 0) -> x
9789 if (isNullOrNullSplat(V: N1))
9790 return N0;
9791
9792 // fold (rot x, c) -> x iff (c % BitSize) == 0
9793 if (isPowerOf2_32(Value: Bitsize) && Bitsize > 1) {
9794 APInt ModuloMask(N1.getScalarValueSizeInBits(), Bitsize - 1);
9795 if (DAG.MaskedValueIsZero(Op: N1, Mask: ModuloMask))
9796 return N0;
9797 }
9798
9799 // fold (rot x, c) -> (rot x, c % BitSize)
9800 bool OutOfRange = false;
9801 auto MatchOutOfRange = [Bitsize, &OutOfRange](ConstantSDNode *C) {
9802 OutOfRange |= C->getAPIntValue().uge(RHS: Bitsize);
9803 return true;
9804 };
9805 if (ISD::matchUnaryPredicate(Op: N1, Match: MatchOutOfRange) && OutOfRange) {
9806 EVT AmtVT = N1.getValueType();
9807 SDValue Bits = DAG.getConstant(Val: Bitsize, DL: dl, VT: AmtVT);
9808 if (SDValue Amt =
9809 DAG.FoldConstantArithmetic(Opcode: ISD::UREM, DL: dl, VT: AmtVT, Ops: {N1, Bits}))
9810 return DAG.getNode(Opcode: N->getOpcode(), DL: dl, VT, N1: N0, N2: Amt);
9811 }
9812
9813 // rot i16 X, 8 --> bswap X
9814 auto *RotAmtC = isConstOrConstSplat(N: N1);
9815 if (RotAmtC && RotAmtC->getAPIntValue() == 8 &&
9816 VT.getScalarSizeInBits() == 16 && hasOperation(Opcode: ISD::BSWAP, VT))
9817 return DAG.getNode(Opcode: ISD::BSWAP, DL: dl, VT, Operand: N0);
9818
9819 // Simplify the operands using demanded-bits information.
9820 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
9821 return SDValue(N, 0);
9822
9823 // fold (rot* x, (trunc (and y, c))) -> (rot* x, (and (trunc y), (trunc c))).
9824 if (N1.getOpcode() == ISD::TRUNCATE &&
9825 N1.getOperand(i: 0).getOpcode() == ISD::AND) {
9826 if (SDValue NewOp1 = distributeTruncateThroughAnd(N: N1.getNode()))
9827 return DAG.getNode(Opcode: N->getOpcode(), DL: dl, VT, N1: N0, N2: NewOp1);
9828 }
9829
9830 unsigned NextOp = N0.getOpcode();
9831
9832 // fold (rot* (rot* x, c2), c1)
9833 // -> (rot* x, ((c1 % bitsize) +- (c2 % bitsize) + bitsize) % bitsize)
9834 if (NextOp == ISD::ROTL || NextOp == ISD::ROTR) {
9835 SDNode *C1 = DAG.isConstantIntBuildVectorOrConstantInt(N: N1);
9836 SDNode *C2 = DAG.isConstantIntBuildVectorOrConstantInt(N: N0.getOperand(i: 1));
9837 if (C1 && C2 && C1->getValueType(ResNo: 0) == C2->getValueType(ResNo: 0)) {
9838 EVT ShiftVT = C1->getValueType(ResNo: 0);
9839 bool SameSide = (N->getOpcode() == NextOp);
9840 unsigned CombineOp = SameSide ? ISD::ADD : ISD::SUB;
9841 SDValue BitsizeC = DAG.getConstant(Val: Bitsize, DL: dl, VT: ShiftVT);
9842 SDValue Norm1 = DAG.FoldConstantArithmetic(Opcode: ISD::UREM, DL: dl, VT: ShiftVT,
9843 Ops: {N1, BitsizeC});
9844 SDValue Norm2 = DAG.FoldConstantArithmetic(Opcode: ISD::UREM, DL: dl, VT: ShiftVT,
9845 Ops: {N0.getOperand(i: 1), BitsizeC});
9846 if (Norm1 && Norm2)
9847 if (SDValue CombinedShift = DAG.FoldConstantArithmetic(
9848 Opcode: CombineOp, DL: dl, VT: ShiftVT, Ops: {Norm1, Norm2})) {
9849 CombinedShift = DAG.FoldConstantArithmetic(Opcode: ISD::ADD, DL: dl, VT: ShiftVT,
9850 Ops: {CombinedShift, BitsizeC});
9851 SDValue CombinedShiftNorm = DAG.FoldConstantArithmetic(
9852 Opcode: ISD::UREM, DL: dl, VT: ShiftVT, Ops: {CombinedShift, BitsizeC});
9853 return DAG.getNode(Opcode: N->getOpcode(), DL: dl, VT, N1: N0->getOperand(Num: 0),
9854 N2: CombinedShiftNorm);
9855 }
9856 }
9857 }
9858 return SDValue();
9859}
9860
9861SDValue DAGCombiner::visitSHL(SDNode *N) {
9862 SDValue N0 = N->getOperand(Num: 0);
9863 SDValue N1 = N->getOperand(Num: 1);
9864 if (SDValue V = DAG.simplifyShift(X: N0, Y: N1))
9865 return V;
9866
9867 SDLoc DL(N);
9868 EVT VT = N0.getValueType();
9869 EVT ShiftVT = N1.getValueType();
9870 unsigned OpSizeInBits = VT.getScalarSizeInBits();
9871
9872 // fold (shl c1, c2) -> c1<<c2
9873 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::SHL, DL, VT, Ops: {N0, N1}))
9874 return C;
9875
9876 // fold vector ops
9877 if (VT.isVector()) {
9878 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
9879 return FoldedVOp;
9880
9881 BuildVectorSDNode *N1CV = dyn_cast<BuildVectorSDNode>(Val&: N1);
9882 // If setcc produces all-one true value then:
9883 // (shl (and (setcc) N01CV) N1CV) -> (and (setcc) N01CV<<N1CV)
9884 if (N1CV && N1CV->isConstant()) {
9885 if (N0.getOpcode() == ISD::AND) {
9886 SDValue N00 = N0->getOperand(Num: 0);
9887 SDValue N01 = N0->getOperand(Num: 1);
9888 BuildVectorSDNode *N01CV = dyn_cast<BuildVectorSDNode>(Val&: N01);
9889
9890 if (N01CV && N01CV->isConstant() && N00.getOpcode() == ISD::SETCC &&
9891 TLI.getBooleanContents(Type: N00.getOperand(i: 0).getValueType()) ==
9892 TargetLowering::ZeroOrNegativeOneBooleanContent) {
9893 if (SDValue C =
9894 DAG.FoldConstantArithmetic(Opcode: ISD::SHL, DL, VT, Ops: {N01, N1}))
9895 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N00, N2: C);
9896 }
9897 }
9898 }
9899 }
9900
9901 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
9902 return NewSel;
9903
9904 // if (shl x, c) is known to be zero, return 0
9905 if (DAG.MaskedValueIsZero(Op: SDValue(N, 0), Mask: APInt::getAllOnes(numBits: OpSizeInBits)))
9906 return DAG.getConstant(Val: 0, DL, VT);
9907
9908 // fold (shl x, (trunc (and y, c))) -> (shl x, (and (trunc y), (trunc c))).
9909 if (N1.getOpcode() == ISD::TRUNCATE &&
9910 N1.getOperand(i: 0).getOpcode() == ISD::AND) {
9911 if (SDValue NewOp1 = distributeTruncateThroughAnd(N: N1.getNode()))
9912 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0, N2: NewOp1);
9913 }
9914
9915 // fold (shl (shl x, c1), c2) -> 0 or (shl x, (add c1, c2))
9916 if (N0.getOpcode() == ISD::SHL) {
9917 auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
9918 ConstantSDNode *RHS) {
9919 APInt c1 = LHS->getAPIntValue();
9920 APInt c2 = RHS->getAPIntValue();
9921 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
9922 return (c1 + c2).uge(RHS: OpSizeInBits);
9923 };
9924 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchOutOfRange))
9925 return DAG.getConstant(Val: 0, DL, VT);
9926
9927 auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
9928 ConstantSDNode *RHS) {
9929 APInt c1 = LHS->getAPIntValue();
9930 APInt c2 = RHS->getAPIntValue();
9931 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
9932 return (c1 + c2).ult(RHS: OpSizeInBits);
9933 };
9934 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchInRange)) {
9935 SDValue Sum = DAG.getNode(Opcode: ISD::ADD, DL, VT: ShiftVT, N1, N2: N0.getOperand(i: 1));
9936 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0.getOperand(i: 0), N2: Sum);
9937 }
9938 }
9939
9940 // fold (shl (ext (shl x, c1)), c2) -> (shl (ext x), (add c1, c2))
9941 // For this to be valid, the second form must not preserve any of the bits
9942 // that are shifted out by the inner shift in the first form. This means
9943 // the outer shift size must be >= the number of bits added by the ext.
9944 // As a corollary, we don't care what kind of ext it is.
9945 if ((N0.getOpcode() == ISD::ZERO_EXTEND ||
9946 N0.getOpcode() == ISD::ANY_EXTEND ||
9947 N0.getOpcode() == ISD::SIGN_EXTEND) &&
9948 N0.getOperand(i: 0).getOpcode() == ISD::SHL) {
9949 SDValue N0Op0 = N0.getOperand(i: 0);
9950 SDValue InnerShiftAmt = N0Op0.getOperand(i: 1);
9951 EVT InnerVT = N0Op0.getValueType();
9952 uint64_t InnerBitwidth = InnerVT.getScalarSizeInBits();
9953
9954 auto MatchOutOfRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
9955 ConstantSDNode *RHS) {
9956 APInt c1 = LHS->getAPIntValue();
9957 APInt c2 = RHS->getAPIntValue();
9958 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
9959 return c2.uge(RHS: OpSizeInBits - InnerBitwidth) &&
9960 (c1 + c2).uge(RHS: OpSizeInBits);
9961 };
9962 if (ISD::matchBinaryPredicate(LHS: InnerShiftAmt, RHS: N1, Match: MatchOutOfRange,
9963 /*AllowUndefs*/ false,
9964 /*AllowTypeMismatch*/ true))
9965 return DAG.getConstant(Val: 0, DL, VT);
9966
9967 auto MatchInRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
9968 ConstantSDNode *RHS) {
9969 APInt c1 = LHS->getAPIntValue();
9970 APInt c2 = RHS->getAPIntValue();
9971 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
9972 return c2.uge(RHS: OpSizeInBits - InnerBitwidth) &&
9973 (c1 + c2).ult(RHS: OpSizeInBits);
9974 };
9975 if (ISD::matchBinaryPredicate(LHS: InnerShiftAmt, RHS: N1, Match: MatchInRange,
9976 /*AllowUndefs*/ false,
9977 /*AllowTypeMismatch*/ true)) {
9978 SDValue Ext = DAG.getNode(Opcode: N0.getOpcode(), DL, VT, Operand: N0Op0.getOperand(i: 0));
9979 SDValue Sum = DAG.getZExtOrTrunc(Op: InnerShiftAmt, DL, VT: ShiftVT);
9980 Sum = DAG.getNode(Opcode: ISD::ADD, DL, VT: ShiftVT, N1: Sum, N2: N1);
9981 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Ext, N2: Sum);
9982 }
9983 }
9984
9985 // fold (shl (zext (srl x, C)), C) -> (zext (shl (srl x, C), C))
9986 // Only fold this if the inner zext has no other uses to avoid increasing
9987 // the total number of instructions.
9988 if (N0.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse() &&
9989 N0.getOperand(i: 0).getOpcode() == ISD::SRL) {
9990 SDValue N0Op0 = N0.getOperand(i: 0);
9991 SDValue InnerShiftAmt = N0Op0.getOperand(i: 1);
9992
9993 auto MatchEqual = [VT](ConstantSDNode *LHS, ConstantSDNode *RHS) {
9994 APInt c1 = LHS->getAPIntValue();
9995 APInt c2 = RHS->getAPIntValue();
9996 zeroExtendToMatch(LHS&: c1, RHS&: c2);
9997 return c1.ult(RHS: VT.getScalarSizeInBits()) && (c1 == c2);
9998 };
9999 if (ISD::matchBinaryPredicate(LHS: InnerShiftAmt, RHS: N1, Match: MatchEqual,
10000 /*AllowUndefs*/ false,
10001 /*AllowTypeMismatch*/ true)) {
10002 EVT InnerShiftAmtVT = N0Op0.getOperand(i: 1).getValueType();
10003 SDValue NewSHL = DAG.getZExtOrTrunc(Op: N1, DL, VT: InnerShiftAmtVT);
10004 NewSHL = DAG.getNode(Opcode: ISD::SHL, DL, VT: N0Op0.getValueType(), N1: N0Op0, N2: NewSHL);
10005 AddToWorklist(N: NewSHL.getNode());
10006 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL: SDLoc(N0), VT, Operand: NewSHL);
10007 }
10008 }
10009
10010 if (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SRA) {
10011 auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
10012 ConstantSDNode *RHS) {
10013 const APInt &LHSC = LHS->getAPIntValue();
10014 const APInt &RHSC = RHS->getAPIntValue();
10015 return LHSC.ult(RHS: OpSizeInBits) && RHSC.ult(RHS: OpSizeInBits) &&
10016 LHSC.getZExtValue() <= RHSC.getZExtValue();
10017 };
10018
10019 // fold (shl (sr[la] exact X, C1), C2) -> (shl X, (C2-C1)) if C1 <= C2
10020 // fold (shl (sr[la] exact X, C1), C2) -> (sr[la] X, (C2-C1)) if C1 >= C2
10021 if (N0->getFlags().hasExact()) {
10022 if (ISD::matchBinaryPredicate(LHS: N0.getOperand(i: 1), RHS: N1, Match: MatchShiftAmount,
10023 /*AllowUndefs*/ false,
10024 /*AllowTypeMismatch*/ true)) {
10025 SDValue N01 = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1), DL, VT: ShiftVT);
10026 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftVT, N1, N2: N01);
10027 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0.getOperand(i: 0), N2: Diff);
10028 }
10029 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchShiftAmount,
10030 /*AllowUndefs*/ false,
10031 /*AllowTypeMismatch*/ true)) {
10032 SDValue N01 = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1), DL, VT: ShiftVT);
10033 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftVT, N1: N01, N2: N1);
10034 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, N1: N0.getOperand(i: 0), N2: Diff);
10035 }
10036 }
10037
10038 // fold (shl (srl x, c1), c2) -> (and (shl x, (sub c2, c1), MASK) or
10039 // (and (srl x, (sub c1, c2), MASK)
10040 // Only fold this if the inner shift has no other uses -- if it does,
10041 // folding this will increase the total number of instructions.
10042 if (N0.getOpcode() == ISD::SRL &&
10043 (N0.getOperand(i: 1) == N1 || N0.hasOneUse()) &&
10044 TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
10045 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchShiftAmount,
10046 /*AllowUndefs*/ false,
10047 /*AllowTypeMismatch*/ true)) {
10048 SDValue N01 = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1), DL, VT: ShiftVT);
10049 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftVT, N1: N01, N2: N1);
10050 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
10051 Mask = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Mask, N2: N01);
10052 Mask = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Mask, N2: Diff);
10053 SDValue Shift = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0.getOperand(i: 0), N2: Diff);
10054 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Shift, N2: Mask);
10055 }
10056 if (ISD::matchBinaryPredicate(LHS: N0.getOperand(i: 1), RHS: N1, Match: MatchShiftAmount,
10057 /*AllowUndefs*/ false,
10058 /*AllowTypeMismatch*/ true)) {
10059 SDValue N01 = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1), DL, VT: ShiftVT);
10060 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftVT, N1, N2: N01);
10061 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
10062 Mask = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Mask, N2: N1);
10063 SDValue Shift = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0.getOperand(i: 0), N2: Diff);
10064 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Shift, N2: Mask);
10065 }
10066 }
10067 }
10068
10069 // fold (shl (sra x, c1), c1) -> (and x, (shl -1, c1))
10070 if (N0.getOpcode() == ISD::SRA && N1 == N0.getOperand(i: 1) &&
10071 isConstantOrConstantVector(N: N1, /* No Opaques */ NoOpaques: true)) {
10072 SDValue AllBits = DAG.getAllOnesConstant(DL, VT);
10073 SDValue HiBitsMask = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: AllBits, N2: N1);
10074 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0.getOperand(i: 0), N2: HiBitsMask);
10075 }
10076
10077 // fold (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2)
10078 // fold (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2)
10079 // Variant of version done on multiply, except mul by a power of 2 is turned
10080 // into a shift.
10081 if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::OR) &&
10082 N0->hasOneUse() && TLI.isDesirableToCommuteWithShift(N, Level)) {
10083 SDValue N01 = N0.getOperand(i: 1);
10084 if (SDValue Shl1 =
10085 DAG.FoldConstantArithmetic(Opcode: ISD::SHL, DL: SDLoc(N1), VT, Ops: {N01, N1})) {
10086 SDValue Shl0 = DAG.getNode(Opcode: ISD::SHL, DL: SDLoc(N0), VT, N1: N0.getOperand(i: 0), N2: N1);
10087 AddToWorklist(N: Shl0.getNode());
10088 SDNodeFlags Flags;
10089 // Preserve the disjoint flag for Or.
10090 if (N0.getOpcode() == ISD::OR && N0->getFlags().hasDisjoint())
10091 Flags.setDisjoint(true);
10092 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, N1: Shl0, N2: Shl1, Flags);
10093 }
10094 }
10095
10096 // fold (shl (sext (add_nsw x, c1)), c2) -> (add (shl (sext x), c2), c1 << c2)
10097 // TODO: Add zext/add_nuw variant with suitable test coverage
10098 // TODO: Should we limit this with isLegalAddImmediate?
10099 if (N0.getOpcode() == ISD::SIGN_EXTEND &&
10100 N0.getOperand(i: 0).getOpcode() == ISD::ADD &&
10101 N0.getOperand(i: 0)->getFlags().hasNoSignedWrap() && N0->hasOneUse() &&
10102 N0.getOperand(i: 0)->hasOneUse() &&
10103 TLI.isDesirableToCommuteWithShift(N, Level)) {
10104 SDValue Add = N0.getOperand(i: 0);
10105 SDLoc DL(N0);
10106 if (SDValue ExtC = DAG.FoldConstantArithmetic(Opcode: N0.getOpcode(), DL, VT,
10107 Ops: {Add.getOperand(i: 1)})) {
10108 if (SDValue ShlC =
10109 DAG.FoldConstantArithmetic(Opcode: ISD::SHL, DL, VT, Ops: {ExtC, N1})) {
10110 SDValue ExtX = DAG.getNode(Opcode: N0.getOpcode(), DL, VT, Operand: Add.getOperand(i: 0));
10111 SDValue ShlX = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: ExtX, N2: N1);
10112 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: ShlX, N2: ShlC);
10113 }
10114 }
10115 }
10116
10117 // fold (shl (mul x, c1), c2) -> (mul x, c1 << c2)
10118 if (N0.getOpcode() == ISD::MUL && N0->hasOneUse()) {
10119 SDValue N01 = N0.getOperand(i: 1);
10120 if (SDValue Shl =
10121 DAG.FoldConstantArithmetic(Opcode: ISD::SHL, DL: SDLoc(N1), VT, Ops: {N01, N1}))
10122 return DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: N0.getOperand(i: 0), N2: Shl);
10123 }
10124
10125 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
10126 if (N1C && !N1C->isOpaque())
10127 if (SDValue NewSHL = visitShiftByConstant(N))
10128 return NewSHL;
10129
10130 // fold (shl X, cttz(Y)) -> (mul (Y & -Y), X) if cttz is unsupported on the
10131 // target.
10132 if (((N1.getOpcode() == ISD::CTTZ &&
10133 VT.getScalarSizeInBits() <= ShiftVT.getScalarSizeInBits()) ||
10134 N1.getOpcode() == ISD::CTTZ_ZERO_UNDEF) &&
10135 N1.hasOneUse() && !TLI.isOperationLegalOrCustom(Op: ISD::CTTZ, VT: ShiftVT) &&
10136 TLI.isOperationLegalOrCustom(Op: ISD::MUL, VT)) {
10137 SDValue Y = N1.getOperand(i: 0);
10138 SDLoc DL(N);
10139 SDValue NegY = DAG.getNegative(Val: Y, DL, VT: ShiftVT);
10140 SDValue And =
10141 DAG.getZExtOrTrunc(Op: DAG.getNode(Opcode: ISD::AND, DL, VT: ShiftVT, N1: Y, N2: NegY), DL, VT);
10142 return DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: And, N2: N0);
10143 }
10144
10145 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
10146 return SDValue(N, 0);
10147
10148 // Fold (shl (vscale * C0), C1) to (vscale * (C0 << C1)).
10149 if (N0.getOpcode() == ISD::VSCALE && N1C) {
10150 const APInt &C0 = N0.getConstantOperandAPInt(i: 0);
10151 const APInt &C1 = N1C->getAPIntValue();
10152 return DAG.getVScale(DL, VT, MulImm: C0 << C1);
10153 }
10154
10155 // Fold (shl step_vector(C0), C1) to (step_vector(C0 << C1)).
10156 APInt ShlVal;
10157 if (N0.getOpcode() == ISD::STEP_VECTOR &&
10158 ISD::isConstantSplatVector(N: N1.getNode(), SplatValue&: ShlVal)) {
10159 const APInt &C0 = N0.getConstantOperandAPInt(i: 0);
10160 if (ShlVal.ult(RHS: C0.getBitWidth())) {
10161 APInt NewStep = C0 << ShlVal;
10162 return DAG.getStepVector(DL, ResVT: VT, StepVal: NewStep);
10163 }
10164 }
10165
10166 return SDValue();
10167}
10168
10169// Transform a right shift of a multiply into a multiply-high.
10170// Examples:
10171// (srl (mul (zext i32:$a to i64), (zext i32:$a to i64)), 32) -> (mulhu $a, $b)
10172// (sra (mul (sext i32:$a to i64), (sext i32:$a to i64)), 32) -> (mulhs $a, $b)
10173static SDValue combineShiftToMULH(SDNode *N, const SDLoc &DL, SelectionDAG &DAG,
10174 const TargetLowering &TLI) {
10175 assert((N->getOpcode() == ISD::SRL || N->getOpcode() == ISD::SRA) &&
10176 "SRL or SRA node is required here!");
10177
10178 // Check the shift amount. Proceed with the transformation if the shift
10179 // amount is constant.
10180 ConstantSDNode *ShiftAmtSrc = isConstOrConstSplat(N: N->getOperand(Num: 1));
10181 if (!ShiftAmtSrc)
10182 return SDValue();
10183
10184 // The operation feeding into the shift must be a multiply.
10185 SDValue ShiftOperand = N->getOperand(Num: 0);
10186 if (ShiftOperand.getOpcode() != ISD::MUL)
10187 return SDValue();
10188
10189 // Both operands must be equivalent extend nodes.
10190 SDValue LeftOp = ShiftOperand.getOperand(i: 0);
10191 SDValue RightOp = ShiftOperand.getOperand(i: 1);
10192
10193 bool IsSignExt = LeftOp.getOpcode() == ISD::SIGN_EXTEND;
10194 bool IsZeroExt = LeftOp.getOpcode() == ISD::ZERO_EXTEND;
10195
10196 if (!IsSignExt && !IsZeroExt)
10197 return SDValue();
10198
10199 EVT NarrowVT = LeftOp.getOperand(i: 0).getValueType();
10200 unsigned NarrowVTSize = NarrowVT.getScalarSizeInBits();
10201
10202 // return true if U may use the lower bits of its operands
10203 auto UserOfLowerBits = [NarrowVTSize](SDNode *U) {
10204 if (U->getOpcode() != ISD::SRL && U->getOpcode() != ISD::SRA) {
10205 return true;
10206 }
10207 ConstantSDNode *UShiftAmtSrc = isConstOrConstSplat(N: U->getOperand(Num: 1));
10208 if (!UShiftAmtSrc) {
10209 return true;
10210 }
10211 unsigned UShiftAmt = UShiftAmtSrc->getZExtValue();
10212 return UShiftAmt < NarrowVTSize;
10213 };
10214
10215 // If the lower part of the MUL is also used and MUL_LOHI is supported
10216 // do not introduce the MULH in favor of MUL_LOHI
10217 unsigned MulLoHiOp = IsSignExt ? ISD::SMUL_LOHI : ISD::UMUL_LOHI;
10218 if (!ShiftOperand.hasOneUse() &&
10219 TLI.isOperationLegalOrCustom(Op: MulLoHiOp, VT: NarrowVT) &&
10220 llvm::any_of(Range: ShiftOperand->uses(), P: UserOfLowerBits)) {
10221 return SDValue();
10222 }
10223
10224 SDValue MulhRightOp;
10225 if (ConstantSDNode *Constant = isConstOrConstSplat(N: RightOp)) {
10226 unsigned ActiveBits = IsSignExt
10227 ? Constant->getAPIntValue().getSignificantBits()
10228 : Constant->getAPIntValue().getActiveBits();
10229 if (ActiveBits > NarrowVTSize)
10230 return SDValue();
10231 MulhRightOp = DAG.getConstant(
10232 Val: Constant->getAPIntValue().trunc(width: NarrowVT.getScalarSizeInBits()), DL,
10233 VT: NarrowVT);
10234 } else {
10235 if (LeftOp.getOpcode() != RightOp.getOpcode())
10236 return SDValue();
10237 // Check that the two extend nodes are the same type.
10238 if (NarrowVT != RightOp.getOperand(i: 0).getValueType())
10239 return SDValue();
10240 MulhRightOp = RightOp.getOperand(i: 0);
10241 }
10242
10243 EVT WideVT = LeftOp.getValueType();
10244 // Proceed with the transformation if the wide types match.
10245 assert((WideVT == RightOp.getValueType()) &&
10246 "Cannot have a multiply node with two different operand types.");
10247
10248 // Proceed with the transformation if the wide type is twice as large
10249 // as the narrow type.
10250 if (WideVT.getScalarSizeInBits() != 2 * NarrowVTSize)
10251 return SDValue();
10252
10253 // Check the shift amount with the narrow type size.
10254 // Proceed with the transformation if the shift amount is the width
10255 // of the narrow type.
10256 unsigned ShiftAmt = ShiftAmtSrc->getZExtValue();
10257 if (ShiftAmt != NarrowVTSize)
10258 return SDValue();
10259
10260 // If the operation feeding into the MUL is a sign extend (sext),
10261 // we use mulhs. Othewise, zero extends (zext) use mulhu.
10262 unsigned MulhOpcode = IsSignExt ? ISD::MULHS : ISD::MULHU;
10263
10264 // Combine to mulh if mulh is legal/custom for the narrow type on the target
10265 // or if it is a vector type then we could transform to an acceptable type and
10266 // rely on legalization to split/combine the result.
10267 if (NarrowVT.isVector()) {
10268 EVT TransformVT = TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT: NarrowVT);
10269 if (TransformVT.getVectorElementType() != NarrowVT.getVectorElementType() ||
10270 !TLI.isOperationLegalOrCustom(Op: MulhOpcode, VT: TransformVT))
10271 return SDValue();
10272 } else {
10273 if (!TLI.isOperationLegalOrCustom(Op: MulhOpcode, VT: NarrowVT))
10274 return SDValue();
10275 }
10276
10277 SDValue Result =
10278 DAG.getNode(Opcode: MulhOpcode, DL, VT: NarrowVT, N1: LeftOp.getOperand(i: 0), N2: MulhRightOp);
10279 bool IsSigned = N->getOpcode() == ISD::SRA;
10280 return DAG.getExtOrTrunc(IsSigned, Op: Result, DL, VT: WideVT);
10281}
10282
10283// fold (bswap (logic_op(bswap(x),y))) -> logic_op(x,bswap(y))
10284// This helper function accept SDNode with opcode ISD::BSWAP and ISD::BITREVERSE
10285static SDValue foldBitOrderCrossLogicOp(SDNode *N, SelectionDAG &DAG) {
10286 unsigned Opcode = N->getOpcode();
10287 if (Opcode != ISD::BSWAP && Opcode != ISD::BITREVERSE)
10288 return SDValue();
10289
10290 SDValue N0 = N->getOperand(Num: 0);
10291 EVT VT = N->getValueType(ResNo: 0);
10292 SDLoc DL(N);
10293 if (ISD::isBitwiseLogicOp(Opcode: N0.getOpcode()) && N0.hasOneUse()) {
10294 SDValue OldLHS = N0.getOperand(i: 0);
10295 SDValue OldRHS = N0.getOperand(i: 1);
10296
10297 // If both operands are bswap/bitreverse, ignore the multiuse
10298 // Otherwise need to ensure logic_op and bswap/bitreverse(x) have one use.
10299 if (OldLHS.getOpcode() == Opcode && OldRHS.getOpcode() == Opcode) {
10300 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, N1: OldLHS.getOperand(i: 0),
10301 N2: OldRHS.getOperand(i: 0));
10302 }
10303
10304 if (OldLHS.getOpcode() == Opcode && OldLHS.hasOneUse()) {
10305 SDValue NewBitReorder = DAG.getNode(Opcode, DL, VT, Operand: OldRHS);
10306 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, N1: OldLHS.getOperand(i: 0),
10307 N2: NewBitReorder);
10308 }
10309
10310 if (OldRHS.getOpcode() == Opcode && OldRHS.hasOneUse()) {
10311 SDValue NewBitReorder = DAG.getNode(Opcode, DL, VT, Operand: OldLHS);
10312 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, N1: NewBitReorder,
10313 N2: OldRHS.getOperand(i: 0));
10314 }
10315 }
10316 return SDValue();
10317}
10318
10319SDValue DAGCombiner::visitSRA(SDNode *N) {
10320 SDValue N0 = N->getOperand(Num: 0);
10321 SDValue N1 = N->getOperand(Num: 1);
10322 if (SDValue V = DAG.simplifyShift(X: N0, Y: N1))
10323 return V;
10324
10325 SDLoc DL(N);
10326 EVT VT = N0.getValueType();
10327 unsigned OpSizeInBits = VT.getScalarSizeInBits();
10328
10329 // fold (sra c1, c2) -> (sra c1, c2)
10330 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::SRA, DL, VT, Ops: {N0, N1}))
10331 return C;
10332
10333 // Arithmetic shifting an all-sign-bit value is a no-op.
10334 // fold (sra 0, x) -> 0
10335 // fold (sra -1, x) -> -1
10336 if (DAG.ComputeNumSignBits(Op: N0) == OpSizeInBits)
10337 return N0;
10338
10339 // fold vector ops
10340 if (VT.isVector())
10341 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
10342 return FoldedVOp;
10343
10344 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
10345 return NewSel;
10346
10347 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
10348
10349 // fold (sra (sra x, c1), c2) -> (sra x, (add c1, c2))
10350 // clamp (add c1, c2) to max shift.
10351 if (N0.getOpcode() == ISD::SRA) {
10352 EVT ShiftVT = N1.getValueType();
10353 EVT ShiftSVT = ShiftVT.getScalarType();
10354 SmallVector<SDValue, 16> ShiftValues;
10355
10356 auto SumOfShifts = [&](ConstantSDNode *LHS, ConstantSDNode *RHS) {
10357 APInt c1 = LHS->getAPIntValue();
10358 APInt c2 = RHS->getAPIntValue();
10359 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
10360 APInt Sum = c1 + c2;
10361 unsigned ShiftSum =
10362 Sum.uge(RHS: OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue();
10363 ShiftValues.push_back(Elt: DAG.getConstant(Val: ShiftSum, DL, VT: ShiftSVT));
10364 return true;
10365 };
10366 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: SumOfShifts)) {
10367 SDValue ShiftValue;
10368 if (N1.getOpcode() == ISD::BUILD_VECTOR)
10369 ShiftValue = DAG.getBuildVector(VT: ShiftVT, DL, Ops: ShiftValues);
10370 else if (N1.getOpcode() == ISD::SPLAT_VECTOR) {
10371 assert(ShiftValues.size() == 1 &&
10372 "Expected matchBinaryPredicate to return one element for "
10373 "SPLAT_VECTORs");
10374 ShiftValue = DAG.getSplatVector(VT: ShiftVT, DL, Op: ShiftValues[0]);
10375 } else
10376 ShiftValue = ShiftValues[0];
10377 return DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: N0.getOperand(i: 0), N2: ShiftValue);
10378 }
10379 }
10380
10381 // fold (sra (shl X, m), (sub result_size, n))
10382 // -> (sign_extend (trunc (shl X, (sub (sub result_size, n), m)))) for
10383 // result_size - n != m.
10384 // If truncate is free for the target sext(shl) is likely to result in better
10385 // code.
10386 if (N0.getOpcode() == ISD::SHL && N1C) {
10387 // Get the two constants of the shifts, CN0 = m, CN = n.
10388 const ConstantSDNode *N01C = isConstOrConstSplat(N: N0.getOperand(i: 1));
10389 if (N01C) {
10390 LLVMContext &Ctx = *DAG.getContext();
10391 // Determine what the truncate's result bitsize and type would be.
10392 EVT TruncVT = EVT::getIntegerVT(Context&: Ctx, BitWidth: OpSizeInBits - N1C->getZExtValue());
10393
10394 if (VT.isVector())
10395 TruncVT = EVT::getVectorVT(Context&: Ctx, VT: TruncVT, EC: VT.getVectorElementCount());
10396
10397 // Determine the residual right-shift amount.
10398 int ShiftAmt = N1C->getZExtValue() - N01C->getZExtValue();
10399
10400 // If the shift is not a no-op (in which case this should be just a sign
10401 // extend already), the truncated to type is legal, sign_extend is legal
10402 // on that type, and the truncate to that type is both legal and free,
10403 // perform the transform.
10404 if ((ShiftAmt > 0) &&
10405 TLI.isOperationLegalOrCustom(Op: ISD::SIGN_EXTEND, VT: TruncVT) &&
10406 TLI.isOperationLegalOrCustom(Op: ISD::TRUNCATE, VT) &&
10407 TLI.isTruncateFree(FromVT: VT, ToVT: TruncVT)) {
10408 SDValue Amt = DAG.getShiftAmountConstant(Val: ShiftAmt, VT, DL);
10409 SDValue Shift = DAG.getNode(Opcode: ISD::SRL, DL, VT,
10410 N1: N0.getOperand(i: 0), N2: Amt);
10411 SDValue Trunc = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: TruncVT,
10412 Operand: Shift);
10413 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL,
10414 VT: N->getValueType(ResNo: 0), Operand: Trunc);
10415 }
10416 }
10417 }
10418
10419 // We convert trunc/ext to opposing shifts in IR, but casts may be cheaper.
10420 // sra (add (shl X, N1C), AddC), N1C -->
10421 // sext (add (trunc X to (width - N1C)), AddC')
10422 // sra (sub AddC, (shl X, N1C)), N1C -->
10423 // sext (sub AddC1',(trunc X to (width - N1C)))
10424 if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::SUB) && N1C &&
10425 N0.hasOneUse()) {
10426 bool IsAdd = N0.getOpcode() == ISD::ADD;
10427 SDValue Shl = N0.getOperand(i: IsAdd ? 0 : 1);
10428 if (Shl.getOpcode() == ISD::SHL && Shl.getOperand(i: 1) == N1 &&
10429 Shl.hasOneUse()) {
10430 // TODO: AddC does not need to be a splat.
10431 if (ConstantSDNode *AddC =
10432 isConstOrConstSplat(N: N0.getOperand(i: IsAdd ? 1 : 0))) {
10433 // Determine what the truncate's type would be and ask the target if
10434 // that is a free operation.
10435 LLVMContext &Ctx = *DAG.getContext();
10436 unsigned ShiftAmt = N1C->getZExtValue();
10437 EVT TruncVT = EVT::getIntegerVT(Context&: Ctx, BitWidth: OpSizeInBits - ShiftAmt);
10438 if (VT.isVector())
10439 TruncVT = EVT::getVectorVT(Context&: Ctx, VT: TruncVT, EC: VT.getVectorElementCount());
10440
10441 // TODO: The simple type check probably belongs in the default hook
10442 // implementation and/or target-specific overrides (because
10443 // non-simple types likely require masking when legalized), but
10444 // that restriction may conflict with other transforms.
10445 if (TruncVT.isSimple() && isTypeLegal(VT: TruncVT) &&
10446 TLI.isTruncateFree(FromVT: VT, ToVT: TruncVT)) {
10447 SDValue Trunc = DAG.getZExtOrTrunc(Op: Shl.getOperand(i: 0), DL, VT: TruncVT);
10448 SDValue ShiftC =
10449 DAG.getConstant(Val: AddC->getAPIntValue().lshr(shiftAmt: ShiftAmt).trunc(
10450 width: TruncVT.getScalarSizeInBits()),
10451 DL, VT: TruncVT);
10452 SDValue Add;
10453 if (IsAdd)
10454 Add = DAG.getNode(Opcode: ISD::ADD, DL, VT: TruncVT, N1: Trunc, N2: ShiftC);
10455 else
10456 Add = DAG.getNode(Opcode: ISD::SUB, DL, VT: TruncVT, N1: ShiftC, N2: Trunc);
10457 return DAG.getSExtOrTrunc(Op: Add, DL, VT);
10458 }
10459 }
10460 }
10461 }
10462
10463 // fold (sra x, (trunc (and y, c))) -> (sra x, (and (trunc y), (trunc c))).
10464 if (N1.getOpcode() == ISD::TRUNCATE &&
10465 N1.getOperand(i: 0).getOpcode() == ISD::AND) {
10466 if (SDValue NewOp1 = distributeTruncateThroughAnd(N: N1.getNode()))
10467 return DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: N0, N2: NewOp1);
10468 }
10469
10470 // fold (sra (trunc (sra x, c1)), c2) -> (trunc (sra x, c1 + c2))
10471 // fold (sra (trunc (srl x, c1)), c2) -> (trunc (sra x, c1 + c2))
10472 // if c1 is equal to the number of bits the trunc removes
10473 // TODO - support non-uniform vector shift amounts.
10474 if (N0.getOpcode() == ISD::TRUNCATE &&
10475 (N0.getOperand(i: 0).getOpcode() == ISD::SRL ||
10476 N0.getOperand(i: 0).getOpcode() == ISD::SRA) &&
10477 N0.getOperand(i: 0).hasOneUse() &&
10478 N0.getOperand(i: 0).getOperand(i: 1).hasOneUse() && N1C) {
10479 SDValue N0Op0 = N0.getOperand(i: 0);
10480 if (ConstantSDNode *LargeShift = isConstOrConstSplat(N: N0Op0.getOperand(i: 1))) {
10481 EVT LargeVT = N0Op0.getValueType();
10482 unsigned TruncBits = LargeVT.getScalarSizeInBits() - OpSizeInBits;
10483 if (LargeShift->getAPIntValue() == TruncBits) {
10484 EVT LargeShiftVT = getShiftAmountTy(LHSTy: LargeVT);
10485 SDValue Amt = DAG.getZExtOrTrunc(Op: N1, DL, VT: LargeShiftVT);
10486 Amt = DAG.getNode(Opcode: ISD::ADD, DL, VT: LargeShiftVT, N1: Amt,
10487 N2: DAG.getConstant(Val: TruncBits, DL, VT: LargeShiftVT));
10488 SDValue SRA =
10489 DAG.getNode(Opcode: ISD::SRA, DL, VT: LargeVT, N1: N0Op0.getOperand(i: 0), N2: Amt);
10490 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: SRA);
10491 }
10492 }
10493 }
10494
10495 // Simplify, based on bits shifted out of the LHS.
10496 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
10497 return SDValue(N, 0);
10498
10499 // If the sign bit is known to be zero, switch this to a SRL.
10500 if (DAG.SignBitIsZero(Op: N0))
10501 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0, N2: N1);
10502
10503 if (N1C && !N1C->isOpaque())
10504 if (SDValue NewSRA = visitShiftByConstant(N))
10505 return NewSRA;
10506
10507 // Try to transform this shift into a multiply-high if
10508 // it matches the appropriate pattern detected in combineShiftToMULH.
10509 if (SDValue MULH = combineShiftToMULH(N, DL, DAG, TLI))
10510 return MULH;
10511
10512 // Attempt to convert a sra of a load into a narrower sign-extending load.
10513 if (SDValue NarrowLoad = reduceLoadWidth(N))
10514 return NarrowLoad;
10515
10516 return SDValue();
10517}
10518
10519SDValue DAGCombiner::visitSRL(SDNode *N) {
10520 SDValue N0 = N->getOperand(Num: 0);
10521 SDValue N1 = N->getOperand(Num: 1);
10522 if (SDValue V = DAG.simplifyShift(X: N0, Y: N1))
10523 return V;
10524
10525 SDLoc DL(N);
10526 EVT VT = N0.getValueType();
10527 EVT ShiftVT = N1.getValueType();
10528 unsigned OpSizeInBits = VT.getScalarSizeInBits();
10529
10530 // fold (srl c1, c2) -> c1 >>u c2
10531 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::SRL, DL, VT, Ops: {N0, N1}))
10532 return C;
10533
10534 // fold vector ops
10535 if (VT.isVector())
10536 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
10537 return FoldedVOp;
10538
10539 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
10540 return NewSel;
10541
10542 // if (srl x, c) is known to be zero, return 0
10543 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
10544 if (N1C &&
10545 DAG.MaskedValueIsZero(Op: SDValue(N, 0), Mask: APInt::getAllOnes(numBits: OpSizeInBits)))
10546 return DAG.getConstant(Val: 0, DL, VT);
10547
10548 // fold (srl (srl x, c1), c2) -> 0 or (srl x, (add c1, c2))
10549 if (N0.getOpcode() == ISD::SRL) {
10550 auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
10551 ConstantSDNode *RHS) {
10552 APInt c1 = LHS->getAPIntValue();
10553 APInt c2 = RHS->getAPIntValue();
10554 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
10555 return (c1 + c2).uge(RHS: OpSizeInBits);
10556 };
10557 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchOutOfRange))
10558 return DAG.getConstant(Val: 0, DL, VT);
10559
10560 auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
10561 ConstantSDNode *RHS) {
10562 APInt c1 = LHS->getAPIntValue();
10563 APInt c2 = RHS->getAPIntValue();
10564 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
10565 return (c1 + c2).ult(RHS: OpSizeInBits);
10566 };
10567 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchInRange)) {
10568 SDValue Sum = DAG.getNode(Opcode: ISD::ADD, DL, VT: ShiftVT, N1, N2: N0.getOperand(i: 1));
10569 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0.getOperand(i: 0), N2: Sum);
10570 }
10571 }
10572
10573 if (N1C && N0.getOpcode() == ISD::TRUNCATE &&
10574 N0.getOperand(i: 0).getOpcode() == ISD::SRL) {
10575 SDValue InnerShift = N0.getOperand(i: 0);
10576 // TODO - support non-uniform vector shift amounts.
10577 if (auto *N001C = isConstOrConstSplat(N: InnerShift.getOperand(i: 1))) {
10578 uint64_t c1 = N001C->getZExtValue();
10579 uint64_t c2 = N1C->getZExtValue();
10580 EVT InnerShiftVT = InnerShift.getValueType();
10581 EVT ShiftAmtVT = InnerShift.getOperand(i: 1).getValueType();
10582 uint64_t InnerShiftSize = InnerShiftVT.getScalarSizeInBits();
10583 // srl (trunc (srl x, c1)), c2 --> 0 or (trunc (srl x, (add c1, c2)))
10584 // This is only valid if the OpSizeInBits + c1 = size of inner shift.
10585 if (c1 + OpSizeInBits == InnerShiftSize) {
10586 if (c1 + c2 >= InnerShiftSize)
10587 return DAG.getConstant(Val: 0, DL, VT);
10588 SDValue NewShiftAmt = DAG.getConstant(Val: c1 + c2, DL, VT: ShiftAmtVT);
10589 SDValue NewShift = DAG.getNode(Opcode: ISD::SRL, DL, VT: InnerShiftVT,
10590 N1: InnerShift.getOperand(i: 0), N2: NewShiftAmt);
10591 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: NewShift);
10592 }
10593 // In the more general case, we can clear the high bits after the shift:
10594 // srl (trunc (srl x, c1)), c2 --> trunc (and (srl x, (c1+c2)), Mask)
10595 if (N0.hasOneUse() && InnerShift.hasOneUse() &&
10596 c1 + c2 < InnerShiftSize) {
10597 SDValue NewShiftAmt = DAG.getConstant(Val: c1 + c2, DL, VT: ShiftAmtVT);
10598 SDValue NewShift = DAG.getNode(Opcode: ISD::SRL, DL, VT: InnerShiftVT,
10599 N1: InnerShift.getOperand(i: 0), N2: NewShiftAmt);
10600 SDValue Mask = DAG.getConstant(Val: APInt::getLowBitsSet(numBits: InnerShiftSize,
10601 loBitsSet: OpSizeInBits - c2),
10602 DL, VT: InnerShiftVT);
10603 SDValue And = DAG.getNode(Opcode: ISD::AND, DL, VT: InnerShiftVT, N1: NewShift, N2: Mask);
10604 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: And);
10605 }
10606 }
10607 }
10608
10609 // fold (srl (shl x, c1), c2) -> (and (shl x, (sub c1, c2), MASK) or
10610 // (and (srl x, (sub c2, c1), MASK)
10611 if (N0.getOpcode() == ISD::SHL &&
10612 (N0.getOperand(i: 1) == N1 || N0->hasOneUse()) &&
10613 TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
10614 auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
10615 ConstantSDNode *RHS) {
10616 const APInt &LHSC = LHS->getAPIntValue();
10617 const APInt &RHSC = RHS->getAPIntValue();
10618 return LHSC.ult(RHS: OpSizeInBits) && RHSC.ult(RHS: OpSizeInBits) &&
10619 LHSC.getZExtValue() <= RHSC.getZExtValue();
10620 };
10621 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchShiftAmount,
10622 /*AllowUndefs*/ false,
10623 /*AllowTypeMismatch*/ true)) {
10624 SDValue N01 = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1), DL, VT: ShiftVT);
10625 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftVT, N1: N01, N2: N1);
10626 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
10627 Mask = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Mask, N2: N01);
10628 Mask = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Mask, N2: Diff);
10629 SDValue Shift = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0.getOperand(i: 0), N2: Diff);
10630 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Shift, N2: Mask);
10631 }
10632 if (ISD::matchBinaryPredicate(LHS: N0.getOperand(i: 1), RHS: N1, Match: MatchShiftAmount,
10633 /*AllowUndefs*/ false,
10634 /*AllowTypeMismatch*/ true)) {
10635 SDValue N01 = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1), DL, VT: ShiftVT);
10636 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftVT, N1, N2: N01);
10637 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
10638 Mask = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Mask, N2: N1);
10639 SDValue Shift = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0.getOperand(i: 0), N2: Diff);
10640 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Shift, N2: Mask);
10641 }
10642 }
10643
10644 // fold (srl (anyextend x), c) -> (and (anyextend (srl x, c)), mask)
10645 // TODO - support non-uniform vector shift amounts.
10646 if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
10647 // Shifting in all undef bits?
10648 EVT SmallVT = N0.getOperand(i: 0).getValueType();
10649 unsigned BitSize = SmallVT.getScalarSizeInBits();
10650 if (N1C->getAPIntValue().uge(RHS: BitSize))
10651 return DAG.getUNDEF(VT);
10652
10653 if (!LegalTypes || TLI.isTypeDesirableForOp(ISD::SRL, VT: SmallVT)) {
10654 uint64_t ShiftAmt = N1C->getZExtValue();
10655 SDLoc DL0(N0);
10656 SDValue SmallShift =
10657 DAG.getNode(Opcode: ISD::SRL, DL: DL0, VT: SmallVT, N1: N0.getOperand(i: 0),
10658 N2: DAG.getShiftAmountConstant(Val: ShiftAmt, VT: SmallVT, DL: DL0));
10659 AddToWorklist(N: SmallShift.getNode());
10660 APInt Mask = APInt::getLowBitsSet(numBits: OpSizeInBits, loBitsSet: OpSizeInBits - ShiftAmt);
10661 return DAG.getNode(Opcode: ISD::AND, DL, VT,
10662 N1: DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT, Operand: SmallShift),
10663 N2: DAG.getConstant(Val: Mask, DL, VT));
10664 }
10665 }
10666
10667 // fold (srl (sra X, Y), 31) -> (srl X, 31). This srl only looks at the sign
10668 // bit, which is unmodified by sra.
10669 if (N1C && N1C->getAPIntValue() == (OpSizeInBits - 1)) {
10670 if (N0.getOpcode() == ISD::SRA)
10671 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0.getOperand(i: 0), N2: N1);
10672 }
10673
10674 // fold (srl (ctlz x), "5") -> x iff x has one bit set (the low bit), and x has a power
10675 // of two bitwidth. The "5" represents (log2 (bitwidth x)).
10676 if (N1C && N0.getOpcode() == ISD::CTLZ &&
10677 isPowerOf2_32(Value: OpSizeInBits) &&
10678 N1C->getAPIntValue() == Log2_32(Value: OpSizeInBits)) {
10679 KnownBits Known = DAG.computeKnownBits(Op: N0.getOperand(i: 0));
10680
10681 // If any of the input bits are KnownOne, then the input couldn't be all
10682 // zeros, thus the result of the srl will always be zero.
10683 if (Known.One.getBoolValue()) return DAG.getConstant(Val: 0, DL: SDLoc(N0), VT);
10684
10685 // If all of the bits input the to ctlz node are known to be zero, then
10686 // the result of the ctlz is "32" and the result of the shift is one.
10687 APInt UnknownBits = ~Known.Zero;
10688 if (UnknownBits == 0) return DAG.getConstant(Val: 1, DL: SDLoc(N0), VT);
10689
10690 // Otherwise, check to see if there is exactly one bit input to the ctlz.
10691 if (UnknownBits.isPowerOf2()) {
10692 // Okay, we know that only that the single bit specified by UnknownBits
10693 // could be set on input to the CTLZ node. If this bit is set, the SRL
10694 // will return 0, if it is clear, it returns 1. Change the CTLZ/SRL pair
10695 // to an SRL/XOR pair, which is likely to simplify more.
10696 unsigned ShAmt = UnknownBits.countr_zero();
10697 SDValue Op = N0.getOperand(i: 0);
10698
10699 if (ShAmt) {
10700 SDLoc DL(N0);
10701 Op = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Op,
10702 N2: DAG.getShiftAmountConstant(Val: ShAmt, VT, DL));
10703 AddToWorklist(N: Op.getNode());
10704 }
10705 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1: Op, N2: DAG.getConstant(Val: 1, DL, VT));
10706 }
10707 }
10708
10709 // fold (srl x, (trunc (and y, c))) -> (srl x, (and (trunc y), (trunc c))).
10710 if (N1.getOpcode() == ISD::TRUNCATE &&
10711 N1.getOperand(i: 0).getOpcode() == ISD::AND) {
10712 if (SDValue NewOp1 = distributeTruncateThroughAnd(N: N1.getNode()))
10713 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0, N2: NewOp1);
10714 }
10715
10716 // fold operands of srl based on knowledge that the low bits are not
10717 // demanded.
10718 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
10719 return SDValue(N, 0);
10720
10721 if (N1C && !N1C->isOpaque())
10722 if (SDValue NewSRL = visitShiftByConstant(N))
10723 return NewSRL;
10724
10725 // Attempt to convert a srl of a load into a narrower zero-extending load.
10726 if (SDValue NarrowLoad = reduceLoadWidth(N))
10727 return NarrowLoad;
10728
10729 // Here is a common situation. We want to optimize:
10730 //
10731 // %a = ...
10732 // %b = and i32 %a, 2
10733 // %c = srl i32 %b, 1
10734 // brcond i32 %c ...
10735 //
10736 // into
10737 //
10738 // %a = ...
10739 // %b = and %a, 2
10740 // %c = setcc eq %b, 0
10741 // brcond %c ...
10742 //
10743 // However when after the source operand of SRL is optimized into AND, the SRL
10744 // itself may not be optimized further. Look for it and add the BRCOND into
10745 // the worklist.
10746 //
10747 // The also tends to happen for binary operations when SimplifyDemandedBits
10748 // is involved.
10749 //
10750 // FIXME: This is unecessary if we process the DAG in topological order,
10751 // which we plan to do. This workaround can be removed once the DAG is
10752 // processed in topological order.
10753 if (N->hasOneUse()) {
10754 SDNode *Use = *N->use_begin();
10755
10756 // Look pass the truncate.
10757 if (Use->getOpcode() == ISD::TRUNCATE && Use->hasOneUse())
10758 Use = *Use->use_begin();
10759
10760 if (Use->getOpcode() == ISD::BRCOND || Use->getOpcode() == ISD::AND ||
10761 Use->getOpcode() == ISD::OR || Use->getOpcode() == ISD::XOR)
10762 AddToWorklist(N: Use);
10763 }
10764
10765 // Try to transform this shift into a multiply-high if
10766 // it matches the appropriate pattern detected in combineShiftToMULH.
10767 if (SDValue MULH = combineShiftToMULH(N, DL, DAG, TLI))
10768 return MULH;
10769
10770 return SDValue();
10771}
10772
10773SDValue DAGCombiner::visitFunnelShift(SDNode *N) {
10774 EVT VT = N->getValueType(ResNo: 0);
10775 SDValue N0 = N->getOperand(Num: 0);
10776 SDValue N1 = N->getOperand(Num: 1);
10777 SDValue N2 = N->getOperand(Num: 2);
10778 bool IsFSHL = N->getOpcode() == ISD::FSHL;
10779 unsigned BitWidth = VT.getScalarSizeInBits();
10780 SDLoc DL(N);
10781
10782 // fold (fshl N0, N1, 0) -> N0
10783 // fold (fshr N0, N1, 0) -> N1
10784 if (isPowerOf2_32(Value: BitWidth))
10785 if (DAG.MaskedValueIsZero(
10786 Op: N2, Mask: APInt(N2.getScalarValueSizeInBits(), BitWidth - 1)))
10787 return IsFSHL ? N0 : N1;
10788
10789 auto IsUndefOrZero = [](SDValue V) {
10790 return V.isUndef() || isNullOrNullSplat(V, /*AllowUndefs*/ true);
10791 };
10792
10793 // TODO - support non-uniform vector shift amounts.
10794 if (ConstantSDNode *Cst = isConstOrConstSplat(N: N2)) {
10795 EVT ShAmtTy = N2.getValueType();
10796
10797 // fold (fsh* N0, N1, c) -> (fsh* N0, N1, c % BitWidth)
10798 if (Cst->getAPIntValue().uge(RHS: BitWidth)) {
10799 uint64_t RotAmt = Cst->getAPIntValue().urem(RHS: BitWidth);
10800 return DAG.getNode(Opcode: N->getOpcode(), DL, VT, N1: N0, N2: N1,
10801 N3: DAG.getConstant(Val: RotAmt, DL, VT: ShAmtTy));
10802 }
10803
10804 unsigned ShAmt = Cst->getZExtValue();
10805 if (ShAmt == 0)
10806 return IsFSHL ? N0 : N1;
10807
10808 // fold fshl(undef_or_zero, N1, C) -> lshr(N1, BW-C)
10809 // fold fshr(undef_or_zero, N1, C) -> lshr(N1, C)
10810 // fold fshl(N0, undef_or_zero, C) -> shl(N0, C)
10811 // fold fshr(N0, undef_or_zero, C) -> shl(N0, BW-C)
10812 if (IsUndefOrZero(N0))
10813 return DAG.getNode(
10814 Opcode: ISD::SRL, DL, VT, N1,
10815 N2: DAG.getConstant(Val: IsFSHL ? BitWidth - ShAmt : ShAmt, DL, VT: ShAmtTy));
10816 if (IsUndefOrZero(N1))
10817 return DAG.getNode(
10818 Opcode: ISD::SHL, DL, VT, N1: N0,
10819 N2: DAG.getConstant(Val: IsFSHL ? ShAmt : BitWidth - ShAmt, DL, VT: ShAmtTy));
10820
10821 // fold (fshl ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive.
10822 // fold (fshr ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive.
10823 // TODO - bigendian support once we have test coverage.
10824 // TODO - can we merge this with CombineConseutiveLoads/MatchLoadCombine?
10825 // TODO - permit LHS EXTLOAD if extensions are shifted out.
10826 if ((BitWidth % 8) == 0 && (ShAmt % 8) == 0 && !VT.isVector() &&
10827 !DAG.getDataLayout().isBigEndian()) {
10828 auto *LHS = dyn_cast<LoadSDNode>(Val&: N0);
10829 auto *RHS = dyn_cast<LoadSDNode>(Val&: N1);
10830 if (LHS && RHS && LHS->isSimple() && RHS->isSimple() &&
10831 LHS->getAddressSpace() == RHS->getAddressSpace() &&
10832 (LHS->hasOneUse() || RHS->hasOneUse()) && ISD::isNON_EXTLoad(N: RHS) &&
10833 ISD::isNON_EXTLoad(N: LHS)) {
10834 if (DAG.areNonVolatileConsecutiveLoads(LD: LHS, Base: RHS, Bytes: BitWidth / 8, Dist: 1)) {
10835 SDLoc DL(RHS);
10836 uint64_t PtrOff =
10837 IsFSHL ? (((BitWidth - ShAmt) % BitWidth) / 8) : (ShAmt / 8);
10838 Align NewAlign = commonAlignment(A: RHS->getAlign(), Offset: PtrOff);
10839 unsigned Fast = 0;
10840 if (TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT,
10841 AddrSpace: RHS->getAddressSpace(), Alignment: NewAlign,
10842 Flags: RHS->getMemOperand()->getFlags(), Fast: &Fast) &&
10843 Fast) {
10844 SDValue NewPtr = DAG.getMemBasePlusOffset(
10845 Base: RHS->getBasePtr(), Offset: TypeSize::getFixed(ExactSize: PtrOff), DL);
10846 AddToWorklist(N: NewPtr.getNode());
10847 SDValue Load = DAG.getLoad(
10848 VT, dl: DL, Chain: RHS->getChain(), Ptr: NewPtr,
10849 PtrInfo: RHS->getPointerInfo().getWithOffset(O: PtrOff), Alignment: NewAlign,
10850 MMOFlags: RHS->getMemOperand()->getFlags(), AAInfo: RHS->getAAInfo());
10851 // Replace the old load's chain with the new load's chain.
10852 WorklistRemover DeadNodes(*this);
10853 DAG.ReplaceAllUsesOfValueWith(From: N1.getValue(R: 1), To: Load.getValue(R: 1));
10854 return Load;
10855 }
10856 }
10857 }
10858 }
10859 }
10860
10861 // fold fshr(undef_or_zero, N1, N2) -> lshr(N1, N2)
10862 // fold fshl(N0, undef_or_zero, N2) -> shl(N0, N2)
10863 // iff We know the shift amount is in range.
10864 // TODO: when is it worth doing SUB(BW, N2) as well?
10865 if (isPowerOf2_32(Value: BitWidth)) {
10866 APInt ModuloBits(N2.getScalarValueSizeInBits(), BitWidth - 1);
10867 if (IsUndefOrZero(N0) && !IsFSHL && DAG.MaskedValueIsZero(Op: N2, Mask: ~ModuloBits))
10868 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1, N2);
10869 if (IsUndefOrZero(N1) && IsFSHL && DAG.MaskedValueIsZero(Op: N2, Mask: ~ModuloBits))
10870 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0, N2);
10871 }
10872
10873 // fold (fshl N0, N0, N2) -> (rotl N0, N2)
10874 // fold (fshr N0, N0, N2) -> (rotr N0, N2)
10875 // TODO: Investigate flipping this rotate if only one is legal.
10876 // If funnel shift is legal as well we might be better off avoiding
10877 // non-constant (BW - N2).
10878 unsigned RotOpc = IsFSHL ? ISD::ROTL : ISD::ROTR;
10879 if (N0 == N1 && hasOperation(Opcode: RotOpc, VT))
10880 return DAG.getNode(Opcode: RotOpc, DL, VT, N1: N0, N2);
10881
10882 // Simplify, based on bits shifted out of N0/N1.
10883 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
10884 return SDValue(N, 0);
10885
10886 return SDValue();
10887}
10888
10889SDValue DAGCombiner::visitSHLSAT(SDNode *N) {
10890 SDValue N0 = N->getOperand(Num: 0);
10891 SDValue N1 = N->getOperand(Num: 1);
10892 if (SDValue V = DAG.simplifyShift(X: N0, Y: N1))
10893 return V;
10894
10895 SDLoc DL(N);
10896 EVT VT = N0.getValueType();
10897
10898 // fold (*shlsat c1, c2) -> c1<<c2
10899 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: N->getOpcode(), DL, VT, Ops: {N0, N1}))
10900 return C;
10901
10902 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
10903
10904 if (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::SHL, VT)) {
10905 // fold (sshlsat x, c) -> (shl x, c)
10906 if (N->getOpcode() == ISD::SSHLSAT && N1C &&
10907 N1C->getAPIntValue().ult(RHS: DAG.ComputeNumSignBits(Op: N0)))
10908 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0, N2: N1);
10909
10910 // fold (ushlsat x, c) -> (shl x, c)
10911 if (N->getOpcode() == ISD::USHLSAT && N1C &&
10912 N1C->getAPIntValue().ule(
10913 RHS: DAG.computeKnownBits(Op: N0).countMinLeadingZeros()))
10914 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0, N2: N1);
10915 }
10916
10917 return SDValue();
10918}
10919
10920// Given a ABS node, detect the following patterns:
10921// (ABS (SUB (EXTEND a), (EXTEND b))).
10922// (TRUNC (ABS (SUB (EXTEND a), (EXTEND b)))).
10923// Generates UABD/SABD instruction.
10924SDValue DAGCombiner::foldABSToABD(SDNode *N, const SDLoc &DL) {
10925 EVT SrcVT = N->getValueType(ResNo: 0);
10926
10927 if (N->getOpcode() == ISD::TRUNCATE)
10928 N = N->getOperand(Num: 0).getNode();
10929
10930 if (N->getOpcode() != ISD::ABS)
10931 return SDValue();
10932
10933 EVT VT = N->getValueType(ResNo: 0);
10934 SDValue AbsOp1 = N->getOperand(Num: 0);
10935 SDValue Op0, Op1;
10936
10937 if (AbsOp1.getOpcode() != ISD::SUB)
10938 return SDValue();
10939
10940 Op0 = AbsOp1.getOperand(i: 0);
10941 Op1 = AbsOp1.getOperand(i: 1);
10942
10943 unsigned Opc0 = Op0.getOpcode();
10944
10945 // Check if the operands of the sub are (zero|sign)-extended.
10946 // TODO: Should we use ValueTracking instead?
10947 if (Opc0 != Op1.getOpcode() ||
10948 (Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND &&
10949 Opc0 != ISD::SIGN_EXTEND_INREG)) {
10950 // fold (abs (sub nsw x, y)) -> abds(x, y)
10951 if (AbsOp1->getFlags().hasNoSignedWrap() && hasOperation(Opcode: ISD::ABDS, VT) &&
10952 TLI.preferABDSToABSWithNSW(VT)) {
10953 SDValue ABD = DAG.getNode(Opcode: ISD::ABDS, DL, VT, N1: Op0, N2: Op1);
10954 return DAG.getZExtOrTrunc(Op: ABD, DL, VT: SrcVT);
10955 }
10956 return SDValue();
10957 }
10958
10959 EVT VT0, VT1;
10960 if (Opc0 == ISD::SIGN_EXTEND_INREG) {
10961 VT0 = cast<VTSDNode>(Val: Op0.getOperand(i: 1))->getVT();
10962 VT1 = cast<VTSDNode>(Val: Op1.getOperand(i: 1))->getVT();
10963 } else {
10964 VT0 = Op0.getOperand(i: 0).getValueType();
10965 VT1 = Op1.getOperand(i: 0).getValueType();
10966 }
10967 unsigned ABDOpcode = (Opc0 == ISD::ZERO_EXTEND) ? ISD::ABDU : ISD::ABDS;
10968
10969 // fold abs(sext(x) - sext(y)) -> zext(abds(x, y))
10970 // fold abs(zext(x) - zext(y)) -> zext(abdu(x, y))
10971 EVT MaxVT = VT0.bitsGT(VT: VT1) ? VT0 : VT1;
10972 if ((VT0 == MaxVT || Op0->hasOneUse()) &&
10973 (VT1 == MaxVT || Op1->hasOneUse()) && hasOperation(Opcode: ABDOpcode, VT: MaxVT)) {
10974 SDValue ABD = DAG.getNode(Opcode: ABDOpcode, DL, VT: MaxVT,
10975 N1: DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MaxVT, Operand: Op0),
10976 N2: DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MaxVT, Operand: Op1));
10977 ABD = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: ABD);
10978 return DAG.getZExtOrTrunc(Op: ABD, DL, VT: SrcVT);
10979 }
10980
10981 // fold abs(sext(x) - sext(y)) -> abds(sext(x), sext(y))
10982 // fold abs(zext(x) - zext(y)) -> abdu(zext(x), zext(y))
10983 if (hasOperation(Opcode: ABDOpcode, VT)) {
10984 SDValue ABD = DAG.getNode(Opcode: ABDOpcode, DL, VT, N1: Op0, N2: Op1);
10985 return DAG.getZExtOrTrunc(Op: ABD, DL, VT: SrcVT);
10986 }
10987
10988 return SDValue();
10989}
10990
10991SDValue DAGCombiner::visitABS(SDNode *N) {
10992 SDValue N0 = N->getOperand(Num: 0);
10993 EVT VT = N->getValueType(ResNo: 0);
10994 SDLoc DL(N);
10995
10996 // fold (abs c1) -> c2
10997 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::ABS, DL, VT, Ops: {N0}))
10998 return C;
10999 // fold (abs (abs x)) -> (abs x)
11000 if (N0.getOpcode() == ISD::ABS)
11001 return N0;
11002 // fold (abs x) -> x iff not-negative
11003 if (DAG.SignBitIsZero(Op: N0))
11004 return N0;
11005
11006 if (SDValue ABD = foldABSToABD(N, DL))
11007 return ABD;
11008
11009 // fold (abs (sign_extend_inreg x)) -> (zero_extend (abs (truncate x)))
11010 // iff zero_extend/truncate are free.
11011 if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG) {
11012 EVT ExtVT = cast<VTSDNode>(Val: N0.getOperand(i: 1))->getVT();
11013 if (TLI.isTruncateFree(FromVT: VT, ToVT: ExtVT) && TLI.isZExtFree(FromTy: ExtVT, ToTy: VT) &&
11014 TLI.isTypeDesirableForOp(ISD::ABS, VT: ExtVT) &&
11015 hasOperation(Opcode: ISD::ABS, VT: ExtVT)) {
11016 return DAG.getNode(
11017 Opcode: ISD::ZERO_EXTEND, DL, VT,
11018 Operand: DAG.getNode(Opcode: ISD::ABS, DL, VT: ExtVT,
11019 Operand: DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: ExtVT, Operand: N0.getOperand(i: 0))));
11020 }
11021 }
11022
11023 return SDValue();
11024}
11025
11026SDValue DAGCombiner::visitBSWAP(SDNode *N) {
11027 SDValue N0 = N->getOperand(Num: 0);
11028 EVT VT = N->getValueType(ResNo: 0);
11029 SDLoc DL(N);
11030
11031 // fold (bswap c1) -> c2
11032 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::BSWAP, DL, VT, Ops: {N0}))
11033 return C;
11034 // fold (bswap (bswap x)) -> x
11035 if (N0.getOpcode() == ISD::BSWAP)
11036 return N0.getOperand(i: 0);
11037
11038 // Canonicalize bswap(bitreverse(x)) -> bitreverse(bswap(x)). If bitreverse
11039 // isn't supported, it will be expanded to bswap followed by a manual reversal
11040 // of bits in each byte. By placing bswaps before bitreverse, we can remove
11041 // the two bswaps if the bitreverse gets expanded.
11042 if (N0.getOpcode() == ISD::BITREVERSE && N0.hasOneUse()) {
11043 SDValue BSwap = DAG.getNode(Opcode: ISD::BSWAP, DL, VT, Operand: N0.getOperand(i: 0));
11044 return DAG.getNode(Opcode: ISD::BITREVERSE, DL, VT, Operand: BSwap);
11045 }
11046
11047 // fold (bswap shl(x,c)) -> (zext(bswap(trunc(shl(x,sub(c,bw/2))))))
11048 // iff x >= bw/2 (i.e. lower half is known zero)
11049 unsigned BW = VT.getScalarSizeInBits();
11050 if (BW >= 32 && N0.getOpcode() == ISD::SHL && N0.hasOneUse()) {
11051 auto *ShAmt = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
11052 EVT HalfVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: BW / 2);
11053 if (ShAmt && ShAmt->getAPIntValue().ult(RHS: BW) &&
11054 ShAmt->getZExtValue() >= (BW / 2) &&
11055 (ShAmt->getZExtValue() % 16) == 0 && TLI.isTypeLegal(VT: HalfVT) &&
11056 TLI.isTruncateFree(FromVT: VT, ToVT: HalfVT) &&
11057 (!LegalOperations || hasOperation(Opcode: ISD::BSWAP, VT: HalfVT))) {
11058 SDValue Res = N0.getOperand(i: 0);
11059 if (uint64_t NewShAmt = (ShAmt->getZExtValue() - (BW / 2)))
11060 Res = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Res,
11061 N2: DAG.getShiftAmountConstant(Val: NewShAmt, VT, DL));
11062 Res = DAG.getZExtOrTrunc(Op: Res, DL, VT: HalfVT);
11063 Res = DAG.getNode(Opcode: ISD::BSWAP, DL, VT: HalfVT, Operand: Res);
11064 return DAG.getZExtOrTrunc(Op: Res, DL, VT);
11065 }
11066 }
11067
11068 // Try to canonicalize bswap-of-logical-shift-by-8-bit-multiple as
11069 // inverse-shift-of-bswap:
11070 // bswap (X u<< C) --> (bswap X) u>> C
11071 // bswap (X u>> C) --> (bswap X) u<< C
11072 if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) &&
11073 N0.hasOneUse()) {
11074 auto *ShAmt = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
11075 if (ShAmt && ShAmt->getAPIntValue().ult(RHS: BW) &&
11076 ShAmt->getZExtValue() % 8 == 0) {
11077 SDValue NewSwap = DAG.getNode(Opcode: ISD::BSWAP, DL, VT, Operand: N0.getOperand(i: 0));
11078 unsigned InverseShift = N0.getOpcode() == ISD::SHL ? ISD::SRL : ISD::SHL;
11079 return DAG.getNode(Opcode: InverseShift, DL, VT, N1: NewSwap, N2: N0.getOperand(i: 1));
11080 }
11081 }
11082
11083 if (SDValue V = foldBitOrderCrossLogicOp(N, DAG))
11084 return V;
11085
11086 return SDValue();
11087}
11088
11089SDValue DAGCombiner::visitBITREVERSE(SDNode *N) {
11090 SDValue N0 = N->getOperand(Num: 0);
11091 EVT VT = N->getValueType(ResNo: 0);
11092 SDLoc DL(N);
11093
11094 // fold (bitreverse c1) -> c2
11095 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::BITREVERSE, DL, VT, Ops: {N0}))
11096 return C;
11097
11098 // fold (bitreverse (bitreverse x)) -> x
11099 if (N0.getOpcode() == ISD::BITREVERSE)
11100 return N0.getOperand(i: 0);
11101
11102 SDValue X, Y;
11103
11104 // fold (bitreverse (lshr (bitreverse x), y)) -> (shl x, y)
11105 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::SHL, VT)) &&
11106 sd_match(N, P: m_BitReverse(Op: m_Srl(L: m_BitReverse(Op: m_Value(N&: X)), R: m_Value(N&: Y)))))
11107 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: X, N2: Y);
11108
11109 // fold (bitreverse (shl (bitreverse x), y)) -> (lshr x, y)
11110 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::SRL, VT)) &&
11111 sd_match(N, P: m_BitReverse(Op: m_Shl(L: m_BitReverse(Op: m_Value(N&: X)), R: m_Value(N&: Y)))))
11112 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: X, N2: Y);
11113
11114 return SDValue();
11115}
11116
11117SDValue DAGCombiner::visitCTLZ(SDNode *N) {
11118 SDValue N0 = N->getOperand(Num: 0);
11119 EVT VT = N->getValueType(ResNo: 0);
11120 SDLoc DL(N);
11121
11122 // fold (ctlz c1) -> c2
11123 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::CTLZ, DL, VT, Ops: {N0}))
11124 return C;
11125
11126 // If the value is known never to be zero, switch to the undef version.
11127 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::CTLZ_ZERO_UNDEF, VT))
11128 if (DAG.isKnownNeverZero(Op: N0))
11129 return DAG.getNode(Opcode: ISD::CTLZ_ZERO_UNDEF, DL, VT, Operand: N0);
11130
11131 return SDValue();
11132}
11133
11134SDValue DAGCombiner::visitCTLZ_ZERO_UNDEF(SDNode *N) {
11135 SDValue N0 = N->getOperand(Num: 0);
11136 EVT VT = N->getValueType(ResNo: 0);
11137 SDLoc DL(N);
11138
11139 // fold (ctlz_zero_undef c1) -> c2
11140 if (SDValue C =
11141 DAG.FoldConstantArithmetic(Opcode: ISD::CTLZ_ZERO_UNDEF, DL, VT, Ops: {N0}))
11142 return C;
11143 return SDValue();
11144}
11145
11146SDValue DAGCombiner::visitCTTZ(SDNode *N) {
11147 SDValue N0 = N->getOperand(Num: 0);
11148 EVT VT = N->getValueType(ResNo: 0);
11149 SDLoc DL(N);
11150
11151 // fold (cttz c1) -> c2
11152 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::CTTZ, DL, VT, Ops: {N0}))
11153 return C;
11154
11155 // If the value is known never to be zero, switch to the undef version.
11156 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::CTTZ_ZERO_UNDEF, VT))
11157 if (DAG.isKnownNeverZero(Op: N0))
11158 return DAG.getNode(Opcode: ISD::CTTZ_ZERO_UNDEF, DL, VT, Operand: N0);
11159
11160 return SDValue();
11161}
11162
11163SDValue DAGCombiner::visitCTTZ_ZERO_UNDEF(SDNode *N) {
11164 SDValue N0 = N->getOperand(Num: 0);
11165 EVT VT = N->getValueType(ResNo: 0);
11166 SDLoc DL(N);
11167
11168 // fold (cttz_zero_undef c1) -> c2
11169 if (SDValue C =
11170 DAG.FoldConstantArithmetic(Opcode: ISD::CTTZ_ZERO_UNDEF, DL, VT, Ops: {N0}))
11171 return C;
11172 return SDValue();
11173}
11174
11175SDValue DAGCombiner::visitCTPOP(SDNode *N) {
11176 SDValue N0 = N->getOperand(Num: 0);
11177 EVT VT = N->getValueType(ResNo: 0);
11178 unsigned NumBits = VT.getScalarSizeInBits();
11179 SDLoc DL(N);
11180
11181 // fold (ctpop c1) -> c2
11182 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::CTPOP, DL, VT, Ops: {N0}))
11183 return C;
11184
11185 // If the source is being shifted, but doesn't affect any active bits,
11186 // then we can call CTPOP on the shift source directly.
11187 if (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SHL) {
11188 if (ConstantSDNode *AmtC = isConstOrConstSplat(N: N0.getOperand(i: 1))) {
11189 const APInt &Amt = AmtC->getAPIntValue();
11190 if (Amt.ult(RHS: NumBits)) {
11191 KnownBits KnownSrc = DAG.computeKnownBits(Op: N0.getOperand(i: 0));
11192 if ((N0.getOpcode() == ISD::SRL &&
11193 Amt.ule(RHS: KnownSrc.countMinTrailingZeros())) ||
11194 (N0.getOpcode() == ISD::SHL &&
11195 Amt.ule(RHS: KnownSrc.countMinLeadingZeros()))) {
11196 return DAG.getNode(Opcode: ISD::CTPOP, DL, VT, Operand: N0.getOperand(i: 0));
11197 }
11198 }
11199 }
11200 }
11201
11202 // If the upper bits are known to be zero, then see if its profitable to
11203 // only count the lower bits.
11204 if (VT.isScalarInteger() && NumBits > 8 && (NumBits & 1) == 0) {
11205 EVT HalfVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: NumBits / 2);
11206 if (hasOperation(Opcode: ISD::CTPOP, VT: HalfVT) &&
11207 TLI.isTypeDesirableForOp(ISD::CTPOP, VT: HalfVT) &&
11208 TLI.isTruncateFree(Val: N0, VT2: HalfVT) && TLI.isZExtFree(FromTy: HalfVT, ToTy: VT)) {
11209 APInt UpperBits = APInt::getHighBitsSet(numBits: NumBits, hiBitsSet: NumBits / 2);
11210 if (DAG.MaskedValueIsZero(Op: N0, Mask: UpperBits)) {
11211 SDValue PopCnt = DAG.getNode(Opcode: ISD::CTPOP, DL, VT: HalfVT,
11212 Operand: DAG.getZExtOrTrunc(Op: N0, DL, VT: HalfVT));
11213 return DAG.getZExtOrTrunc(Op: PopCnt, DL, VT);
11214 }
11215 }
11216 }
11217
11218 return SDValue();
11219}
11220
11221static bool isLegalToCombineMinNumMaxNum(SelectionDAG &DAG, SDValue LHS,
11222 SDValue RHS, const SDNodeFlags Flags,
11223 const TargetLowering &TLI) {
11224 EVT VT = LHS.getValueType();
11225 if (!VT.isFloatingPoint())
11226 return false;
11227
11228 const TargetOptions &Options = DAG.getTarget().Options;
11229
11230 return (Flags.hasNoSignedZeros() || Options.NoSignedZerosFPMath) &&
11231 TLI.isProfitableToCombineMinNumMaxNum(VT) &&
11232 (Flags.hasNoNaNs() ||
11233 (DAG.isKnownNeverNaN(Op: RHS) && DAG.isKnownNeverNaN(Op: LHS)));
11234}
11235
11236static SDValue combineMinNumMaxNumImpl(const SDLoc &DL, EVT VT, SDValue LHS,
11237 SDValue RHS, SDValue True, SDValue False,
11238 ISD::CondCode CC,
11239 const TargetLowering &TLI,
11240 SelectionDAG &DAG) {
11241 EVT TransformVT = TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT);
11242 switch (CC) {
11243 case ISD::SETOLT:
11244 case ISD::SETOLE:
11245 case ISD::SETLT:
11246 case ISD::SETLE:
11247 case ISD::SETULT:
11248 case ISD::SETULE: {
11249 // Since it's known never nan to get here already, either fminnum or
11250 // fminnum_ieee are OK. Try the ieee version first, since it's fminnum is
11251 // expanded in terms of it.
11252 unsigned IEEEOpcode = (LHS == True) ? ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE;
11253 if (TLI.isOperationLegalOrCustom(Op: IEEEOpcode, VT))
11254 return DAG.getNode(Opcode: IEEEOpcode, DL, VT, N1: LHS, N2: RHS);
11255
11256 unsigned Opcode = (LHS == True) ? ISD::FMINNUM : ISD::FMAXNUM;
11257 if (TLI.isOperationLegalOrCustom(Op: Opcode, VT: TransformVT))
11258 return DAG.getNode(Opcode, DL, VT, N1: LHS, N2: RHS);
11259 return SDValue();
11260 }
11261 case ISD::SETOGT:
11262 case ISD::SETOGE:
11263 case ISD::SETGT:
11264 case ISD::SETGE:
11265 case ISD::SETUGT:
11266 case ISD::SETUGE: {
11267 unsigned IEEEOpcode = (LHS == True) ? ISD::FMAXNUM_IEEE : ISD::FMINNUM_IEEE;
11268 if (TLI.isOperationLegalOrCustom(Op: IEEEOpcode, VT))
11269 return DAG.getNode(Opcode: IEEEOpcode, DL, VT, N1: LHS, N2: RHS);
11270
11271 unsigned Opcode = (LHS == True) ? ISD::FMAXNUM : ISD::FMINNUM;
11272 if (TLI.isOperationLegalOrCustom(Op: Opcode, VT: TransformVT))
11273 return DAG.getNode(Opcode, DL, VT, N1: LHS, N2: RHS);
11274 return SDValue();
11275 }
11276 default:
11277 return SDValue();
11278 }
11279}
11280
11281/// Generate Min/Max node
11282SDValue DAGCombiner::combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
11283 SDValue RHS, SDValue True,
11284 SDValue False, ISD::CondCode CC) {
11285 if ((LHS == True && RHS == False) || (LHS == False && RHS == True))
11286 return combineMinNumMaxNumImpl(DL, VT, LHS, RHS, True, False, CC, TLI, DAG);
11287
11288 // If we can't directly match this, try to see if we can pull an fneg out of
11289 // the select.
11290 SDValue NegTrue = TLI.getCheaperOrNeutralNegatedExpression(
11291 Op: True, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize);
11292 if (!NegTrue)
11293 return SDValue();
11294
11295 HandleSDNode NegTrueHandle(NegTrue);
11296
11297 // Try to unfold an fneg from the select if we are comparing the negated
11298 // constant.
11299 //
11300 // select (setcc x, K) (fneg x), -K -> fneg(minnum(x, K))
11301 //
11302 // TODO: Handle fabs
11303 if (LHS == NegTrue) {
11304 // If we can't directly match this, try to see if we can pull an fneg out of
11305 // the select.
11306 SDValue NegRHS = TLI.getCheaperOrNeutralNegatedExpression(
11307 Op: RHS, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize);
11308 if (NegRHS) {
11309 HandleSDNode NegRHSHandle(NegRHS);
11310 if (NegRHS == False) {
11311 SDValue Combined = combineMinNumMaxNumImpl(DL, VT, LHS, RHS, True: NegTrue,
11312 False, CC, TLI, DAG);
11313 if (Combined)
11314 return DAG.getNode(Opcode: ISD::FNEG, DL, VT, Operand: Combined);
11315 }
11316 }
11317 }
11318
11319 return SDValue();
11320}
11321
11322/// If a (v)select has a condition value that is a sign-bit test, try to smear
11323/// the condition operand sign-bit across the value width and use it as a mask.
11324static SDValue foldSelectOfConstantsUsingSra(SDNode *N, const SDLoc &DL,
11325 SelectionDAG &DAG) {
11326 SDValue Cond = N->getOperand(Num: 0);
11327 SDValue C1 = N->getOperand(Num: 1);
11328 SDValue C2 = N->getOperand(Num: 2);
11329 if (!isConstantOrConstantVector(N: C1) || !isConstantOrConstantVector(N: C2))
11330 return SDValue();
11331
11332 EVT VT = N->getValueType(ResNo: 0);
11333 if (Cond.getOpcode() != ISD::SETCC || !Cond.hasOneUse() ||
11334 VT != Cond.getOperand(i: 0).getValueType())
11335 return SDValue();
11336
11337 // The inverted-condition + commuted-select variants of these patterns are
11338 // canonicalized to these forms in IR.
11339 SDValue X = Cond.getOperand(i: 0);
11340 SDValue CondC = Cond.getOperand(i: 1);
11341 ISD::CondCode CC = cast<CondCodeSDNode>(Val: Cond.getOperand(i: 2))->get();
11342 if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(V: CondC) &&
11343 isAllOnesOrAllOnesSplat(V: C2)) {
11344 // i32 X > -1 ? C1 : -1 --> (X >>s 31) | C1
11345 SDValue ShAmtC = DAG.getConstant(Val: X.getScalarValueSizeInBits() - 1, DL, VT);
11346 SDValue Sra = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: X, N2: ShAmtC);
11347 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: Sra, N2: C1);
11348 }
11349 if (CC == ISD::SETLT && isNullOrNullSplat(V: CondC) && isNullOrNullSplat(V: C2)) {
11350 // i8 X < 0 ? C1 : 0 --> (X >>s 7) & C1
11351 SDValue ShAmtC = DAG.getConstant(Val: X.getScalarValueSizeInBits() - 1, DL, VT);
11352 SDValue Sra = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: X, N2: ShAmtC);
11353 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Sra, N2: C1);
11354 }
11355 return SDValue();
11356}
11357
11358static bool shouldConvertSelectOfConstantsToMath(const SDValue &Cond, EVT VT,
11359 const TargetLowering &TLI) {
11360 if (!TLI.convertSelectOfConstantsToMath(VT))
11361 return false;
11362
11363 if (Cond.getOpcode() != ISD::SETCC || !Cond->hasOneUse())
11364 return true;
11365 if (!TLI.isOperationLegalOrCustom(Op: ISD::SELECT_CC, VT))
11366 return true;
11367
11368 ISD::CondCode CC = cast<CondCodeSDNode>(Val: Cond.getOperand(i: 2))->get();
11369 if (CC == ISD::SETLT && isNullOrNullSplat(V: Cond.getOperand(i: 1)))
11370 return true;
11371 if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(V: Cond.getOperand(i: 1)))
11372 return true;
11373
11374 return false;
11375}
11376
11377SDValue DAGCombiner::foldSelectOfConstants(SDNode *N) {
11378 SDValue Cond = N->getOperand(Num: 0);
11379 SDValue N1 = N->getOperand(Num: 1);
11380 SDValue N2 = N->getOperand(Num: 2);
11381 EVT VT = N->getValueType(ResNo: 0);
11382 EVT CondVT = Cond.getValueType();
11383 SDLoc DL(N);
11384
11385 if (!VT.isInteger())
11386 return SDValue();
11387
11388 auto *C1 = dyn_cast<ConstantSDNode>(Val&: N1);
11389 auto *C2 = dyn_cast<ConstantSDNode>(Val&: N2);
11390 if (!C1 || !C2)
11391 return SDValue();
11392
11393 if (CondVT != MVT::i1 || LegalOperations) {
11394 // fold (select Cond, 0, 1) -> (xor Cond, 1)
11395 // We can't do this reliably if integer based booleans have different contents
11396 // to floating point based booleans. This is because we can't tell whether we
11397 // have an integer-based boolean or a floating-point-based boolean unless we
11398 // can find the SETCC that produced it and inspect its operands. This is
11399 // fairly easy if C is the SETCC node, but it can potentially be
11400 // undiscoverable (or not reasonably discoverable). For example, it could be
11401 // in another basic block or it could require searching a complicated
11402 // expression.
11403 if (CondVT.isInteger() &&
11404 TLI.getBooleanContents(/*isVec*/false, /*isFloat*/true) ==
11405 TargetLowering::ZeroOrOneBooleanContent &&
11406 TLI.getBooleanContents(/*isVec*/false, /*isFloat*/false) ==
11407 TargetLowering::ZeroOrOneBooleanContent &&
11408 C1->isZero() && C2->isOne()) {
11409 SDValue NotCond =
11410 DAG.getNode(Opcode: ISD::XOR, DL, VT: CondVT, N1: Cond, N2: DAG.getConstant(Val: 1, DL, VT: CondVT));
11411 if (VT.bitsEq(VT: CondVT))
11412 return NotCond;
11413 return DAG.getZExtOrTrunc(Op: NotCond, DL, VT);
11414 }
11415
11416 return SDValue();
11417 }
11418
11419 // Only do this before legalization to avoid conflicting with target-specific
11420 // transforms in the other direction (create a select from a zext/sext). There
11421 // is also a target-independent combine here in DAGCombiner in the other
11422 // direction for (select Cond, -1, 0) when the condition is not i1.
11423 assert(CondVT == MVT::i1 && !LegalOperations);
11424
11425 // select Cond, 1, 0 --> zext (Cond)
11426 if (C1->isOne() && C2->isZero())
11427 return DAG.getZExtOrTrunc(Op: Cond, DL, VT);
11428
11429 // select Cond, -1, 0 --> sext (Cond)
11430 if (C1->isAllOnes() && C2->isZero())
11431 return DAG.getSExtOrTrunc(Op: Cond, DL, VT);
11432
11433 // select Cond, 0, 1 --> zext (!Cond)
11434 if (C1->isZero() && C2->isOne()) {
11435 SDValue NotCond = DAG.getNOT(DL, Val: Cond, VT: MVT::i1);
11436 NotCond = DAG.getZExtOrTrunc(Op: NotCond, DL, VT);
11437 return NotCond;
11438 }
11439
11440 // select Cond, 0, -1 --> sext (!Cond)
11441 if (C1->isZero() && C2->isAllOnes()) {
11442 SDValue NotCond = DAG.getNOT(DL, Val: Cond, VT: MVT::i1);
11443 NotCond = DAG.getSExtOrTrunc(Op: NotCond, DL, VT);
11444 return NotCond;
11445 }
11446
11447 // Use a target hook because some targets may prefer to transform in the
11448 // other direction.
11449 if (!shouldConvertSelectOfConstantsToMath(Cond, VT, TLI))
11450 return SDValue();
11451
11452 // For any constants that differ by 1, we can transform the select into
11453 // an extend and add.
11454 const APInt &C1Val = C1->getAPIntValue();
11455 const APInt &C2Val = C2->getAPIntValue();
11456
11457 // select Cond, C1, C1-1 --> add (zext Cond), C1-1
11458 if (C1Val - 1 == C2Val) {
11459 Cond = DAG.getZExtOrTrunc(Op: Cond, DL, VT);
11460 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Cond, N2);
11461 }
11462
11463 // select Cond, C1, C1+1 --> add (sext Cond), C1+1
11464 if (C1Val + 1 == C2Val) {
11465 Cond = DAG.getSExtOrTrunc(Op: Cond, DL, VT);
11466 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Cond, N2);
11467 }
11468
11469 // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2)
11470 if (C1Val.isPowerOf2() && C2Val.isZero()) {
11471 Cond = DAG.getZExtOrTrunc(Op: Cond, DL, VT);
11472 SDValue ShAmtC =
11473 DAG.getShiftAmountConstant(Val: C1Val.exactLogBase2(), VT, DL);
11474 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Cond, N2: ShAmtC);
11475 }
11476
11477 // select Cond, -1, C --> or (sext Cond), C
11478 if (C1->isAllOnes()) {
11479 Cond = DAG.getSExtOrTrunc(Op: Cond, DL, VT);
11480 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: Cond, N2);
11481 }
11482
11483 // select Cond, C, -1 --> or (sext (not Cond)), C
11484 if (C2->isAllOnes()) {
11485 SDValue NotCond = DAG.getNOT(DL, Val: Cond, VT: MVT::i1);
11486 NotCond = DAG.getSExtOrTrunc(Op: NotCond, DL, VT);
11487 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: NotCond, N2: N1);
11488 }
11489
11490 if (SDValue V = foldSelectOfConstantsUsingSra(N, DL, DAG))
11491 return V;
11492
11493 return SDValue();
11494}
11495
11496template <class MatchContextClass>
11497static SDValue foldBoolSelectToLogic(SDNode *N, const SDLoc &DL,
11498 SelectionDAG &DAG) {
11499 assert((N->getOpcode() == ISD::SELECT || N->getOpcode() == ISD::VSELECT ||
11500 N->getOpcode() == ISD::VP_SELECT) &&
11501 "Expected a (v)(vp.)select");
11502 SDValue Cond = N->getOperand(Num: 0);
11503 SDValue T = N->getOperand(Num: 1), F = N->getOperand(Num: 2);
11504 EVT VT = N->getValueType(ResNo: 0);
11505 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
11506 MatchContextClass matcher(DAG, TLI, N);
11507
11508 if (VT != Cond.getValueType() || VT.getScalarSizeInBits() != 1)
11509 return SDValue();
11510
11511 // select Cond, Cond, F --> or Cond, freeze(F)
11512 // select Cond, 1, F --> or Cond, freeze(F)
11513 if (Cond == T || isOneOrOneSplat(V: T, /* AllowUndefs */ true))
11514 return matcher.getNode(ISD::OR, DL, VT, Cond, DAG.getFreeze(V: F));
11515
11516 // select Cond, T, Cond --> and Cond, freeze(T)
11517 // select Cond, T, 0 --> and Cond, freeze(T)
11518 if (Cond == F || isNullOrNullSplat(V: F, /* AllowUndefs */ true))
11519 return matcher.getNode(ISD::AND, DL, VT, Cond, DAG.getFreeze(V: T));
11520
11521 // select Cond, T, 1 --> or (not Cond), freeze(T)
11522 if (isOneOrOneSplat(V: F, /* AllowUndefs */ true)) {
11523 SDValue NotCond =
11524 matcher.getNode(ISD::XOR, DL, VT, Cond, DAG.getAllOnesConstant(DL, VT));
11525 return matcher.getNode(ISD::OR, DL, VT, NotCond, DAG.getFreeze(V: T));
11526 }
11527
11528 // select Cond, 0, F --> and (not Cond), freeze(F)
11529 if (isNullOrNullSplat(V: T, /* AllowUndefs */ true)) {
11530 SDValue NotCond =
11531 matcher.getNode(ISD::XOR, DL, VT, Cond, DAG.getAllOnesConstant(DL, VT));
11532 return matcher.getNode(ISD::AND, DL, VT, NotCond, DAG.getFreeze(V: F));
11533 }
11534
11535 return SDValue();
11536}
11537
11538static SDValue foldVSelectToSignBitSplatMask(SDNode *N, SelectionDAG &DAG) {
11539 SDValue N0 = N->getOperand(Num: 0);
11540 SDValue N1 = N->getOperand(Num: 1);
11541 SDValue N2 = N->getOperand(Num: 2);
11542 EVT VT = N->getValueType(ResNo: 0);
11543
11544 SDValue Cond0, Cond1;
11545 ISD::CondCode CC;
11546 if (!sd_match(N: N0, P: m_OneUse(P: m_SetCC(LHS: m_Value(N&: Cond0), RHS: m_Value(N&: Cond1),
11547 CC: m_CondCode(CC)))) ||
11548 VT != Cond0.getValueType())
11549 return SDValue();
11550
11551 // Match a signbit check of Cond0 as "Cond0 s<0". Swap select operands if the
11552 // compare is inverted from that pattern ("Cond0 s> -1").
11553 if (CC == ISD::SETLT && isNullOrNullSplat(V: Cond1))
11554 ; // This is the pattern we are looking for.
11555 else if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(V: Cond1))
11556 std::swap(a&: N1, b&: N2);
11557 else
11558 return SDValue();
11559
11560 // (Cond0 s< 0) ? N1 : 0 --> (Cond0 s>> BW-1) & freeze(N1)
11561 if (isNullOrNullSplat(V: N2)) {
11562 SDLoc DL(N);
11563 SDValue ShiftAmt = DAG.getConstant(Val: VT.getScalarSizeInBits() - 1, DL, VT);
11564 SDValue Sra = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: Cond0, N2: ShiftAmt);
11565 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Sra, N2: DAG.getFreeze(V: N1));
11566 }
11567
11568 // (Cond0 s< 0) ? -1 : N2 --> (Cond0 s>> BW-1) | freeze(N2)
11569 if (isAllOnesOrAllOnesSplat(V: N1)) {
11570 SDLoc DL(N);
11571 SDValue ShiftAmt = DAG.getConstant(Val: VT.getScalarSizeInBits() - 1, DL, VT);
11572 SDValue Sra = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: Cond0, N2: ShiftAmt);
11573 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: Sra, N2: DAG.getFreeze(V: N2));
11574 }
11575
11576 // If we have to invert the sign bit mask, only do that transform if the
11577 // target has a bitwise 'and not' instruction (the invert is free).
11578 // (Cond0 s< -0) ? 0 : N2 --> ~(Cond0 s>> BW-1) & freeze(N2)
11579 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
11580 if (isNullOrNullSplat(V: N1) && TLI.hasAndNot(X: N1)) {
11581 SDLoc DL(N);
11582 SDValue ShiftAmt = DAG.getConstant(Val: VT.getScalarSizeInBits() - 1, DL, VT);
11583 SDValue Sra = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: Cond0, N2: ShiftAmt);
11584 SDValue Not = DAG.getNOT(DL, Val: Sra, VT);
11585 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Not, N2: DAG.getFreeze(V: N2));
11586 }
11587
11588 // TODO: There's another pattern in this family, but it may require
11589 // implementing hasOrNot() to check for profitability:
11590 // (Cond0 s> -1) ? -1 : N2 --> ~(Cond0 s>> BW-1) | freeze(N2)
11591
11592 return SDValue();
11593}
11594
11595SDValue DAGCombiner::visitSELECT(SDNode *N) {
11596 SDValue N0 = N->getOperand(Num: 0);
11597 SDValue N1 = N->getOperand(Num: 1);
11598 SDValue N2 = N->getOperand(Num: 2);
11599 EVT VT = N->getValueType(ResNo: 0);
11600 EVT VT0 = N0.getValueType();
11601 SDLoc DL(N);
11602 SDNodeFlags Flags = N->getFlags();
11603
11604 if (SDValue V = DAG.simplifySelect(Cond: N0, TVal: N1, FVal: N2))
11605 return V;
11606
11607 if (SDValue V = foldBoolSelectToLogic<EmptyMatchContext>(N, DL, DAG))
11608 return V;
11609
11610 // select (not Cond), N1, N2 -> select Cond, N2, N1
11611 if (SDValue F = extractBooleanFlip(V: N0, DAG, TLI, Force: false)) {
11612 SDValue SelectOp = DAG.getSelect(DL, VT, Cond: F, LHS: N2, RHS: N1);
11613 SelectOp->setFlags(Flags);
11614 return SelectOp;
11615 }
11616
11617 if (SDValue V = foldSelectOfConstants(N))
11618 return V;
11619
11620 // If we can fold this based on the true/false value, do so.
11621 if (SimplifySelectOps(SELECT: N, LHS: N1, RHS: N2))
11622 return SDValue(N, 0); // Don't revisit N.
11623
11624 if (VT0 == MVT::i1) {
11625 // The code in this block deals with the following 2 equivalences:
11626 // select(C0|C1, x, y) <=> select(C0, x, select(C1, x, y))
11627 // select(C0&C1, x, y) <=> select(C0, select(C1, x, y), y)
11628 // The target can specify its preferred form with the
11629 // shouldNormalizeToSelectSequence() callback. However we always transform
11630 // to the right anyway if we find the inner select exists in the DAG anyway
11631 // and we always transform to the left side if we know that we can further
11632 // optimize the combination of the conditions.
11633 bool normalizeToSequence =
11634 TLI.shouldNormalizeToSelectSequence(Context&: *DAG.getContext(), VT);
11635 // select (and Cond0, Cond1), X, Y
11636 // -> select Cond0, (select Cond1, X, Y), Y
11637 if (N0->getOpcode() == ISD::AND && N0->hasOneUse()) {
11638 SDValue Cond0 = N0->getOperand(Num: 0);
11639 SDValue Cond1 = N0->getOperand(Num: 1);
11640 SDValue InnerSelect =
11641 DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: Cond1, N2: N1, N3: N2, Flags);
11642 if (normalizeToSequence || !InnerSelect.use_empty())
11643 return DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: Cond0,
11644 N2: InnerSelect, N3: N2, Flags);
11645 // Cleanup on failure.
11646 if (InnerSelect.use_empty())
11647 recursivelyDeleteUnusedNodes(N: InnerSelect.getNode());
11648 }
11649 // select (or Cond0, Cond1), X, Y -> select Cond0, X, (select Cond1, X, Y)
11650 if (N0->getOpcode() == ISD::OR && N0->hasOneUse()) {
11651 SDValue Cond0 = N0->getOperand(Num: 0);
11652 SDValue Cond1 = N0->getOperand(Num: 1);
11653 SDValue InnerSelect = DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(),
11654 N1: Cond1, N2: N1, N3: N2, Flags);
11655 if (normalizeToSequence || !InnerSelect.use_empty())
11656 return DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: Cond0, N2: N1,
11657 N3: InnerSelect, Flags);
11658 // Cleanup on failure.
11659 if (InnerSelect.use_empty())
11660 recursivelyDeleteUnusedNodes(N: InnerSelect.getNode());
11661 }
11662
11663 // select Cond0, (select Cond1, X, Y), Y -> select (and Cond0, Cond1), X, Y
11664 if (N1->getOpcode() == ISD::SELECT && N1->hasOneUse()) {
11665 SDValue N1_0 = N1->getOperand(Num: 0);
11666 SDValue N1_1 = N1->getOperand(Num: 1);
11667 SDValue N1_2 = N1->getOperand(Num: 2);
11668 if (N1_2 == N2 && N0.getValueType() == N1_0.getValueType()) {
11669 // Create the actual and node if we can generate good code for it.
11670 if (!normalizeToSequence) {
11671 SDValue And = DAG.getNode(Opcode: ISD::AND, DL, VT: N0.getValueType(), N1: N0, N2: N1_0);
11672 return DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: And, N2: N1_1,
11673 N3: N2, Flags);
11674 }
11675 // Otherwise see if we can optimize the "and" to a better pattern.
11676 if (SDValue Combined = visitANDLike(N0, N1: N1_0, N)) {
11677 return DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: Combined, N2: N1_1,
11678 N3: N2, Flags);
11679 }
11680 }
11681 }
11682 // select Cond0, X, (select Cond1, X, Y) -> select (or Cond0, Cond1), X, Y
11683 if (N2->getOpcode() == ISD::SELECT && N2->hasOneUse()) {
11684 SDValue N2_0 = N2->getOperand(Num: 0);
11685 SDValue N2_1 = N2->getOperand(Num: 1);
11686 SDValue N2_2 = N2->getOperand(Num: 2);
11687 if (N2_1 == N1 && N0.getValueType() == N2_0.getValueType()) {
11688 // Create the actual or node if we can generate good code for it.
11689 if (!normalizeToSequence) {
11690 SDValue Or = DAG.getNode(Opcode: ISD::OR, DL, VT: N0.getValueType(), N1: N0, N2: N2_0);
11691 return DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: Or, N2: N1,
11692 N3: N2_2, Flags);
11693 }
11694 // Otherwise see if we can optimize to a better pattern.
11695 if (SDValue Combined = visitORLike(N0, N1: N2_0, DL))
11696 return DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: Combined, N2: N1,
11697 N3: N2_2, Flags);
11698 }
11699 }
11700 }
11701
11702 // Fold selects based on a setcc into other things, such as min/max/abs.
11703 if (N0.getOpcode() == ISD::SETCC) {
11704 SDValue Cond0 = N0.getOperand(i: 0), Cond1 = N0.getOperand(i: 1);
11705 ISD::CondCode CC = cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get();
11706
11707 // select (fcmp lt x, y), x, y -> fminnum x, y
11708 // select (fcmp gt x, y), x, y -> fmaxnum x, y
11709 //
11710 // This is OK if we don't care what happens if either operand is a NaN.
11711 if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, LHS: N1, RHS: N2, Flags, TLI))
11712 if (SDValue FMinMax =
11713 combineMinNumMaxNum(DL, VT, LHS: Cond0, RHS: Cond1, True: N1, False: N2, CC))
11714 return FMinMax;
11715
11716 // Use 'unsigned add with overflow' to optimize an unsigned saturating add.
11717 // This is conservatively limited to pre-legal-operations to give targets
11718 // a chance to reverse the transform if they want to do that. Also, it is
11719 // unlikely that the pattern would be formed late, so it's probably not
11720 // worth going through the other checks.
11721 if (!LegalOperations && TLI.isOperationLegalOrCustom(Op: ISD::UADDO, VT) &&
11722 CC == ISD::SETUGT && N0.hasOneUse() && isAllOnesConstant(V: N1) &&
11723 N2.getOpcode() == ISD::ADD && Cond0 == N2.getOperand(i: 0)) {
11724 auto *C = dyn_cast<ConstantSDNode>(Val: N2.getOperand(i: 1));
11725 auto *NotC = dyn_cast<ConstantSDNode>(Val&: Cond1);
11726 if (C && NotC && C->getAPIntValue() == ~NotC->getAPIntValue()) {
11727 // select (setcc Cond0, ~C, ugt), -1, (add Cond0, C) -->
11728 // uaddo Cond0, C; select uaddo.1, -1, uaddo.0
11729 //
11730 // The IR equivalent of this transform would have this form:
11731 // %a = add %x, C
11732 // %c = icmp ugt %x, ~C
11733 // %r = select %c, -1, %a
11734 // =>
11735 // %u = call {iN,i1} llvm.uadd.with.overflow(%x, C)
11736 // %u0 = extractvalue %u, 0
11737 // %u1 = extractvalue %u, 1
11738 // %r = select %u1, -1, %u0
11739 SDVTList VTs = DAG.getVTList(VT1: VT, VT2: VT0);
11740 SDValue UAO = DAG.getNode(Opcode: ISD::UADDO, DL, VTList: VTs, N1: Cond0, N2: N2.getOperand(i: 1));
11741 return DAG.getSelect(DL, VT, Cond: UAO.getValue(R: 1), LHS: N1, RHS: UAO.getValue(R: 0));
11742 }
11743 }
11744
11745 if (TLI.isOperationLegal(Op: ISD::SELECT_CC, VT) ||
11746 (!LegalOperations &&
11747 TLI.isOperationLegalOrCustom(Op: ISD::SELECT_CC, VT))) {
11748 // Any flags available in a select/setcc fold will be on the setcc as they
11749 // migrated from fcmp
11750 Flags = N0->getFlags();
11751 SDValue SelectNode = DAG.getNode(Opcode: ISD::SELECT_CC, DL, VT, N1: Cond0, N2: Cond1, N3: N1,
11752 N4: N2, N5: N0.getOperand(i: 2));
11753 SelectNode->setFlags(Flags);
11754 return SelectNode;
11755 }
11756
11757 if (SDValue NewSel = SimplifySelect(DL, N0, N1, N2))
11758 return NewSel;
11759 }
11760
11761 if (!VT.isVector())
11762 if (SDValue BinOp = foldSelectOfBinops(N))
11763 return BinOp;
11764
11765 if (SDValue R = combineSelectAsExtAnd(Cond: N0, T: N1, F: N2, DL, DAG))
11766 return R;
11767
11768 return SDValue();
11769}
11770
11771// This function assumes all the vselect's arguments are CONCAT_VECTOR
11772// nodes and that the condition is a BV of ConstantSDNodes (or undefs).
11773static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
11774 SDLoc DL(N);
11775 SDValue Cond = N->getOperand(Num: 0);
11776 SDValue LHS = N->getOperand(Num: 1);
11777 SDValue RHS = N->getOperand(Num: 2);
11778 EVT VT = N->getValueType(ResNo: 0);
11779 int NumElems = VT.getVectorNumElements();
11780 assert(LHS.getOpcode() == ISD::CONCAT_VECTORS &&
11781 RHS.getOpcode() == ISD::CONCAT_VECTORS &&
11782 Cond.getOpcode() == ISD::BUILD_VECTOR);
11783
11784 // CONCAT_VECTOR can take an arbitrary number of arguments. We only care about
11785 // binary ones here.
11786 if (LHS->getNumOperands() != 2 || RHS->getNumOperands() != 2)
11787 return SDValue();
11788
11789 // We're sure we have an even number of elements due to the
11790 // concat_vectors we have as arguments to vselect.
11791 // Skip BV elements until we find one that's not an UNDEF
11792 // After we find an UNDEF element, keep looping until we get to half the
11793 // length of the BV and see if all the non-undef nodes are the same.
11794 ConstantSDNode *BottomHalf = nullptr;
11795 for (int i = 0; i < NumElems / 2; ++i) {
11796 if (Cond->getOperand(Num: i)->isUndef())
11797 continue;
11798
11799 if (BottomHalf == nullptr)
11800 BottomHalf = cast<ConstantSDNode>(Val: Cond.getOperand(i));
11801 else if (Cond->getOperand(Num: i).getNode() != BottomHalf)
11802 return SDValue();
11803 }
11804
11805 // Do the same for the second half of the BuildVector
11806 ConstantSDNode *TopHalf = nullptr;
11807 for (int i = NumElems / 2; i < NumElems; ++i) {
11808 if (Cond->getOperand(Num: i)->isUndef())
11809 continue;
11810
11811 if (TopHalf == nullptr)
11812 TopHalf = cast<ConstantSDNode>(Val: Cond.getOperand(i));
11813 else if (Cond->getOperand(Num: i).getNode() != TopHalf)
11814 return SDValue();
11815 }
11816
11817 assert(TopHalf && BottomHalf &&
11818 "One half of the selector was all UNDEFs and the other was all the "
11819 "same value. This should have been addressed before this function.");
11820 return DAG.getNode(
11821 Opcode: ISD::CONCAT_VECTORS, DL, VT,
11822 N1: BottomHalf->isZero() ? RHS->getOperand(Num: 0) : LHS->getOperand(Num: 0),
11823 N2: TopHalf->isZero() ? RHS->getOperand(Num: 1) : LHS->getOperand(Num: 1));
11824}
11825
11826bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
11827 SelectionDAG &DAG, const SDLoc &DL) {
11828
11829 // Only perform the transformation when existing operands can be reused.
11830 if (IndexIsScaled)
11831 return false;
11832
11833 if (!isNullConstant(V: BasePtr) && !Index.hasOneUse())
11834 return false;
11835
11836 EVT VT = BasePtr.getValueType();
11837
11838 if (SDValue SplatVal = DAG.getSplatValue(V: Index);
11839 SplatVal && !isNullConstant(V: SplatVal) &&
11840 SplatVal.getValueType() == VT) {
11841 BasePtr = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: BasePtr, N2: SplatVal);
11842 Index = DAG.getSplat(VT: Index.getValueType(), DL, Op: DAG.getConstant(Val: 0, DL, VT));
11843 return true;
11844 }
11845
11846 if (Index.getOpcode() != ISD::ADD)
11847 return false;
11848
11849 if (SDValue SplatVal = DAG.getSplatValue(V: Index.getOperand(i: 0));
11850 SplatVal && SplatVal.getValueType() == VT) {
11851 BasePtr = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: BasePtr, N2: SplatVal);
11852 Index = Index.getOperand(i: 1);
11853 return true;
11854 }
11855 if (SDValue SplatVal = DAG.getSplatValue(V: Index.getOperand(i: 1));
11856 SplatVal && SplatVal.getValueType() == VT) {
11857 BasePtr = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: BasePtr, N2: SplatVal);
11858 Index = Index.getOperand(i: 0);
11859 return true;
11860 }
11861 return false;
11862}
11863
11864// Fold sext/zext of index into index type.
11865bool refineIndexType(SDValue &Index, ISD::MemIndexType &IndexType, EVT DataVT,
11866 SelectionDAG &DAG) {
11867 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
11868
11869 // It's always safe to look through zero extends.
11870 if (Index.getOpcode() == ISD::ZERO_EXTEND) {
11871 if (TLI.shouldRemoveExtendFromGSIndex(Extend: Index, DataVT)) {
11872 IndexType = ISD::UNSIGNED_SCALED;
11873 Index = Index.getOperand(i: 0);
11874 return true;
11875 }
11876 if (ISD::isIndexTypeSigned(IndexType)) {
11877 IndexType = ISD::UNSIGNED_SCALED;
11878 return true;
11879 }
11880 }
11881
11882 // It's only safe to look through sign extends when Index is signed.
11883 if (Index.getOpcode() == ISD::SIGN_EXTEND &&
11884 ISD::isIndexTypeSigned(IndexType) &&
11885 TLI.shouldRemoveExtendFromGSIndex(Extend: Index, DataVT)) {
11886 Index = Index.getOperand(i: 0);
11887 return true;
11888 }
11889
11890 return false;
11891}
11892
11893SDValue DAGCombiner::visitVPSCATTER(SDNode *N) {
11894 VPScatterSDNode *MSC = cast<VPScatterSDNode>(Val: N);
11895 SDValue Mask = MSC->getMask();
11896 SDValue Chain = MSC->getChain();
11897 SDValue Index = MSC->getIndex();
11898 SDValue Scale = MSC->getScale();
11899 SDValue StoreVal = MSC->getValue();
11900 SDValue BasePtr = MSC->getBasePtr();
11901 SDValue VL = MSC->getVectorLength();
11902 ISD::MemIndexType IndexType = MSC->getIndexType();
11903 SDLoc DL(N);
11904
11905 // Zap scatters with a zero mask.
11906 if (ISD::isConstantSplatVectorAllZeros(N: Mask.getNode()))
11907 return Chain;
11908
11909 if (refineUniformBase(BasePtr, Index, IndexIsScaled: MSC->isIndexScaled(), DAG, DL)) {
11910 SDValue Ops[] = {Chain, StoreVal, BasePtr, Index, Scale, Mask, VL};
11911 return DAG.getScatterVP(VTs: DAG.getVTList(VT: MVT::Other), VT: MSC->getMemoryVT(),
11912 dl: DL, Ops, MMO: MSC->getMemOperand(), IndexType);
11913 }
11914
11915 if (refineIndexType(Index, IndexType, DataVT: StoreVal.getValueType(), DAG)) {
11916 SDValue Ops[] = {Chain, StoreVal, BasePtr, Index, Scale, Mask, VL};
11917 return DAG.getScatterVP(VTs: DAG.getVTList(VT: MVT::Other), VT: MSC->getMemoryVT(),
11918 dl: DL, Ops, MMO: MSC->getMemOperand(), IndexType);
11919 }
11920
11921 return SDValue();
11922}
11923
11924SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
11925 MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(Val: N);
11926 SDValue Mask = MSC->getMask();
11927 SDValue Chain = MSC->getChain();
11928 SDValue Index = MSC->getIndex();
11929 SDValue Scale = MSC->getScale();
11930 SDValue StoreVal = MSC->getValue();
11931 SDValue BasePtr = MSC->getBasePtr();
11932 ISD::MemIndexType IndexType = MSC->getIndexType();
11933 SDLoc DL(N);
11934
11935 // Zap scatters with a zero mask.
11936 if (ISD::isConstantSplatVectorAllZeros(N: Mask.getNode()))
11937 return Chain;
11938
11939 if (refineUniformBase(BasePtr, Index, IndexIsScaled: MSC->isIndexScaled(), DAG, DL)) {
11940 SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
11941 return DAG.getMaskedScatter(VTs: DAG.getVTList(VT: MVT::Other), MemVT: MSC->getMemoryVT(),
11942 dl: DL, Ops, MMO: MSC->getMemOperand(), IndexType,
11943 IsTruncating: MSC->isTruncatingStore());
11944 }
11945
11946 if (refineIndexType(Index, IndexType, DataVT: StoreVal.getValueType(), DAG)) {
11947 SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
11948 return DAG.getMaskedScatter(VTs: DAG.getVTList(VT: MVT::Other), MemVT: MSC->getMemoryVT(),
11949 dl: DL, Ops, MMO: MSC->getMemOperand(), IndexType,
11950 IsTruncating: MSC->isTruncatingStore());
11951 }
11952
11953 return SDValue();
11954}
11955
11956SDValue DAGCombiner::visitMSTORE(SDNode *N) {
11957 MaskedStoreSDNode *MST = cast<MaskedStoreSDNode>(Val: N);
11958 SDValue Mask = MST->getMask();
11959 SDValue Chain = MST->getChain();
11960 SDValue Value = MST->getValue();
11961 SDValue Ptr = MST->getBasePtr();
11962 SDLoc DL(N);
11963
11964 // Zap masked stores with a zero mask.
11965 if (ISD::isConstantSplatVectorAllZeros(N: Mask.getNode()))
11966 return Chain;
11967
11968 // Remove a masked store if base pointers and masks are equal.
11969 if (MaskedStoreSDNode *MST1 = dyn_cast<MaskedStoreSDNode>(Val&: Chain)) {
11970 if (MST->isUnindexed() && MST->isSimple() && MST1->isUnindexed() &&
11971 MST1->isSimple() && MST1->getBasePtr() == Ptr &&
11972 !MST->getBasePtr().isUndef() &&
11973 ((Mask == MST1->getMask() && MST->getMemoryVT().getStoreSize() ==
11974 MST1->getMemoryVT().getStoreSize()) ||
11975 ISD::isConstantSplatVectorAllOnes(N: Mask.getNode())) &&
11976 TypeSize::isKnownLE(LHS: MST1->getMemoryVT().getStoreSize(),
11977 RHS: MST->getMemoryVT().getStoreSize())) {
11978 CombineTo(N: MST1, Res: MST1->getChain());
11979 if (N->getOpcode() != ISD::DELETED_NODE)
11980 AddToWorklist(N);
11981 return SDValue(N, 0);
11982 }
11983 }
11984
11985 // If this is a masked load with an all ones mask, we can use a unmasked load.
11986 // FIXME: Can we do this for indexed, compressing, or truncating stores?
11987 if (ISD::isConstantSplatVectorAllOnes(N: Mask.getNode()) && MST->isUnindexed() &&
11988 !MST->isCompressingStore() && !MST->isTruncatingStore())
11989 return DAG.getStore(Chain: MST->getChain(), dl: SDLoc(N), Val: MST->getValue(),
11990 Ptr: MST->getBasePtr(), PtrInfo: MST->getPointerInfo(),
11991 Alignment: MST->getOriginalAlign(),
11992 MMOFlags: MST->getMemOperand()->getFlags(), AAInfo: MST->getAAInfo());
11993
11994 // Try transforming N to an indexed store.
11995 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
11996 return SDValue(N, 0);
11997
11998 if (MST->isTruncatingStore() && MST->isUnindexed() &&
11999 Value.getValueType().isInteger() &&
12000 (!isa<ConstantSDNode>(Val: Value) ||
12001 !cast<ConstantSDNode>(Val&: Value)->isOpaque())) {
12002 APInt TruncDemandedBits =
12003 APInt::getLowBitsSet(numBits: Value.getScalarValueSizeInBits(),
12004 loBitsSet: MST->getMemoryVT().getScalarSizeInBits());
12005
12006 // See if we can simplify the operation with
12007 // SimplifyDemandedBits, which only works if the value has a single use.
12008 if (SimplifyDemandedBits(Op: Value, DemandedBits: TruncDemandedBits)) {
12009 // Re-visit the store if anything changed and the store hasn't been merged
12010 // with another node (N is deleted) SimplifyDemandedBits will add Value's
12011 // node back to the worklist if necessary, but we also need to re-visit
12012 // the Store node itself.
12013 if (N->getOpcode() != ISD::DELETED_NODE)
12014 AddToWorklist(N);
12015 return SDValue(N, 0);
12016 }
12017 }
12018
12019 // If this is a TRUNC followed by a masked store, fold this into a masked
12020 // truncating store. We can do this even if this is already a masked
12021 // truncstore.
12022 // TODO: Try combine to masked compress store if possiable.
12023 if ((Value.getOpcode() == ISD::TRUNCATE) && Value->hasOneUse() &&
12024 MST->isUnindexed() && !MST->isCompressingStore() &&
12025 TLI.canCombineTruncStore(ValVT: Value.getOperand(i: 0).getValueType(),
12026 MemVT: MST->getMemoryVT(), LegalOnly: LegalOperations)) {
12027 auto Mask = TLI.promoteTargetBoolean(DAG, Bool: MST->getMask(),
12028 ValVT: Value.getOperand(i: 0).getValueType());
12029 return DAG.getMaskedStore(Chain, dl: SDLoc(N), Val: Value.getOperand(i: 0), Base: Ptr,
12030 Offset: MST->getOffset(), Mask, MemVT: MST->getMemoryVT(),
12031 MMO: MST->getMemOperand(), AM: MST->getAddressingMode(),
12032 /*IsTruncating=*/true);
12033 }
12034
12035 return SDValue();
12036}
12037
12038SDValue DAGCombiner::visitVP_STRIDED_STORE(SDNode *N) {
12039 auto *SST = cast<VPStridedStoreSDNode>(Val: N);
12040 EVT EltVT = SST->getValue().getValueType().getVectorElementType();
12041 // Combine strided stores with unit-stride to a regular VP store.
12042 if (auto *CStride = dyn_cast<ConstantSDNode>(Val: SST->getStride());
12043 CStride && CStride->getZExtValue() == EltVT.getStoreSize()) {
12044 return DAG.getStoreVP(Chain: SST->getChain(), dl: SDLoc(N), Val: SST->getValue(),
12045 Ptr: SST->getBasePtr(), Offset: SST->getOffset(), Mask: SST->getMask(),
12046 EVL: SST->getVectorLength(), MemVT: SST->getMemoryVT(),
12047 MMO: SST->getMemOperand(), AM: SST->getAddressingMode(),
12048 IsTruncating: SST->isTruncatingStore(), IsCompressing: SST->isCompressingStore());
12049 }
12050 return SDValue();
12051}
12052
12053SDValue DAGCombiner::visitVECTOR_COMPRESS(SDNode *N) {
12054 SDLoc DL(N);
12055 SDValue Vec = N->getOperand(Num: 0);
12056 SDValue Mask = N->getOperand(Num: 1);
12057 SDValue Passthru = N->getOperand(Num: 2);
12058 EVT VecVT = Vec.getValueType();
12059
12060 bool HasPassthru = !Passthru.isUndef();
12061
12062 APInt SplatVal;
12063 if (ISD::isConstantSplatVector(N: Mask.getNode(), SplatValue&: SplatVal))
12064 return TLI.isConstTrueVal(N: Mask) ? Vec : Passthru;
12065
12066 if (Vec.isUndef() || Mask.isUndef())
12067 return Passthru;
12068
12069 // No need for potentially expensive compress if the mask is constant.
12070 if (ISD::isBuildVectorOfConstantSDNodes(N: Mask.getNode())) {
12071 SmallVector<SDValue, 16> Ops;
12072 EVT ScalarVT = VecVT.getVectorElementType();
12073 unsigned NumSelected = 0;
12074 unsigned NumElmts = VecVT.getVectorNumElements();
12075 for (unsigned I = 0; I < NumElmts; ++I) {
12076 SDValue MaskI = Mask.getOperand(i: I);
12077 // We treat undef mask entries as "false".
12078 if (MaskI.isUndef())
12079 continue;
12080
12081 if (TLI.isConstTrueVal(N: MaskI)) {
12082 SDValue VecI = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: ScalarVT, N1: Vec,
12083 N2: DAG.getVectorIdxConstant(Val: I, DL));
12084 Ops.push_back(Elt: VecI);
12085 NumSelected++;
12086 }
12087 }
12088 for (unsigned Rest = NumSelected; Rest < NumElmts; ++Rest) {
12089 SDValue Val =
12090 HasPassthru
12091 ? DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: ScalarVT, N1: Passthru,
12092 N2: DAG.getVectorIdxConstant(Val: Rest, DL))
12093 : DAG.getUNDEF(VT: ScalarVT);
12094 Ops.push_back(Elt: Val);
12095 }
12096 return DAG.getBuildVector(VT: VecVT, DL, Ops);
12097 }
12098
12099 return SDValue();
12100}
12101
12102SDValue DAGCombiner::visitVPGATHER(SDNode *N) {
12103 VPGatherSDNode *MGT = cast<VPGatherSDNode>(Val: N);
12104 SDValue Mask = MGT->getMask();
12105 SDValue Chain = MGT->getChain();
12106 SDValue Index = MGT->getIndex();
12107 SDValue Scale = MGT->getScale();
12108 SDValue BasePtr = MGT->getBasePtr();
12109 SDValue VL = MGT->getVectorLength();
12110 ISD::MemIndexType IndexType = MGT->getIndexType();
12111 SDLoc DL(N);
12112
12113 if (refineUniformBase(BasePtr, Index, IndexIsScaled: MGT->isIndexScaled(), DAG, DL)) {
12114 SDValue Ops[] = {Chain, BasePtr, Index, Scale, Mask, VL};
12115 return DAG.getGatherVP(
12116 VTs: DAG.getVTList(VT1: N->getValueType(ResNo: 0), VT2: MVT::Other), VT: MGT->getMemoryVT(), dl: DL,
12117 Ops, MMO: MGT->getMemOperand(), IndexType);
12118 }
12119
12120 if (refineIndexType(Index, IndexType, DataVT: N->getValueType(ResNo: 0), DAG)) {
12121 SDValue Ops[] = {Chain, BasePtr, Index, Scale, Mask, VL};
12122 return DAG.getGatherVP(
12123 VTs: DAG.getVTList(VT1: N->getValueType(ResNo: 0), VT2: MVT::Other), VT: MGT->getMemoryVT(), dl: DL,
12124 Ops, MMO: MGT->getMemOperand(), IndexType);
12125 }
12126
12127 return SDValue();
12128}
12129
12130SDValue DAGCombiner::visitMGATHER(SDNode *N) {
12131 MaskedGatherSDNode *MGT = cast<MaskedGatherSDNode>(Val: N);
12132 SDValue Mask = MGT->getMask();
12133 SDValue Chain = MGT->getChain();
12134 SDValue Index = MGT->getIndex();
12135 SDValue Scale = MGT->getScale();
12136 SDValue PassThru = MGT->getPassThru();
12137 SDValue BasePtr = MGT->getBasePtr();
12138 ISD::MemIndexType IndexType = MGT->getIndexType();
12139 SDLoc DL(N);
12140
12141 // Zap gathers with a zero mask.
12142 if (ISD::isConstantSplatVectorAllZeros(N: Mask.getNode()))
12143 return CombineTo(N, Res0: PassThru, Res1: MGT->getChain());
12144
12145 if (refineUniformBase(BasePtr, Index, IndexIsScaled: MGT->isIndexScaled(), DAG, DL)) {
12146 SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
12147 return DAG.getMaskedGather(
12148 VTs: DAG.getVTList(VT1: N->getValueType(ResNo: 0), VT2: MVT::Other), MemVT: MGT->getMemoryVT(), dl: DL,
12149 Ops, MMO: MGT->getMemOperand(), IndexType, ExtTy: MGT->getExtensionType());
12150 }
12151
12152 if (refineIndexType(Index, IndexType, DataVT: N->getValueType(ResNo: 0), DAG)) {
12153 SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
12154 return DAG.getMaskedGather(
12155 VTs: DAG.getVTList(VT1: N->getValueType(ResNo: 0), VT2: MVT::Other), MemVT: MGT->getMemoryVT(), dl: DL,
12156 Ops, MMO: MGT->getMemOperand(), IndexType, ExtTy: MGT->getExtensionType());
12157 }
12158
12159 return SDValue();
12160}
12161
12162SDValue DAGCombiner::visitMLOAD(SDNode *N) {
12163 MaskedLoadSDNode *MLD = cast<MaskedLoadSDNode>(Val: N);
12164 SDValue Mask = MLD->getMask();
12165 SDLoc DL(N);
12166
12167 // Zap masked loads with a zero mask.
12168 if (ISD::isConstantSplatVectorAllZeros(N: Mask.getNode()))
12169 return CombineTo(N, Res0: MLD->getPassThru(), Res1: MLD->getChain());
12170
12171 // If this is a masked load with an all ones mask, we can use a unmasked load.
12172 // FIXME: Can we do this for indexed, expanding, or extending loads?
12173 if (ISD::isConstantSplatVectorAllOnes(N: Mask.getNode()) && MLD->isUnindexed() &&
12174 !MLD->isExpandingLoad() && MLD->getExtensionType() == ISD::NON_EXTLOAD) {
12175 SDValue NewLd = DAG.getLoad(
12176 VT: N->getValueType(ResNo: 0), dl: SDLoc(N), Chain: MLD->getChain(), Ptr: MLD->getBasePtr(),
12177 PtrInfo: MLD->getPointerInfo(), Alignment: MLD->getOriginalAlign(),
12178 MMOFlags: MLD->getMemOperand()->getFlags(), AAInfo: MLD->getAAInfo(), Ranges: MLD->getRanges());
12179 return CombineTo(N, Res0: NewLd, Res1: NewLd.getValue(R: 1));
12180 }
12181
12182 // Try transforming N to an indexed load.
12183 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
12184 return SDValue(N, 0);
12185
12186 return SDValue();
12187}
12188
12189SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
12190 auto *SLD = cast<VPStridedLoadSDNode>(Val: N);
12191 EVT EltVT = SLD->getValueType(ResNo: 0).getVectorElementType();
12192 // Combine strided loads with unit-stride to a regular VP load.
12193 if (auto *CStride = dyn_cast<ConstantSDNode>(Val: SLD->getStride());
12194 CStride && CStride->getZExtValue() == EltVT.getStoreSize()) {
12195 SDValue NewLd = DAG.getLoadVP(
12196 AM: SLD->getAddressingMode(), ExtType: SLD->getExtensionType(), VT: SLD->getValueType(ResNo: 0),
12197 dl: SDLoc(N), Chain: SLD->getChain(), Ptr: SLD->getBasePtr(), Offset: SLD->getOffset(),
12198 Mask: SLD->getMask(), EVL: SLD->getVectorLength(), MemVT: SLD->getMemoryVT(),
12199 MMO: SLD->getMemOperand(), IsExpanding: SLD->isExpandingLoad());
12200 return CombineTo(N, Res0: NewLd, Res1: NewLd.getValue(R: 1));
12201 }
12202 return SDValue();
12203}
12204
12205/// A vector select of 2 constant vectors can be simplified to math/logic to
12206/// avoid a variable select instruction and possibly avoid constant loads.
12207SDValue DAGCombiner::foldVSelectOfConstants(SDNode *N) {
12208 SDValue Cond = N->getOperand(Num: 0);
12209 SDValue N1 = N->getOperand(Num: 1);
12210 SDValue N2 = N->getOperand(Num: 2);
12211 EVT VT = N->getValueType(ResNo: 0);
12212 if (!Cond.hasOneUse() || Cond.getScalarValueSizeInBits() != 1 ||
12213 !shouldConvertSelectOfConstantsToMath(Cond, VT, TLI) ||
12214 !ISD::isBuildVectorOfConstantSDNodes(N: N1.getNode()) ||
12215 !ISD::isBuildVectorOfConstantSDNodes(N: N2.getNode()))
12216 return SDValue();
12217
12218 // Check if we can use the condition value to increment/decrement a single
12219 // constant value. This simplifies a select to an add and removes a constant
12220 // load/materialization from the general case.
12221 bool AllAddOne = true;
12222 bool AllSubOne = true;
12223 unsigned Elts = VT.getVectorNumElements();
12224 for (unsigned i = 0; i != Elts; ++i) {
12225 SDValue N1Elt = N1.getOperand(i);
12226 SDValue N2Elt = N2.getOperand(i);
12227 if (N1Elt.isUndef() || N2Elt.isUndef())
12228 continue;
12229 if (N1Elt.getValueType() != N2Elt.getValueType()) {
12230 AllAddOne = false;
12231 AllSubOne = false;
12232 break;
12233 }
12234
12235 const APInt &C1 = N1Elt->getAsAPIntVal();
12236 const APInt &C2 = N2Elt->getAsAPIntVal();
12237 if (C1 != C2 + 1)
12238 AllAddOne = false;
12239 if (C1 != C2 - 1)
12240 AllSubOne = false;
12241 }
12242
12243 // Further simplifications for the extra-special cases where the constants are
12244 // all 0 or all -1 should be implemented as folds of these patterns.
12245 SDLoc DL(N);
12246 if (AllAddOne || AllSubOne) {
12247 // vselect <N x i1> Cond, C+1, C --> add (zext Cond), C
12248 // vselect <N x i1> Cond, C-1, C --> add (sext Cond), C
12249 auto ExtendOpcode = AllAddOne ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND;
12250 SDValue ExtendedCond = DAG.getNode(Opcode: ExtendOpcode, DL, VT, Operand: Cond);
12251 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: ExtendedCond, N2);
12252 }
12253
12254 // select Cond, Pow2C, 0 --> (zext Cond) << log2(Pow2C)
12255 APInt Pow2C;
12256 if (ISD::isConstantSplatVector(N: N1.getNode(), SplatValue&: Pow2C) && Pow2C.isPowerOf2() &&
12257 isNullOrNullSplat(V: N2)) {
12258 SDValue ZextCond = DAG.getZExtOrTrunc(Op: Cond, DL, VT);
12259 SDValue ShAmtC = DAG.getConstant(Val: Pow2C.exactLogBase2(), DL, VT);
12260 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: ZextCond, N2: ShAmtC);
12261 }
12262
12263 if (SDValue V = foldSelectOfConstantsUsingSra(N, DL, DAG))
12264 return V;
12265
12266 // The general case for select-of-constants:
12267 // vselect <N x i1> Cond, C1, C2 --> xor (and (sext Cond), (C1^C2)), C2
12268 // ...but that only makes sense if a vselect is slower than 2 logic ops, so
12269 // leave that to a machine-specific pass.
12270 return SDValue();
12271}
12272
12273SDValue DAGCombiner::visitVP_SELECT(SDNode *N) {
12274 SDValue N0 = N->getOperand(Num: 0);
12275 SDValue N1 = N->getOperand(Num: 1);
12276 SDValue N2 = N->getOperand(Num: 2);
12277 SDLoc DL(N);
12278
12279 if (SDValue V = DAG.simplifySelect(Cond: N0, TVal: N1, FVal: N2))
12280 return V;
12281
12282 if (SDValue V = foldBoolSelectToLogic<VPMatchContext>(N, DL, DAG))
12283 return V;
12284
12285 return SDValue();
12286}
12287
12288SDValue DAGCombiner::visitVSELECT(SDNode *N) {
12289 SDValue N0 = N->getOperand(Num: 0);
12290 SDValue N1 = N->getOperand(Num: 1);
12291 SDValue N2 = N->getOperand(Num: 2);
12292 EVT VT = N->getValueType(ResNo: 0);
12293 SDLoc DL(N);
12294
12295 if (SDValue V = DAG.simplifySelect(Cond: N0, TVal: N1, FVal: N2))
12296 return V;
12297
12298 if (SDValue V = foldBoolSelectToLogic<EmptyMatchContext>(N, DL, DAG))
12299 return V;
12300
12301 // vselect (not Cond), N1, N2 -> vselect Cond, N2, N1
12302 if (SDValue F = extractBooleanFlip(V: N0, DAG, TLI, Force: false))
12303 return DAG.getSelect(DL, VT, Cond: F, LHS: N2, RHS: N1);
12304
12305 // select (sext m), (add X, C), X --> (add X, (and C, (sext m))))
12306 if (N1.getOpcode() == ISD::ADD && N1.getOperand(i: 0) == N2 && N1->hasOneUse() &&
12307 DAG.isConstantIntBuildVectorOrConstantInt(N: N1.getOperand(i: 1)) &&
12308 N0.getScalarValueSizeInBits() == N1.getScalarValueSizeInBits() &&
12309 TLI.getBooleanContents(Type: N0.getValueType()) ==
12310 TargetLowering::ZeroOrNegativeOneBooleanContent) {
12311 return DAG.getNode(
12312 Opcode: ISD::ADD, DL, VT: N1.getValueType(), N1: N2,
12313 N2: DAG.getNode(Opcode: ISD::AND, DL, VT: N0.getValueType(), N1: N1.getOperand(i: 1), N2: N0));
12314 }
12315
12316 // Canonicalize integer abs.
12317 // vselect (setg[te] X, 0), X, -X ->
12318 // vselect (setgt X, -1), X, -X ->
12319 // vselect (setl[te] X, 0), -X, X ->
12320 // Y = sra (X, size(X)-1); xor (add (X, Y), Y)
12321 if (N0.getOpcode() == ISD::SETCC) {
12322 SDValue LHS = N0.getOperand(i: 0), RHS = N0.getOperand(i: 1);
12323 ISD::CondCode CC = cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get();
12324 bool isAbs = false;
12325 bool RHSIsAllZeros = ISD::isBuildVectorAllZeros(N: RHS.getNode());
12326
12327 if (((RHSIsAllZeros && (CC == ISD::SETGT || CC == ISD::SETGE)) ||
12328 (ISD::isBuildVectorAllOnes(N: RHS.getNode()) && CC == ISD::SETGT)) &&
12329 N1 == LHS && N2.getOpcode() == ISD::SUB && N1 == N2.getOperand(i: 1))
12330 isAbs = ISD::isBuildVectorAllZeros(N: N2.getOperand(i: 0).getNode());
12331 else if ((RHSIsAllZeros && (CC == ISD::SETLT || CC == ISD::SETLE)) &&
12332 N2 == LHS && N1.getOpcode() == ISD::SUB && N2 == N1.getOperand(i: 1))
12333 isAbs = ISD::isBuildVectorAllZeros(N: N1.getOperand(i: 0).getNode());
12334
12335 if (isAbs) {
12336 if (TLI.isOperationLegalOrCustom(Op: ISD::ABS, VT))
12337 return DAG.getNode(Opcode: ISD::ABS, DL, VT, Operand: LHS);
12338
12339 SDValue Shift = DAG.getNode(
12340 Opcode: ISD::SRA, DL, VT, N1: LHS,
12341 N2: DAG.getShiftAmountConstant(Val: VT.getScalarSizeInBits() - 1, VT, DL));
12342 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: LHS, N2: Shift);
12343 AddToWorklist(N: Shift.getNode());
12344 AddToWorklist(N: Add.getNode());
12345 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1: Add, N2: Shift);
12346 }
12347
12348 // vselect x, y (fcmp lt x, y) -> fminnum x, y
12349 // vselect x, y (fcmp gt x, y) -> fmaxnum x, y
12350 //
12351 // This is OK if we don't care about what happens if either operand is a
12352 // NaN.
12353 //
12354 if (N0.hasOneUse() &&
12355 isLegalToCombineMinNumMaxNum(DAG, LHS, RHS, Flags: N->getFlags(), TLI)) {
12356 if (SDValue FMinMax = combineMinNumMaxNum(DL, VT, LHS, RHS, True: N1, False: N2, CC))
12357 return FMinMax;
12358 }
12359
12360 if (SDValue S = PerformMinMaxFpToSatCombine(N0: LHS, N1: RHS, N2: N1, N3: N2, CC, DAG))
12361 return S;
12362 if (SDValue S = PerformUMinFpToSatCombine(N0: LHS, N1: RHS, N2: N1, N3: N2, CC, DAG))
12363 return S;
12364
12365 // If this select has a condition (setcc) with narrower operands than the
12366 // select, try to widen the compare to match the select width.
12367 // TODO: This should be extended to handle any constant.
12368 // TODO: This could be extended to handle non-loading patterns, but that
12369 // requires thorough testing to avoid regressions.
12370 if (isNullOrNullSplat(V: RHS)) {
12371 EVT NarrowVT = LHS.getValueType();
12372 EVT WideVT = N1.getValueType().changeVectorElementTypeToInteger();
12373 EVT SetCCVT = getSetCCResultType(VT: LHS.getValueType());
12374 unsigned SetCCWidth = SetCCVT.getScalarSizeInBits();
12375 unsigned WideWidth = WideVT.getScalarSizeInBits();
12376 bool IsSigned = isSignedIntSetCC(Code: CC);
12377 auto LoadExtOpcode = IsSigned ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
12378 if (LHS.getOpcode() == ISD::LOAD && LHS.hasOneUse() &&
12379 SetCCWidth != 1 && SetCCWidth < WideWidth &&
12380 TLI.isLoadExtLegalOrCustom(ExtType: LoadExtOpcode, ValVT: WideVT, MemVT: NarrowVT) &&
12381 TLI.isOperationLegalOrCustom(Op: ISD::SETCC, VT: WideVT)) {
12382 // Both compare operands can be widened for free. The LHS can use an
12383 // extended load, and the RHS is a constant:
12384 // vselect (ext (setcc load(X), C)), N1, N2 -->
12385 // vselect (setcc extload(X), C'), N1, N2
12386 auto ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
12387 SDValue WideLHS = DAG.getNode(Opcode: ExtOpcode, DL, VT: WideVT, Operand: LHS);
12388 SDValue WideRHS = DAG.getNode(Opcode: ExtOpcode, DL, VT: WideVT, Operand: RHS);
12389 EVT WideSetCCVT = getSetCCResultType(VT: WideVT);
12390 SDValue WideSetCC = DAG.getSetCC(DL, VT: WideSetCCVT, LHS: WideLHS, RHS: WideRHS, Cond: CC);
12391 return DAG.getSelect(DL, VT: N1.getValueType(), Cond: WideSetCC, LHS: N1, RHS: N2);
12392 }
12393 }
12394
12395 // Match VSELECTs with absolute difference patterns.
12396 // (vselect (setcc a, b, set?gt), (sub a, b), (sub b, a)) --> (abd? a, b)
12397 // (vselect (setcc a, b, set?ge), (sub a, b), (sub b, a)) --> (abd? a, b)
12398 // (vselect (setcc a, b, set?lt), (sub b, a), (sub a, b)) --> (abd? a, b)
12399 // (vselect (setcc a, b, set?le), (sub b, a), (sub a, b)) --> (abd? a, b)
12400 if (N1.getOpcode() == ISD::SUB && N2.getOpcode() == ISD::SUB &&
12401 N1.getOperand(i: 0) == N2.getOperand(i: 1) &&
12402 N1.getOperand(i: 1) == N2.getOperand(i: 0)) {
12403 bool IsSigned = isSignedIntSetCC(Code: CC);
12404 unsigned ABDOpc = IsSigned ? ISD::ABDS : ISD::ABDU;
12405 if (hasOperation(Opcode: ABDOpc, VT)) {
12406 switch (CC) {
12407 case ISD::SETGT:
12408 case ISD::SETGE:
12409 case ISD::SETUGT:
12410 case ISD::SETUGE:
12411 if (LHS == N1.getOperand(i: 0) && RHS == N1.getOperand(i: 1))
12412 return DAG.getNode(Opcode: ABDOpc, DL, VT, N1: LHS, N2: RHS);
12413 break;
12414 case ISD::SETLT:
12415 case ISD::SETLE:
12416 case ISD::SETULT:
12417 case ISD::SETULE:
12418 if (RHS == N1.getOperand(i: 0) && LHS == N1.getOperand(i: 1) )
12419 return DAG.getNode(Opcode: ABDOpc, DL, VT, N1: LHS, N2: RHS);
12420 break;
12421 default:
12422 break;
12423 }
12424 }
12425 }
12426
12427 // Match VSELECTs into add with unsigned saturation.
12428 if (hasOperation(Opcode: ISD::UADDSAT, VT)) {
12429 // Check if one of the arms of the VSELECT is vector with all bits set.
12430 // If it's on the left side invert the predicate to simplify logic below.
12431 SDValue Other;
12432 ISD::CondCode SatCC = CC;
12433 if (ISD::isConstantSplatVectorAllOnes(N: N1.getNode())) {
12434 Other = N2;
12435 SatCC = ISD::getSetCCInverse(Operation: SatCC, Type: VT.getScalarType());
12436 } else if (ISD::isConstantSplatVectorAllOnes(N: N2.getNode())) {
12437 Other = N1;
12438 }
12439
12440 if (Other && Other.getOpcode() == ISD::ADD) {
12441 SDValue CondLHS = LHS, CondRHS = RHS;
12442 SDValue OpLHS = Other.getOperand(i: 0), OpRHS = Other.getOperand(i: 1);
12443
12444 // Canonicalize condition operands.
12445 if (SatCC == ISD::SETUGE) {
12446 std::swap(a&: CondLHS, b&: CondRHS);
12447 SatCC = ISD::SETULE;
12448 }
12449
12450 // We can test against either of the addition operands.
12451 // x <= x+y ? x+y : ~0 --> uaddsat x, y
12452 // x+y >= x ? x+y : ~0 --> uaddsat x, y
12453 if (SatCC == ISD::SETULE && Other == CondRHS &&
12454 (OpLHS == CondLHS || OpRHS == CondLHS))
12455 return DAG.getNode(Opcode: ISD::UADDSAT, DL, VT, N1: OpLHS, N2: OpRHS);
12456
12457 if (OpRHS.getOpcode() == CondRHS.getOpcode() &&
12458 (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
12459 OpRHS.getOpcode() == ISD::SPLAT_VECTOR) &&
12460 CondLHS == OpLHS) {
12461 // If the RHS is a constant we have to reverse the const
12462 // canonicalization.
12463 // x >= ~C ? x+C : ~0 --> uaddsat x, C
12464 auto MatchUADDSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
12465 return Cond->getAPIntValue() == ~Op->getAPIntValue();
12466 };
12467 if (SatCC == ISD::SETULE &&
12468 ISD::matchBinaryPredicate(LHS: OpRHS, RHS: CondRHS, Match: MatchUADDSAT))
12469 return DAG.getNode(Opcode: ISD::UADDSAT, DL, VT, N1: OpLHS, N2: OpRHS);
12470 }
12471 }
12472 }
12473
12474 // Match VSELECTs into sub with unsigned saturation.
12475 if (hasOperation(Opcode: ISD::USUBSAT, VT)) {
12476 // Check if one of the arms of the VSELECT is a zero vector. If it's on
12477 // the left side invert the predicate to simplify logic below.
12478 SDValue Other;
12479 ISD::CondCode SatCC = CC;
12480 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode())) {
12481 Other = N2;
12482 SatCC = ISD::getSetCCInverse(Operation: SatCC, Type: VT.getScalarType());
12483 } else if (ISD::isConstantSplatVectorAllZeros(N: N2.getNode())) {
12484 Other = N1;
12485 }
12486
12487 // zext(x) >= y ? trunc(zext(x) - y) : 0
12488 // --> usubsat(trunc(zext(x)),trunc(umin(y,SatLimit)))
12489 // zext(x) > y ? trunc(zext(x) - y) : 0
12490 // --> usubsat(trunc(zext(x)),trunc(umin(y,SatLimit)))
12491 if (Other && Other.getOpcode() == ISD::TRUNCATE &&
12492 Other.getOperand(i: 0).getOpcode() == ISD::SUB &&
12493 (SatCC == ISD::SETUGE || SatCC == ISD::SETUGT)) {
12494 SDValue OpLHS = Other.getOperand(i: 0).getOperand(i: 0);
12495 SDValue OpRHS = Other.getOperand(i: 0).getOperand(i: 1);
12496 if (LHS == OpLHS && RHS == OpRHS && LHS.getOpcode() == ISD::ZERO_EXTEND)
12497 if (SDValue R = getTruncatedUSUBSAT(DstVT: VT, SrcVT: LHS.getValueType(), LHS, RHS,
12498 DAG, DL))
12499 return R;
12500 }
12501
12502 if (Other && Other.getNumOperands() == 2) {
12503 SDValue CondRHS = RHS;
12504 SDValue OpLHS = Other.getOperand(i: 0), OpRHS = Other.getOperand(i: 1);
12505
12506 if (OpLHS == LHS) {
12507 // Look for a general sub with unsigned saturation first.
12508 // x >= y ? x-y : 0 --> usubsat x, y
12509 // x > y ? x-y : 0 --> usubsat x, y
12510 if ((SatCC == ISD::SETUGE || SatCC == ISD::SETUGT) &&
12511 Other.getOpcode() == ISD::SUB && OpRHS == CondRHS)
12512 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT, N1: OpLHS, N2: OpRHS);
12513
12514 if (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
12515 OpRHS.getOpcode() == ISD::SPLAT_VECTOR) {
12516 if (CondRHS.getOpcode() == ISD::BUILD_VECTOR ||
12517 CondRHS.getOpcode() == ISD::SPLAT_VECTOR) {
12518 // If the RHS is a constant we have to reverse the const
12519 // canonicalization.
12520 // x > C-1 ? x+-C : 0 --> usubsat x, C
12521 auto MatchUSUBSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
12522 return (!Op && !Cond) ||
12523 (Op && Cond &&
12524 Cond->getAPIntValue() == (-Op->getAPIntValue() - 1));
12525 };
12526 if (SatCC == ISD::SETUGT && Other.getOpcode() == ISD::ADD &&
12527 ISD::matchBinaryPredicate(LHS: OpRHS, RHS: CondRHS, Match: MatchUSUBSAT,
12528 /*AllowUndefs*/ true)) {
12529 OpRHS = DAG.getNegative(Val: OpRHS, DL, VT);
12530 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT, N1: OpLHS, N2: OpRHS);
12531 }
12532
12533 // Another special case: If C was a sign bit, the sub has been
12534 // canonicalized into a xor.
12535 // FIXME: Would it be better to use computeKnownBits to
12536 // determine whether it's safe to decanonicalize the xor?
12537 // x s< 0 ? x^C : 0 --> usubsat x, C
12538 APInt SplatValue;
12539 if (SatCC == ISD::SETLT && Other.getOpcode() == ISD::XOR &&
12540 ISD::isConstantSplatVector(N: OpRHS.getNode(), SplatValue) &&
12541 ISD::isConstantSplatVectorAllZeros(N: CondRHS.getNode()) &&
12542 SplatValue.isSignMask()) {
12543 // Note that we have to rebuild the RHS constant here to
12544 // ensure we don't rely on particular values of undef lanes.
12545 OpRHS = DAG.getConstant(Val: SplatValue, DL, VT);
12546 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT, N1: OpLHS, N2: OpRHS);
12547 }
12548 }
12549 }
12550 }
12551 }
12552 }
12553 }
12554
12555 if (SimplifySelectOps(SELECT: N, LHS: N1, RHS: N2))
12556 return SDValue(N, 0); // Don't revisit N.
12557
12558 // Fold (vselect all_ones, N1, N2) -> N1
12559 if (ISD::isConstantSplatVectorAllOnes(N: N0.getNode()))
12560 return N1;
12561 // Fold (vselect all_zeros, N1, N2) -> N2
12562 if (ISD::isConstantSplatVectorAllZeros(N: N0.getNode()))
12563 return N2;
12564
12565 // The ConvertSelectToConcatVector function is assuming both the above
12566 // checks for (vselect (build_vector all{ones,zeros) ...) have been made
12567 // and addressed.
12568 if (N1.getOpcode() == ISD::CONCAT_VECTORS &&
12569 N2.getOpcode() == ISD::CONCAT_VECTORS &&
12570 ISD::isBuildVectorOfConstantSDNodes(N: N0.getNode())) {
12571 if (SDValue CV = ConvertSelectToConcatVector(N, DAG))
12572 return CV;
12573 }
12574
12575 if (SDValue V = foldVSelectOfConstants(N))
12576 return V;
12577
12578 if (hasOperation(Opcode: ISD::SRA, VT))
12579 if (SDValue V = foldVSelectToSignBitSplatMask(N, DAG))
12580 return V;
12581
12582 if (SimplifyDemandedVectorElts(Op: SDValue(N, 0)))
12583 return SDValue(N, 0);
12584
12585 return SDValue();
12586}
12587
12588SDValue DAGCombiner::visitSELECT_CC(SDNode *N) {
12589 SDValue N0 = N->getOperand(Num: 0);
12590 SDValue N1 = N->getOperand(Num: 1);
12591 SDValue N2 = N->getOperand(Num: 2);
12592 SDValue N3 = N->getOperand(Num: 3);
12593 SDValue N4 = N->getOperand(Num: 4);
12594 ISD::CondCode CC = cast<CondCodeSDNode>(Val&: N4)->get();
12595 SDLoc DL(N);
12596
12597 // fold select_cc lhs, rhs, x, x, cc -> x
12598 if (N2 == N3)
12599 return N2;
12600
12601 // select_cc bool, 0, x, y, seteq -> select bool, y, x
12602 if (CC == ISD::SETEQ && !LegalTypes && N0.getValueType() == MVT::i1 &&
12603 isNullConstant(V: N1))
12604 return DAG.getSelect(DL, VT: N2.getValueType(), Cond: N0, LHS: N3, RHS: N2);
12605
12606 // Determine if the condition we're dealing with is constant
12607 if (SDValue SCC = SimplifySetCC(VT: getSetCCResultType(VT: N0.getValueType()), N0, N1,
12608 Cond: CC, DL, foldBooleans: false)) {
12609 AddToWorklist(N: SCC.getNode());
12610
12611 // cond always true -> true val
12612 // cond always false -> false val
12613 if (auto *SCCC = dyn_cast<ConstantSDNode>(Val: SCC.getNode()))
12614 return SCCC->isZero() ? N3 : N2;
12615
12616 // When the condition is UNDEF, just return the first operand. This is
12617 // coherent the DAG creation, no setcc node is created in this case
12618 if (SCC->isUndef())
12619 return N2;
12620
12621 // Fold to a simpler select_cc
12622 if (SCC.getOpcode() == ISD::SETCC) {
12623 SDValue SelectOp =
12624 DAG.getNode(Opcode: ISD::SELECT_CC, DL, VT: N2.getValueType(), N1: SCC.getOperand(i: 0),
12625 N2: SCC.getOperand(i: 1), N3: N2, N4: N3, N5: SCC.getOperand(i: 2));
12626 SelectOp->setFlags(SCC->getFlags());
12627 return SelectOp;
12628 }
12629 }
12630
12631 // If we can fold this based on the true/false value, do so.
12632 if (SimplifySelectOps(SELECT: N, LHS: N2, RHS: N3))
12633 return SDValue(N, 0); // Don't revisit N.
12634
12635 // fold select_cc into other things, such as min/max/abs
12636 return SimplifySelectCC(DL, N0, N1, N2, N3, CC);
12637}
12638
12639SDValue DAGCombiner::visitSETCC(SDNode *N) {
12640 // setcc is very commonly used as an argument to brcond. This pattern
12641 // also lend itself to numerous combines and, as a result, it is desired
12642 // we keep the argument to a brcond as a setcc as much as possible.
12643 bool PreferSetCC =
12644 N->hasOneUse() && N->use_begin()->getOpcode() == ISD::BRCOND;
12645
12646 ISD::CondCode Cond = cast<CondCodeSDNode>(Val: N->getOperand(Num: 2))->get();
12647 EVT VT = N->getValueType(ResNo: 0);
12648 SDValue N0 = N->getOperand(Num: 0), N1 = N->getOperand(Num: 1);
12649 SDLoc DL(N);
12650
12651 if (SDValue Combined = SimplifySetCC(VT, N0, N1, Cond, DL, foldBooleans: !PreferSetCC)) {
12652 // If we prefer to have a setcc, and we don't, we'll try our best to
12653 // recreate one using rebuildSetCC.
12654 if (PreferSetCC && Combined.getOpcode() != ISD::SETCC) {
12655 SDValue NewSetCC = rebuildSetCC(N: Combined);
12656
12657 // We don't have anything interesting to combine to.
12658 if (NewSetCC.getNode() == N)
12659 return SDValue();
12660
12661 if (NewSetCC)
12662 return NewSetCC;
12663 }
12664 return Combined;
12665 }
12666
12667 // Optimize
12668 // 1) (icmp eq/ne (and X, C0), (shift X, C1))
12669 // or
12670 // 2) (icmp eq/ne X, (rotate X, C1))
12671 // If C0 is a mask or shifted mask and the shift amt (C1) isolates the
12672 // remaining bits (i.e something like `(x64 & UINT32_MAX) == (x64 >> 32)`)
12673 // Then:
12674 // If C1 is a power of 2, then the rotate and shift+and versions are
12675 // equivilent, so we can interchange them depending on target preference.
12676 // Otherwise, if we have the shift+and version we can interchange srl/shl
12677 // which inturn affects the constant C0. We can use this to get better
12678 // constants again determined by target preference.
12679 if (Cond == ISD::SETNE || Cond == ISD::SETEQ) {
12680 auto IsAndWithShift = [](SDValue A, SDValue B) {
12681 return A.getOpcode() == ISD::AND &&
12682 (B.getOpcode() == ISD::SRL || B.getOpcode() == ISD::SHL) &&
12683 A.getOperand(i: 0) == B.getOperand(i: 0);
12684 };
12685 auto IsRotateWithOp = [](SDValue A, SDValue B) {
12686 return (B.getOpcode() == ISD::ROTL || B.getOpcode() == ISD::ROTR) &&
12687 B.getOperand(i: 0) == A;
12688 };
12689 SDValue AndOrOp = SDValue(), ShiftOrRotate = SDValue();
12690 bool IsRotate = false;
12691
12692 // Find either shift+and or rotate pattern.
12693 if (IsAndWithShift(N0, N1)) {
12694 AndOrOp = N0;
12695 ShiftOrRotate = N1;
12696 } else if (IsAndWithShift(N1, N0)) {
12697 AndOrOp = N1;
12698 ShiftOrRotate = N0;
12699 } else if (IsRotateWithOp(N0, N1)) {
12700 IsRotate = true;
12701 AndOrOp = N0;
12702 ShiftOrRotate = N1;
12703 } else if (IsRotateWithOp(N1, N0)) {
12704 IsRotate = true;
12705 AndOrOp = N1;
12706 ShiftOrRotate = N0;
12707 }
12708
12709 if (AndOrOp && ShiftOrRotate && ShiftOrRotate.hasOneUse() &&
12710 (IsRotate || AndOrOp.hasOneUse())) {
12711 EVT OpVT = N0.getValueType();
12712 // Get constant shift/rotate amount and possibly mask (if its shift+and
12713 // variant).
12714 auto GetAPIntValue = [](SDValue Op) -> std::optional<APInt> {
12715 ConstantSDNode *CNode = isConstOrConstSplat(N: Op, /*AllowUndefs*/ false,
12716 /*AllowTrunc*/ AllowTruncation: false);
12717 if (CNode == nullptr)
12718 return std::nullopt;
12719 return CNode->getAPIntValue();
12720 };
12721 std::optional<APInt> AndCMask =
12722 IsRotate ? std::nullopt : GetAPIntValue(AndOrOp.getOperand(i: 1));
12723 std::optional<APInt> ShiftCAmt =
12724 GetAPIntValue(ShiftOrRotate.getOperand(i: 1));
12725 unsigned NumBits = OpVT.getScalarSizeInBits();
12726
12727 // We found constants.
12728 if (ShiftCAmt && (IsRotate || AndCMask) && ShiftCAmt->ult(RHS: NumBits)) {
12729 unsigned ShiftOpc = ShiftOrRotate.getOpcode();
12730 // Check that the constants meet the constraints.
12731 bool CanTransform = IsRotate;
12732 if (!CanTransform) {
12733 // Check that mask and shift compliment eachother
12734 CanTransform = *ShiftCAmt == (~*AndCMask).popcount();
12735 // Check that we are comparing all bits
12736 CanTransform &= (*ShiftCAmt + AndCMask->popcount()) == NumBits;
12737 // Check that the and mask is correct for the shift
12738 CanTransform &=
12739 ShiftOpc == ISD::SHL ? (~*AndCMask).isMask() : AndCMask->isMask();
12740 }
12741
12742 // See if target prefers another shift/rotate opcode.
12743 unsigned NewShiftOpc = TLI.preferedOpcodeForCmpEqPiecesOfOperand(
12744 VT: OpVT, ShiftOpc, MayTransformRotate: ShiftCAmt->isPowerOf2(), ShiftOrRotateAmt: *ShiftCAmt, AndMask: AndCMask);
12745 // Transform is valid and we have a new preference.
12746 if (CanTransform && NewShiftOpc != ShiftOpc) {
12747 SDValue NewShiftOrRotate =
12748 DAG.getNode(Opcode: NewShiftOpc, DL, VT: OpVT, N1: ShiftOrRotate.getOperand(i: 0),
12749 N2: ShiftOrRotate.getOperand(i: 1));
12750 SDValue NewAndOrOp = SDValue();
12751
12752 if (NewShiftOpc == ISD::SHL || NewShiftOpc == ISD::SRL) {
12753 APInt NewMask =
12754 NewShiftOpc == ISD::SHL
12755 ? APInt::getHighBitsSet(numBits: NumBits,
12756 hiBitsSet: NumBits - ShiftCAmt->getZExtValue())
12757 : APInt::getLowBitsSet(numBits: NumBits,
12758 loBitsSet: NumBits - ShiftCAmt->getZExtValue());
12759 NewAndOrOp =
12760 DAG.getNode(Opcode: ISD::AND, DL, VT: OpVT, N1: ShiftOrRotate.getOperand(i: 0),
12761 N2: DAG.getConstant(Val: NewMask, DL, VT: OpVT));
12762 } else {
12763 NewAndOrOp = ShiftOrRotate.getOperand(i: 0);
12764 }
12765
12766 return DAG.getSetCC(DL, VT, LHS: NewAndOrOp, RHS: NewShiftOrRotate, Cond);
12767 }
12768 }
12769 }
12770 }
12771 return SDValue();
12772}
12773
12774SDValue DAGCombiner::visitSETCCCARRY(SDNode *N) {
12775 SDValue LHS = N->getOperand(Num: 0);
12776 SDValue RHS = N->getOperand(Num: 1);
12777 SDValue Carry = N->getOperand(Num: 2);
12778 SDValue Cond = N->getOperand(Num: 3);
12779
12780 // If Carry is false, fold to a regular SETCC.
12781 if (isNullConstant(V: Carry))
12782 return DAG.getNode(Opcode: ISD::SETCC, DL: SDLoc(N), VTList: N->getVTList(), N1: LHS, N2: RHS, N3: Cond);
12783
12784 return SDValue();
12785}
12786
12787/// Check if N satisfies:
12788/// N is used once.
12789/// N is a Load.
12790/// The load is compatible with ExtOpcode. It means
12791/// If load has explicit zero/sign extension, ExpOpcode must have the same
12792/// extension.
12793/// Otherwise returns true.
12794static bool isCompatibleLoad(SDValue N, unsigned ExtOpcode) {
12795 if (!N.hasOneUse())
12796 return false;
12797
12798 if (!isa<LoadSDNode>(Val: N))
12799 return false;
12800
12801 LoadSDNode *Load = cast<LoadSDNode>(Val&: N);
12802 ISD::LoadExtType LoadExt = Load->getExtensionType();
12803 if (LoadExt == ISD::NON_EXTLOAD || LoadExt == ISD::EXTLOAD)
12804 return true;
12805
12806 // Now LoadExt is either SEXTLOAD or ZEXTLOAD, ExtOpcode must have the same
12807 // extension.
12808 if ((LoadExt == ISD::SEXTLOAD && ExtOpcode != ISD::SIGN_EXTEND) ||
12809 (LoadExt == ISD::ZEXTLOAD && ExtOpcode != ISD::ZERO_EXTEND))
12810 return false;
12811
12812 return true;
12813}
12814
12815/// Fold
12816/// (sext (select c, load x, load y)) -> (select c, sextload x, sextload y)
12817/// (zext (select c, load x, load y)) -> (select c, zextload x, zextload y)
12818/// (aext (select c, load x, load y)) -> (select c, extload x, extload y)
12819/// This function is called by the DAGCombiner when visiting sext/zext/aext
12820/// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
12821static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI,
12822 SelectionDAG &DAG, const SDLoc &DL,
12823 CombineLevel Level) {
12824 unsigned Opcode = N->getOpcode();
12825 SDValue N0 = N->getOperand(Num: 0);
12826 EVT VT = N->getValueType(ResNo: 0);
12827 assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND ||
12828 Opcode == ISD::ANY_EXTEND) &&
12829 "Expected EXTEND dag node in input!");
12830
12831 if (!(N0->getOpcode() == ISD::SELECT || N0->getOpcode() == ISD::VSELECT) ||
12832 !N0.hasOneUse())
12833 return SDValue();
12834
12835 SDValue Op1 = N0->getOperand(Num: 1);
12836 SDValue Op2 = N0->getOperand(Num: 2);
12837 if (!isCompatibleLoad(N: Op1, ExtOpcode: Opcode) || !isCompatibleLoad(N: Op2, ExtOpcode: Opcode))
12838 return SDValue();
12839
12840 auto ExtLoadOpcode = ISD::EXTLOAD;
12841 if (Opcode == ISD::SIGN_EXTEND)
12842 ExtLoadOpcode = ISD::SEXTLOAD;
12843 else if (Opcode == ISD::ZERO_EXTEND)
12844 ExtLoadOpcode = ISD::ZEXTLOAD;
12845
12846 // Illegal VSELECT may ISel fail if happen after legalization (DAG
12847 // Combine2), so we should conservatively check the OperationAction.
12848 LoadSDNode *Load1 = cast<LoadSDNode>(Val&: Op1);
12849 LoadSDNode *Load2 = cast<LoadSDNode>(Val&: Op2);
12850 if (!TLI.isLoadExtLegal(ExtType: ExtLoadOpcode, ValVT: VT, MemVT: Load1->getMemoryVT()) ||
12851 !TLI.isLoadExtLegal(ExtType: ExtLoadOpcode, ValVT: VT, MemVT: Load2->getMemoryVT()) ||
12852 (N0->getOpcode() == ISD::VSELECT && Level >= AfterLegalizeTypes &&
12853 TLI.getOperationAction(Op: ISD::VSELECT, VT) != TargetLowering::Legal))
12854 return SDValue();
12855
12856 SDValue Ext1 = DAG.getNode(Opcode, DL, VT, Operand: Op1);
12857 SDValue Ext2 = DAG.getNode(Opcode, DL, VT, Operand: Op2);
12858 return DAG.getSelect(DL, VT, Cond: N0->getOperand(Num: 0), LHS: Ext1, RHS: Ext2);
12859}
12860
12861/// Try to fold a sext/zext/aext dag node into a ConstantSDNode or
12862/// a build_vector of constants.
12863/// This function is called by the DAGCombiner when visiting sext/zext/aext
12864/// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
12865/// Vector extends are not folded if operations are legal; this is to
12866/// avoid introducing illegal build_vector dag nodes.
12867static SDValue tryToFoldExtendOfConstant(SDNode *N, const SDLoc &DL,
12868 const TargetLowering &TLI,
12869 SelectionDAG &DAG, bool LegalTypes) {
12870 unsigned Opcode = N->getOpcode();
12871 SDValue N0 = N->getOperand(Num: 0);
12872 EVT VT = N->getValueType(ResNo: 0);
12873
12874 assert((ISD::isExtOpcode(Opcode) || ISD::isExtVecInRegOpcode(Opcode)) &&
12875 "Expected EXTEND dag node in input!");
12876
12877 // fold (sext c1) -> c1
12878 // fold (zext c1) -> c1
12879 // fold (aext c1) -> c1
12880 if (isa<ConstantSDNode>(Val: N0))
12881 return DAG.getNode(Opcode, DL, VT, Operand: N0);
12882
12883 // fold (sext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
12884 // fold (zext (select cond, c1, c2)) -> (select cond, zext c1, zext c2)
12885 // fold (aext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
12886 if (N0->getOpcode() == ISD::SELECT) {
12887 SDValue Op1 = N0->getOperand(Num: 1);
12888 SDValue Op2 = N0->getOperand(Num: 2);
12889 if (isa<ConstantSDNode>(Val: Op1) && isa<ConstantSDNode>(Val: Op2) &&
12890 (Opcode != ISD::ZERO_EXTEND || !TLI.isZExtFree(FromTy: N0.getValueType(), ToTy: VT))) {
12891 // For any_extend, choose sign extension of the constants to allow a
12892 // possible further transform to sign_extend_inreg.i.e.
12893 //
12894 // t1: i8 = select t0, Constant:i8<-1>, Constant:i8<0>
12895 // t2: i64 = any_extend t1
12896 // -->
12897 // t3: i64 = select t0, Constant:i64<-1>, Constant:i64<0>
12898 // -->
12899 // t4: i64 = sign_extend_inreg t3
12900 unsigned FoldOpc = Opcode;
12901 if (FoldOpc == ISD::ANY_EXTEND)
12902 FoldOpc = ISD::SIGN_EXTEND;
12903 return DAG.getSelect(DL, VT, Cond: N0->getOperand(Num: 0),
12904 LHS: DAG.getNode(Opcode: FoldOpc, DL, VT, Operand: Op1),
12905 RHS: DAG.getNode(Opcode: FoldOpc, DL, VT, Operand: Op2));
12906 }
12907 }
12908
12909 // fold (sext (build_vector AllConstants) -> (build_vector AllConstants)
12910 // fold (zext (build_vector AllConstants) -> (build_vector AllConstants)
12911 // fold (aext (build_vector AllConstants) -> (build_vector AllConstants)
12912 EVT SVT = VT.getScalarType();
12913 if (!(VT.isVector() && (!LegalTypes || TLI.isTypeLegal(VT: SVT)) &&
12914 ISD::isBuildVectorOfConstantSDNodes(N: N0.getNode())))
12915 return SDValue();
12916
12917 // We can fold this node into a build_vector.
12918 unsigned VTBits = SVT.getSizeInBits();
12919 unsigned EVTBits = N0->getValueType(ResNo: 0).getScalarSizeInBits();
12920 SmallVector<SDValue, 8> Elts;
12921 unsigned NumElts = VT.getVectorNumElements();
12922
12923 for (unsigned i = 0; i != NumElts; ++i) {
12924 SDValue Op = N0.getOperand(i);
12925 if (Op.isUndef()) {
12926 if (Opcode == ISD::ANY_EXTEND || Opcode == ISD::ANY_EXTEND_VECTOR_INREG)
12927 Elts.push_back(Elt: DAG.getUNDEF(VT: SVT));
12928 else
12929 Elts.push_back(Elt: DAG.getConstant(Val: 0, DL, VT: SVT));
12930 continue;
12931 }
12932
12933 SDLoc DL(Op);
12934 // Get the constant value and if needed trunc it to the size of the type.
12935 // Nodes like build_vector might have constants wider than the scalar type.
12936 APInt C = Op->getAsAPIntVal().zextOrTrunc(width: EVTBits);
12937 if (Opcode == ISD::SIGN_EXTEND || Opcode == ISD::SIGN_EXTEND_VECTOR_INREG)
12938 Elts.push_back(Elt: DAG.getConstant(Val: C.sext(width: VTBits), DL, VT: SVT));
12939 else
12940 Elts.push_back(Elt: DAG.getConstant(Val: C.zext(width: VTBits), DL, VT: SVT));
12941 }
12942
12943 return DAG.getBuildVector(VT, DL, Ops: Elts);
12944}
12945
12946// ExtendUsesToFormExtLoad - Trying to extend uses of a load to enable this:
12947// "fold ({s|z|a}ext (load x)) -> ({s|z|a}ext (truncate ({s|z|a}extload x)))"
12948// transformation. Returns true if extension are possible and the above
12949// mentioned transformation is profitable.
12950static bool ExtendUsesToFormExtLoad(EVT VT, SDNode *N, SDValue N0,
12951 unsigned ExtOpc,
12952 SmallVectorImpl<SDNode *> &ExtendNodes,
12953 const TargetLowering &TLI) {
12954 bool HasCopyToRegUses = false;
12955 bool isTruncFree = TLI.isTruncateFree(FromVT: VT, ToVT: N0.getValueType());
12956 for (SDNode::use_iterator UI = N0->use_begin(), UE = N0->use_end(); UI != UE;
12957 ++UI) {
12958 SDNode *User = *UI;
12959 if (User == N)
12960 continue;
12961 if (UI.getUse().getResNo() != N0.getResNo())
12962 continue;
12963 // FIXME: Only extend SETCC N, N and SETCC N, c for now.
12964 if (ExtOpc != ISD::ANY_EXTEND && User->getOpcode() == ISD::SETCC) {
12965 ISD::CondCode CC = cast<CondCodeSDNode>(Val: User->getOperand(Num: 2))->get();
12966 if (ExtOpc == ISD::ZERO_EXTEND && ISD::isSignedIntSetCC(Code: CC))
12967 // Sign bits will be lost after a zext.
12968 return false;
12969 bool Add = false;
12970 for (unsigned i = 0; i != 2; ++i) {
12971 SDValue UseOp = User->getOperand(Num: i);
12972 if (UseOp == N0)
12973 continue;
12974 if (!isa<ConstantSDNode>(Val: UseOp))
12975 return false;
12976 Add = true;
12977 }
12978 if (Add)
12979 ExtendNodes.push_back(Elt: User);
12980 continue;
12981 }
12982 // If truncates aren't free and there are users we can't
12983 // extend, it isn't worthwhile.
12984 if (!isTruncFree)
12985 return false;
12986 // Remember if this value is live-out.
12987 if (User->getOpcode() == ISD::CopyToReg)
12988 HasCopyToRegUses = true;
12989 }
12990
12991 if (HasCopyToRegUses) {
12992 bool BothLiveOut = false;
12993 for (SDNode::use_iterator UI = N->use_begin(), UE = N->use_end();
12994 UI != UE; ++UI) {
12995 SDUse &Use = UI.getUse();
12996 if (Use.getResNo() == 0 && Use.getUser()->getOpcode() == ISD::CopyToReg) {
12997 BothLiveOut = true;
12998 break;
12999 }
13000 }
13001 if (BothLiveOut)
13002 // Both unextended and extended values are live out. There had better be
13003 // a good reason for the transformation.
13004 return !ExtendNodes.empty();
13005 }
13006 return true;
13007}
13008
13009void DAGCombiner::ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
13010 SDValue OrigLoad, SDValue ExtLoad,
13011 ISD::NodeType ExtType) {
13012 // Extend SetCC uses if necessary.
13013 SDLoc DL(ExtLoad);
13014 for (SDNode *SetCC : SetCCs) {
13015 SmallVector<SDValue, 4> Ops;
13016
13017 for (unsigned j = 0; j != 2; ++j) {
13018 SDValue SOp = SetCC->getOperand(Num: j);
13019 if (SOp == OrigLoad)
13020 Ops.push_back(Elt: ExtLoad);
13021 else
13022 Ops.push_back(Elt: DAG.getNode(Opcode: ExtType, DL, VT: ExtLoad->getValueType(ResNo: 0), Operand: SOp));
13023 }
13024
13025 Ops.push_back(Elt: SetCC->getOperand(Num: 2));
13026 CombineTo(N: SetCC, Res: DAG.getNode(Opcode: ISD::SETCC, DL, VT: SetCC->getValueType(ResNo: 0), Ops));
13027 }
13028}
13029
13030// FIXME: Bring more similar combines here, common to sext/zext (maybe aext?).
13031SDValue DAGCombiner::CombineExtLoad(SDNode *N) {
13032 SDValue N0 = N->getOperand(Num: 0);
13033 EVT DstVT = N->getValueType(ResNo: 0);
13034 EVT SrcVT = N0.getValueType();
13035
13036 assert((N->getOpcode() == ISD::SIGN_EXTEND ||
13037 N->getOpcode() == ISD::ZERO_EXTEND) &&
13038 "Unexpected node type (not an extend)!");
13039
13040 // fold (sext (load x)) to multiple smaller sextloads; same for zext.
13041 // For example, on a target with legal v4i32, but illegal v8i32, turn:
13042 // (v8i32 (sext (v8i16 (load x))))
13043 // into:
13044 // (v8i32 (concat_vectors (v4i32 (sextload x)),
13045 // (v4i32 (sextload (x + 16)))))
13046 // Where uses of the original load, i.e.:
13047 // (v8i16 (load x))
13048 // are replaced with:
13049 // (v8i16 (truncate
13050 // (v8i32 (concat_vectors (v4i32 (sextload x)),
13051 // (v4i32 (sextload (x + 16)))))))
13052 //
13053 // This combine is only applicable to illegal, but splittable, vectors.
13054 // All legal types, and illegal non-vector types, are handled elsewhere.
13055 // This combine is controlled by TargetLowering::isVectorLoadExtDesirable.
13056 //
13057 if (N0->getOpcode() != ISD::LOAD)
13058 return SDValue();
13059
13060 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
13061
13062 if (!ISD::isNON_EXTLoad(N: LN0) || !ISD::isUNINDEXEDLoad(N: LN0) ||
13063 !N0.hasOneUse() || !LN0->isSimple() ||
13064 !DstVT.isVector() || !DstVT.isPow2VectorType() ||
13065 !TLI.isVectorLoadExtDesirable(ExtVal: SDValue(N, 0)))
13066 return SDValue();
13067
13068 SmallVector<SDNode *, 4> SetCCs;
13069 if (!ExtendUsesToFormExtLoad(VT: DstVT, N, N0, ExtOpc: N->getOpcode(), ExtendNodes&: SetCCs, TLI))
13070 return SDValue();
13071
13072 ISD::LoadExtType ExtType =
13073 N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
13074
13075 // Try to split the vector types to get down to legal types.
13076 EVT SplitSrcVT = SrcVT;
13077 EVT SplitDstVT = DstVT;
13078 while (!TLI.isLoadExtLegalOrCustom(ExtType, ValVT: SplitDstVT, MemVT: SplitSrcVT) &&
13079 SplitSrcVT.getVectorNumElements() > 1) {
13080 SplitDstVT = DAG.GetSplitDestVTs(VT: SplitDstVT).first;
13081 SplitSrcVT = DAG.GetSplitDestVTs(VT: SplitSrcVT).first;
13082 }
13083
13084 if (!TLI.isLoadExtLegalOrCustom(ExtType, ValVT: SplitDstVT, MemVT: SplitSrcVT))
13085 return SDValue();
13086
13087 assert(!DstVT.isScalableVector() && "Unexpected scalable vector type");
13088
13089 SDLoc DL(N);
13090 const unsigned NumSplits =
13091 DstVT.getVectorNumElements() / SplitDstVT.getVectorNumElements();
13092 const unsigned Stride = SplitSrcVT.getStoreSize();
13093 SmallVector<SDValue, 4> Loads;
13094 SmallVector<SDValue, 4> Chains;
13095
13096 SDValue BasePtr = LN0->getBasePtr();
13097 for (unsigned Idx = 0; Idx < NumSplits; Idx++) {
13098 const unsigned Offset = Idx * Stride;
13099
13100 SDValue SplitLoad =
13101 DAG.getExtLoad(ExtType, dl: SDLoc(LN0), VT: SplitDstVT, Chain: LN0->getChain(),
13102 Ptr: BasePtr, PtrInfo: LN0->getPointerInfo().getWithOffset(O: Offset),
13103 MemVT: SplitSrcVT, Alignment: LN0->getOriginalAlign(),
13104 MMOFlags: LN0->getMemOperand()->getFlags(), AAInfo: LN0->getAAInfo());
13105
13106 BasePtr = DAG.getMemBasePlusOffset(Base: BasePtr, Offset: TypeSize::getFixed(ExactSize: Stride), DL);
13107
13108 Loads.push_back(Elt: SplitLoad.getValue(R: 0));
13109 Chains.push_back(Elt: SplitLoad.getValue(R: 1));
13110 }
13111
13112 SDValue NewChain = DAG.getNode(Opcode: ISD::TokenFactor, DL, VT: MVT::Other, Ops: Chains);
13113 SDValue NewValue = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT: DstVT, Ops: Loads);
13114
13115 // Simplify TF.
13116 AddToWorklist(N: NewChain.getNode());
13117
13118 CombineTo(N, Res: NewValue);
13119
13120 // Replace uses of the original load (before extension)
13121 // with a truncate of the concatenated sextloaded vectors.
13122 SDValue Trunc =
13123 DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(N0), VT: N0.getValueType(), Operand: NewValue);
13124 ExtendSetCCUses(SetCCs, OrigLoad: N0, ExtLoad: NewValue, ExtType: (ISD::NodeType)N->getOpcode());
13125 CombineTo(N: N0.getNode(), Res0: Trunc, Res1: NewChain);
13126 return SDValue(N, 0); // Return N so it doesn't get rechecked!
13127}
13128
13129// fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
13130// (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
13131SDValue DAGCombiner::CombineZExtLogicopShiftLoad(SDNode *N) {
13132 assert(N->getOpcode() == ISD::ZERO_EXTEND);
13133 EVT VT = N->getValueType(ResNo: 0);
13134 EVT OrigVT = N->getOperand(Num: 0).getValueType();
13135 if (TLI.isZExtFree(FromTy: OrigVT, ToTy: VT))
13136 return SDValue();
13137
13138 // and/or/xor
13139 SDValue N0 = N->getOperand(Num: 0);
13140 if (!ISD::isBitwiseLogicOp(Opcode: N0.getOpcode()) ||
13141 N0.getOperand(i: 1).getOpcode() != ISD::Constant ||
13142 (LegalOperations && !TLI.isOperationLegal(Op: N0.getOpcode(), VT)))
13143 return SDValue();
13144
13145 // shl/shr
13146 SDValue N1 = N0->getOperand(Num: 0);
13147 if (!(N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) ||
13148 N1.getOperand(i: 1).getOpcode() != ISD::Constant ||
13149 (LegalOperations && !TLI.isOperationLegal(Op: N1.getOpcode(), VT)))
13150 return SDValue();
13151
13152 // load
13153 if (!isa<LoadSDNode>(Val: N1.getOperand(i: 0)))
13154 return SDValue();
13155 LoadSDNode *Load = cast<LoadSDNode>(Val: N1.getOperand(i: 0));
13156 EVT MemVT = Load->getMemoryVT();
13157 if (!TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: VT, MemVT) ||
13158 Load->getExtensionType() == ISD::SEXTLOAD || Load->isIndexed())
13159 return SDValue();
13160
13161
13162 // If the shift op is SHL, the logic op must be AND, otherwise the result
13163 // will be wrong.
13164 if (N1.getOpcode() == ISD::SHL && N0.getOpcode() != ISD::AND)
13165 return SDValue();
13166
13167 if (!N0.hasOneUse() || !N1.hasOneUse())
13168 return SDValue();
13169
13170 SmallVector<SDNode*, 4> SetCCs;
13171 if (!ExtendUsesToFormExtLoad(VT, N: N1.getNode(), N0: N1.getOperand(i: 0),
13172 ExtOpc: ISD::ZERO_EXTEND, ExtendNodes&: SetCCs, TLI))
13173 return SDValue();
13174
13175 // Actually do the transformation.
13176 SDValue ExtLoad = DAG.getExtLoad(ExtType: ISD::ZEXTLOAD, dl: SDLoc(Load), VT,
13177 Chain: Load->getChain(), Ptr: Load->getBasePtr(),
13178 MemVT: Load->getMemoryVT(), MMO: Load->getMemOperand());
13179
13180 SDLoc DL1(N1);
13181 SDValue Shift = DAG.getNode(Opcode: N1.getOpcode(), DL: DL1, VT, N1: ExtLoad,
13182 N2: N1.getOperand(i: 1));
13183
13184 APInt Mask = N0.getConstantOperandAPInt(i: 1).zext(width: VT.getSizeInBits());
13185 SDLoc DL0(N0);
13186 SDValue And = DAG.getNode(Opcode: N0.getOpcode(), DL: DL0, VT, N1: Shift,
13187 N2: DAG.getConstant(Val: Mask, DL: DL0, VT));
13188
13189 ExtendSetCCUses(SetCCs, OrigLoad: N1.getOperand(i: 0), ExtLoad, ExtType: ISD::ZERO_EXTEND);
13190 CombineTo(N, Res: And);
13191 if (SDValue(Load, 0).hasOneUse()) {
13192 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Load, 1), To: ExtLoad.getValue(R: 1));
13193 } else {
13194 SDValue Trunc = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(Load),
13195 VT: Load->getValueType(ResNo: 0), Operand: ExtLoad);
13196 CombineTo(N: Load, Res0: Trunc, Res1: ExtLoad.getValue(R: 1));
13197 }
13198
13199 // N0 is dead at this point.
13200 recursivelyDeleteUnusedNodes(N: N0.getNode());
13201
13202 return SDValue(N,0); // Return N so it doesn't get rechecked!
13203}
13204
13205/// If we're narrowing or widening the result of a vector select and the final
13206/// size is the same size as a setcc (compare) feeding the select, then try to
13207/// apply the cast operation to the select's operands because matching vector
13208/// sizes for a select condition and other operands should be more efficient.
13209SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
13210 unsigned CastOpcode = Cast->getOpcode();
13211 assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
13212 CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
13213 CastOpcode == ISD::FP_ROUND) &&
13214 "Unexpected opcode for vector select narrowing/widening");
13215
13216 // We only do this transform before legal ops because the pattern may be
13217 // obfuscated by target-specific operations after legalization. Do not create
13218 // an illegal select op, however, because that may be difficult to lower.
13219 EVT VT = Cast->getValueType(ResNo: 0);
13220 if (LegalOperations || !TLI.isOperationLegalOrCustom(Op: ISD::VSELECT, VT))
13221 return SDValue();
13222
13223 SDValue VSel = Cast->getOperand(Num: 0);
13224 if (VSel.getOpcode() != ISD::VSELECT || !VSel.hasOneUse() ||
13225 VSel.getOperand(i: 0).getOpcode() != ISD::SETCC)
13226 return SDValue();
13227
13228 // Does the setcc have the same vector size as the casted select?
13229 SDValue SetCC = VSel.getOperand(i: 0);
13230 EVT SetCCVT = getSetCCResultType(VT: SetCC.getOperand(i: 0).getValueType());
13231 if (SetCCVT.getSizeInBits() != VT.getSizeInBits())
13232 return SDValue();
13233
13234 // cast (vsel (setcc X), A, B) --> vsel (setcc X), (cast A), (cast B)
13235 SDValue A = VSel.getOperand(i: 1);
13236 SDValue B = VSel.getOperand(i: 2);
13237 SDValue CastA, CastB;
13238 SDLoc DL(Cast);
13239 if (CastOpcode == ISD::FP_ROUND) {
13240 // FP_ROUND (fptrunc) has an extra flag operand to pass along.
13241 CastA = DAG.getNode(Opcode: CastOpcode, DL, VT, N1: A, N2: Cast->getOperand(Num: 1));
13242 CastB = DAG.getNode(Opcode: CastOpcode, DL, VT, N1: B, N2: Cast->getOperand(Num: 1));
13243 } else {
13244 CastA = DAG.getNode(Opcode: CastOpcode, DL, VT, Operand: A);
13245 CastB = DAG.getNode(Opcode: CastOpcode, DL, VT, Operand: B);
13246 }
13247 return DAG.getNode(Opcode: ISD::VSELECT, DL, VT, N1: SetCC, N2: CastA, N3: CastB);
13248}
13249
13250// fold ([s|z]ext ([s|z]extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
13251// fold ([s|z]ext ( extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
13252static SDValue tryToFoldExtOfExtload(SelectionDAG &DAG, DAGCombiner &Combiner,
13253 const TargetLowering &TLI, EVT VT,
13254 bool LegalOperations, SDNode *N,
13255 SDValue N0, ISD::LoadExtType ExtLoadType) {
13256 SDNode *N0Node = N0.getNode();
13257 bool isAExtLoad = (ExtLoadType == ISD::SEXTLOAD) ? ISD::isSEXTLoad(N: N0Node)
13258 : ISD::isZEXTLoad(N: N0Node);
13259 if ((!isAExtLoad && !ISD::isEXTLoad(N: N0Node)) ||
13260 !ISD::isUNINDEXEDLoad(N: N0Node) || !N0.hasOneUse())
13261 return SDValue();
13262
13263 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
13264 EVT MemVT = LN0->getMemoryVT();
13265 if ((LegalOperations || !LN0->isSimple() ||
13266 VT.isVector()) &&
13267 !TLI.isLoadExtLegal(ExtType: ExtLoadType, ValVT: VT, MemVT))
13268 return SDValue();
13269
13270 SDValue ExtLoad =
13271 DAG.getExtLoad(ExtType: ExtLoadType, dl: SDLoc(LN0), VT, Chain: LN0->getChain(),
13272 Ptr: LN0->getBasePtr(), MemVT, MMO: LN0->getMemOperand());
13273 Combiner.CombineTo(N, Res: ExtLoad);
13274 DAG.ReplaceAllUsesOfValueWith(From: SDValue(LN0, 1), To: ExtLoad.getValue(R: 1));
13275 if (LN0->use_empty())
13276 Combiner.recursivelyDeleteUnusedNodes(N: LN0);
13277 return SDValue(N, 0); // Return N so it doesn't get rechecked!
13278}
13279
13280// fold ([s|z]ext (load x)) -> ([s|z]ext (truncate ([s|z]extload x)))
13281// Only generate vector extloads when 1) they're legal, and 2) they are
13282// deemed desirable by the target. NonNegZExt can be set to true if a zero
13283// extend has the nonneg flag to allow use of sextload if profitable.
13284static SDValue tryToFoldExtOfLoad(SelectionDAG &DAG, DAGCombiner &Combiner,
13285 const TargetLowering &TLI, EVT VT,
13286 bool LegalOperations, SDNode *N, SDValue N0,
13287 ISD::LoadExtType ExtLoadType,
13288 ISD::NodeType ExtOpc,
13289 bool NonNegZExt = false) {
13290 if (!ISD::isNON_EXTLoad(N: N0.getNode()) || !ISD::isUNINDEXEDLoad(N: N0.getNode()))
13291 return {};
13292
13293 // If this is zext nneg, see if it would make sense to treat it as a sext.
13294 if (NonNegZExt) {
13295 assert(ExtLoadType == ISD::ZEXTLOAD && ExtOpc == ISD::ZERO_EXTEND &&
13296 "Unexpected load type or opcode");
13297 for (SDNode *User : N0->uses()) {
13298 if (User->getOpcode() == ISD::SETCC) {
13299 ISD::CondCode CC = cast<CondCodeSDNode>(Val: User->getOperand(Num: 2))->get();
13300 if (ISD::isSignedIntSetCC(Code: CC)) {
13301 ExtLoadType = ISD::SEXTLOAD;
13302 ExtOpc = ISD::SIGN_EXTEND;
13303 break;
13304 }
13305 }
13306 }
13307 }
13308
13309 // TODO: isFixedLengthVector() should be removed and any negative effects on
13310 // code generation being the result of that target's implementation of
13311 // isVectorLoadExtDesirable().
13312 if ((LegalOperations || VT.isFixedLengthVector() ||
13313 !cast<LoadSDNode>(Val&: N0)->isSimple()) &&
13314 !TLI.isLoadExtLegal(ExtType: ExtLoadType, ValVT: VT, MemVT: N0.getValueType()))
13315 return {};
13316
13317 bool DoXform = true;
13318 SmallVector<SDNode *, 4> SetCCs;
13319 if (!N0.hasOneUse())
13320 DoXform = ExtendUsesToFormExtLoad(VT, N, N0, ExtOpc, ExtendNodes&: SetCCs, TLI);
13321 if (VT.isVector())
13322 DoXform &= TLI.isVectorLoadExtDesirable(ExtVal: SDValue(N, 0));
13323 if (!DoXform)
13324 return {};
13325
13326 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
13327 SDValue ExtLoad = DAG.getExtLoad(ExtType: ExtLoadType, dl: SDLoc(LN0), VT, Chain: LN0->getChain(),
13328 Ptr: LN0->getBasePtr(), MemVT: N0.getValueType(),
13329 MMO: LN0->getMemOperand());
13330 Combiner.ExtendSetCCUses(SetCCs, OrigLoad: N0, ExtLoad, ExtType: ExtOpc);
13331 // If the load value is used only by N, replace it via CombineTo N.
13332 bool NoReplaceTrunc = SDValue(LN0, 0).hasOneUse();
13333 Combiner.CombineTo(N, Res: ExtLoad);
13334 if (NoReplaceTrunc) {
13335 DAG.ReplaceAllUsesOfValueWith(From: SDValue(LN0, 1), To: ExtLoad.getValue(R: 1));
13336 Combiner.recursivelyDeleteUnusedNodes(N: LN0);
13337 } else {
13338 SDValue Trunc =
13339 DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(N0), VT: N0.getValueType(), Operand: ExtLoad);
13340 Combiner.CombineTo(N: LN0, Res0: Trunc, Res1: ExtLoad.getValue(R: 1));
13341 }
13342 return SDValue(N, 0); // Return N so it doesn't get rechecked!
13343}
13344
13345static SDValue
13346tryToFoldExtOfMaskedLoad(SelectionDAG &DAG, const TargetLowering &TLI, EVT VT,
13347 bool LegalOperations, SDNode *N, SDValue N0,
13348 ISD::LoadExtType ExtLoadType, ISD::NodeType ExtOpc) {
13349 if (!N0.hasOneUse())
13350 return SDValue();
13351
13352 MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(Val&: N0);
13353 if (!Ld || Ld->getExtensionType() != ISD::NON_EXTLOAD)
13354 return SDValue();
13355
13356 if ((LegalOperations || !cast<MaskedLoadSDNode>(Val&: N0)->isSimple()) &&
13357 !TLI.isLoadExtLegalOrCustom(ExtType: ExtLoadType, ValVT: VT, MemVT: Ld->getValueType(ResNo: 0)))
13358 return SDValue();
13359
13360 if (!TLI.isVectorLoadExtDesirable(ExtVal: SDValue(N, 0)))
13361 return SDValue();
13362
13363 SDLoc dl(Ld);
13364 SDValue PassThru = DAG.getNode(Opcode: ExtOpc, DL: dl, VT, Operand: Ld->getPassThru());
13365 SDValue NewLoad = DAG.getMaskedLoad(
13366 VT, dl, Chain: Ld->getChain(), Base: Ld->getBasePtr(), Offset: Ld->getOffset(), Mask: Ld->getMask(),
13367 Src0: PassThru, MemVT: Ld->getMemoryVT(), MMO: Ld->getMemOperand(), AM: Ld->getAddressingMode(),
13368 ExtLoadType, IsExpanding: Ld->isExpandingLoad());
13369 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Ld, 1), To: SDValue(NewLoad.getNode(), 1));
13370 return NewLoad;
13371}
13372
13373// fold ([s|z]ext (atomic_load)) -> ([s|z]ext (truncate ([s|z]ext atomic_load)))
13374static SDValue tryToFoldExtOfAtomicLoad(SelectionDAG &DAG,
13375 const TargetLowering &TLI, EVT VT,
13376 SDValue N0,
13377 ISD::LoadExtType ExtLoadType) {
13378 auto *ALoad = dyn_cast<AtomicSDNode>(Val&: N0);
13379 if (!ALoad || ALoad->getOpcode() != ISD::ATOMIC_LOAD)
13380 return {};
13381 EVT MemoryVT = ALoad->getMemoryVT();
13382 if (!TLI.isAtomicLoadExtLegal(ExtType: ExtLoadType, ValVT: VT, MemVT: MemoryVT))
13383 return {};
13384 // Can't fold into ALoad if it is already extending differently.
13385 ISD::LoadExtType ALoadExtTy = ALoad->getExtensionType();
13386 if ((ALoadExtTy == ISD::ZEXTLOAD && ExtLoadType == ISD::SEXTLOAD) ||
13387 (ALoadExtTy == ISD::SEXTLOAD && ExtLoadType == ISD::ZEXTLOAD))
13388 return {};
13389
13390 EVT OrigVT = ALoad->getValueType(ResNo: 0);
13391 assert(OrigVT.getSizeInBits() < VT.getSizeInBits() && "VT should be wider.");
13392 auto *NewALoad = cast<AtomicSDNode>(Val: DAG.getAtomic(
13393 Opcode: ISD::ATOMIC_LOAD, dl: SDLoc(ALoad), MemVT: MemoryVT, VT, Chain: ALoad->getChain(),
13394 Ptr: ALoad->getBasePtr(), MMO: ALoad->getMemOperand()));
13395 NewALoad->setExtensionType(ExtLoadType);
13396 DAG.ReplaceAllUsesOfValueWith(
13397 From: SDValue(ALoad, 0),
13398 To: DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(ALoad), VT: OrigVT, Operand: SDValue(NewALoad, 0)));
13399 // Update the chain uses.
13400 DAG.ReplaceAllUsesOfValueWith(From: SDValue(ALoad, 1), To: SDValue(NewALoad, 1));
13401 return SDValue(NewALoad, 0);
13402}
13403
13404static SDValue foldExtendedSignBitTest(SDNode *N, SelectionDAG &DAG,
13405 bool LegalOperations) {
13406 assert((N->getOpcode() == ISD::SIGN_EXTEND ||
13407 N->getOpcode() == ISD::ZERO_EXTEND) && "Expected sext or zext");
13408
13409 SDValue SetCC = N->getOperand(Num: 0);
13410 if (LegalOperations || SetCC.getOpcode() != ISD::SETCC ||
13411 !SetCC.hasOneUse() || SetCC.getValueType() != MVT::i1)
13412 return SDValue();
13413
13414 SDValue X = SetCC.getOperand(i: 0);
13415 SDValue Ones = SetCC.getOperand(i: 1);
13416 ISD::CondCode CC = cast<CondCodeSDNode>(Val: SetCC.getOperand(i: 2))->get();
13417 EVT VT = N->getValueType(ResNo: 0);
13418 EVT XVT = X.getValueType();
13419 // setge X, C is canonicalized to setgt, so we do not need to match that
13420 // pattern. The setlt sibling is folded in SimplifySelectCC() because it does
13421 // not require the 'not' op.
13422 if (CC == ISD::SETGT && isAllOnesConstant(V: Ones) && VT == XVT) {
13423 // Invert and smear/shift the sign bit:
13424 // sext i1 (setgt iN X, -1) --> sra (not X), (N - 1)
13425 // zext i1 (setgt iN X, -1) --> srl (not X), (N - 1)
13426 SDLoc DL(N);
13427 unsigned ShCt = VT.getSizeInBits() - 1;
13428 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
13429 if (!TLI.shouldAvoidTransformToShift(VT, Amount: ShCt)) {
13430 SDValue NotX = DAG.getNOT(DL, Val: X, VT);
13431 SDValue ShiftAmount = DAG.getConstant(Val: ShCt, DL, VT);
13432 auto ShiftOpcode =
13433 N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SRA : ISD::SRL;
13434 return DAG.getNode(Opcode: ShiftOpcode, DL, VT, N1: NotX, N2: ShiftAmount);
13435 }
13436 }
13437 return SDValue();
13438}
13439
13440SDValue DAGCombiner::foldSextSetcc(SDNode *N) {
13441 SDValue N0 = N->getOperand(Num: 0);
13442 if (N0.getOpcode() != ISD::SETCC)
13443 return SDValue();
13444
13445 SDValue N00 = N0.getOperand(i: 0);
13446 SDValue N01 = N0.getOperand(i: 1);
13447 ISD::CondCode CC = cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get();
13448 EVT VT = N->getValueType(ResNo: 0);
13449 EVT N00VT = N00.getValueType();
13450 SDLoc DL(N);
13451
13452 // Propagate fast-math-flags.
13453 SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
13454
13455 // On some architectures (such as SSE/NEON/etc) the SETCC result type is
13456 // the same size as the compared operands. Try to optimize sext(setcc())
13457 // if this is the case.
13458 if (VT.isVector() && !LegalOperations &&
13459 TLI.getBooleanContents(Type: N00VT) ==
13460 TargetLowering::ZeroOrNegativeOneBooleanContent) {
13461 EVT SVT = getSetCCResultType(VT: N00VT);
13462
13463 // If we already have the desired type, don't change it.
13464 if (SVT != N0.getValueType()) {
13465 // We know that the # elements of the results is the same as the
13466 // # elements of the compare (and the # elements of the compare result
13467 // for that matter). Check to see that they are the same size. If so,
13468 // we know that the element size of the sext'd result matches the
13469 // element size of the compare operands.
13470 if (VT.getSizeInBits() == SVT.getSizeInBits())
13471 return DAG.getSetCC(DL, VT, LHS: N00, RHS: N01, Cond: CC);
13472
13473 // If the desired elements are smaller or larger than the source
13474 // elements, we can use a matching integer vector type and then
13475 // truncate/sign extend.
13476 EVT MatchingVecType = N00VT.changeVectorElementTypeToInteger();
13477 if (SVT == MatchingVecType) {
13478 SDValue VsetCC = DAG.getSetCC(DL, VT: MatchingVecType, LHS: N00, RHS: N01, Cond: CC);
13479 return DAG.getSExtOrTrunc(Op: VsetCC, DL, VT);
13480 }
13481 }
13482
13483 // Try to eliminate the sext of a setcc by zexting the compare operands.
13484 if (N0.hasOneUse() && TLI.isOperationLegalOrCustom(Op: ISD::SETCC, VT) &&
13485 !TLI.isOperationLegalOrCustom(Op: ISD::SETCC, VT: SVT)) {
13486 bool IsSignedCmp = ISD::isSignedIntSetCC(Code: CC);
13487 unsigned LoadOpcode = IsSignedCmp ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
13488 unsigned ExtOpcode = IsSignedCmp ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
13489
13490 // We have an unsupported narrow vector compare op that would be legal
13491 // if extended to the destination type. See if the compare operands
13492 // can be freely extended to the destination type.
13493 auto IsFreeToExtend = [&](SDValue V) {
13494 if (isConstantOrConstantVector(N: V, /*NoOpaques*/ true))
13495 return true;
13496 // Match a simple, non-extended load that can be converted to a
13497 // legal {z/s}ext-load.
13498 // TODO: Allow widening of an existing {z/s}ext-load?
13499 if (!(ISD::isNON_EXTLoad(N: V.getNode()) &&
13500 ISD::isUNINDEXEDLoad(N: V.getNode()) &&
13501 cast<LoadSDNode>(Val&: V)->isSimple() &&
13502 TLI.isLoadExtLegal(ExtType: LoadOpcode, ValVT: VT, MemVT: V.getValueType())))
13503 return false;
13504
13505 // Non-chain users of this value must either be the setcc in this
13506 // sequence or extends that can be folded into the new {z/s}ext-load.
13507 for (SDNode::use_iterator UI = V->use_begin(), UE = V->use_end();
13508 UI != UE; ++UI) {
13509 // Skip uses of the chain and the setcc.
13510 SDNode *User = *UI;
13511 if (UI.getUse().getResNo() != 0 || User == N0.getNode())
13512 continue;
13513 // Extra users must have exactly the same cast we are about to create.
13514 // TODO: This restriction could be eased if ExtendUsesToFormExtLoad()
13515 // is enhanced similarly.
13516 if (User->getOpcode() != ExtOpcode || User->getValueType(ResNo: 0) != VT)
13517 return false;
13518 }
13519 return true;
13520 };
13521
13522 if (IsFreeToExtend(N00) && IsFreeToExtend(N01)) {
13523 SDValue Ext0 = DAG.getNode(Opcode: ExtOpcode, DL, VT, Operand: N00);
13524 SDValue Ext1 = DAG.getNode(Opcode: ExtOpcode, DL, VT, Operand: N01);
13525 return DAG.getSetCC(DL, VT, LHS: Ext0, RHS: Ext1, Cond: CC);
13526 }
13527 }
13528 }
13529
13530 // sext(setcc x, y, cc) -> (select (setcc x, y, cc), T, 0)
13531 // Here, T can be 1 or -1, depending on the type of the setcc and
13532 // getBooleanContents().
13533 unsigned SetCCWidth = N0.getScalarValueSizeInBits();
13534
13535 // To determine the "true" side of the select, we need to know the high bit
13536 // of the value returned by the setcc if it evaluates to true.
13537 // If the type of the setcc is i1, then the true case of the select is just
13538 // sext(i1 1), that is, -1.
13539 // If the type of the setcc is larger (say, i8) then the value of the high
13540 // bit depends on getBooleanContents(), so ask TLI for a real "true" value
13541 // of the appropriate width.
13542 SDValue ExtTrueVal = (SetCCWidth == 1)
13543 ? DAG.getAllOnesConstant(DL, VT)
13544 : DAG.getBoolConstant(V: true, DL, VT, OpVT: N00VT);
13545 SDValue Zero = DAG.getConstant(Val: 0, DL, VT);
13546 if (SDValue SCC = SimplifySelectCC(DL, N0: N00, N1: N01, N2: ExtTrueVal, N3: Zero, CC, NotExtCompare: true))
13547 return SCC;
13548
13549 if (!VT.isVector() && !shouldConvertSelectOfConstantsToMath(Cond: N0, VT, TLI)) {
13550 EVT SetCCVT = getSetCCResultType(VT: N00VT);
13551 // Don't do this transform for i1 because there's a select transform
13552 // that would reverse it.
13553 // TODO: We should not do this transform at all without a target hook
13554 // because a sext is likely cheaper than a select?
13555 if (SetCCVT.getScalarSizeInBits() != 1 &&
13556 (!LegalOperations || TLI.isOperationLegal(Op: ISD::SETCC, VT: N00VT))) {
13557 SDValue SetCC = DAG.getSetCC(DL, VT: SetCCVT, LHS: N00, RHS: N01, Cond: CC);
13558 return DAG.getSelect(DL, VT, Cond: SetCC, LHS: ExtTrueVal, RHS: Zero);
13559 }
13560 }
13561
13562 return SDValue();
13563}
13564
13565SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
13566 SDValue N0 = N->getOperand(Num: 0);
13567 EVT VT = N->getValueType(ResNo: 0);
13568 SDLoc DL(N);
13569
13570 if (VT.isVector())
13571 if (SDValue FoldedVOp = SimplifyVCastOp(N, DL))
13572 return FoldedVOp;
13573
13574 // sext(undef) = 0 because the top bit will all be the same.
13575 if (N0.isUndef())
13576 return DAG.getConstant(Val: 0, DL, VT);
13577
13578 if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
13579 return Res;
13580
13581 // fold (sext (sext x)) -> (sext x)
13582 // fold (sext (aext x)) -> (sext x)
13583 if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND)
13584 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: N0.getOperand(i: 0));
13585
13586 // fold (sext (aext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x)
13587 // fold (sext (sext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x)
13588 if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
13589 N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG)
13590 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_VECTOR_INREG, DL: SDLoc(N), VT,
13591 Operand: N0.getOperand(i: 0));
13592
13593 // fold (sext (sext_inreg x)) -> (sext (trunc x))
13594 if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG) {
13595 SDValue N00 = N0.getOperand(i: 0);
13596 EVT ExtVT = cast<VTSDNode>(Val: N0->getOperand(Num: 1))->getVT();
13597 if ((N00.getOpcode() == ISD::TRUNCATE || TLI.isTruncateFree(Val: N00, VT2: ExtVT)) &&
13598 (!LegalTypes || TLI.isTypeLegal(VT: ExtVT))) {
13599 SDValue T = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: ExtVT, Operand: N00);
13600 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: T);
13601 }
13602 }
13603
13604 if (N0.getOpcode() == ISD::TRUNCATE) {
13605 // fold (sext (truncate (load x))) -> (sext (smaller load x))
13606 // fold (sext (truncate (srl (load x), c))) -> (sext (smaller load (x+c/n)))
13607 if (SDValue NarrowLoad = reduceLoadWidth(N: N0.getNode())) {
13608 SDNode *oye = N0.getOperand(i: 0).getNode();
13609 if (NarrowLoad.getNode() != N0.getNode()) {
13610 CombineTo(N: N0.getNode(), Res: NarrowLoad);
13611 // CombineTo deleted the truncate, if needed, but not what's under it.
13612 AddToWorklist(N: oye);
13613 }
13614 return SDValue(N, 0); // Return N so it doesn't get rechecked!
13615 }
13616
13617 // See if the value being truncated is already sign extended. If so, just
13618 // eliminate the trunc/sext pair.
13619 SDValue Op = N0.getOperand(i: 0);
13620 unsigned OpBits = Op.getScalarValueSizeInBits();
13621 unsigned MidBits = N0.getScalarValueSizeInBits();
13622 unsigned DestBits = VT.getScalarSizeInBits();
13623 unsigned NumSignBits = DAG.ComputeNumSignBits(Op);
13624
13625 if (OpBits == DestBits) {
13626 // Op is i32, Mid is i8, and Dest is i32. If Op has more than 24 sign
13627 // bits, it is already ready.
13628 if (NumSignBits > DestBits-MidBits)
13629 return Op;
13630 } else if (OpBits < DestBits) {
13631 // Op is i32, Mid is i8, and Dest is i64. If Op has more than 24 sign
13632 // bits, just sext from i32.
13633 if (NumSignBits > OpBits-MidBits)
13634 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: Op);
13635 } else {
13636 // Op is i64, Mid is i8, and Dest is i32. If Op has more than 56 sign
13637 // bits, just truncate to i32.
13638 if (NumSignBits > OpBits-MidBits)
13639 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: Op);
13640 }
13641
13642 // fold (sext (truncate x)) -> (sextinreg x).
13643 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::SIGN_EXTEND_INREG,
13644 VT: N0.getValueType())) {
13645 if (OpBits < DestBits)
13646 Op = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL: SDLoc(N0), VT, Operand: Op);
13647 else if (OpBits > DestBits)
13648 Op = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(N0), VT, Operand: Op);
13649 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL, VT, N1: Op,
13650 N2: DAG.getValueType(N0.getValueType()));
13651 }
13652 }
13653
13654 // Try to simplify (sext (load x)).
13655 if (SDValue foldedExt =
13656 tryToFoldExtOfLoad(DAG, Combiner&: *this, TLI, VT, LegalOperations, N, N0,
13657 ExtLoadType: ISD::SEXTLOAD, ExtOpc: ISD::SIGN_EXTEND))
13658 return foldedExt;
13659
13660 if (SDValue foldedExt =
13661 tryToFoldExtOfMaskedLoad(DAG, TLI, VT, LegalOperations, N, N0,
13662 ExtLoadType: ISD::SEXTLOAD, ExtOpc: ISD::SIGN_EXTEND))
13663 return foldedExt;
13664
13665 // fold (sext (load x)) to multiple smaller sextloads.
13666 // Only on illegal but splittable vectors.
13667 if (SDValue ExtLoad = CombineExtLoad(N))
13668 return ExtLoad;
13669
13670 // Try to simplify (sext (sextload x)).
13671 if (SDValue foldedExt = tryToFoldExtOfExtload(
13672 DAG, Combiner&: *this, TLI, VT, LegalOperations, N, N0, ExtLoadType: ISD::SEXTLOAD))
13673 return foldedExt;
13674
13675 // Try to simplify (sext (atomic_load x)).
13676 if (SDValue foldedExt =
13677 tryToFoldExtOfAtomicLoad(DAG, TLI, VT, N0, ExtLoadType: ISD::SEXTLOAD))
13678 return foldedExt;
13679
13680 // fold (sext (and/or/xor (load x), cst)) ->
13681 // (and/or/xor (sextload x), (sext cst))
13682 if (ISD::isBitwiseLogicOp(Opcode: N0.getOpcode()) &&
13683 isa<LoadSDNode>(Val: N0.getOperand(i: 0)) &&
13684 N0.getOperand(i: 1).getOpcode() == ISD::Constant &&
13685 (!LegalOperations && TLI.isOperationLegal(Op: N0.getOpcode(), VT))) {
13686 LoadSDNode *LN00 = cast<LoadSDNode>(Val: N0.getOperand(i: 0));
13687 EVT MemVT = LN00->getMemoryVT();
13688 if (TLI.isLoadExtLegal(ExtType: ISD::SEXTLOAD, ValVT: VT, MemVT) &&
13689 LN00->getExtensionType() != ISD::ZEXTLOAD && LN00->isUnindexed()) {
13690 SmallVector<SDNode*, 4> SetCCs;
13691 bool DoXform = ExtendUsesToFormExtLoad(VT, N: N0.getNode(), N0: N0.getOperand(i: 0),
13692 ExtOpc: ISD::SIGN_EXTEND, ExtendNodes&: SetCCs, TLI);
13693 if (DoXform) {
13694 SDValue ExtLoad = DAG.getExtLoad(ExtType: ISD::SEXTLOAD, dl: SDLoc(LN00), VT,
13695 Chain: LN00->getChain(), Ptr: LN00->getBasePtr(),
13696 MemVT: LN00->getMemoryVT(),
13697 MMO: LN00->getMemOperand());
13698 APInt Mask = N0.getConstantOperandAPInt(i: 1).sext(width: VT.getSizeInBits());
13699 SDValue And = DAG.getNode(Opcode: N0.getOpcode(), DL, VT,
13700 N1: ExtLoad, N2: DAG.getConstant(Val: Mask, DL, VT));
13701 ExtendSetCCUses(SetCCs, OrigLoad: N0.getOperand(i: 0), ExtLoad, ExtType: ISD::SIGN_EXTEND);
13702 bool NoReplaceTruncAnd = !N0.hasOneUse();
13703 bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
13704 CombineTo(N, Res: And);
13705 // If N0 has multiple uses, change other uses as well.
13706 if (NoReplaceTruncAnd) {
13707 SDValue TruncAnd =
13708 DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: N0.getValueType(), Operand: And);
13709 CombineTo(N: N0.getNode(), Res: TruncAnd);
13710 }
13711 if (NoReplaceTrunc) {
13712 DAG.ReplaceAllUsesOfValueWith(From: SDValue(LN00, 1), To: ExtLoad.getValue(R: 1));
13713 } else {
13714 SDValue Trunc = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(LN00),
13715 VT: LN00->getValueType(ResNo: 0), Operand: ExtLoad);
13716 CombineTo(N: LN00, Res0: Trunc, Res1: ExtLoad.getValue(R: 1));
13717 }
13718 return SDValue(N,0); // Return N so it doesn't get rechecked!
13719 }
13720 }
13721 }
13722
13723 if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
13724 return V;
13725
13726 if (SDValue V = foldSextSetcc(N))
13727 return V;
13728
13729 // fold (sext x) -> (zext x) if the sign bit is known zero.
13730 if (!TLI.isSExtCheaperThanZExt(FromTy: N0.getValueType(), ToTy: VT) &&
13731 (!LegalOperations || TLI.isOperationLegal(Op: ISD::ZERO_EXTEND, VT)) &&
13732 DAG.SignBitIsZero(Op: N0)) {
13733 SDNodeFlags Flags;
13734 Flags.setNonNeg(true);
13735 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: N0, Flags);
13736 }
13737
13738 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(Cast: N))
13739 return NewVSel;
13740
13741 // Eliminate this sign extend by doing a negation in the destination type:
13742 // sext i32 (0 - (zext i8 X to i32)) to i64 --> 0 - (zext i8 X to i64)
13743 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
13744 isNullOrNullSplat(V: N0.getOperand(i: 0)) &&
13745 N0.getOperand(i: 1).getOpcode() == ISD::ZERO_EXTEND &&
13746 TLI.isOperationLegalOrCustom(Op: ISD::SUB, VT)) {
13747 SDValue Zext = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1).getOperand(i: 0), DL, VT);
13748 return DAG.getNegative(Val: Zext, DL, VT);
13749 }
13750 // Eliminate this sign extend by doing a decrement in the destination type:
13751 // sext i32 ((zext i8 X to i32) + (-1)) to i64 --> (zext i8 X to i64) + (-1)
13752 if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
13753 isAllOnesOrAllOnesSplat(V: N0.getOperand(i: 1)) &&
13754 N0.getOperand(i: 0).getOpcode() == ISD::ZERO_EXTEND &&
13755 TLI.isOperationLegalOrCustom(Op: ISD::ADD, VT)) {
13756 SDValue Zext = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 0).getOperand(i: 0), DL, VT);
13757 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Zext, N2: DAG.getAllOnesConstant(DL, VT));
13758 }
13759
13760 // fold sext (not i1 X) -> add (zext i1 X), -1
13761 // TODO: This could be extended to handle bool vectors.
13762 if (N0.getValueType() == MVT::i1 && isBitwiseNot(V: N0) && N0.hasOneUse() &&
13763 (!LegalOperations || (TLI.isOperationLegal(Op: ISD::ZERO_EXTEND, VT) &&
13764 TLI.isOperationLegal(Op: ISD::ADD, VT)))) {
13765 // If we can eliminate the 'not', the sext form should be better
13766 if (SDValue NewXor = visitXOR(N: N0.getNode())) {
13767 // Returning N0 is a form of in-visit replacement that may have
13768 // invalidated N0.
13769 if (NewXor.getNode() == N0.getNode()) {
13770 // Return SDValue here as the xor should have already been replaced in
13771 // this sext.
13772 return SDValue();
13773 }
13774
13775 // Return a new sext with the new xor.
13776 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: NewXor);
13777 }
13778
13779 SDValue Zext = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: N0.getOperand(i: 0));
13780 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Zext, N2: DAG.getAllOnesConstant(DL, VT));
13781 }
13782
13783 if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, DL, Level))
13784 return Res;
13785
13786 return SDValue();
13787}
13788
13789/// Given an extending node with a pop-count operand, if the target does not
13790/// support a pop-count in the narrow source type but does support it in the
13791/// destination type, widen the pop-count to the destination type.
13792static SDValue widenCtPop(SDNode *Extend, SelectionDAG &DAG, const SDLoc &DL) {
13793 assert((Extend->getOpcode() == ISD::ZERO_EXTEND ||
13794 Extend->getOpcode() == ISD::ANY_EXTEND) &&
13795 "Expected extend op");
13796
13797 SDValue CtPop = Extend->getOperand(Num: 0);
13798 if (CtPop.getOpcode() != ISD::CTPOP || !CtPop.hasOneUse())
13799 return SDValue();
13800
13801 EVT VT = Extend->getValueType(ResNo: 0);
13802 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
13803 if (TLI.isOperationLegalOrCustom(Op: ISD::CTPOP, VT: CtPop.getValueType()) ||
13804 !TLI.isOperationLegalOrCustom(Op: ISD::CTPOP, VT))
13805 return SDValue();
13806
13807 // zext (ctpop X) --> ctpop (zext X)
13808 SDValue NewZext = DAG.getZExtOrTrunc(Op: CtPop.getOperand(i: 0), DL, VT);
13809 return DAG.getNode(Opcode: ISD::CTPOP, DL, VT, Operand: NewZext);
13810}
13811
13812// If we have (zext (abs X)) where X is a type that will be promoted by type
13813// legalization, convert to (abs (sext X)). But don't extend past a legal type.
13814static SDValue widenAbs(SDNode *Extend, SelectionDAG &DAG) {
13815 assert(Extend->getOpcode() == ISD::ZERO_EXTEND && "Expected zero extend.");
13816
13817 EVT VT = Extend->getValueType(ResNo: 0);
13818 if (VT.isVector())
13819 return SDValue();
13820
13821 SDValue Abs = Extend->getOperand(Num: 0);
13822 if (Abs.getOpcode() != ISD::ABS || !Abs.hasOneUse())
13823 return SDValue();
13824
13825 EVT AbsVT = Abs.getValueType();
13826 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
13827 if (TLI.getTypeAction(Context&: *DAG.getContext(), VT: AbsVT) !=
13828 TargetLowering::TypePromoteInteger)
13829 return SDValue();
13830
13831 EVT LegalVT = TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT: AbsVT);
13832
13833 SDValue SExt =
13834 DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL: SDLoc(Abs), VT: LegalVT, Operand: Abs.getOperand(i: 0));
13835 SDValue NewAbs = DAG.getNode(Opcode: ISD::ABS, DL: SDLoc(Abs), VT: LegalVT, Operand: SExt);
13836 return DAG.getZExtOrTrunc(Op: NewAbs, DL: SDLoc(Extend), VT);
13837}
13838
13839SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
13840 SDValue N0 = N->getOperand(Num: 0);
13841 EVT VT = N->getValueType(ResNo: 0);
13842 SDLoc DL(N);
13843
13844 if (VT.isVector())
13845 if (SDValue FoldedVOp = SimplifyVCastOp(N, DL))
13846 return FoldedVOp;
13847
13848 // zext(undef) = 0
13849 if (N0.isUndef())
13850 return DAG.getConstant(Val: 0, DL, VT);
13851
13852 if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
13853 return Res;
13854
13855 // fold (zext (zext x)) -> (zext x)
13856 // fold (zext (aext x)) -> (zext x)
13857 if (N0.getOpcode() == ISD::ZERO_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) {
13858 SDNodeFlags Flags;
13859 if (N0.getOpcode() == ISD::ZERO_EXTEND)
13860 Flags.setNonNeg(N0->getFlags().hasNonNeg());
13861 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: N0.getOperand(i: 0), Flags);
13862 }
13863
13864 // fold (zext (aext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x)
13865 // fold (zext (zext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x)
13866 if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
13867 N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG)
13868 return DAG.getNode(Opcode: ISD::ZERO_EXTEND_VECTOR_INREG, DL, VT, Operand: N0.getOperand(i: 0));
13869
13870 // fold (zext (truncate x)) -> (zext x) or
13871 // (zext (truncate x)) -> (truncate x)
13872 // This is valid when the truncated bits of x are already zero.
13873 SDValue Op;
13874 KnownBits Known;
13875 if (isTruncateOf(DAG, N: N0, Op, Known)) {
13876 APInt TruncatedBits =
13877 (Op.getScalarValueSizeInBits() == N0.getScalarValueSizeInBits()) ?
13878 APInt(Op.getScalarValueSizeInBits(), 0) :
13879 APInt::getBitsSet(numBits: Op.getScalarValueSizeInBits(),
13880 loBit: N0.getScalarValueSizeInBits(),
13881 hiBit: std::min(a: Op.getScalarValueSizeInBits(),
13882 b: VT.getScalarSizeInBits()));
13883 if (TruncatedBits.isSubsetOf(RHS: Known.Zero)) {
13884 SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, DL, VT);
13885 DAG.salvageDebugInfo(N&: *N0.getNode());
13886
13887 return ZExtOrTrunc;
13888 }
13889 }
13890
13891 // fold (zext (truncate x)) -> (and x, mask)
13892 if (N0.getOpcode() == ISD::TRUNCATE) {
13893 // fold (zext (truncate (load x))) -> (zext (smaller load x))
13894 // fold (zext (truncate (srl (load x), c))) -> (zext (smaller load (x+c/n)))
13895 if (SDValue NarrowLoad = reduceLoadWidth(N: N0.getNode())) {
13896 SDNode *oye = N0.getOperand(i: 0).getNode();
13897 if (NarrowLoad.getNode() != N0.getNode()) {
13898 CombineTo(N: N0.getNode(), Res: NarrowLoad);
13899 // CombineTo deleted the truncate, if needed, but not what's under it.
13900 AddToWorklist(N: oye);
13901 }
13902 return SDValue(N, 0); // Return N so it doesn't get rechecked!
13903 }
13904
13905 EVT SrcVT = N0.getOperand(i: 0).getValueType();
13906 EVT MinVT = N0.getValueType();
13907
13908 if (N->getFlags().hasNonNeg()) {
13909 SDValue Op = N0.getOperand(i: 0);
13910 unsigned OpBits = SrcVT.getScalarSizeInBits();
13911 unsigned MidBits = MinVT.getScalarSizeInBits();
13912 unsigned DestBits = VT.getScalarSizeInBits();
13913 unsigned NumSignBits = DAG.ComputeNumSignBits(Op);
13914
13915 if (OpBits == DestBits) {
13916 // Op is i32, Mid is i8, and Dest is i32. If Op has more than 24 sign
13917 // bits, it is already ready.
13918 if (NumSignBits > DestBits - MidBits)
13919 return Op;
13920 } else if (OpBits < DestBits) {
13921 // Op is i32, Mid is i8, and Dest is i64. If Op has more than 24 sign
13922 // bits, just sext from i32.
13923 // FIXME: This can probably be ZERO_EXTEND nneg?
13924 if (NumSignBits > OpBits - MidBits)
13925 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: Op);
13926 } else {
13927 // Op is i64, Mid is i8, and Dest is i32. If Op has more than 56 sign
13928 // bits, just truncate to i32.
13929 if (NumSignBits > OpBits - MidBits)
13930 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: Op);
13931 }
13932 }
13933
13934 // Try to mask before the extension to avoid having to generate a larger mask,
13935 // possibly over several sub-vectors.
13936 if (SrcVT.bitsLT(VT) && VT.isVector()) {
13937 if (!LegalOperations || (TLI.isOperationLegal(Op: ISD::AND, VT: SrcVT) &&
13938 TLI.isOperationLegal(Op: ISD::ZERO_EXTEND, VT))) {
13939 SDValue Op = N0.getOperand(i: 0);
13940 Op = DAG.getZeroExtendInReg(Op, DL, VT: MinVT);
13941 AddToWorklist(N: Op.getNode());
13942 SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, DL, VT);
13943 // Transfer the debug info; the new node is equivalent to N0.
13944 DAG.transferDbgValues(From: N0, To: ZExtOrTrunc);
13945 return ZExtOrTrunc;
13946 }
13947 }
13948
13949 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::AND, VT)) {
13950 SDValue Op = DAG.getAnyExtOrTrunc(Op: N0.getOperand(i: 0), DL, VT);
13951 AddToWorklist(N: Op.getNode());
13952 SDValue And = DAG.getZeroExtendInReg(Op, DL, VT: MinVT);
13953 // We may safely transfer the debug info describing the truncate node over
13954 // to the equivalent and operation.
13955 DAG.transferDbgValues(From: N0, To: And);
13956 return And;
13957 }
13958 }
13959
13960 // Fold (zext (and (trunc x), cst)) -> (and x, cst),
13961 // if either of the casts is not free.
13962 if (N0.getOpcode() == ISD::AND &&
13963 N0.getOperand(i: 0).getOpcode() == ISD::TRUNCATE &&
13964 N0.getOperand(i: 1).getOpcode() == ISD::Constant &&
13965 (!TLI.isTruncateFree(Val: N0.getOperand(i: 0).getOperand(i: 0), VT2: N0.getValueType()) ||
13966 !TLI.isZExtFree(FromTy: N0.getValueType(), ToTy: VT))) {
13967 SDValue X = N0.getOperand(i: 0).getOperand(i: 0);
13968 X = DAG.getAnyExtOrTrunc(Op: X, DL: SDLoc(X), VT);
13969 APInt Mask = N0.getConstantOperandAPInt(i: 1).zext(width: VT.getSizeInBits());
13970 return DAG.getNode(Opcode: ISD::AND, DL, VT,
13971 N1: X, N2: DAG.getConstant(Val: Mask, DL, VT));
13972 }
13973
13974 // Try to simplify (zext (load x)).
13975 if (SDValue foldedExt = tryToFoldExtOfLoad(
13976 DAG, Combiner&: *this, TLI, VT, LegalOperations, N, N0, ExtLoadType: ISD::ZEXTLOAD,
13977 ExtOpc: ISD::ZERO_EXTEND, NonNegZExt: N->getFlags().hasNonNeg()))
13978 return foldedExt;
13979
13980 if (SDValue foldedExt =
13981 tryToFoldExtOfMaskedLoad(DAG, TLI, VT, LegalOperations, N, N0,
13982 ExtLoadType: ISD::ZEXTLOAD, ExtOpc: ISD::ZERO_EXTEND))
13983 return foldedExt;
13984
13985 // fold (zext (load x)) to multiple smaller zextloads.
13986 // Only on illegal but splittable vectors.
13987 if (SDValue ExtLoad = CombineExtLoad(N))
13988 return ExtLoad;
13989
13990 // Try to simplify (zext (atomic_load x)).
13991 if (SDValue foldedExt =
13992 tryToFoldExtOfAtomicLoad(DAG, TLI, VT, N0, ExtLoadType: ISD::ZEXTLOAD))
13993 return foldedExt;
13994
13995 // fold (zext (and/or/xor (load x), cst)) ->
13996 // (and/or/xor (zextload x), (zext cst))
13997 // Unless (and (load x) cst) will match as a zextload already and has
13998 // additional users, or the zext is already free.
13999 if (ISD::isBitwiseLogicOp(Opcode: N0.getOpcode()) && !TLI.isZExtFree(Val: N0, VT2: VT) &&
14000 isa<LoadSDNode>(Val: N0.getOperand(i: 0)) &&
14001 N0.getOperand(i: 1).getOpcode() == ISD::Constant &&
14002 (!LegalOperations && TLI.isOperationLegal(Op: N0.getOpcode(), VT))) {
14003 LoadSDNode *LN00 = cast<LoadSDNode>(Val: N0.getOperand(i: 0));
14004 EVT MemVT = LN00->getMemoryVT();
14005 if (TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: VT, MemVT) &&
14006 LN00->getExtensionType() != ISD::SEXTLOAD && LN00->isUnindexed()) {
14007 bool DoXform = true;
14008 SmallVector<SDNode*, 4> SetCCs;
14009 if (!N0.hasOneUse()) {
14010 if (N0.getOpcode() == ISD::AND) {
14011 auto *AndC = cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
14012 EVT LoadResultTy = AndC->getValueType(ResNo: 0);
14013 EVT ExtVT;
14014 if (isAndLoadExtLoad(AndC, LoadN: LN00, LoadResultTy, ExtVT))
14015 DoXform = false;
14016 }
14017 }
14018 if (DoXform)
14019 DoXform = ExtendUsesToFormExtLoad(VT, N: N0.getNode(), N0: N0.getOperand(i: 0),
14020 ExtOpc: ISD::ZERO_EXTEND, ExtendNodes&: SetCCs, TLI);
14021 if (DoXform) {
14022 SDValue ExtLoad = DAG.getExtLoad(ExtType: ISD::ZEXTLOAD, dl: SDLoc(LN00), VT,
14023 Chain: LN00->getChain(), Ptr: LN00->getBasePtr(),
14024 MemVT: LN00->getMemoryVT(),
14025 MMO: LN00->getMemOperand());
14026 APInt Mask = N0.getConstantOperandAPInt(i: 1).zext(width: VT.getSizeInBits());
14027 SDValue And = DAG.getNode(Opcode: N0.getOpcode(), DL, VT,
14028 N1: ExtLoad, N2: DAG.getConstant(Val: Mask, DL, VT));
14029 ExtendSetCCUses(SetCCs, OrigLoad: N0.getOperand(i: 0), ExtLoad, ExtType: ISD::ZERO_EXTEND);
14030 bool NoReplaceTruncAnd = !N0.hasOneUse();
14031 bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
14032 CombineTo(N, Res: And);
14033 // If N0 has multiple uses, change other uses as well.
14034 if (NoReplaceTruncAnd) {
14035 SDValue TruncAnd =
14036 DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: N0.getValueType(), Operand: And);
14037 CombineTo(N: N0.getNode(), Res: TruncAnd);
14038 }
14039 if (NoReplaceTrunc) {
14040 DAG.ReplaceAllUsesOfValueWith(From: SDValue(LN00, 1), To: ExtLoad.getValue(R: 1));
14041 } else {
14042 SDValue Trunc = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(LN00),
14043 VT: LN00->getValueType(ResNo: 0), Operand: ExtLoad);
14044 CombineTo(N: LN00, Res0: Trunc, Res1: ExtLoad.getValue(R: 1));
14045 }
14046 return SDValue(N,0); // Return N so it doesn't get rechecked!
14047 }
14048 }
14049 }
14050
14051 // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
14052 // (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
14053 if (SDValue ZExtLoad = CombineZExtLogicopShiftLoad(N))
14054 return ZExtLoad;
14055
14056 // Try to simplify (zext (zextload x)).
14057 if (SDValue foldedExt = tryToFoldExtOfExtload(
14058 DAG, Combiner&: *this, TLI, VT, LegalOperations, N, N0, ExtLoadType: ISD::ZEXTLOAD))
14059 return foldedExt;
14060
14061 if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
14062 return V;
14063
14064 if (N0.getOpcode() == ISD::SETCC) {
14065 // Propagate fast-math-flags.
14066 SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
14067
14068 // Only do this before legalize for now.
14069 if (!LegalOperations && VT.isVector() &&
14070 N0.getValueType().getVectorElementType() == MVT::i1) {
14071 EVT N00VT = N0.getOperand(i: 0).getValueType();
14072 if (getSetCCResultType(VT: N00VT) == N0.getValueType())
14073 return SDValue();
14074
14075 // We know that the # elements of the results is the same as the #
14076 // elements of the compare (and the # elements of the compare result for
14077 // that matter). Check to see that they are the same size. If so, we know
14078 // that the element size of the sext'd result matches the element size of
14079 // the compare operands.
14080 if (VT.getSizeInBits() == N00VT.getSizeInBits()) {
14081 // zext(setcc) -> zext_in_reg(vsetcc) for vectors.
14082 SDValue VSetCC = DAG.getNode(Opcode: ISD::SETCC, DL, VT, N1: N0.getOperand(i: 0),
14083 N2: N0.getOperand(i: 1), N3: N0.getOperand(i: 2));
14084 return DAG.getZeroExtendInReg(Op: VSetCC, DL, VT: N0.getValueType());
14085 }
14086
14087 // If the desired elements are smaller or larger than the source
14088 // elements we can use a matching integer vector type and then
14089 // truncate/any extend followed by zext_in_reg.
14090 EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
14091 SDValue VsetCC =
14092 DAG.getNode(Opcode: ISD::SETCC, DL, VT: MatchingVectorType, N1: N0.getOperand(i: 0),
14093 N2: N0.getOperand(i: 1), N3: N0.getOperand(i: 2));
14094 return DAG.getZeroExtendInReg(Op: DAG.getAnyExtOrTrunc(Op: VsetCC, DL, VT), DL,
14095 VT: N0.getValueType());
14096 }
14097
14098 // zext(setcc x,y,cc) -> zext(select x, y, true, false, cc)
14099 EVT N0VT = N0.getValueType();
14100 EVT N00VT = N0.getOperand(i: 0).getValueType();
14101 if (SDValue SCC = SimplifySelectCC(
14102 DL, N0: N0.getOperand(i: 0), N1: N0.getOperand(i: 1),
14103 N2: DAG.getBoolConstant(V: true, DL, VT: N0VT, OpVT: N00VT),
14104 N3: DAG.getBoolConstant(V: false, DL, VT: N0VT, OpVT: N00VT),
14105 CC: cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get(), NotExtCompare: true))
14106 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: SCC);
14107 }
14108
14109 // (zext (shl (zext x), cst)) -> (shl (zext x), cst)
14110 if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) &&
14111 !TLI.isZExtFree(Val: N0, VT2: VT)) {
14112 SDValue ShVal = N0.getOperand(i: 0);
14113 SDValue ShAmt = N0.getOperand(i: 1);
14114 if (auto *ShAmtC = dyn_cast<ConstantSDNode>(Val&: ShAmt)) {
14115 if (ShVal.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse()) {
14116 if (N0.getOpcode() == ISD::SHL) {
14117 // If the original shl may be shifting out bits, do not perform this
14118 // transformation.
14119 unsigned KnownZeroBits = ShVal.getValueSizeInBits() -
14120 ShVal.getOperand(i: 0).getValueSizeInBits();
14121 if (ShAmtC->getAPIntValue().ugt(RHS: KnownZeroBits)) {
14122 // If the shift is too large, then see if we can deduce that the
14123 // shift is safe anyway.
14124 // Create a mask that has ones for the bits being shifted out.
14125 APInt ShiftOutMask =
14126 APInt::getHighBitsSet(numBits: ShVal.getValueSizeInBits(),
14127 hiBitsSet: ShAmtC->getAPIntValue().getZExtValue());
14128
14129 // Check if the bits being shifted out are known to be zero.
14130 if (!DAG.MaskedValueIsZero(Op: ShVal, Mask: ShiftOutMask))
14131 return SDValue();
14132 }
14133 }
14134
14135 // Ensure that the shift amount is wide enough for the shifted value.
14136 if (Log2_32_Ceil(Value: VT.getSizeInBits()) > ShAmt.getValueSizeInBits())
14137 ShAmt = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: MVT::i32, Operand: ShAmt);
14138
14139 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT,
14140 N1: DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: ShVal), N2: ShAmt);
14141 }
14142 }
14143 }
14144
14145 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(Cast: N))
14146 return NewVSel;
14147
14148 if (SDValue NewCtPop = widenCtPop(Extend: N, DAG, DL))
14149 return NewCtPop;
14150
14151 if (SDValue V = widenAbs(Extend: N, DAG))
14152 return V;
14153
14154 if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, DL, Level))
14155 return Res;
14156
14157 // CSE zext nneg with sext if the zext is not free.
14158 if (N->getFlags().hasNonNeg() && !TLI.isZExtFree(FromTy: N0.getValueType(), ToTy: VT)) {
14159 SDNode *CSENode = DAG.getNodeIfExists(Opcode: ISD::SIGN_EXTEND, VTList: N->getVTList(), Ops: N0);
14160 if (CSENode)
14161 return SDValue(CSENode, 0);
14162 }
14163
14164 return SDValue();
14165}
14166
14167SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
14168 SDValue N0 = N->getOperand(Num: 0);
14169 EVT VT = N->getValueType(ResNo: 0);
14170 SDLoc DL(N);
14171
14172 // aext(undef) = undef
14173 if (N0.isUndef())
14174 return DAG.getUNDEF(VT);
14175
14176 if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
14177 return Res;
14178
14179 // fold (aext (aext x)) -> (aext x)
14180 // fold (aext (zext x)) -> (zext x)
14181 // fold (aext (sext x)) -> (sext x)
14182 if (N0.getOpcode() == ISD::ANY_EXTEND || N0.getOpcode() == ISD::ZERO_EXTEND ||
14183 N0.getOpcode() == ISD::SIGN_EXTEND) {
14184 SDNodeFlags Flags;
14185 if (N0.getOpcode() == ISD::ZERO_EXTEND)
14186 Flags.setNonNeg(N0->getFlags().hasNonNeg());
14187 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, Operand: N0.getOperand(i: 0), Flags);
14188 }
14189
14190 // fold (aext (aext_extend_vector_inreg x)) -> (aext_extend_vector_inreg x)
14191 // fold (aext (zext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x)
14192 // fold (aext (sext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x)
14193 if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
14194 N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG ||
14195 N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG)
14196 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, Operand: N0.getOperand(i: 0));
14197
14198 // fold (aext (truncate (load x))) -> (aext (smaller load x))
14199 // fold (aext (truncate (srl (load x), c))) -> (aext (small load (x+c/n)))
14200 if (N0.getOpcode() == ISD::TRUNCATE) {
14201 if (SDValue NarrowLoad = reduceLoadWidth(N: N0.getNode())) {
14202 SDNode *oye = N0.getOperand(i: 0).getNode();
14203 if (NarrowLoad.getNode() != N0.getNode()) {
14204 CombineTo(N: N0.getNode(), Res: NarrowLoad);
14205 // CombineTo deleted the truncate, if needed, but not what's under it.
14206 AddToWorklist(N: oye);
14207 }
14208 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14209 }
14210 }
14211
14212 // fold (aext (truncate x))
14213 if (N0.getOpcode() == ISD::TRUNCATE)
14214 return DAG.getAnyExtOrTrunc(Op: N0.getOperand(i: 0), DL, VT);
14215
14216 // Fold (aext (and (trunc x), cst)) -> (and x, cst)
14217 // if the trunc is not free.
14218 if (N0.getOpcode() == ISD::AND &&
14219 N0.getOperand(i: 0).getOpcode() == ISD::TRUNCATE &&
14220 N0.getOperand(i: 1).getOpcode() == ISD::Constant &&
14221 !TLI.isTruncateFree(Val: N0.getOperand(i: 0).getOperand(i: 0), VT2: N0.getValueType())) {
14222 SDValue X = DAG.getAnyExtOrTrunc(Op: N0.getOperand(i: 0).getOperand(i: 0), DL, VT);
14223 SDValue Y = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT, Operand: N0.getOperand(i: 1));
14224 assert(isa<ConstantSDNode>(Y) && "Expected constant to be folded!");
14225 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: X, N2: Y);
14226 }
14227
14228 // fold (aext (load x)) -> (aext (truncate (extload x)))
14229 // None of the supported targets knows how to perform load and any_ext
14230 // on vectors in one instruction, so attempt to fold to zext instead.
14231 if (VT.isVector()) {
14232 // Try to simplify (zext (load x)).
14233 if (SDValue foldedExt =
14234 tryToFoldExtOfLoad(DAG, Combiner&: *this, TLI, VT, LegalOperations, N, N0,
14235 ExtLoadType: ISD::ZEXTLOAD, ExtOpc: ISD::ZERO_EXTEND))
14236 return foldedExt;
14237 } else if (ISD::isNON_EXTLoad(N: N0.getNode()) &&
14238 ISD::isUNINDEXEDLoad(N: N0.getNode()) &&
14239 TLI.isLoadExtLegal(ExtType: ISD::EXTLOAD, ValVT: VT, MemVT: N0.getValueType())) {
14240 bool DoXform = true;
14241 SmallVector<SDNode *, 4> SetCCs;
14242 if (!N0.hasOneUse())
14243 DoXform =
14244 ExtendUsesToFormExtLoad(VT, N, N0, ExtOpc: ISD::ANY_EXTEND, ExtendNodes&: SetCCs, TLI);
14245 if (DoXform) {
14246 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
14247 SDValue ExtLoad = DAG.getExtLoad(ExtType: ISD::EXTLOAD, dl: DL, VT, Chain: LN0->getChain(),
14248 Ptr: LN0->getBasePtr(), MemVT: N0.getValueType(),
14249 MMO: LN0->getMemOperand());
14250 ExtendSetCCUses(SetCCs, OrigLoad: N0, ExtLoad, ExtType: ISD::ANY_EXTEND);
14251 // If the load value is used only by N, replace it via CombineTo N.
14252 bool NoReplaceTrunc = N0.hasOneUse();
14253 CombineTo(N, Res: ExtLoad);
14254 if (NoReplaceTrunc) {
14255 DAG.ReplaceAllUsesOfValueWith(From: SDValue(LN0, 1), To: ExtLoad.getValue(R: 1));
14256 recursivelyDeleteUnusedNodes(N: LN0);
14257 } else {
14258 SDValue Trunc =
14259 DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(N0), VT: N0.getValueType(), Operand: ExtLoad);
14260 CombineTo(N: LN0, Res0: Trunc, Res1: ExtLoad.getValue(R: 1));
14261 }
14262 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14263 }
14264 }
14265
14266 // fold (aext (zextload x)) -> (aext (truncate (zextload x)))
14267 // fold (aext (sextload x)) -> (aext (truncate (sextload x)))
14268 // fold (aext ( extload x)) -> (aext (truncate (extload x)))
14269 if (N0.getOpcode() == ISD::LOAD && !ISD::isNON_EXTLoad(N: N0.getNode()) &&
14270 ISD::isUNINDEXEDLoad(N: N0.getNode()) && N0.hasOneUse()) {
14271 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
14272 ISD::LoadExtType ExtType = LN0->getExtensionType();
14273 EVT MemVT = LN0->getMemoryVT();
14274 if (!LegalOperations || TLI.isLoadExtLegal(ExtType, ValVT: VT, MemVT)) {
14275 SDValue ExtLoad =
14276 DAG.getExtLoad(ExtType, dl: DL, VT, Chain: LN0->getChain(), Ptr: LN0->getBasePtr(),
14277 MemVT, MMO: LN0->getMemOperand());
14278 CombineTo(N, Res: ExtLoad);
14279 DAG.ReplaceAllUsesOfValueWith(From: SDValue(LN0, 1), To: ExtLoad.getValue(R: 1));
14280 recursivelyDeleteUnusedNodes(N: LN0);
14281 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14282 }
14283 }
14284
14285 if (N0.getOpcode() == ISD::SETCC) {
14286 // Propagate fast-math-flags.
14287 SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
14288
14289 // For vectors:
14290 // aext(setcc) -> vsetcc
14291 // aext(setcc) -> truncate(vsetcc)
14292 // aext(setcc) -> aext(vsetcc)
14293 // Only do this before legalize for now.
14294 if (VT.isVector() && !LegalOperations) {
14295 EVT N00VT = N0.getOperand(i: 0).getValueType();
14296 if (getSetCCResultType(VT: N00VT) == N0.getValueType())
14297 return SDValue();
14298
14299 // We know that the # elements of the results is the same as the
14300 // # elements of the compare (and the # elements of the compare result
14301 // for that matter). Check to see that they are the same size. If so,
14302 // we know that the element size of the sext'd result matches the
14303 // element size of the compare operands.
14304 if (VT.getSizeInBits() == N00VT.getSizeInBits())
14305 return DAG.getSetCC(DL, VT, LHS: N0.getOperand(i: 0), RHS: N0.getOperand(i: 1),
14306 Cond: cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get());
14307
14308 // If the desired elements are smaller or larger than the source
14309 // elements we can use a matching integer vector type and then
14310 // truncate/any extend
14311 EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
14312 SDValue VsetCC = DAG.getSetCC(
14313 DL, VT: MatchingVectorType, LHS: N0.getOperand(i: 0), RHS: N0.getOperand(i: 1),
14314 Cond: cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get());
14315 return DAG.getAnyExtOrTrunc(Op: VsetCC, DL, VT);
14316 }
14317
14318 // aext(setcc x,y,cc) -> select_cc x, y, 1, 0, cc
14319 if (SDValue SCC = SimplifySelectCC(
14320 DL, N0: N0.getOperand(i: 0), N1: N0.getOperand(i: 1), N2: DAG.getConstant(Val: 1, DL, VT),
14321 N3: DAG.getConstant(Val: 0, DL, VT),
14322 CC: cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get(), NotExtCompare: true))
14323 return SCC;
14324 }
14325
14326 if (SDValue NewCtPop = widenCtPop(Extend: N, DAG, DL))
14327 return NewCtPop;
14328
14329 if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, DL, Level))
14330 return Res;
14331
14332 return SDValue();
14333}
14334
14335SDValue DAGCombiner::visitAssertExt(SDNode *N) {
14336 unsigned Opcode = N->getOpcode();
14337 SDValue N0 = N->getOperand(Num: 0);
14338 SDValue N1 = N->getOperand(Num: 1);
14339 EVT AssertVT = cast<VTSDNode>(Val&: N1)->getVT();
14340
14341 // fold (assert?ext (assert?ext x, vt), vt) -> (assert?ext x, vt)
14342 if (N0.getOpcode() == Opcode &&
14343 AssertVT == cast<VTSDNode>(Val: N0.getOperand(i: 1))->getVT())
14344 return N0;
14345
14346 if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
14347 N0.getOperand(i: 0).getOpcode() == Opcode) {
14348 // We have an assert, truncate, assert sandwich. Make one stronger assert
14349 // by asserting on the smallest asserted type to the larger source type.
14350 // This eliminates the later assert:
14351 // assert (trunc (assert X, i8) to iN), i1 --> trunc (assert X, i1) to iN
14352 // assert (trunc (assert X, i1) to iN), i8 --> trunc (assert X, i1) to iN
14353 SDLoc DL(N);
14354 SDValue BigA = N0.getOperand(i: 0);
14355 EVT BigA_AssertVT = cast<VTSDNode>(Val: BigA.getOperand(i: 1))->getVT();
14356 EVT MinAssertVT = AssertVT.bitsLT(VT: BigA_AssertVT) ? AssertVT : BigA_AssertVT;
14357 SDValue MinAssertVTVal = DAG.getValueType(MinAssertVT);
14358 SDValue NewAssert = DAG.getNode(Opcode, DL, VT: BigA.getValueType(),
14359 N1: BigA.getOperand(i: 0), N2: MinAssertVTVal);
14360 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: N->getValueType(ResNo: 0), Operand: NewAssert);
14361 }
14362
14363 // If we have (AssertZext (truncate (AssertSext X, iX)), iY) and Y is smaller
14364 // than X. Just move the AssertZext in front of the truncate and drop the
14365 // AssertSExt.
14366 if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
14367 N0.getOperand(i: 0).getOpcode() == ISD::AssertSext &&
14368 Opcode == ISD::AssertZext) {
14369 SDValue BigA = N0.getOperand(i: 0);
14370 EVT BigA_AssertVT = cast<VTSDNode>(Val: BigA.getOperand(i: 1))->getVT();
14371 if (AssertVT.bitsLT(VT: BigA_AssertVT)) {
14372 SDLoc DL(N);
14373 SDValue NewAssert = DAG.getNode(Opcode, DL, VT: BigA.getValueType(),
14374 N1: BigA.getOperand(i: 0), N2: N1);
14375 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: N->getValueType(ResNo: 0), Operand: NewAssert);
14376 }
14377 }
14378
14379 return SDValue();
14380}
14381
14382SDValue DAGCombiner::visitAssertAlign(SDNode *N) {
14383 SDLoc DL(N);
14384
14385 Align AL = cast<AssertAlignSDNode>(Val: N)->getAlign();
14386 SDValue N0 = N->getOperand(Num: 0);
14387
14388 // Fold (assertalign (assertalign x, AL0), AL1) ->
14389 // (assertalign x, max(AL0, AL1))
14390 if (auto *AAN = dyn_cast<AssertAlignSDNode>(Val&: N0))
14391 return DAG.getAssertAlign(DL, V: N0.getOperand(i: 0),
14392 A: std::max(a: AL, b: AAN->getAlign()));
14393
14394 // In rare cases, there are trivial arithmetic ops in source operands. Sink
14395 // this assert down to source operands so that those arithmetic ops could be
14396 // exposed to the DAG combining.
14397 switch (N0.getOpcode()) {
14398 default:
14399 break;
14400 case ISD::ADD:
14401 case ISD::SUB: {
14402 unsigned AlignShift = Log2(A: AL);
14403 SDValue LHS = N0.getOperand(i: 0);
14404 SDValue RHS = N0.getOperand(i: 1);
14405 unsigned LHSAlignShift = DAG.computeKnownBits(Op: LHS).countMinTrailingZeros();
14406 unsigned RHSAlignShift = DAG.computeKnownBits(Op: RHS).countMinTrailingZeros();
14407 if (LHSAlignShift >= AlignShift || RHSAlignShift >= AlignShift) {
14408 if (LHSAlignShift < AlignShift)
14409 LHS = DAG.getAssertAlign(DL, V: LHS, A: AL);
14410 if (RHSAlignShift < AlignShift)
14411 RHS = DAG.getAssertAlign(DL, V: RHS, A: AL);
14412 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT: N0.getValueType(), N1: LHS, N2: RHS);
14413 }
14414 break;
14415 }
14416 }
14417
14418 return SDValue();
14419}
14420
14421/// If the result of a load is shifted/masked/truncated to an effectively
14422/// narrower type, try to transform the load to a narrower type and/or
14423/// use an extending load.
14424SDValue DAGCombiner::reduceLoadWidth(SDNode *N) {
14425 unsigned Opc = N->getOpcode();
14426
14427 ISD::LoadExtType ExtType = ISD::NON_EXTLOAD;
14428 SDValue N0 = N->getOperand(Num: 0);
14429 EVT VT = N->getValueType(ResNo: 0);
14430 EVT ExtVT = VT;
14431
14432 // This transformation isn't valid for vector loads.
14433 if (VT.isVector())
14434 return SDValue();
14435
14436 // The ShAmt variable is used to indicate that we've consumed a right
14437 // shift. I.e. we want to narrow the width of the load by skipping to load the
14438 // ShAmt least significant bits.
14439 unsigned ShAmt = 0;
14440 // A special case is when the least significant bits from the load are masked
14441 // away, but using an AND rather than a right shift. HasShiftedOffset is used
14442 // to indicate that the narrowed load should be left-shifted ShAmt bits to get
14443 // the result.
14444 unsigned ShiftedOffset = 0;
14445 // Special case: SIGN_EXTEND_INREG is basically truncating to ExtVT then
14446 // extended to VT.
14447 if (Opc == ISD::SIGN_EXTEND_INREG) {
14448 ExtType = ISD::SEXTLOAD;
14449 ExtVT = cast<VTSDNode>(Val: N->getOperand(Num: 1))->getVT();
14450 } else if (Opc == ISD::SRL || Opc == ISD::SRA) {
14451 // Another special-case: SRL/SRA is basically zero/sign-extending a narrower
14452 // value, or it may be shifting a higher subword, half or byte into the
14453 // lowest bits.
14454
14455 // Only handle shift with constant shift amount, and the shiftee must be a
14456 // load.
14457 auto *LN = dyn_cast<LoadSDNode>(Val&: N0);
14458 auto *N1C = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
14459 if (!N1C || !LN)
14460 return SDValue();
14461 // If the shift amount is larger than the memory type then we're not
14462 // accessing any of the loaded bytes.
14463 ShAmt = N1C->getZExtValue();
14464 uint64_t MemoryWidth = LN->getMemoryVT().getScalarSizeInBits();
14465 if (MemoryWidth <= ShAmt)
14466 return SDValue();
14467 // Attempt to fold away the SRL by using ZEXTLOAD and SRA by using SEXTLOAD.
14468 ExtType = Opc == ISD::SRL ? ISD::ZEXTLOAD : ISD::SEXTLOAD;
14469 ExtVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: MemoryWidth - ShAmt);
14470 // If original load is a SEXTLOAD then we can't simply replace it by a
14471 // ZEXTLOAD (we could potentially replace it by a more narrow SEXTLOAD
14472 // followed by a ZEXT, but that is not handled at the moment). Similarly if
14473 // the original load is a ZEXTLOAD and we want to use a SEXTLOAD.
14474 if ((LN->getExtensionType() == ISD::SEXTLOAD ||
14475 LN->getExtensionType() == ISD::ZEXTLOAD) &&
14476 LN->getExtensionType() != ExtType)
14477 return SDValue();
14478 } else if (Opc == ISD::AND) {
14479 // An AND with a constant mask is the same as a truncate + zero-extend.
14480 auto AndC = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
14481 if (!AndC)
14482 return SDValue();
14483
14484 const APInt &Mask = AndC->getAPIntValue();
14485 unsigned ActiveBits = 0;
14486 if (Mask.isMask()) {
14487 ActiveBits = Mask.countr_one();
14488 } else if (Mask.isShiftedMask(MaskIdx&: ShAmt, MaskLen&: ActiveBits)) {
14489 ShiftedOffset = ShAmt;
14490 } else {
14491 return SDValue();
14492 }
14493
14494 ExtType = ISD::ZEXTLOAD;
14495 ExtVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: ActiveBits);
14496 }
14497
14498 // In case Opc==SRL we've already prepared ExtVT/ExtType/ShAmt based on doing
14499 // a right shift. Here we redo some of those checks, to possibly adjust the
14500 // ExtVT even further based on "a masking AND". We could also end up here for
14501 // other reasons (e.g. based on Opc==TRUNCATE) and that is why some checks
14502 // need to be done here as well.
14503 if (Opc == ISD::SRL || N0.getOpcode() == ISD::SRL) {
14504 SDValue SRL = Opc == ISD::SRL ? SDValue(N, 0) : N0;
14505 // Bail out when the SRL has more than one use. This is done for historical
14506 // (undocumented) reasons. Maybe intent was to guard the AND-masking below
14507 // check below? And maybe it could be non-profitable to do the transform in
14508 // case the SRL has multiple uses and we get here with Opc!=ISD::SRL?
14509 // FIXME: Can't we just skip this check for the Opc==ISD::SRL case.
14510 if (!SRL.hasOneUse())
14511 return SDValue();
14512
14513 // Only handle shift with constant shift amount, and the shiftee must be a
14514 // load.
14515 auto *LN = dyn_cast<LoadSDNode>(Val: SRL.getOperand(i: 0));
14516 auto *SRL1C = dyn_cast<ConstantSDNode>(Val: SRL.getOperand(i: 1));
14517 if (!SRL1C || !LN)
14518 return SDValue();
14519
14520 // If the shift amount is larger than the input type then we're not
14521 // accessing any of the loaded bytes. If the load was a zextload/extload
14522 // then the result of the shift+trunc is zero/undef (handled elsewhere).
14523 ShAmt = SRL1C->getZExtValue();
14524 uint64_t MemoryWidth = LN->getMemoryVT().getSizeInBits();
14525 if (ShAmt >= MemoryWidth)
14526 return SDValue();
14527
14528 // Because a SRL must be assumed to *need* to zero-extend the high bits
14529 // (as opposed to anyext the high bits), we can't combine the zextload
14530 // lowering of SRL and an sextload.
14531 if (LN->getExtensionType() == ISD::SEXTLOAD)
14532 return SDValue();
14533
14534 // Avoid reading outside the memory accessed by the original load (could
14535 // happened if we only adjust the load base pointer by ShAmt). Instead we
14536 // try to narrow the load even further. The typical scenario here is:
14537 // (i64 (truncate (i96 (srl (load x), 64)))) ->
14538 // (i64 (truncate (i96 (zextload (load i32 + offset) from i32))))
14539 if (ExtVT.getScalarSizeInBits() > MemoryWidth - ShAmt) {
14540 // Don't replace sextload by zextload.
14541 if (ExtType == ISD::SEXTLOAD)
14542 return SDValue();
14543 // Narrow the load.
14544 ExtType = ISD::ZEXTLOAD;
14545 ExtVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: MemoryWidth - ShAmt);
14546 }
14547
14548 // If the SRL is only used by a masking AND, we may be able to adjust
14549 // the ExtVT to make the AND redundant.
14550 SDNode *Mask = *(SRL->use_begin());
14551 if (SRL.hasOneUse() && Mask->getOpcode() == ISD::AND &&
14552 isa<ConstantSDNode>(Val: Mask->getOperand(Num: 1))) {
14553 unsigned Offset, ActiveBits;
14554 const APInt& ShiftMask = Mask->getConstantOperandAPInt(Num: 1);
14555 if (ShiftMask.isMask()) {
14556 EVT MaskedVT =
14557 EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: ShiftMask.countr_one());
14558 // If the mask is smaller, recompute the type.
14559 if ((ExtVT.getScalarSizeInBits() > MaskedVT.getScalarSizeInBits()) &&
14560 TLI.isLoadExtLegal(ExtType, ValVT: SRL.getValueType(), MemVT: MaskedVT))
14561 ExtVT = MaskedVT;
14562 } else if (ExtType == ISD::ZEXTLOAD &&
14563 ShiftMask.isShiftedMask(MaskIdx&: Offset, MaskLen&: ActiveBits) &&
14564 (Offset + ShAmt) < VT.getScalarSizeInBits()) {
14565 EVT MaskedVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: ActiveBits);
14566 // If the mask is shifted we can use a narrower load and a shl to insert
14567 // the trailing zeros.
14568 if (((Offset + ActiveBits) <= ExtVT.getScalarSizeInBits()) &&
14569 TLI.isLoadExtLegal(ExtType, ValVT: SRL.getValueType(), MemVT: MaskedVT)) {
14570 ExtVT = MaskedVT;
14571 ShAmt = Offset + ShAmt;
14572 ShiftedOffset = Offset;
14573 }
14574 }
14575 }
14576
14577 N0 = SRL.getOperand(i: 0);
14578 }
14579
14580 // If the load is shifted left (and the result isn't shifted back right), we
14581 // can fold a truncate through the shift. The typical scenario is that N
14582 // points at a TRUNCATE here so the attempted fold is:
14583 // (truncate (shl (load x), c))) -> (shl (narrow load x), c)
14584 // ShLeftAmt will indicate how much a narrowed load should be shifted left.
14585 unsigned ShLeftAmt = 0;
14586 if (ShAmt == 0 && N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
14587 ExtVT == VT && TLI.isNarrowingProfitable(SrcVT: N0.getValueType(), DestVT: VT)) {
14588 if (ConstantSDNode *N01 = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1))) {
14589 ShLeftAmt = N01->getZExtValue();
14590 N0 = N0.getOperand(i: 0);
14591 }
14592 }
14593
14594 // If we haven't found a load, we can't narrow it.
14595 if (!isa<LoadSDNode>(Val: N0))
14596 return SDValue();
14597
14598 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
14599 // Reducing the width of a volatile load is illegal. For atomics, we may be
14600 // able to reduce the width provided we never widen again. (see D66309)
14601 if (!LN0->isSimple() ||
14602 !isLegalNarrowLdSt(LDST: LN0, ExtType, MemVT&: ExtVT, ShAmt))
14603 return SDValue();
14604
14605 auto AdjustBigEndianShift = [&](unsigned ShAmt) {
14606 unsigned LVTStoreBits =
14607 LN0->getMemoryVT().getStoreSizeInBits().getFixedValue();
14608 unsigned EVTStoreBits = ExtVT.getStoreSizeInBits().getFixedValue();
14609 return LVTStoreBits - EVTStoreBits - ShAmt;
14610 };
14611
14612 // We need to adjust the pointer to the load by ShAmt bits in order to load
14613 // the correct bytes.
14614 unsigned PtrAdjustmentInBits =
14615 DAG.getDataLayout().isBigEndian() ? AdjustBigEndianShift(ShAmt) : ShAmt;
14616
14617 uint64_t PtrOff = PtrAdjustmentInBits / 8;
14618 SDLoc DL(LN0);
14619 // The original load itself didn't wrap, so an offset within it doesn't.
14620 SDNodeFlags Flags;
14621 Flags.setNoUnsignedWrap(true);
14622 SDValue NewPtr = DAG.getMemBasePlusOffset(
14623 Base: LN0->getBasePtr(), Offset: TypeSize::getFixed(ExactSize: PtrOff), DL, Flags);
14624 AddToWorklist(N: NewPtr.getNode());
14625
14626 SDValue Load;
14627 if (ExtType == ISD::NON_EXTLOAD)
14628 Load = DAG.getLoad(VT, dl: DL, Chain: LN0->getChain(), Ptr: NewPtr,
14629 PtrInfo: LN0->getPointerInfo().getWithOffset(O: PtrOff),
14630 Alignment: LN0->getOriginalAlign(),
14631 MMOFlags: LN0->getMemOperand()->getFlags(), AAInfo: LN0->getAAInfo());
14632 else
14633 Load = DAG.getExtLoad(ExtType, dl: DL, VT, Chain: LN0->getChain(), Ptr: NewPtr,
14634 PtrInfo: LN0->getPointerInfo().getWithOffset(O: PtrOff), MemVT: ExtVT,
14635 Alignment: LN0->getOriginalAlign(),
14636 MMOFlags: LN0->getMemOperand()->getFlags(), AAInfo: LN0->getAAInfo());
14637
14638 // Replace the old load's chain with the new load's chain.
14639 WorklistRemover DeadNodes(*this);
14640 DAG.ReplaceAllUsesOfValueWith(From: N0.getValue(R: 1), To: Load.getValue(R: 1));
14641
14642 // Shift the result left, if we've swallowed a left shift.
14643 SDValue Result = Load;
14644 if (ShLeftAmt != 0) {
14645 // If the shift amount is as large as the result size (but, presumably,
14646 // no larger than the source) then the useful bits of the result are
14647 // zero; we can't simply return the shortened shift, because the result
14648 // of that operation is undefined.
14649 if (ShLeftAmt >= VT.getScalarSizeInBits())
14650 Result = DAG.getConstant(Val: 0, DL, VT);
14651 else
14652 Result = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Result,
14653 N2: DAG.getShiftAmountConstant(Val: ShLeftAmt, VT, DL));
14654 }
14655
14656 if (ShiftedOffset != 0) {
14657 // We're using a shifted mask, so the load now has an offset. This means
14658 // that data has been loaded into the lower bytes than it would have been
14659 // before, so we need to shl the loaded data into the correct position in the
14660 // register.
14661 SDValue ShiftC = DAG.getConstant(Val: ShiftedOffset, DL, VT);
14662 Result = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Result, N2: ShiftC);
14663 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Result);
14664 }
14665
14666 // Return the new loaded value.
14667 return Result;
14668}
14669
14670SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
14671 SDValue N0 = N->getOperand(Num: 0);
14672 SDValue N1 = N->getOperand(Num: 1);
14673 EVT VT = N->getValueType(ResNo: 0);
14674 EVT ExtVT = cast<VTSDNode>(Val&: N1)->getVT();
14675 unsigned VTBits = VT.getScalarSizeInBits();
14676 unsigned ExtVTBits = ExtVT.getScalarSizeInBits();
14677
14678 // sext_vector_inreg(undef) = 0 because the top bit will all be the same.
14679 if (N0.isUndef())
14680 return DAG.getConstant(Val: 0, DL: SDLoc(N), VT);
14681
14682 // fold (sext_in_reg c1) -> c1
14683 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0))
14684 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL: SDLoc(N), VT, N1: N0, N2: N1);
14685
14686 // If the input is already sign extended, just drop the extension.
14687 if (ExtVTBits >= DAG.ComputeMaxSignificantBits(Op: N0))
14688 return N0;
14689
14690 // fold (sext_in_reg (sext_in_reg x, VT2), VT1) -> (sext_in_reg x, minVT) pt2
14691 if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
14692 ExtVT.bitsLT(VT: cast<VTSDNode>(Val: N0.getOperand(i: 1))->getVT()))
14693 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL: SDLoc(N), VT, N1: N0.getOperand(i: 0),
14694 N2: N1);
14695
14696 // fold (sext_in_reg (sext x)) -> (sext x)
14697 // fold (sext_in_reg (aext x)) -> (sext x)
14698 // if x is small enough or if we know that x has more than 1 sign bit and the
14699 // sign_extend_inreg is extending from one of them.
14700 if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) {
14701 SDValue N00 = N0.getOperand(i: 0);
14702 unsigned N00Bits = N00.getScalarValueSizeInBits();
14703 if ((N00Bits <= ExtVTBits ||
14704 DAG.ComputeMaxSignificantBits(Op: N00) <= ExtVTBits) &&
14705 (!LegalOperations || TLI.isOperationLegal(Op: ISD::SIGN_EXTEND, VT)))
14706 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL: SDLoc(N), VT, Operand: N00);
14707 }
14708
14709 // fold (sext_in_reg (*_extend_vector_inreg x)) -> (sext_vector_inreg x)
14710 // if x is small enough or if we know that x has more than 1 sign bit and the
14711 // sign_extend_inreg is extending from one of them.
14712 if (ISD::isExtVecInRegOpcode(Opcode: N0.getOpcode())) {
14713 SDValue N00 = N0.getOperand(i: 0);
14714 unsigned N00Bits = N00.getScalarValueSizeInBits();
14715 unsigned DstElts = N0.getValueType().getVectorMinNumElements();
14716 unsigned SrcElts = N00.getValueType().getVectorMinNumElements();
14717 bool IsZext = N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG;
14718 APInt DemandedSrcElts = APInt::getLowBitsSet(numBits: SrcElts, loBitsSet: DstElts);
14719 if ((N00Bits == ExtVTBits ||
14720 (!IsZext && (N00Bits < ExtVTBits ||
14721 DAG.ComputeMaxSignificantBits(Op: N00) <= ExtVTBits))) &&
14722 (!LegalOperations ||
14723 TLI.isOperationLegal(Op: ISD::SIGN_EXTEND_VECTOR_INREG, VT)))
14724 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_VECTOR_INREG, DL: SDLoc(N), VT, Operand: N00);
14725 }
14726
14727 // fold (sext_in_reg (zext x)) -> (sext x)
14728 // iff we are extending the source sign bit.
14729 if (N0.getOpcode() == ISD::ZERO_EXTEND) {
14730 SDValue N00 = N0.getOperand(i: 0);
14731 if (N00.getScalarValueSizeInBits() == ExtVTBits &&
14732 (!LegalOperations || TLI.isOperationLegal(Op: ISD::SIGN_EXTEND, VT)))
14733 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL: SDLoc(N), VT, Operand: N00);
14734 }
14735
14736 // fold (sext_in_reg x) -> (zext_in_reg x) if the sign bit is known zero.
14737 if (DAG.MaskedValueIsZero(Op: N0, Mask: APInt::getOneBitSet(numBits: VTBits, BitNo: ExtVTBits - 1)))
14738 return DAG.getZeroExtendInReg(Op: N0, DL: SDLoc(N), VT: ExtVT);
14739
14740 // fold operands of sext_in_reg based on knowledge that the top bits are not
14741 // demanded.
14742 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
14743 return SDValue(N, 0);
14744
14745 // fold (sext_in_reg (load x)) -> (smaller sextload x)
14746 // fold (sext_in_reg (srl (load x), c)) -> (smaller sextload (x+c/evtbits))
14747 if (SDValue NarrowLoad = reduceLoadWidth(N))
14748 return NarrowLoad;
14749
14750 // fold (sext_in_reg (srl X, 24), i8) -> (sra X, 24)
14751 // fold (sext_in_reg (srl X, 23), i8) -> (sra X, 23) iff possible.
14752 // We already fold "(sext_in_reg (srl X, 25), i8) -> srl X, 25" above.
14753 if (N0.getOpcode() == ISD::SRL) {
14754 if (auto *ShAmt = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1)))
14755 if (ShAmt->getAPIntValue().ule(RHS: VTBits - ExtVTBits)) {
14756 // We can turn this into an SRA iff the input to the SRL is already sign
14757 // extended enough.
14758 unsigned InSignBits = DAG.ComputeNumSignBits(Op: N0.getOperand(i: 0));
14759 if (((VTBits - ExtVTBits) - ShAmt->getZExtValue()) < InSignBits)
14760 return DAG.getNode(Opcode: ISD::SRA, DL: SDLoc(N), VT, N1: N0.getOperand(i: 0),
14761 N2: N0.getOperand(i: 1));
14762 }
14763 }
14764
14765 // fold (sext_inreg (extload x)) -> (sextload x)
14766 // If sextload is not supported by target, we can only do the combine when
14767 // load has one use. Doing otherwise can block folding the extload with other
14768 // extends that the target does support.
14769 if (ISD::isEXTLoad(N: N0.getNode()) &&
14770 ISD::isUNINDEXEDLoad(N: N0.getNode()) &&
14771 ExtVT == cast<LoadSDNode>(Val&: N0)->getMemoryVT() &&
14772 ((!LegalOperations && cast<LoadSDNode>(Val&: N0)->isSimple() &&
14773 N0.hasOneUse()) ||
14774 TLI.isLoadExtLegal(ExtType: ISD::SEXTLOAD, ValVT: VT, MemVT: ExtVT))) {
14775 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
14776 SDValue ExtLoad = DAG.getExtLoad(ExtType: ISD::SEXTLOAD, dl: SDLoc(N), VT,
14777 Chain: LN0->getChain(),
14778 Ptr: LN0->getBasePtr(), MemVT: ExtVT,
14779 MMO: LN0->getMemOperand());
14780 CombineTo(N, Res: ExtLoad);
14781 CombineTo(N: N0.getNode(), Res0: ExtLoad, Res1: ExtLoad.getValue(R: 1));
14782 AddToWorklist(N: ExtLoad.getNode());
14783 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14784 }
14785
14786 // fold (sext_inreg (zextload x)) -> (sextload x) iff load has one use
14787 if (ISD::isZEXTLoad(N: N0.getNode()) && ISD::isUNINDEXEDLoad(N: N0.getNode()) &&
14788 N0.hasOneUse() &&
14789 ExtVT == cast<LoadSDNode>(Val&: N0)->getMemoryVT() &&
14790 ((!LegalOperations && cast<LoadSDNode>(Val&: N0)->isSimple()) &&
14791 TLI.isLoadExtLegal(ExtType: ISD::SEXTLOAD, ValVT: VT, MemVT: ExtVT))) {
14792 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
14793 SDValue ExtLoad = DAG.getExtLoad(ExtType: ISD::SEXTLOAD, dl: SDLoc(N), VT,
14794 Chain: LN0->getChain(),
14795 Ptr: LN0->getBasePtr(), MemVT: ExtVT,
14796 MMO: LN0->getMemOperand());
14797 CombineTo(N, Res: ExtLoad);
14798 CombineTo(N: N0.getNode(), Res0: ExtLoad, Res1: ExtLoad.getValue(R: 1));
14799 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14800 }
14801
14802 // fold (sext_inreg (masked_load x)) -> (sext_masked_load x)
14803 // ignore it if the masked load is already sign extended
14804 if (MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(Val&: N0)) {
14805 if (ExtVT == Ld->getMemoryVT() && N0.hasOneUse() &&
14806 Ld->getExtensionType() != ISD::LoadExtType::NON_EXTLOAD &&
14807 TLI.isLoadExtLegal(ExtType: ISD::SEXTLOAD, ValVT: VT, MemVT: ExtVT)) {
14808 SDValue ExtMaskedLoad = DAG.getMaskedLoad(
14809 VT, dl: SDLoc(N), Chain: Ld->getChain(), Base: Ld->getBasePtr(), Offset: Ld->getOffset(),
14810 Mask: Ld->getMask(), Src0: Ld->getPassThru(), MemVT: ExtVT, MMO: Ld->getMemOperand(),
14811 AM: Ld->getAddressingMode(), ISD::SEXTLOAD, IsExpanding: Ld->isExpandingLoad());
14812 CombineTo(N, Res: ExtMaskedLoad);
14813 CombineTo(N: N0.getNode(), Res0: ExtMaskedLoad, Res1: ExtMaskedLoad.getValue(R: 1));
14814 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14815 }
14816 }
14817
14818 // fold (sext_inreg (masked_gather x)) -> (sext_masked_gather x)
14819 if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(Val&: N0)) {
14820 if (SDValue(GN0, 0).hasOneUse() &&
14821 ExtVT == GN0->getMemoryVT() &&
14822 TLI.isVectorLoadExtDesirable(ExtVal: SDValue(SDValue(GN0, 0)))) {
14823 SDValue Ops[] = {GN0->getChain(), GN0->getPassThru(), GN0->getMask(),
14824 GN0->getBasePtr(), GN0->getIndex(), GN0->getScale()};
14825
14826 SDValue ExtLoad = DAG.getMaskedGather(
14827 VTs: DAG.getVTList(VT1: VT, VT2: MVT::Other), MemVT: ExtVT, dl: SDLoc(N), Ops,
14828 MMO: GN0->getMemOperand(), IndexType: GN0->getIndexType(), ExtTy: ISD::SEXTLOAD);
14829
14830 CombineTo(N, Res: ExtLoad);
14831 CombineTo(N: N0.getNode(), Res0: ExtLoad, Res1: ExtLoad.getValue(R: 1));
14832 AddToWorklist(N: ExtLoad.getNode());
14833 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14834 }
14835 }
14836
14837 // Form (sext_inreg (bswap >> 16)) or (sext_inreg (rotl (bswap) 16))
14838 if (ExtVTBits <= 16 && N0.getOpcode() == ISD::OR) {
14839 if (SDValue BSwap = MatchBSwapHWordLow(N: N0.getNode(), N0: N0.getOperand(i: 0),
14840 N1: N0.getOperand(i: 1), DemandHighBits: false))
14841 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL: SDLoc(N), VT, N1: BSwap, N2: N1);
14842 }
14843
14844 // Fold (iM_signext_inreg
14845 // (extract_subvector (zext|anyext|sext iN_v to _) _)
14846 // from iN)
14847 // -> (extract_subvector (signext iN_v to iM))
14848 if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR && N0.hasOneUse() &&
14849 ISD::isExtOpcode(Opcode: N0.getOperand(i: 0).getOpcode())) {
14850 SDValue InnerExt = N0.getOperand(i: 0);
14851 EVT InnerExtVT = InnerExt->getValueType(ResNo: 0);
14852 SDValue Extendee = InnerExt->getOperand(Num: 0);
14853
14854 if (ExtVTBits == Extendee.getValueType().getScalarSizeInBits() &&
14855 (!LegalOperations ||
14856 TLI.isOperationLegal(Op: ISD::SIGN_EXTEND, VT: InnerExtVT))) {
14857 SDValue SignExtExtendee =
14858 DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL: SDLoc(N), VT: InnerExtVT, Operand: Extendee);
14859 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL: SDLoc(N), VT, N1: SignExtExtendee,
14860 N2: N0.getOperand(i: 1));
14861 }
14862 }
14863
14864 return SDValue();
14865}
14866
14867static SDValue foldExtendVectorInregToExtendOfSubvector(
14868 SDNode *N, const SDLoc &DL, const TargetLowering &TLI, SelectionDAG &DAG,
14869 bool LegalOperations) {
14870 unsigned InregOpcode = N->getOpcode();
14871 unsigned Opcode = DAG.getOpcode_EXTEND(Opcode: InregOpcode);
14872
14873 SDValue Src = N->getOperand(Num: 0);
14874 EVT VT = N->getValueType(ResNo: 0);
14875 EVT SrcVT = EVT::getVectorVT(Context&: *DAG.getContext(),
14876 VT: Src.getValueType().getVectorElementType(),
14877 EC: VT.getVectorElementCount());
14878
14879 assert(ISD::isExtVecInRegOpcode(InregOpcode) &&
14880 "Expected EXTEND_VECTOR_INREG dag node in input!");
14881
14882 // Profitability check: our operand must be an one-use CONCAT_VECTORS.
14883 // FIXME: one-use check may be overly restrictive
14884 if (!Src.hasOneUse() || Src.getOpcode() != ISD::CONCAT_VECTORS)
14885 return SDValue();
14886
14887 // Profitability check: we must be extending exactly one of it's operands.
14888 // FIXME: this is probably overly restrictive.
14889 Src = Src.getOperand(i: 0);
14890 if (Src.getValueType() != SrcVT)
14891 return SDValue();
14892
14893 if (LegalOperations && !TLI.isOperationLegal(Op: Opcode, VT))
14894 return SDValue();
14895
14896 return DAG.getNode(Opcode, DL, VT, Operand: Src);
14897}
14898
14899SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
14900 SDValue N0 = N->getOperand(Num: 0);
14901 EVT VT = N->getValueType(ResNo: 0);
14902 SDLoc DL(N);
14903
14904 if (N0.isUndef()) {
14905 // aext_vector_inreg(undef) = undef because the top bits are undefined.
14906 // {s/z}ext_vector_inreg(undef) = 0 because the top bits must be the same.
14907 return N->getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG
14908 ? DAG.getUNDEF(VT)
14909 : DAG.getConstant(Val: 0, DL, VT);
14910 }
14911
14912 if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
14913 return Res;
14914
14915 if (SimplifyDemandedVectorElts(Op: SDValue(N, 0)))
14916 return SDValue(N, 0);
14917
14918 if (SDValue R = foldExtendVectorInregToExtendOfSubvector(N, DL, TLI, DAG,
14919 LegalOperations))
14920 return R;
14921
14922 return SDValue();
14923}
14924
14925SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
14926 SDValue N0 = N->getOperand(Num: 0);
14927 EVT VT = N->getValueType(ResNo: 0);
14928 EVT SrcVT = N0.getValueType();
14929 bool isLE = DAG.getDataLayout().isLittleEndian();
14930 SDLoc DL(N);
14931
14932 // trunc(undef) = undef
14933 if (N0.isUndef())
14934 return DAG.getUNDEF(VT);
14935
14936 // fold (truncate (truncate x)) -> (truncate x)
14937 if (N0.getOpcode() == ISD::TRUNCATE)
14938 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N0.getOperand(i: 0));
14939
14940 // fold (truncate c1) -> c1
14941 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::TRUNCATE, DL, VT, Ops: {N0}))
14942 return C;
14943
14944 // fold (truncate (ext x)) -> (ext x) or (truncate x) or x
14945 if (N0.getOpcode() == ISD::ZERO_EXTEND ||
14946 N0.getOpcode() == ISD::SIGN_EXTEND ||
14947 N0.getOpcode() == ISD::ANY_EXTEND) {
14948 // if the source is smaller than the dest, we still need an extend.
14949 if (N0.getOperand(i: 0).getValueType().bitsLT(VT))
14950 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, Operand: N0.getOperand(i: 0));
14951 // if the source is larger than the dest, than we just need the truncate.
14952 if (N0.getOperand(i: 0).getValueType().bitsGT(VT))
14953 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N0.getOperand(i: 0));
14954 // if the source and dest are the same type, we can drop both the extend
14955 // and the truncate.
14956 return N0.getOperand(i: 0);
14957 }
14958
14959 // Try to narrow a truncate-of-sext_in_reg to the destination type:
14960 // trunc (sign_ext_inreg X, iM) to iN --> sign_ext_inreg (trunc X to iN), iM
14961 if (!LegalTypes && N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
14962 N0.hasOneUse()) {
14963 SDValue X = N0.getOperand(i: 0);
14964 SDValue ExtVal = N0.getOperand(i: 1);
14965 EVT ExtVT = cast<VTSDNode>(Val&: ExtVal)->getVT();
14966 if (ExtVT.bitsLT(VT) && TLI.preferSextInRegOfTruncate(TruncVT: VT, VT: SrcVT, ExtVT)) {
14967 SDValue TrX = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: X);
14968 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL, VT, N1: TrX, N2: ExtVal);
14969 }
14970 }
14971
14972 // If this is anyext(trunc), don't fold it, allow ourselves to be folded.
14973 if (N->hasOneUse() && (N->use_begin()->getOpcode() == ISD::ANY_EXTEND))
14974 return SDValue();
14975
14976 // Fold extract-and-trunc into a narrow extract. For example:
14977 // i64 x = EXTRACT_VECTOR_ELT(v2i64 val, i32 1)
14978 // i32 y = TRUNCATE(i64 x)
14979 // -- becomes --
14980 // v16i8 b = BITCAST (v2i64 val)
14981 // i8 x = EXTRACT_VECTOR_ELT(v16i8 b, i32 8)
14982 //
14983 // Note: We only run this optimization after type legalization (which often
14984 // creates this pattern) and before operation legalization after which
14985 // we need to be more careful about the vector instructions that we generate.
14986 if (N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
14987 LegalTypes && !LegalOperations && N0->hasOneUse() && VT != MVT::i1) {
14988 EVT VecTy = N0.getOperand(i: 0).getValueType();
14989 EVT ExTy = N0.getValueType();
14990 EVT TrTy = N->getValueType(ResNo: 0);
14991
14992 auto EltCnt = VecTy.getVectorElementCount();
14993 unsigned SizeRatio = ExTy.getSizeInBits()/TrTy.getSizeInBits();
14994 auto NewEltCnt = EltCnt * SizeRatio;
14995
14996 EVT NVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: TrTy, EC: NewEltCnt);
14997 assert(NVT.getSizeInBits() == VecTy.getSizeInBits() && "Invalid Size");
14998
14999 SDValue EltNo = N0->getOperand(Num: 1);
15000 if (isa<ConstantSDNode>(Val: EltNo) && isTypeLegal(VT: NVT)) {
15001 int Elt = EltNo->getAsZExtVal();
15002 int Index = isLE ? (Elt*SizeRatio) : (Elt*SizeRatio + (SizeRatio-1));
15003 return DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: TrTy,
15004 N1: DAG.getBitcast(VT: NVT, V: N0.getOperand(i: 0)),
15005 N2: DAG.getVectorIdxConstant(Val: Index, DL));
15006 }
15007 }
15008
15009 // trunc (select c, a, b) -> select c, (trunc a), (trunc b)
15010 if (N0.getOpcode() == ISD::SELECT && N0.hasOneUse()) {
15011 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::SELECT, VT: SrcVT)) &&
15012 TLI.isTruncateFree(FromVT: SrcVT, ToVT: VT)) {
15013 SDLoc SL(N0);
15014 SDValue Cond = N0.getOperand(i: 0);
15015 SDValue TruncOp0 = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SL, VT, Operand: N0.getOperand(i: 1));
15016 SDValue TruncOp1 = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SL, VT, Operand: N0.getOperand(i: 2));
15017 return DAG.getNode(Opcode: ISD::SELECT, DL, VT, N1: Cond, N2: TruncOp0, N3: TruncOp1);
15018 }
15019 }
15020
15021 // trunc (shl x, K) -> shl (trunc x), K => K < VT.getScalarSizeInBits()
15022 if (N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
15023 (!LegalOperations || TLI.isOperationLegal(Op: ISD::SHL, VT)) &&
15024 TLI.isTypeDesirableForOp(ISD::SHL, VT)) {
15025 SDValue Amt = N0.getOperand(i: 1);
15026 KnownBits Known = DAG.computeKnownBits(Op: Amt);
15027 unsigned Size = VT.getScalarSizeInBits();
15028 if (Known.countMaxActiveBits() <= Log2_32(Value: Size)) {
15029 EVT AmtVT = TLI.getShiftAmountTy(LHSTy: VT, DL: DAG.getDataLayout());
15030 SDValue Trunc = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N0.getOperand(i: 0));
15031 if (AmtVT != Amt.getValueType()) {
15032 Amt = DAG.getZExtOrTrunc(Op: Amt, DL, VT: AmtVT);
15033 AddToWorklist(N: Amt.getNode());
15034 }
15035 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Trunc, N2: Amt);
15036 }
15037 }
15038
15039 if (SDValue V = foldSubToUSubSat(DstVT: VT, N: N0.getNode(), DL))
15040 return V;
15041
15042 if (SDValue ABD = foldABSToABD(N, DL))
15043 return ABD;
15044
15045 // Attempt to pre-truncate BUILD_VECTOR sources.
15046 if (N0.getOpcode() == ISD::BUILD_VECTOR && !LegalOperations &&
15047 N0.hasOneUse() &&
15048 TLI.isTruncateFree(FromVT: SrcVT.getScalarType(), ToVT: VT.getScalarType()) &&
15049 // Avoid creating illegal types if running after type legalizer.
15050 (!LegalTypes || TLI.isTypeLegal(VT: VT.getScalarType()))) {
15051 EVT SVT = VT.getScalarType();
15052 SmallVector<SDValue, 8> TruncOps;
15053 for (const SDValue &Op : N0->op_values()) {
15054 SDValue TruncOp = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: SVT, Operand: Op);
15055 TruncOps.push_back(Elt: TruncOp);
15056 }
15057 return DAG.getBuildVector(VT, DL, Ops: TruncOps);
15058 }
15059
15060 // trunc (splat_vector x) -> splat_vector (trunc x)
15061 if (N0.getOpcode() == ISD::SPLAT_VECTOR &&
15062 (!LegalTypes || TLI.isTypeLegal(VT: VT.getScalarType())) &&
15063 (!LegalOperations || TLI.isOperationLegal(Op: ISD::SPLAT_VECTOR, VT))) {
15064 EVT SVT = VT.getScalarType();
15065 return DAG.getSplatVector(
15066 VT, DL, Op: DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: SVT, Operand: N0->getOperand(Num: 0)));
15067 }
15068
15069 // Fold a series of buildvector, bitcast, and truncate if possible.
15070 // For example fold
15071 // (2xi32 trunc (bitcast ((4xi32)buildvector x, x, y, y) 2xi64)) to
15072 // (2xi32 (buildvector x, y)).
15073 if (Level == AfterLegalizeVectorOps && VT.isVector() &&
15074 N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
15075 N0.getOperand(i: 0).getOpcode() == ISD::BUILD_VECTOR &&
15076 N0.getOperand(i: 0).hasOneUse()) {
15077 SDValue BuildVect = N0.getOperand(i: 0);
15078 EVT BuildVectEltTy = BuildVect.getValueType().getVectorElementType();
15079 EVT TruncVecEltTy = VT.getVectorElementType();
15080
15081 // Check that the element types match.
15082 if (BuildVectEltTy == TruncVecEltTy) {
15083 // Now we only need to compute the offset of the truncated elements.
15084 unsigned BuildVecNumElts = BuildVect.getNumOperands();
15085 unsigned TruncVecNumElts = VT.getVectorNumElements();
15086 unsigned TruncEltOffset = BuildVecNumElts / TruncVecNumElts;
15087
15088 assert((BuildVecNumElts % TruncVecNumElts) == 0 &&
15089 "Invalid number of elements");
15090
15091 SmallVector<SDValue, 8> Opnds;
15092 for (unsigned i = 0, e = BuildVecNumElts; i != e; i += TruncEltOffset)
15093 Opnds.push_back(Elt: BuildVect.getOperand(i));
15094
15095 return DAG.getBuildVector(VT, DL, Ops: Opnds);
15096 }
15097 }
15098
15099 // fold (truncate (load x)) -> (smaller load x)
15100 // fold (truncate (srl (load x), c)) -> (smaller load (x+c/evtbits))
15101 if (!LegalTypes || TLI.isTypeDesirableForOp(N0.getOpcode(), VT)) {
15102 if (SDValue Reduced = reduceLoadWidth(N))
15103 return Reduced;
15104
15105 // Handle the case where the truncated result is at least as wide as the
15106 // loaded type.
15107 if (N0.hasOneUse() && ISD::isUNINDEXEDLoad(N: N0.getNode())) {
15108 auto *LN0 = cast<LoadSDNode>(Val&: N0);
15109 if (LN0->isSimple() && LN0->getMemoryVT().bitsLE(VT)) {
15110 SDValue NewLoad = DAG.getExtLoad(
15111 ExtType: LN0->getExtensionType(), dl: SDLoc(LN0), VT, Chain: LN0->getChain(),
15112 Ptr: LN0->getBasePtr(), MemVT: LN0->getMemoryVT(), MMO: LN0->getMemOperand());
15113 DAG.ReplaceAllUsesOfValueWith(From: N0.getValue(R: 1), To: NewLoad.getValue(R: 1));
15114 return NewLoad;
15115 }
15116 }
15117 }
15118
15119 // fold (trunc (concat ... x ...)) -> (concat ..., (trunc x), ...)),
15120 // where ... are all 'undef'.
15121 if (N0.getOpcode() == ISD::CONCAT_VECTORS && !LegalTypes) {
15122 SmallVector<EVT, 8> VTs;
15123 SDValue V;
15124 unsigned Idx = 0;
15125 unsigned NumDefs = 0;
15126
15127 for (unsigned i = 0, e = N0.getNumOperands(); i != e; ++i) {
15128 SDValue X = N0.getOperand(i);
15129 if (!X.isUndef()) {
15130 V = X;
15131 Idx = i;
15132 NumDefs++;
15133 }
15134 // Stop if more than one members are non-undef.
15135 if (NumDefs > 1)
15136 break;
15137
15138 VTs.push_back(Elt: EVT::getVectorVT(Context&: *DAG.getContext(),
15139 VT: VT.getVectorElementType(),
15140 EC: X.getValueType().getVectorElementCount()));
15141 }
15142
15143 if (NumDefs == 0)
15144 return DAG.getUNDEF(VT);
15145
15146 if (NumDefs == 1) {
15147 assert(V.getNode() && "The single defined operand is empty!");
15148 SmallVector<SDValue, 8> Opnds;
15149 for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
15150 if (i != Idx) {
15151 Opnds.push_back(Elt: DAG.getUNDEF(VT: VTs[i]));
15152 continue;
15153 }
15154 SDValue NV = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(V), VT: VTs[i], Operand: V);
15155 AddToWorklist(N: NV.getNode());
15156 Opnds.push_back(Elt: NV);
15157 }
15158 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT, Ops: Opnds);
15159 }
15160 }
15161
15162 // Fold truncate of a bitcast of a vector to an extract of the low vector
15163 // element.
15164 //
15165 // e.g. trunc (i64 (bitcast v2i32:x)) -> extract_vector_elt v2i32:x, idx
15166 if (N0.getOpcode() == ISD::BITCAST && !VT.isVector()) {
15167 SDValue VecSrc = N0.getOperand(i: 0);
15168 EVT VecSrcVT = VecSrc.getValueType();
15169 if (VecSrcVT.isVector() && VecSrcVT.getScalarType() == VT &&
15170 (!LegalOperations ||
15171 TLI.isOperationLegal(Op: ISD::EXTRACT_VECTOR_ELT, VT: VecSrcVT))) {
15172 unsigned Idx = isLE ? 0 : VecSrcVT.getVectorNumElements() - 1;
15173 return DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT, N1: VecSrc,
15174 N2: DAG.getVectorIdxConstant(Val: Idx, DL));
15175 }
15176 }
15177
15178 // Simplify the operands using demanded-bits information.
15179 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
15180 return SDValue(N, 0);
15181
15182 // fold (truncate (extract_subvector(ext x))) ->
15183 // (extract_subvector x)
15184 // TODO: This can be generalized to cover cases where the truncate and extract
15185 // do not fully cancel each other out.
15186 if (!LegalTypes && N0.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
15187 SDValue N00 = N0.getOperand(i: 0);
15188 if (N00.getOpcode() == ISD::SIGN_EXTEND ||
15189 N00.getOpcode() == ISD::ZERO_EXTEND ||
15190 N00.getOpcode() == ISD::ANY_EXTEND) {
15191 if (N00.getOperand(i: 0)->getValueType(ResNo: 0).getVectorElementType() ==
15192 VT.getVectorElementType())
15193 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL: SDLoc(N0->getOperand(Num: 0)), VT,
15194 N1: N00.getOperand(i: 0), N2: N0.getOperand(i: 1));
15195 }
15196 }
15197
15198 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(Cast: N))
15199 return NewVSel;
15200
15201 // Narrow a suitable binary operation with a non-opaque constant operand by
15202 // moving it ahead of the truncate. This is limited to pre-legalization
15203 // because targets may prefer a wider type during later combines and invert
15204 // this transform.
15205 switch (N0.getOpcode()) {
15206 case ISD::ADD:
15207 case ISD::SUB:
15208 case ISD::MUL:
15209 case ISD::AND:
15210 case ISD::OR:
15211 case ISD::XOR:
15212 if (!LegalOperations && N0.hasOneUse() &&
15213 (isConstantOrConstantVector(N: N0.getOperand(i: 0), NoOpaques: true) ||
15214 isConstantOrConstantVector(N: N0.getOperand(i: 1), NoOpaques: true))) {
15215 // TODO: We already restricted this to pre-legalization, but for vectors
15216 // we are extra cautious to not create an unsupported operation.
15217 // Target-specific changes are likely needed to avoid regressions here.
15218 if (VT.isScalarInteger() || TLI.isOperationLegal(Op: N0.getOpcode(), VT)) {
15219 SDValue NarrowL = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N0.getOperand(i: 0));
15220 SDValue NarrowR = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N0.getOperand(i: 1));
15221 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, N1: NarrowL, N2: NarrowR);
15222 }
15223 }
15224 break;
15225 case ISD::ADDE:
15226 case ISD::UADDO_CARRY:
15227 // (trunc adde(X, Y, Carry)) -> (adde trunc(X), trunc(Y), Carry)
15228 // (trunc uaddo_carry(X, Y, Carry)) ->
15229 // (uaddo_carry trunc(X), trunc(Y), Carry)
15230 // When the adde's carry is not used.
15231 // We only do for uaddo_carry before legalize operation
15232 if (((!LegalOperations && N0.getOpcode() == ISD::UADDO_CARRY) ||
15233 TLI.isOperationLegal(Op: N0.getOpcode(), VT)) &&
15234 N0.hasOneUse() && !N0->hasAnyUseOfValue(Value: 1)) {
15235 SDValue X = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N0.getOperand(i: 0));
15236 SDValue Y = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N0.getOperand(i: 1));
15237 SDVTList VTs = DAG.getVTList(VT1: VT, VT2: N0->getValueType(ResNo: 1));
15238 return DAG.getNode(Opcode: N0.getOpcode(), DL, VTList: VTs, N1: X, N2: Y, N3: N0.getOperand(i: 2));
15239 }
15240 break;
15241 case ISD::USUBSAT:
15242 // Truncate the USUBSAT only if LHS is a known zero-extension, its not
15243 // enough to know that the upper bits are zero we must ensure that we don't
15244 // introduce an extra truncate.
15245 if (!LegalOperations && N0.hasOneUse() &&
15246 N0.getOperand(i: 0).getOpcode() == ISD::ZERO_EXTEND &&
15247 N0.getOperand(i: 0).getOperand(i: 0).getScalarValueSizeInBits() <=
15248 VT.getScalarSizeInBits() &&
15249 hasOperation(Opcode: N0.getOpcode(), VT)) {
15250 return getTruncatedUSUBSAT(DstVT: VT, SrcVT, LHS: N0.getOperand(i: 0), RHS: N0.getOperand(i: 1),
15251 DAG, DL);
15252 }
15253 break;
15254 }
15255
15256 return SDValue();
15257}
15258
15259static SDNode *getBuildPairElt(SDNode *N, unsigned i) {
15260 SDValue Elt = N->getOperand(Num: i);
15261 if (Elt.getOpcode() != ISD::MERGE_VALUES)
15262 return Elt.getNode();
15263 return Elt.getOperand(i: Elt.getResNo()).getNode();
15264}
15265
15266/// build_pair (load, load) -> load
15267/// if load locations are consecutive.
15268SDValue DAGCombiner::CombineConsecutiveLoads(SDNode *N, EVT VT) {
15269 assert(N->getOpcode() == ISD::BUILD_PAIR);
15270
15271 auto *LD1 = dyn_cast<LoadSDNode>(Val: getBuildPairElt(N, i: 0));
15272 auto *LD2 = dyn_cast<LoadSDNode>(Val: getBuildPairElt(N, i: 1));
15273
15274 // A BUILD_PAIR is always having the least significant part in elt 0 and the
15275 // most significant part in elt 1. So when combining into one large load, we
15276 // need to consider the endianness.
15277 if (DAG.getDataLayout().isBigEndian())
15278 std::swap(a&: LD1, b&: LD2);
15279
15280 if (!LD1 || !LD2 || !ISD::isNON_EXTLoad(N: LD1) || !ISD::isNON_EXTLoad(N: LD2) ||
15281 !LD1->hasOneUse() || !LD2->hasOneUse() ||
15282 LD1->getAddressSpace() != LD2->getAddressSpace())
15283 return SDValue();
15284
15285 unsigned LD1Fast = 0;
15286 EVT LD1VT = LD1->getValueType(ResNo: 0);
15287 unsigned LD1Bytes = LD1VT.getStoreSize();
15288 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::LOAD, VT)) &&
15289 DAG.areNonVolatileConsecutiveLoads(LD: LD2, Base: LD1, Bytes: LD1Bytes, Dist: 1) &&
15290 TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT,
15291 MMO: *LD1->getMemOperand(), Fast: &LD1Fast) && LD1Fast)
15292 return DAG.getLoad(VT, dl: SDLoc(N), Chain: LD1->getChain(), Ptr: LD1->getBasePtr(),
15293 PtrInfo: LD1->getPointerInfo(), Alignment: LD1->getAlign());
15294
15295 return SDValue();
15296}
15297
15298static unsigned getPPCf128HiElementSelector(const SelectionDAG &DAG) {
15299 // On little-endian machines, bitcasting from ppcf128 to i128 does swap the Hi
15300 // and Lo parts; on big-endian machines it doesn't.
15301 return DAG.getDataLayout().isBigEndian() ? 1 : 0;
15302}
15303
15304SDValue DAGCombiner::foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
15305 const TargetLowering &TLI) {
15306 // If this is not a bitcast to an FP type or if the target doesn't have
15307 // IEEE754-compliant FP logic, we're done.
15308 EVT VT = N->getValueType(ResNo: 0);
15309 SDValue N0 = N->getOperand(Num: 0);
15310 EVT SourceVT = N0.getValueType();
15311
15312 if (!VT.isFloatingPoint())
15313 return SDValue();
15314
15315 // TODO: Handle cases where the integer constant is a different scalar
15316 // bitwidth to the FP.
15317 if (VT.getScalarSizeInBits() != SourceVT.getScalarSizeInBits())
15318 return SDValue();
15319
15320 unsigned FPOpcode;
15321 APInt SignMask;
15322 switch (N0.getOpcode()) {
15323 case ISD::AND:
15324 FPOpcode = ISD::FABS;
15325 SignMask = ~APInt::getSignMask(BitWidth: SourceVT.getScalarSizeInBits());
15326 break;
15327 case ISD::XOR:
15328 FPOpcode = ISD::FNEG;
15329 SignMask = APInt::getSignMask(BitWidth: SourceVT.getScalarSizeInBits());
15330 break;
15331 case ISD::OR:
15332 FPOpcode = ISD::FABS;
15333 SignMask = APInt::getSignMask(BitWidth: SourceVT.getScalarSizeInBits());
15334 break;
15335 default:
15336 return SDValue();
15337 }
15338
15339 if (LegalOperations && !TLI.isOperationLegal(Op: FPOpcode, VT))
15340 return SDValue();
15341
15342 // This needs to be the inverse of logic in foldSignChangeInBitcast.
15343 // FIXME: I don't think looking for bitcast intrinsically makes sense, but
15344 // removing this would require more changes.
15345 auto IsBitCastOrFree = [&TLI, FPOpcode](SDValue Op, EVT VT) {
15346 if (Op.getOpcode() == ISD::BITCAST && Op.getOperand(i: 0).getValueType() == VT)
15347 return true;
15348
15349 return FPOpcode == ISD::FABS ? TLI.isFAbsFree(VT) : TLI.isFNegFree(VT);
15350 };
15351
15352 // Fold (bitcast int (and (bitcast fp X to int), 0x7fff...) to fp) -> fabs X
15353 // Fold (bitcast int (xor (bitcast fp X to int), 0x8000...) to fp) -> fneg X
15354 // Fold (bitcast int (or (bitcast fp X to int), 0x8000...) to fp) ->
15355 // fneg (fabs X)
15356 SDValue LogicOp0 = N0.getOperand(i: 0);
15357 ConstantSDNode *LogicOp1 = isConstOrConstSplat(N: N0.getOperand(i: 1), AllowUndefs: true);
15358 if (LogicOp1 && LogicOp1->getAPIntValue() == SignMask &&
15359 IsBitCastOrFree(LogicOp0, VT)) {
15360 SDValue CastOp0 = DAG.getNode(Opcode: ISD::BITCAST, DL: SDLoc(N), VT, Operand: LogicOp0);
15361 SDValue FPOp = DAG.getNode(Opcode: FPOpcode, DL: SDLoc(N), VT, Operand: CastOp0);
15362 NumFPLogicOpsConv++;
15363 if (N0.getOpcode() == ISD::OR)
15364 return DAG.getNode(Opcode: ISD::FNEG, DL: SDLoc(N), VT, Operand: FPOp);
15365 return FPOp;
15366 }
15367
15368 return SDValue();
15369}
15370
15371SDValue DAGCombiner::visitBITCAST(SDNode *N) {
15372 SDValue N0 = N->getOperand(Num: 0);
15373 EVT VT = N->getValueType(ResNo: 0);
15374
15375 if (N0.isUndef())
15376 return DAG.getUNDEF(VT);
15377
15378 // If the input is a BUILD_VECTOR with all constant elements, fold this now.
15379 // Only do this before legalize types, unless both types are integer and the
15380 // scalar type is legal. Only do this before legalize ops, since the target
15381 // maybe depending on the bitcast.
15382 // First check to see if this is all constant.
15383 // TODO: Support FP bitcasts after legalize types.
15384 if (VT.isVector() &&
15385 (!LegalTypes ||
15386 (!LegalOperations && VT.isInteger() && N0.getValueType().isInteger() &&
15387 TLI.isTypeLegal(VT: VT.getVectorElementType()))) &&
15388 N0.getOpcode() == ISD::BUILD_VECTOR && N0->hasOneUse() &&
15389 cast<BuildVectorSDNode>(Val&: N0)->isConstant())
15390 return ConstantFoldBITCASTofBUILD_VECTOR(N0.getNode(),
15391 VT.getVectorElementType());
15392
15393 // If the input is a constant, let getNode fold it.
15394 if (isIntOrFPConstant(V: N0)) {
15395 // If we can't allow illegal operations, we need to check that this is just
15396 // a fp -> int or int -> conversion and that the resulting operation will
15397 // be legal.
15398 if (!LegalOperations ||
15399 (isa<ConstantSDNode>(Val: N0) && VT.isFloatingPoint() && !VT.isVector() &&
15400 TLI.isOperationLegal(Op: ISD::ConstantFP, VT)) ||
15401 (isa<ConstantFPSDNode>(Val: N0) && VT.isInteger() && !VT.isVector() &&
15402 TLI.isOperationLegal(Op: ISD::Constant, VT))) {
15403 SDValue C = DAG.getBitcast(VT, V: N0);
15404 if (C.getNode() != N)
15405 return C;
15406 }
15407 }
15408
15409 // (conv (conv x, t1), t2) -> (conv x, t2)
15410 if (N0.getOpcode() == ISD::BITCAST)
15411 return DAG.getBitcast(VT, V: N0.getOperand(i: 0));
15412
15413 // fold (conv (logicop (conv x), (c))) -> (logicop x, (conv c))
15414 // iff the current bitwise logicop type isn't legal
15415 if (ISD::isBitwiseLogicOp(Opcode: N0.getOpcode()) && VT.isInteger() &&
15416 !TLI.isTypeLegal(VT: N0.getOperand(i: 0).getValueType())) {
15417 auto IsFreeBitcast = [VT](SDValue V) {
15418 return (V.getOpcode() == ISD::BITCAST &&
15419 V.getOperand(i: 0).getValueType() == VT) ||
15420 (ISD::isBuildVectorOfConstantSDNodes(N: V.getNode()) &&
15421 V->hasOneUse());
15422 };
15423 if (IsFreeBitcast(N0.getOperand(i: 0)) && IsFreeBitcast(N0.getOperand(i: 1)))
15424 return DAG.getNode(Opcode: N0.getOpcode(), DL: SDLoc(N), VT,
15425 N1: DAG.getBitcast(VT, V: N0.getOperand(i: 0)),
15426 N2: DAG.getBitcast(VT, V: N0.getOperand(i: 1)));
15427 }
15428
15429 // fold (conv (load x)) -> (load (conv*)x)
15430 // If the resultant load doesn't need a higher alignment than the original!
15431 if (ISD::isNormalLoad(N: N0.getNode()) && N0.hasOneUse() &&
15432 // Do not remove the cast if the types differ in endian layout.
15433 TLI.hasBigEndianPartOrdering(VT: N0.getValueType(), DL: DAG.getDataLayout()) ==
15434 TLI.hasBigEndianPartOrdering(VT, DL: DAG.getDataLayout()) &&
15435 // If the load is volatile, we only want to change the load type if the
15436 // resulting load is legal. Otherwise we might increase the number of
15437 // memory accesses. We don't care if the original type was legal or not
15438 // as we assume software couldn't rely on the number of accesses of an
15439 // illegal type.
15440 ((!LegalOperations && cast<LoadSDNode>(Val&: N0)->isSimple()) ||
15441 TLI.isOperationLegal(Op: ISD::LOAD, VT))) {
15442 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
15443
15444 if (TLI.isLoadBitCastBeneficial(LoadVT: N0.getValueType(), BitcastVT: VT, DAG,
15445 MMO: *LN0->getMemOperand())) {
15446 SDValue Load =
15447 DAG.getLoad(VT, dl: SDLoc(N), Chain: LN0->getChain(), Ptr: LN0->getBasePtr(),
15448 MMO: LN0->getMemOperand());
15449 DAG.ReplaceAllUsesOfValueWith(From: N0.getValue(R: 1), To: Load.getValue(R: 1));
15450 return Load;
15451 }
15452 }
15453
15454 if (SDValue V = foldBitcastedFPLogic(N, DAG, TLI))
15455 return V;
15456
15457 // fold (bitconvert (fneg x)) -> (xor (bitconvert x), signbit)
15458 // fold (bitconvert (fabs x)) -> (and (bitconvert x), (not signbit))
15459 //
15460 // For ppc_fp128:
15461 // fold (bitcast (fneg x)) ->
15462 // flipbit = signbit
15463 // (xor (bitcast x) (build_pair flipbit, flipbit))
15464 //
15465 // fold (bitcast (fabs x)) ->
15466 // flipbit = (and (extract_element (bitcast x), 0), signbit)
15467 // (xor (bitcast x) (build_pair flipbit, flipbit))
15468 // This often reduces constant pool loads.
15469 if (((N0.getOpcode() == ISD::FNEG && !TLI.isFNegFree(VT: N0.getValueType())) ||
15470 (N0.getOpcode() == ISD::FABS && !TLI.isFAbsFree(VT: N0.getValueType()))) &&
15471 N0->hasOneUse() && VT.isInteger() && !VT.isVector() &&
15472 !N0.getValueType().isVector()) {
15473 SDValue NewConv = DAG.getBitcast(VT, V: N0.getOperand(i: 0));
15474 AddToWorklist(N: NewConv.getNode());
15475
15476 SDLoc DL(N);
15477 if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
15478 assert(VT.getSizeInBits() == 128);
15479 SDValue SignBit = DAG.getConstant(
15480 Val: APInt::getSignMask(BitWidth: VT.getSizeInBits() / 2), DL: SDLoc(N0), VT: MVT::i64);
15481 SDValue FlipBit;
15482 if (N0.getOpcode() == ISD::FNEG) {
15483 FlipBit = SignBit;
15484 AddToWorklist(N: FlipBit.getNode());
15485 } else {
15486 assert(N0.getOpcode() == ISD::FABS);
15487 SDValue Hi =
15488 DAG.getNode(Opcode: ISD::EXTRACT_ELEMENT, DL: SDLoc(NewConv), VT: MVT::i64, N1: NewConv,
15489 N2: DAG.getIntPtrConstant(Val: getPPCf128HiElementSelector(DAG),
15490 DL: SDLoc(NewConv)));
15491 AddToWorklist(N: Hi.getNode());
15492 FlipBit = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(N0), VT: MVT::i64, N1: Hi, N2: SignBit);
15493 AddToWorklist(N: FlipBit.getNode());
15494 }
15495 SDValue FlipBits =
15496 DAG.getNode(Opcode: ISD::BUILD_PAIR, DL: SDLoc(N0), VT, N1: FlipBit, N2: FlipBit);
15497 AddToWorklist(N: FlipBits.getNode());
15498 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1: NewConv, N2: FlipBits);
15499 }
15500 APInt SignBit = APInt::getSignMask(BitWidth: VT.getSizeInBits());
15501 if (N0.getOpcode() == ISD::FNEG)
15502 return DAG.getNode(Opcode: ISD::XOR, DL, VT,
15503 N1: NewConv, N2: DAG.getConstant(Val: SignBit, DL, VT));
15504 assert(N0.getOpcode() == ISD::FABS);
15505 return DAG.getNode(Opcode: ISD::AND, DL, VT,
15506 N1: NewConv, N2: DAG.getConstant(Val: ~SignBit, DL, VT));
15507 }
15508
15509 // fold (bitconvert (fcopysign cst, x)) ->
15510 // (or (and (bitconvert x), sign), (and cst, (not sign)))
15511 // Note that we don't handle (copysign x, cst) because this can always be
15512 // folded to an fneg or fabs.
15513 //
15514 // For ppc_fp128:
15515 // fold (bitcast (fcopysign cst, x)) ->
15516 // flipbit = (and (extract_element
15517 // (xor (bitcast cst), (bitcast x)), 0),
15518 // signbit)
15519 // (xor (bitcast cst) (build_pair flipbit, flipbit))
15520 if (N0.getOpcode() == ISD::FCOPYSIGN && N0->hasOneUse() &&
15521 isa<ConstantFPSDNode>(Val: N0.getOperand(i: 0)) && VT.isInteger() &&
15522 !VT.isVector()) {
15523 unsigned OrigXWidth = N0.getOperand(i: 1).getValueSizeInBits();
15524 EVT IntXVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: OrigXWidth);
15525 if (isTypeLegal(VT: IntXVT)) {
15526 SDValue X = DAG.getBitcast(VT: IntXVT, V: N0.getOperand(i: 1));
15527 AddToWorklist(N: X.getNode());
15528
15529 // If X has a different width than the result/lhs, sext it or truncate it.
15530 unsigned VTWidth = VT.getSizeInBits();
15531 if (OrigXWidth < VTWidth) {
15532 X = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL: SDLoc(N), VT, Operand: X);
15533 AddToWorklist(N: X.getNode());
15534 } else if (OrigXWidth > VTWidth) {
15535 // To get the sign bit in the right place, we have to shift it right
15536 // before truncating.
15537 SDLoc DL(X);
15538 X = DAG.getNode(Opcode: ISD::SRL, DL,
15539 VT: X.getValueType(), N1: X,
15540 N2: DAG.getConstant(Val: OrigXWidth-VTWidth, DL,
15541 VT: X.getValueType()));
15542 AddToWorklist(N: X.getNode());
15543 X = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(X), VT, Operand: X);
15544 AddToWorklist(N: X.getNode());
15545 }
15546
15547 if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
15548 APInt SignBit = APInt::getSignMask(BitWidth: VT.getSizeInBits() / 2);
15549 SDValue Cst = DAG.getBitcast(VT, V: N0.getOperand(i: 0));
15550 AddToWorklist(N: Cst.getNode());
15551 SDValue X = DAG.getBitcast(VT, V: N0.getOperand(i: 1));
15552 AddToWorklist(N: X.getNode());
15553 SDValue XorResult = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N0), VT, N1: Cst, N2: X);
15554 AddToWorklist(N: XorResult.getNode());
15555 SDValue XorResult64 = DAG.getNode(
15556 Opcode: ISD::EXTRACT_ELEMENT, DL: SDLoc(XorResult), VT: MVT::i64, N1: XorResult,
15557 N2: DAG.getIntPtrConstant(Val: getPPCf128HiElementSelector(DAG),
15558 DL: SDLoc(XorResult)));
15559 AddToWorklist(N: XorResult64.getNode());
15560 SDValue FlipBit =
15561 DAG.getNode(Opcode: ISD::AND, DL: SDLoc(XorResult64), VT: MVT::i64, N1: XorResult64,
15562 N2: DAG.getConstant(Val: SignBit, DL: SDLoc(XorResult64), VT: MVT::i64));
15563 AddToWorklist(N: FlipBit.getNode());
15564 SDValue FlipBits =
15565 DAG.getNode(Opcode: ISD::BUILD_PAIR, DL: SDLoc(N0), VT, N1: FlipBit, N2: FlipBit);
15566 AddToWorklist(N: FlipBits.getNode());
15567 return DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N), VT, N1: Cst, N2: FlipBits);
15568 }
15569 APInt SignBit = APInt::getSignMask(BitWidth: VT.getSizeInBits());
15570 X = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(X), VT,
15571 N1: X, N2: DAG.getConstant(Val: SignBit, DL: SDLoc(X), VT));
15572 AddToWorklist(N: X.getNode());
15573
15574 SDValue Cst = DAG.getBitcast(VT, V: N0.getOperand(i: 0));
15575 Cst = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(Cst), VT,
15576 N1: Cst, N2: DAG.getConstant(Val: ~SignBit, DL: SDLoc(Cst), VT));
15577 AddToWorklist(N: Cst.getNode());
15578
15579 return DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N), VT, N1: X, N2: Cst);
15580 }
15581 }
15582
15583 // bitconvert(build_pair(ld, ld)) -> ld iff load locations are consecutive.
15584 if (N0.getOpcode() == ISD::BUILD_PAIR)
15585 if (SDValue CombineLD = CombineConsecutiveLoads(N: N0.getNode(), VT))
15586 return CombineLD;
15587
15588 // Remove double bitcasts from shuffles - this is often a legacy of
15589 // XformToShuffleWithZero being used to combine bitmaskings (of
15590 // float vectors bitcast to integer vectors) into shuffles.
15591 // bitcast(shuffle(bitcast(s0),bitcast(s1))) -> shuffle(s0,s1)
15592 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT) && VT.isVector() &&
15593 N0->getOpcode() == ISD::VECTOR_SHUFFLE && N0.hasOneUse() &&
15594 VT.getVectorNumElements() >= N0.getValueType().getVectorNumElements() &&
15595 !(VT.getVectorNumElements() % N0.getValueType().getVectorNumElements())) {
15596 ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Val&: N0);
15597
15598 // If operands are a bitcast, peek through if it casts the original VT.
15599 // If operands are a constant, just bitcast back to original VT.
15600 auto PeekThroughBitcast = [&](SDValue Op) {
15601 if (Op.getOpcode() == ISD::BITCAST &&
15602 Op.getOperand(i: 0).getValueType() == VT)
15603 return SDValue(Op.getOperand(i: 0));
15604 if (Op.isUndef() || isAnyConstantBuildVector(V: Op))
15605 return DAG.getBitcast(VT, V: Op);
15606 return SDValue();
15607 };
15608
15609 // FIXME: If either input vector is bitcast, try to convert the shuffle to
15610 // the result type of this bitcast. This would eliminate at least one
15611 // bitcast. See the transform in InstCombine.
15612 SDValue SV0 = PeekThroughBitcast(N0->getOperand(Num: 0));
15613 SDValue SV1 = PeekThroughBitcast(N0->getOperand(Num: 1));
15614 if (!(SV0 && SV1))
15615 return SDValue();
15616
15617 int MaskScale =
15618 VT.getVectorNumElements() / N0.getValueType().getVectorNumElements();
15619 SmallVector<int, 8> NewMask;
15620 for (int M : SVN->getMask())
15621 for (int i = 0; i != MaskScale; ++i)
15622 NewMask.push_back(Elt: M < 0 ? -1 : M * MaskScale + i);
15623
15624 SDValue LegalShuffle =
15625 TLI.buildLegalVectorShuffle(VT, DL: SDLoc(N), N0: SV0, N1: SV1, Mask: NewMask, DAG);
15626 if (LegalShuffle)
15627 return LegalShuffle;
15628 }
15629
15630 return SDValue();
15631}
15632
15633SDValue DAGCombiner::visitBUILD_PAIR(SDNode *N) {
15634 EVT VT = N->getValueType(ResNo: 0);
15635 return CombineConsecutiveLoads(N, VT);
15636}
15637
15638SDValue DAGCombiner::visitFREEZE(SDNode *N) {
15639 SDValue N0 = N->getOperand(Num: 0);
15640
15641 if (DAG.isGuaranteedNotToBeUndefOrPoison(Op: N0, /*PoisonOnly*/ false))
15642 return N0;
15643
15644 // We currently avoid folding freeze over SRA/SRL, due to the problems seen
15645 // with (freeze (assert ext)) blocking simplifications of SRA/SRL. See for
15646 // example https://reviews.llvm.org/D136529#4120959.
15647 if (N0.getOpcode() == ISD::SRA || N0.getOpcode() == ISD::SRL)
15648 return SDValue();
15649
15650 // Fold freeze(op(x, ...)) -> op(freeze(x), ...).
15651 // Try to push freeze through instructions that propagate but don't produce
15652 // poison as far as possible. If an operand of freeze follows three
15653 // conditions 1) one-use, 2) does not produce poison, and 3) has all but one
15654 // guaranteed-non-poison operands (or is a BUILD_VECTOR or similar) then push
15655 // the freeze through to the operands that are not guaranteed non-poison.
15656 // NOTE: we will strip poison-generating flags, so ignore them here.
15657 if (DAG.canCreateUndefOrPoison(Op: N0, /*PoisonOnly*/ false,
15658 /*ConsiderFlags*/ false) ||
15659 N0->getNumValues() != 1 || !N0->hasOneUse())
15660 return SDValue();
15661
15662 bool AllowMultipleMaybePoisonOperands =
15663 N0.getOpcode() == ISD::SELECT_CC ||
15664 N0.getOpcode() == ISD::SETCC ||
15665 N0.getOpcode() == ISD::BUILD_VECTOR ||
15666 N0.getOpcode() == ISD::BUILD_PAIR ||
15667 N0.getOpcode() == ISD::VECTOR_SHUFFLE ||
15668 N0.getOpcode() == ISD::CONCAT_VECTORS;
15669
15670 // Avoid turning a BUILD_VECTOR that can be recognized as "all zeros", "all
15671 // ones" or "constant" into something that depends on FrozenUndef. We can
15672 // instead pick undef values to keep those properties, while at the same time
15673 // folding away the freeze.
15674 // If we implement a more general solution for folding away freeze(undef) in
15675 // the future, then this special handling can be removed.
15676 if (N0.getOpcode() == ISD::BUILD_VECTOR) {
15677 SDLoc DL(N0);
15678 EVT VT = N0.getValueType();
15679 if (llvm::ISD::isBuildVectorAllOnes(N: N0.getNode()))
15680 return DAG.getAllOnesConstant(DL, VT);
15681 if (llvm::ISD::isBuildVectorOfConstantSDNodes(N: N0.getNode())) {
15682 SmallVector<SDValue, 8> NewVecC;
15683 for (const SDValue &Op : N0->op_values())
15684 NewVecC.push_back(
15685 Elt: Op.isUndef() ? DAG.getConstant(Val: 0, DL, VT: Op.getValueType()) : Op);
15686 return DAG.getBuildVector(VT, DL, Ops: NewVecC);
15687 }
15688 }
15689
15690 SmallSet<SDValue, 8> MaybePoisonOperands;
15691 SmallVector<unsigned, 8> MaybePoisonOperandNumbers;
15692 for (auto [OpNo, Op] : enumerate(First: N0->ops())) {
15693 if (DAG.isGuaranteedNotToBeUndefOrPoison(Op, /*PoisonOnly*/ false,
15694 /*Depth*/ 1))
15695 continue;
15696 bool HadMaybePoisonOperands = !MaybePoisonOperands.empty();
15697 bool IsNewMaybePoisonOperand = MaybePoisonOperands.insert(V: Op).second;
15698 if (IsNewMaybePoisonOperand)
15699 MaybePoisonOperandNumbers.push_back(Elt: OpNo);
15700 if (!HadMaybePoisonOperands)
15701 continue;
15702 if (IsNewMaybePoisonOperand && !AllowMultipleMaybePoisonOperands) {
15703 // Multiple maybe-poison ops when not allowed - bail out.
15704 return SDValue();
15705 }
15706 }
15707 // NOTE: the whole op may be not guaranteed to not be undef or poison because
15708 // it could create undef or poison due to it's poison-generating flags.
15709 // So not finding any maybe-poison operands is fine.
15710
15711 for (unsigned OpNo : MaybePoisonOperandNumbers) {
15712 // N0 can mutate during iteration, so make sure to refetch the maybe poison
15713 // operands via the operand numbers. The typical scenario is that we have
15714 // something like this
15715 // t262: i32 = freeze t181
15716 // t150: i32 = ctlz_zero_undef t262
15717 // t184: i32 = ctlz_zero_undef t181
15718 // t268: i32 = select_cc t181, Constant:i32<0>, t184, t186, setne:ch
15719 // When freezing the t181 operand we get t262 back, and then the
15720 // ReplaceAllUsesOfValueWith call will not only replace t181 by t262, but
15721 // also recursively replace t184 by t150.
15722 SDValue MaybePoisonOperand = N->getOperand(Num: 0).getOperand(i: OpNo);
15723 // Don't replace every single UNDEF everywhere with frozen UNDEF, though.
15724 if (MaybePoisonOperand.getOpcode() == ISD::UNDEF)
15725 continue;
15726 // First, freeze each offending operand.
15727 SDValue FrozenMaybePoisonOperand = DAG.getFreeze(V: MaybePoisonOperand);
15728 // Then, change all other uses of unfrozen operand to use frozen operand.
15729 DAG.ReplaceAllUsesOfValueWith(From: MaybePoisonOperand, To: FrozenMaybePoisonOperand);
15730 if (FrozenMaybePoisonOperand.getOpcode() == ISD::FREEZE &&
15731 FrozenMaybePoisonOperand.getOperand(i: 0) == FrozenMaybePoisonOperand) {
15732 // But, that also updated the use in the freeze we just created, thus
15733 // creating a cycle in a DAG. Let's undo that by mutating the freeze.
15734 DAG.UpdateNodeOperands(N: FrozenMaybePoisonOperand.getNode(),
15735 Op: MaybePoisonOperand);
15736 }
15737 }
15738
15739 // This node has been merged with another.
15740 if (N->getOpcode() == ISD::DELETED_NODE)
15741 return SDValue(N, 0);
15742
15743 // The whole node may have been updated, so the value we were holding
15744 // may no longer be valid. Re-fetch the operand we're `freeze`ing.
15745 N0 = N->getOperand(Num: 0);
15746
15747 // Finally, recreate the node, it's operands were updated to use
15748 // frozen operands, so we just need to use it's "original" operands.
15749 SmallVector<SDValue> Ops(N0->op_begin(), N0->op_end());
15750 // Special-handle ISD::UNDEF, each single one of them can be it's own thing.
15751 for (SDValue &Op : Ops) {
15752 if (Op.getOpcode() == ISD::UNDEF)
15753 Op = DAG.getFreeze(V: Op);
15754 }
15755
15756 SDValue R;
15757 if (auto *SVN = dyn_cast<ShuffleVectorSDNode>(Val&: N0)) {
15758 // Special case handling for ShuffleVectorSDNode nodes.
15759 R = DAG.getVectorShuffle(VT: N0.getValueType(), dl: SDLoc(N0), N1: Ops[0], N2: Ops[1],
15760 Mask: SVN->getMask());
15761 } else {
15762 // NOTE: this strips poison generating flags.
15763 R = DAG.getNode(Opcode: N0.getOpcode(), DL: SDLoc(N0), VTList: N0->getVTList(), Ops);
15764 }
15765 assert(DAG.isGuaranteedNotToBeUndefOrPoison(R, /*PoisonOnly*/ false) &&
15766 "Can't create node that may be undef/poison!");
15767 return R;
15768}
15769
15770/// We know that BV is a build_vector node with Constant, ConstantFP or Undef
15771/// operands. DstEltVT indicates the destination element value type.
15772SDValue DAGCombiner::
15773ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) {
15774 EVT SrcEltVT = BV->getValueType(ResNo: 0).getVectorElementType();
15775
15776 // If this is already the right type, we're done.
15777 if (SrcEltVT == DstEltVT) return SDValue(BV, 0);
15778
15779 unsigned SrcBitSize = SrcEltVT.getSizeInBits();
15780 unsigned DstBitSize = DstEltVT.getSizeInBits();
15781
15782 // If this is a conversion of N elements of one type to N elements of another
15783 // type, convert each element. This handles FP<->INT cases.
15784 if (SrcBitSize == DstBitSize) {
15785 SmallVector<SDValue, 8> Ops;
15786 for (SDValue Op : BV->op_values()) {
15787 // If the vector element type is not legal, the BUILD_VECTOR operands
15788 // are promoted and implicitly truncated. Make that explicit here.
15789 if (Op.getValueType() != SrcEltVT)
15790 Op = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(BV), VT: SrcEltVT, Operand: Op);
15791 Ops.push_back(Elt: DAG.getBitcast(VT: DstEltVT, V: Op));
15792 AddToWorklist(N: Ops.back().getNode());
15793 }
15794 EVT VT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: DstEltVT,
15795 NumElements: BV->getValueType(ResNo: 0).getVectorNumElements());
15796 return DAG.getBuildVector(VT, DL: SDLoc(BV), Ops);
15797 }
15798
15799 // Otherwise, we're growing or shrinking the elements. To avoid having to
15800 // handle annoying details of growing/shrinking FP values, we convert them to
15801 // int first.
15802 if (SrcEltVT.isFloatingPoint()) {
15803 // Convert the input float vector to a int vector where the elements are the
15804 // same sizes.
15805 EVT IntVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: SrcEltVT.getSizeInBits());
15806 BV = ConstantFoldBITCASTofBUILD_VECTOR(BV, DstEltVT: IntVT).getNode();
15807 SrcEltVT = IntVT;
15808 }
15809
15810 // Now we know the input is an integer vector. If the output is a FP type,
15811 // convert to integer first, then to FP of the right size.
15812 if (DstEltVT.isFloatingPoint()) {
15813 EVT TmpVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: DstEltVT.getSizeInBits());
15814 SDNode *Tmp = ConstantFoldBITCASTofBUILD_VECTOR(BV, DstEltVT: TmpVT).getNode();
15815
15816 // Next, convert to FP elements of the same size.
15817 return ConstantFoldBITCASTofBUILD_VECTOR(BV: Tmp, DstEltVT);
15818 }
15819
15820 // Okay, we know the src/dst types are both integers of differing types.
15821 assert(SrcEltVT.isInteger() && DstEltVT.isInteger());
15822
15823 // TODO: Should ConstantFoldBITCASTofBUILD_VECTOR always take a
15824 // BuildVectorSDNode?
15825 auto *BVN = cast<BuildVectorSDNode>(Val: BV);
15826
15827 // Extract the constant raw bit data.
15828 BitVector UndefElements;
15829 SmallVector<APInt> RawBits;
15830 bool IsLE = DAG.getDataLayout().isLittleEndian();
15831 if (!BVN->getConstantRawBits(IsLittleEndian: IsLE, DstEltSizeInBits: DstBitSize, RawBitElements&: RawBits, UndefElements))
15832 return SDValue();
15833
15834 SDLoc DL(BV);
15835 SmallVector<SDValue, 8> Ops;
15836 for (unsigned I = 0, E = RawBits.size(); I != E; ++I) {
15837 if (UndefElements[I])
15838 Ops.push_back(Elt: DAG.getUNDEF(VT: DstEltVT));
15839 else
15840 Ops.push_back(Elt: DAG.getConstant(Val: RawBits[I], DL, VT: DstEltVT));
15841 }
15842
15843 EVT VT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: DstEltVT, NumElements: Ops.size());
15844 return DAG.getBuildVector(VT, DL, Ops);
15845}
15846
15847// Returns true if floating point contraction is allowed on the FMUL-SDValue
15848// `N`
15849static bool isContractableFMUL(const TargetOptions &Options, SDValue N) {
15850 assert(N.getOpcode() == ISD::FMUL);
15851
15852 return Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath ||
15853 N->getFlags().hasAllowContract();
15854}
15855
15856// Returns true if `N` can assume no infinities involved in its computation.
15857static bool hasNoInfs(const TargetOptions &Options, SDValue N) {
15858 return Options.NoInfsFPMath || N->getFlags().hasNoInfs();
15859}
15860
15861/// Try to perform FMA combining on a given FADD node.
15862template <class MatchContextClass>
15863SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
15864 SDValue N0 = N->getOperand(Num: 0);
15865 SDValue N1 = N->getOperand(Num: 1);
15866 EVT VT = N->getValueType(ResNo: 0);
15867 SDLoc SL(N);
15868 MatchContextClass matcher(DAG, TLI, N);
15869 const TargetOptions &Options = DAG.getTarget().Options;
15870
15871 bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
15872
15873 // Floating-point multiply-add with intermediate rounding.
15874 // FIXME: Make isFMADLegal have specific behavior when using VPMatchContext.
15875 // FIXME: Add VP_FMAD opcode.
15876 bool HasFMAD = !UseVP && (LegalOperations && TLI.isFMADLegal(DAG, N));
15877
15878 // Floating-point multiply-add without intermediate rounding.
15879 bool HasFMA =
15880 TLI.isFMAFasterThanFMulAndFAdd(MF: DAG.getMachineFunction(), VT) &&
15881 (!LegalOperations || matcher.isOperationLegalOrCustom(ISD::FMA, VT));
15882
15883 // No valid opcode, do not combine.
15884 if (!HasFMAD && !HasFMA)
15885 return SDValue();
15886
15887 bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
15888 Options.UnsafeFPMath || HasFMAD);
15889 // If the addition is not contractable, do not combine.
15890 if (!AllowFusionGlobally && !N->getFlags().hasAllowContract())
15891 return SDValue();
15892
15893 // Folding fadd (fmul x, y), (fmul x, y) -> fma x, y, (fmul x, y) is never
15894 // beneficial. It does not reduce latency. It increases register pressure. It
15895 // replaces an fadd with an fma which is a more complex instruction, so is
15896 // likely to have a larger encoding, use more functional units, etc.
15897 if (N0 == N1)
15898 return SDValue();
15899
15900 if (TLI.generateFMAsInMachineCombiner(VT, OptLevel))
15901 return SDValue();
15902
15903 // Always prefer FMAD to FMA for precision.
15904 unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
15905 bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
15906
15907 auto isFusedOp = [&](SDValue N) {
15908 return matcher.match(N, ISD::FMA) || matcher.match(N, ISD::FMAD);
15909 };
15910
15911 // Is the node an FMUL and contractable either due to global flags or
15912 // SDNodeFlags.
15913 auto isContractableFMUL = [AllowFusionGlobally, &matcher](SDValue N) {
15914 if (!matcher.match(N, ISD::FMUL))
15915 return false;
15916 return AllowFusionGlobally || N->getFlags().hasAllowContract();
15917 };
15918 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
15919 // prefer to fold the multiply with fewer uses.
15920 if (Aggressive && isContractableFMUL(N0) && isContractableFMUL(N1)) {
15921 if (N0->use_size() > N1->use_size())
15922 std::swap(a&: N0, b&: N1);
15923 }
15924
15925 // fold (fadd (fmul x, y), z) -> (fma x, y, z)
15926 if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) {
15927 return matcher.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(i: 0),
15928 N0.getOperand(i: 1), N1);
15929 }
15930
15931 // fold (fadd x, (fmul y, z)) -> (fma y, z, x)
15932 // Note: Commutes FADD operands.
15933 if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) {
15934 return matcher.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(i: 0),
15935 N1.getOperand(i: 1), N0);
15936 }
15937
15938 // fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E)
15939 // fadd E, (fma A, B, (fmul C, D)) --> fma A, B, (fma C, D, E)
15940 // This also works with nested fma instructions:
15941 // fadd (fma A, B, (fma (C, D, (fmul (E, F))))), G -->
15942 // fma A, B, (fma C, D, fma (E, F, G))
15943 // fadd (G, (fma A, B, (fma (C, D, (fmul (E, F)))))) -->
15944 // fma A, B, (fma C, D, fma (E, F, G)).
15945 // This requires reassociation because it changes the order of operations.
15946 bool CanReassociate =
15947 Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
15948 if (CanReassociate) {
15949 SDValue FMA, E;
15950 if (isFusedOp(N0) && N0.hasOneUse()) {
15951 FMA = N0;
15952 E = N1;
15953 } else if (isFusedOp(N1) && N1.hasOneUse()) {
15954 FMA = N1;
15955 E = N0;
15956 }
15957
15958 SDValue TmpFMA = FMA;
15959 while (E && isFusedOp(TmpFMA) && TmpFMA.hasOneUse()) {
15960 SDValue FMul = TmpFMA->getOperand(Num: 2);
15961 if (matcher.match(FMul, ISD::FMUL) && FMul.hasOneUse()) {
15962 SDValue C = FMul.getOperand(i: 0);
15963 SDValue D = FMul.getOperand(i: 1);
15964 SDValue CDE = matcher.getNode(PreferredFusedOpcode, SL, VT, C, D, E);
15965 DAG.ReplaceAllUsesOfValueWith(From: FMul, To: CDE);
15966 // Replacing the inner FMul could cause the outer FMA to be simplified
15967 // away.
15968 return FMA.getOpcode() == ISD::DELETED_NODE ? SDValue(N, 0) : FMA;
15969 }
15970
15971 TmpFMA = TmpFMA->getOperand(Num: 2);
15972 }
15973 }
15974
15975 // Look through FP_EXTEND nodes to do more combining.
15976
15977 // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
15978 if (matcher.match(N0, ISD::FP_EXTEND)) {
15979 SDValue N00 = N0.getOperand(i: 0);
15980 if (isContractableFMUL(N00) &&
15981 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
15982 SrcVT: N00.getValueType())) {
15983 return matcher.getNode(
15984 PreferredFusedOpcode, SL, VT,
15985 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(i: 0)),
15986 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(i: 1)), N1);
15987 }
15988 }
15989
15990 // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x)
15991 // Note: Commutes FADD operands.
15992 if (matcher.match(N1, ISD::FP_EXTEND)) {
15993 SDValue N10 = N1.getOperand(i: 0);
15994 if (isContractableFMUL(N10) &&
15995 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
15996 SrcVT: N10.getValueType())) {
15997 return matcher.getNode(
15998 PreferredFusedOpcode, SL, VT,
15999 matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(i: 0)),
16000 matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(i: 1)), N0);
16001 }
16002 }
16003
16004 // More folding opportunities when target permits.
16005 if (Aggressive) {
16006 // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
16007 // -> (fma x, y, (fma (fpext u), (fpext v), z))
16008 auto FoldFAddFMAFPExtFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
16009 SDValue Z) {
16010 return matcher.getNode(
16011 PreferredFusedOpcode, SL, VT, X, Y,
16012 matcher.getNode(PreferredFusedOpcode, SL, VT,
16013 matcher.getNode(ISD::FP_EXTEND, SL, VT, U),
16014 matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z));
16015 };
16016 if (isFusedOp(N0)) {
16017 SDValue N02 = N0.getOperand(i: 2);
16018 if (matcher.match(N02, ISD::FP_EXTEND)) {
16019 SDValue N020 = N02.getOperand(i: 0);
16020 if (isContractableFMUL(N020) &&
16021 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
16022 SrcVT: N020.getValueType())) {
16023 return FoldFAddFMAFPExtFMul(N0.getOperand(i: 0), N0.getOperand(i: 1),
16024 N020.getOperand(i: 0), N020.getOperand(i: 1),
16025 N1);
16026 }
16027 }
16028 }
16029
16030 // fold (fadd (fpext (fma x, y, (fmul u, v))), z)
16031 // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
16032 // FIXME: This turns two single-precision and one double-precision
16033 // operation into two double-precision operations, which might not be
16034 // interesting for all targets, especially GPUs.
16035 auto FoldFAddFPExtFMAFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
16036 SDValue Z) {
16037 return matcher.getNode(
16038 PreferredFusedOpcode, SL, VT,
16039 matcher.getNode(ISD::FP_EXTEND, SL, VT, X),
16040 matcher.getNode(ISD::FP_EXTEND, SL, VT, Y),
16041 matcher.getNode(PreferredFusedOpcode, SL, VT,
16042 matcher.getNode(ISD::FP_EXTEND, SL, VT, U),
16043 matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z));
16044 };
16045 if (N0.getOpcode() == ISD::FP_EXTEND) {
16046 SDValue N00 = N0.getOperand(i: 0);
16047 if (isFusedOp(N00)) {
16048 SDValue N002 = N00.getOperand(i: 2);
16049 if (isContractableFMUL(N002) &&
16050 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
16051 SrcVT: N00.getValueType())) {
16052 return FoldFAddFPExtFMAFMul(N00.getOperand(i: 0), N00.getOperand(i: 1),
16053 N002.getOperand(i: 0), N002.getOperand(i: 1),
16054 N1);
16055 }
16056 }
16057 }
16058
16059 // fold (fadd x, (fma y, z, (fpext (fmul u, v)))
16060 // -> (fma y, z, (fma (fpext u), (fpext v), x))
16061 if (isFusedOp(N1)) {
16062 SDValue N12 = N1.getOperand(i: 2);
16063 if (N12.getOpcode() == ISD::FP_EXTEND) {
16064 SDValue N120 = N12.getOperand(i: 0);
16065 if (isContractableFMUL(N120) &&
16066 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
16067 SrcVT: N120.getValueType())) {
16068 return FoldFAddFMAFPExtFMul(N1.getOperand(i: 0), N1.getOperand(i: 1),
16069 N120.getOperand(i: 0), N120.getOperand(i: 1),
16070 N0);
16071 }
16072 }
16073 }
16074
16075 // fold (fadd x, (fpext (fma y, z, (fmul u, v)))
16076 // -> (fma (fpext y), (fpext z), (fma (fpext u), (fpext v), x))
16077 // FIXME: This turns two single-precision and one double-precision
16078 // operation into two double-precision operations, which might not be
16079 // interesting for all targets, especially GPUs.
16080 if (N1.getOpcode() == ISD::FP_EXTEND) {
16081 SDValue N10 = N1.getOperand(i: 0);
16082 if (isFusedOp(N10)) {
16083 SDValue N102 = N10.getOperand(i: 2);
16084 if (isContractableFMUL(N102) &&
16085 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
16086 SrcVT: N10.getValueType())) {
16087 return FoldFAddFPExtFMAFMul(N10.getOperand(i: 0), N10.getOperand(i: 1),
16088 N102.getOperand(i: 0), N102.getOperand(i: 1),
16089 N0);
16090 }
16091 }
16092 }
16093 }
16094
16095 return SDValue();
16096}
16097
16098/// Try to perform FMA combining on a given FSUB node.
16099template <class MatchContextClass>
16100SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
16101 SDValue N0 = N->getOperand(Num: 0);
16102 SDValue N1 = N->getOperand(Num: 1);
16103 EVT VT = N->getValueType(ResNo: 0);
16104 SDLoc SL(N);
16105 MatchContextClass matcher(DAG, TLI, N);
16106 const TargetOptions &Options = DAG.getTarget().Options;
16107
16108 bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
16109
16110 // Floating-point multiply-add with intermediate rounding.
16111 // FIXME: Make isFMADLegal have specific behavior when using VPMatchContext.
16112 // FIXME: Add VP_FMAD opcode.
16113 bool HasFMAD = !UseVP && (LegalOperations && TLI.isFMADLegal(DAG, N));
16114
16115 // Floating-point multiply-add without intermediate rounding.
16116 bool HasFMA =
16117 TLI.isFMAFasterThanFMulAndFAdd(MF: DAG.getMachineFunction(), VT) &&
16118 (!LegalOperations || matcher.isOperationLegalOrCustom(ISD::FMA, VT));
16119
16120 // No valid opcode, do not combine.
16121 if (!HasFMAD && !HasFMA)
16122 return SDValue();
16123
16124 const SDNodeFlags Flags = N->getFlags();
16125 bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
16126 Options.UnsafeFPMath || HasFMAD);
16127
16128 // If the subtraction is not contractable, do not combine.
16129 if (!AllowFusionGlobally && !N->getFlags().hasAllowContract())
16130 return SDValue();
16131
16132 if (TLI.generateFMAsInMachineCombiner(VT, OptLevel))
16133 return SDValue();
16134
16135 // Always prefer FMAD to FMA for precision.
16136 unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
16137 bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
16138 bool NoSignedZero = Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros();
16139
16140 // Is the node an FMUL and contractable either due to global flags or
16141 // SDNodeFlags.
16142 auto isContractableFMUL = [AllowFusionGlobally, &matcher](SDValue N) {
16143 if (!matcher.match(N, ISD::FMUL))
16144 return false;
16145 return AllowFusionGlobally || N->getFlags().hasAllowContract();
16146 };
16147
16148 // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
16149 auto tryToFoldXYSubZ = [&](SDValue XY, SDValue Z) {
16150 if (isContractableFMUL(XY) && (Aggressive || XY->hasOneUse())) {
16151 return matcher.getNode(PreferredFusedOpcode, SL, VT, XY.getOperand(i: 0),
16152 XY.getOperand(i: 1),
16153 matcher.getNode(ISD::FNEG, SL, VT, Z));
16154 }
16155 return SDValue();
16156 };
16157
16158 // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
16159 // Note: Commutes FSUB operands.
16160 auto tryToFoldXSubYZ = [&](SDValue X, SDValue YZ) {
16161 if (isContractableFMUL(YZ) && (Aggressive || YZ->hasOneUse())) {
16162 return matcher.getNode(
16163 PreferredFusedOpcode, SL, VT,
16164 matcher.getNode(ISD::FNEG, SL, VT, YZ.getOperand(i: 0)),
16165 YZ.getOperand(i: 1), X);
16166 }
16167 return SDValue();
16168 };
16169
16170 // If we have two choices trying to fold (fsub (fmul u, v), (fmul x, y)),
16171 // prefer to fold the multiply with fewer uses.
16172 if (isContractableFMUL(N0) && isContractableFMUL(N1) &&
16173 (N0->use_size() > N1->use_size())) {
16174 // fold (fsub (fmul a, b), (fmul c, d)) -> (fma (fneg c), d, (fmul a, b))
16175 if (SDValue V = tryToFoldXSubYZ(N0, N1))
16176 return V;
16177 // fold (fsub (fmul a, b), (fmul c, d)) -> (fma a, b, (fneg (fmul c, d)))
16178 if (SDValue V = tryToFoldXYSubZ(N0, N1))
16179 return V;
16180 } else {
16181 // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
16182 if (SDValue V = tryToFoldXYSubZ(N0, N1))
16183 return V;
16184 // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
16185 if (SDValue V = tryToFoldXSubYZ(N0, N1))
16186 return V;
16187 }
16188
16189 // fold (fsub (fneg (fmul, x, y)), z) -> (fma (fneg x), y, (fneg z))
16190 if (matcher.match(N0, ISD::FNEG) && isContractableFMUL(N0.getOperand(i: 0)) &&
16191 (Aggressive || (N0->hasOneUse() && N0.getOperand(i: 0).hasOneUse()))) {
16192 SDValue N00 = N0.getOperand(i: 0).getOperand(i: 0);
16193 SDValue N01 = N0.getOperand(i: 0).getOperand(i: 1);
16194 return matcher.getNode(PreferredFusedOpcode, SL, VT,
16195 matcher.getNode(ISD::FNEG, SL, VT, N00), N01,
16196 matcher.getNode(ISD::FNEG, SL, VT, N1));
16197 }
16198
16199 // Look through FP_EXTEND nodes to do more combining.
16200
16201 // fold (fsub (fpext (fmul x, y)), z)
16202 // -> (fma (fpext x), (fpext y), (fneg z))
16203 if (matcher.match(N0, ISD::FP_EXTEND)) {
16204 SDValue N00 = N0.getOperand(i: 0);
16205 if (isContractableFMUL(N00) &&
16206 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
16207 SrcVT: N00.getValueType())) {
16208 return matcher.getNode(
16209 PreferredFusedOpcode, SL, VT,
16210 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(i: 0)),
16211 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(i: 1)),
16212 matcher.getNode(ISD::FNEG, SL, VT, N1));
16213 }
16214 }
16215
16216 // fold (fsub x, (fpext (fmul y, z)))
16217 // -> (fma (fneg (fpext y)), (fpext z), x)
16218 // Note: Commutes FSUB operands.
16219 if (matcher.match(N1, ISD::FP_EXTEND)) {
16220 SDValue N10 = N1.getOperand(i: 0);
16221 if (isContractableFMUL(N10) &&
16222 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
16223 SrcVT: N10.getValueType())) {
16224 return matcher.getNode(
16225 PreferredFusedOpcode, SL, VT,
16226 matcher.getNode(
16227 ISD::FNEG, SL, VT,
16228 matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(i: 0))),
16229 matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(i: 1)), N0);
16230 }
16231 }
16232
16233 // fold (fsub (fpext (fneg (fmul, x, y))), z)
16234 // -> (fneg (fma (fpext x), (fpext y), z))
16235 // Note: This could be removed with appropriate canonicalization of the
16236 // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
16237 // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
16238 // from implementing the canonicalization in visitFSUB.
16239 if (matcher.match(N0, ISD::FP_EXTEND)) {
16240 SDValue N00 = N0.getOperand(i: 0);
16241 if (matcher.match(N00, ISD::FNEG)) {
16242 SDValue N000 = N00.getOperand(i: 0);
16243 if (isContractableFMUL(N000) &&
16244 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
16245 SrcVT: N00.getValueType())) {
16246 return matcher.getNode(
16247 ISD::FNEG, SL, VT,
16248 matcher.getNode(
16249 PreferredFusedOpcode, SL, VT,
16250 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(i: 0)),
16251 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(i: 1)),
16252 N1));
16253 }
16254 }
16255 }
16256
16257 // fold (fsub (fneg (fpext (fmul, x, y))), z)
16258 // -> (fneg (fma (fpext x)), (fpext y), z)
16259 // Note: This could be removed with appropriate canonicalization of the
16260 // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
16261 // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
16262 // from implementing the canonicalization in visitFSUB.
16263 if (matcher.match(N0, ISD::FNEG)) {
16264 SDValue N00 = N0.getOperand(i: 0);
16265 if (matcher.match(N00, ISD::FP_EXTEND)) {
16266 SDValue N000 = N00.getOperand(i: 0);
16267 if (isContractableFMUL(N000) &&
16268 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
16269 SrcVT: N000.getValueType())) {
16270 return matcher.getNode(
16271 ISD::FNEG, SL, VT,
16272 matcher.getNode(
16273 PreferredFusedOpcode, SL, VT,
16274 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(i: 0)),
16275 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(i: 1)),
16276 N1));
16277 }
16278 }
16279 }
16280
16281 auto isReassociable = [&Options](SDNode *N) {
16282 return Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
16283 };
16284
16285 auto isContractableAndReassociableFMUL = [&isContractableFMUL,
16286 &isReassociable](SDValue N) {
16287 return isContractableFMUL(N) && isReassociable(N.getNode());
16288 };
16289
16290 auto isFusedOp = [&](SDValue N) {
16291 return matcher.match(N, ISD::FMA) || matcher.match(N, ISD::FMAD);
16292 };
16293
16294 // More folding opportunities when target permits.
16295 if (Aggressive && isReassociable(N)) {
16296 bool CanFuse = Options.UnsafeFPMath || N->getFlags().hasAllowContract();
16297 // fold (fsub (fma x, y, (fmul u, v)), z)
16298 // -> (fma x, y (fma u, v, (fneg z)))
16299 if (CanFuse && isFusedOp(N0) &&
16300 isContractableAndReassociableFMUL(N0.getOperand(i: 2)) &&
16301 N0->hasOneUse() && N0.getOperand(i: 2)->hasOneUse()) {
16302 return matcher.getNode(
16303 PreferredFusedOpcode, SL, VT, N0.getOperand(i: 0), N0.getOperand(i: 1),
16304 matcher.getNode(PreferredFusedOpcode, SL, VT,
16305 N0.getOperand(i: 2).getOperand(i: 0),
16306 N0.getOperand(i: 2).getOperand(i: 1),
16307 matcher.getNode(ISD::FNEG, SL, VT, N1)));
16308 }
16309
16310 // fold (fsub x, (fma y, z, (fmul u, v)))
16311 // -> (fma (fneg y), z, (fma (fneg u), v, x))
16312 if (CanFuse && isFusedOp(N1) &&
16313 isContractableAndReassociableFMUL(N1.getOperand(i: 2)) &&
16314 N1->hasOneUse() && NoSignedZero) {
16315 SDValue N20 = N1.getOperand(i: 2).getOperand(i: 0);
16316 SDValue N21 = N1.getOperand(i: 2).getOperand(i: 1);
16317 return matcher.getNode(
16318 PreferredFusedOpcode, SL, VT,
16319 matcher.getNode(ISD::FNEG, SL, VT, N1.getOperand(i: 0)),
16320 N1.getOperand(i: 1),
16321 matcher.getNode(PreferredFusedOpcode, SL, VT,
16322 matcher.getNode(ISD::FNEG, SL, VT, N20), N21, N0));
16323 }
16324
16325 // fold (fsub (fma x, y, (fpext (fmul u, v))), z)
16326 // -> (fma x, y (fma (fpext u), (fpext v), (fneg z)))
16327 if (isFusedOp(N0) && N0->hasOneUse()) {
16328 SDValue N02 = N0.getOperand(i: 2);
16329 if (matcher.match(N02, ISD::FP_EXTEND)) {
16330 SDValue N020 = N02.getOperand(i: 0);
16331 if (isContractableAndReassociableFMUL(N020) &&
16332 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
16333 SrcVT: N020.getValueType())) {
16334 return matcher.getNode(
16335 PreferredFusedOpcode, SL, VT, N0.getOperand(i: 0), N0.getOperand(i: 1),
16336 matcher.getNode(
16337 PreferredFusedOpcode, SL, VT,
16338 matcher.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(i: 0)),
16339 matcher.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(i: 1)),
16340 matcher.getNode(ISD::FNEG, SL, VT, N1)));
16341 }
16342 }
16343 }
16344
16345 // fold (fsub (fpext (fma x, y, (fmul u, v))), z)
16346 // -> (fma (fpext x), (fpext y),
16347 // (fma (fpext u), (fpext v), (fneg z)))
16348 // FIXME: This turns two single-precision and one double-precision
16349 // operation into two double-precision operations, which might not be
16350 // interesting for all targets, especially GPUs.
16351 if (matcher.match(N0, ISD::FP_EXTEND)) {
16352 SDValue N00 = N0.getOperand(i: 0);
16353 if (isFusedOp(N00)) {
16354 SDValue N002 = N00.getOperand(i: 2);
16355 if (isContractableAndReassociableFMUL(N002) &&
16356 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
16357 SrcVT: N00.getValueType())) {
16358 return matcher.getNode(
16359 PreferredFusedOpcode, SL, VT,
16360 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(i: 0)),
16361 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(i: 1)),
16362 matcher.getNode(
16363 PreferredFusedOpcode, SL, VT,
16364 matcher.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(i: 0)),
16365 matcher.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(i: 1)),
16366 matcher.getNode(ISD::FNEG, SL, VT, N1)));
16367 }
16368 }
16369 }
16370
16371 // fold (fsub x, (fma y, z, (fpext (fmul u, v))))
16372 // -> (fma (fneg y), z, (fma (fneg (fpext u)), (fpext v), x))
16373 if (isFusedOp(N1) && matcher.match(N1.getOperand(i: 2), ISD::FP_EXTEND) &&
16374 N1->hasOneUse()) {
16375 SDValue N120 = N1.getOperand(i: 2).getOperand(i: 0);
16376 if (isContractableAndReassociableFMUL(N120) &&
16377 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
16378 SrcVT: N120.getValueType())) {
16379 SDValue N1200 = N120.getOperand(i: 0);
16380 SDValue N1201 = N120.getOperand(i: 1);
16381 return matcher.getNode(
16382 PreferredFusedOpcode, SL, VT,
16383 matcher.getNode(ISD::FNEG, SL, VT, N1.getOperand(i: 0)),
16384 N1.getOperand(i: 1),
16385 matcher.getNode(
16386 PreferredFusedOpcode, SL, VT,
16387 matcher.getNode(ISD::FNEG, SL, VT,
16388 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1200)),
16389 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1201), N0));
16390 }
16391 }
16392
16393 // fold (fsub x, (fpext (fma y, z, (fmul u, v))))
16394 // -> (fma (fneg (fpext y)), (fpext z),
16395 // (fma (fneg (fpext u)), (fpext v), x))
16396 // FIXME: This turns two single-precision and one double-precision
16397 // operation into two double-precision operations, which might not be
16398 // interesting for all targets, especially GPUs.
16399 if (matcher.match(N1, ISD::FP_EXTEND) && isFusedOp(N1.getOperand(i: 0))) {
16400 SDValue CvtSrc = N1.getOperand(i: 0);
16401 SDValue N100 = CvtSrc.getOperand(i: 0);
16402 SDValue N101 = CvtSrc.getOperand(i: 1);
16403 SDValue N102 = CvtSrc.getOperand(i: 2);
16404 if (isContractableAndReassociableFMUL(N102) &&
16405 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
16406 SrcVT: CvtSrc.getValueType())) {
16407 SDValue N1020 = N102.getOperand(i: 0);
16408 SDValue N1021 = N102.getOperand(i: 1);
16409 return matcher.getNode(
16410 PreferredFusedOpcode, SL, VT,
16411 matcher.getNode(ISD::FNEG, SL, VT,
16412 matcher.getNode(ISD::FP_EXTEND, SL, VT, N100)),
16413 matcher.getNode(ISD::FP_EXTEND, SL, VT, N101),
16414 matcher.getNode(
16415 PreferredFusedOpcode, SL, VT,
16416 matcher.getNode(ISD::FNEG, SL, VT,
16417 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1020)),
16418 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1021), N0));
16419 }
16420 }
16421 }
16422
16423 return SDValue();
16424}
16425
16426/// Try to perform FMA combining on a given FMUL node based on the distributive
16427/// law x * (y + 1) = x * y + x and variants thereof (commuted versions,
16428/// subtraction instead of addition).
16429SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) {
16430 SDValue N0 = N->getOperand(Num: 0);
16431 SDValue N1 = N->getOperand(Num: 1);
16432 EVT VT = N->getValueType(ResNo: 0);
16433 SDLoc SL(N);
16434
16435 assert(N->getOpcode() == ISD::FMUL && "Expected FMUL Operation");
16436
16437 const TargetOptions &Options = DAG.getTarget().Options;
16438
16439 // The transforms below are incorrect when x == 0 and y == inf, because the
16440 // intermediate multiplication produces a nan.
16441 SDValue FAdd = N0.getOpcode() == ISD::FADD ? N0 : N1;
16442 if (!hasNoInfs(Options, N: FAdd))
16443 return SDValue();
16444
16445 // Floating-point multiply-add without intermediate rounding.
16446 bool HasFMA =
16447 isContractableFMUL(Options, N: SDValue(N, 0)) &&
16448 TLI.isFMAFasterThanFMulAndFAdd(MF: DAG.getMachineFunction(), VT) &&
16449 (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::FMA, VT));
16450
16451 // Floating-point multiply-add with intermediate rounding. This can result
16452 // in a less precise result due to the changed rounding order.
16453 bool HasFMAD = Options.UnsafeFPMath &&
16454 (LegalOperations && TLI.isFMADLegal(DAG, N));
16455
16456 // No valid opcode, do not combine.
16457 if (!HasFMAD && !HasFMA)
16458 return SDValue();
16459
16460 // Always prefer FMAD to FMA for precision.
16461 unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
16462 bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
16463
16464 // fold (fmul (fadd x0, +1.0), y) -> (fma x0, y, y)
16465 // fold (fmul (fadd x0, -1.0), y) -> (fma x0, y, (fneg y))
16466 auto FuseFADD = [&](SDValue X, SDValue Y) {
16467 if (X.getOpcode() == ISD::FADD && (Aggressive || X->hasOneUse())) {
16468 if (auto *C = isConstOrConstSplatFP(N: X.getOperand(i: 1), AllowUndefs: true)) {
16469 if (C->isExactlyValue(V: +1.0))
16470 return DAG.getNode(Opcode: PreferredFusedOpcode, DL: SL, VT, N1: X.getOperand(i: 0), N2: Y,
16471 N3: Y);
16472 if (C->isExactlyValue(V: -1.0))
16473 return DAG.getNode(Opcode: PreferredFusedOpcode, DL: SL, VT, N1: X.getOperand(i: 0), N2: Y,
16474 N3: DAG.getNode(Opcode: ISD::FNEG, DL: SL, VT, Operand: Y));
16475 }
16476 }
16477 return SDValue();
16478 };
16479
16480 if (SDValue FMA = FuseFADD(N0, N1))
16481 return FMA;
16482 if (SDValue FMA = FuseFADD(N1, N0))
16483 return FMA;
16484
16485 // fold (fmul (fsub +1.0, x1), y) -> (fma (fneg x1), y, y)
16486 // fold (fmul (fsub -1.0, x1), y) -> (fma (fneg x1), y, (fneg y))
16487 // fold (fmul (fsub x0, +1.0), y) -> (fma x0, y, (fneg y))
16488 // fold (fmul (fsub x0, -1.0), y) -> (fma x0, y, y)
16489 auto FuseFSUB = [&](SDValue X, SDValue Y) {
16490 if (X.getOpcode() == ISD::FSUB && (Aggressive || X->hasOneUse())) {
16491 if (auto *C0 = isConstOrConstSplatFP(N: X.getOperand(i: 0), AllowUndefs: true)) {
16492 if (C0->isExactlyValue(V: +1.0))
16493 return DAG.getNode(Opcode: PreferredFusedOpcode, DL: SL, VT,
16494 N1: DAG.getNode(Opcode: ISD::FNEG, DL: SL, VT, Operand: X.getOperand(i: 1)), N2: Y,
16495 N3: Y);
16496 if (C0->isExactlyValue(V: -1.0))
16497 return DAG.getNode(Opcode: PreferredFusedOpcode, DL: SL, VT,
16498 N1: DAG.getNode(Opcode: ISD::FNEG, DL: SL, VT, Operand: X.getOperand(i: 1)), N2: Y,
16499 N3: DAG.getNode(Opcode: ISD::FNEG, DL: SL, VT, Operand: Y));
16500 }
16501 if (auto *C1 = isConstOrConstSplatFP(N: X.getOperand(i: 1), AllowUndefs: true)) {
16502 if (C1->isExactlyValue(V: +1.0))
16503 return DAG.getNode(Opcode: PreferredFusedOpcode, DL: SL, VT, N1: X.getOperand(i: 0), N2: Y,
16504 N3: DAG.getNode(Opcode: ISD::FNEG, DL: SL, VT, Operand: Y));
16505 if (C1->isExactlyValue(V: -1.0))
16506 return DAG.getNode(Opcode: PreferredFusedOpcode, DL: SL, VT, N1: X.getOperand(i: 0), N2: Y,
16507 N3: Y);
16508 }
16509 }
16510 return SDValue();
16511 };
16512
16513 if (SDValue FMA = FuseFSUB(N0, N1))
16514 return FMA;
16515 if (SDValue FMA = FuseFSUB(N1, N0))
16516 return FMA;
16517
16518 return SDValue();
16519}
16520
16521SDValue DAGCombiner::visitVP_FADD(SDNode *N) {
16522 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16523
16524 // FADD -> FMA combines:
16525 if (SDValue Fused = visitFADDForFMACombine<VPMatchContext>(N)) {
16526 if (Fused.getOpcode() != ISD::DELETED_NODE)
16527 AddToWorklist(N: Fused.getNode());
16528 return Fused;
16529 }
16530 return SDValue();
16531}
16532
16533SDValue DAGCombiner::visitFADD(SDNode *N) {
16534 SDValue N0 = N->getOperand(Num: 0);
16535 SDValue N1 = N->getOperand(Num: 1);
16536 SDNode *N0CFP = DAG.isConstantFPBuildVectorOrConstantFP(N: N0);
16537 SDNode *N1CFP = DAG.isConstantFPBuildVectorOrConstantFP(N: N1);
16538 EVT VT = N->getValueType(ResNo: 0);
16539 SDLoc DL(N);
16540 const TargetOptions &Options = DAG.getTarget().Options;
16541 SDNodeFlags Flags = N->getFlags();
16542 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16543
16544 if (SDValue R = DAG.simplifyFPBinop(Opcode: N->getOpcode(), X: N0, Y: N1, Flags))
16545 return R;
16546
16547 // fold (fadd c1, c2) -> c1 + c2
16548 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FADD, DL, VT, Ops: {N0, N1}))
16549 return C;
16550
16551 // canonicalize constant to RHS
16552 if (N0CFP && !N1CFP)
16553 return DAG.getNode(Opcode: ISD::FADD, DL, VT, N1, N2: N0);
16554
16555 // fold vector ops
16556 if (VT.isVector())
16557 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
16558 return FoldedVOp;
16559
16560 // N0 + -0.0 --> N0 (also allowed with +0.0 and fast-math)
16561 ConstantFPSDNode *N1C = isConstOrConstSplatFP(N: N1, AllowUndefs: true);
16562 if (N1C && N1C->isZero())
16563 if (N1C->isNegative() || Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())
16564 return N0;
16565
16566 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
16567 return NewSel;
16568
16569 // fold (fadd A, (fneg B)) -> (fsub A, B)
16570 if (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::FSUB, VT))
16571 if (SDValue NegN1 = TLI.getCheaperNegatedExpression(
16572 Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize))
16573 return DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1: N0, N2: NegN1);
16574
16575 // fold (fadd (fneg A), B) -> (fsub B, A)
16576 if (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::FSUB, VT))
16577 if (SDValue NegN0 = TLI.getCheaperNegatedExpression(
16578 Op: N0, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize))
16579 return DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1, N2: NegN0);
16580
16581 auto isFMulNegTwo = [](SDValue FMul) {
16582 if (!FMul.hasOneUse() || FMul.getOpcode() != ISD::FMUL)
16583 return false;
16584 auto *C = isConstOrConstSplatFP(N: FMul.getOperand(i: 1), AllowUndefs: true);
16585 return C && C->isExactlyValue(V: -2.0);
16586 };
16587
16588 // fadd (fmul B, -2.0), A --> fsub A, (fadd B, B)
16589 if (isFMulNegTwo(N0)) {
16590 SDValue B = N0.getOperand(i: 0);
16591 SDValue Add = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: B, N2: B);
16592 return DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1, N2: Add);
16593 }
16594 // fadd A, (fmul B, -2.0) --> fsub A, (fadd B, B)
16595 if (isFMulNegTwo(N1)) {
16596 SDValue B = N1.getOperand(i: 0);
16597 SDValue Add = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: B, N2: B);
16598 return DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1: N0, N2: Add);
16599 }
16600
16601 // No FP constant should be created after legalization as Instruction
16602 // Selection pass has a hard time dealing with FP constants.
16603 bool AllowNewConst = (Level < AfterLegalizeDAG);
16604
16605 // If nnan is enabled, fold lots of things.
16606 if ((Options.NoNaNsFPMath || Flags.hasNoNaNs()) && AllowNewConst) {
16607 // If allowed, fold (fadd (fneg x), x) -> 0.0
16608 if (N0.getOpcode() == ISD::FNEG && N0.getOperand(i: 0) == N1)
16609 return DAG.getConstantFP(Val: 0.0, DL, VT);
16610
16611 // If allowed, fold (fadd x, (fneg x)) -> 0.0
16612 if (N1.getOpcode() == ISD::FNEG && N1.getOperand(i: 0) == N0)
16613 return DAG.getConstantFP(Val: 0.0, DL, VT);
16614 }
16615
16616 // If 'unsafe math' or reassoc and nsz, fold lots of things.
16617 // TODO: break out portions of the transformations below for which Unsafe is
16618 // considered and which do not require both nsz and reassoc
16619 if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
16620 (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
16621 AllowNewConst) {
16622 // fadd (fadd x, c1), c2 -> fadd x, c1 + c2
16623 if (N1CFP && N0.getOpcode() == ISD::FADD &&
16624 DAG.isConstantFPBuildVectorOrConstantFP(N: N0.getOperand(i: 1))) {
16625 SDValue NewC = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N0.getOperand(i: 1), N2: N1);
16626 return DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N0.getOperand(i: 0), N2: NewC);
16627 }
16628
16629 // We can fold chains of FADD's of the same value into multiplications.
16630 // This transform is not safe in general because we are reducing the number
16631 // of rounding steps.
16632 if (TLI.isOperationLegalOrCustom(Op: ISD::FMUL, VT) && !N0CFP && !N1CFP) {
16633 if (N0.getOpcode() == ISD::FMUL) {
16634 SDNode *CFP00 =
16635 DAG.isConstantFPBuildVectorOrConstantFP(N: N0.getOperand(i: 0));
16636 SDNode *CFP01 =
16637 DAG.isConstantFPBuildVectorOrConstantFP(N: N0.getOperand(i: 1));
16638
16639 // (fadd (fmul x, c), x) -> (fmul x, c+1)
16640 if (CFP01 && !CFP00 && N0.getOperand(i: 0) == N1) {
16641 SDValue NewCFP = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N0.getOperand(i: 1),
16642 N2: DAG.getConstantFP(Val: 1.0, DL, VT));
16643 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1, N2: NewCFP);
16644 }
16645
16646 // (fadd (fmul x, c), (fadd x, x)) -> (fmul x, c+2)
16647 if (CFP01 && !CFP00 && N1.getOpcode() == ISD::FADD &&
16648 N1.getOperand(i: 0) == N1.getOperand(i: 1) &&
16649 N0.getOperand(i: 0) == N1.getOperand(i: 0)) {
16650 SDValue NewCFP = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N0.getOperand(i: 1),
16651 N2: DAG.getConstantFP(Val: 2.0, DL, VT));
16652 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0.getOperand(i: 0), N2: NewCFP);
16653 }
16654 }
16655
16656 if (N1.getOpcode() == ISD::FMUL) {
16657 SDNode *CFP10 =
16658 DAG.isConstantFPBuildVectorOrConstantFP(N: N1.getOperand(i: 0));
16659 SDNode *CFP11 =
16660 DAG.isConstantFPBuildVectorOrConstantFP(N: N1.getOperand(i: 1));
16661
16662 // (fadd x, (fmul x, c)) -> (fmul x, c+1)
16663 if (CFP11 && !CFP10 && N1.getOperand(i: 0) == N0) {
16664 SDValue NewCFP = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N1.getOperand(i: 1),
16665 N2: DAG.getConstantFP(Val: 1.0, DL, VT));
16666 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0, N2: NewCFP);
16667 }
16668
16669 // (fadd (fadd x, x), (fmul x, c)) -> (fmul x, c+2)
16670 if (CFP11 && !CFP10 && N0.getOpcode() == ISD::FADD &&
16671 N0.getOperand(i: 0) == N0.getOperand(i: 1) &&
16672 N1.getOperand(i: 0) == N0.getOperand(i: 0)) {
16673 SDValue NewCFP = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N1.getOperand(i: 1),
16674 N2: DAG.getConstantFP(Val: 2.0, DL, VT));
16675 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N1.getOperand(i: 0), N2: NewCFP);
16676 }
16677 }
16678
16679 if (N0.getOpcode() == ISD::FADD) {
16680 SDNode *CFP00 =
16681 DAG.isConstantFPBuildVectorOrConstantFP(N: N0.getOperand(i: 0));
16682 // (fadd (fadd x, x), x) -> (fmul x, 3.0)
16683 if (!CFP00 && N0.getOperand(i: 0) == N0.getOperand(i: 1) &&
16684 (N0.getOperand(i: 0) == N1)) {
16685 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1,
16686 N2: DAG.getConstantFP(Val: 3.0, DL, VT));
16687 }
16688 }
16689
16690 if (N1.getOpcode() == ISD::FADD) {
16691 SDNode *CFP10 =
16692 DAG.isConstantFPBuildVectorOrConstantFP(N: N1.getOperand(i: 0));
16693 // (fadd x, (fadd x, x)) -> (fmul x, 3.0)
16694 if (!CFP10 && N1.getOperand(i: 0) == N1.getOperand(i: 1) &&
16695 N1.getOperand(i: 0) == N0) {
16696 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0,
16697 N2: DAG.getConstantFP(Val: 3.0, DL, VT));
16698 }
16699 }
16700
16701 // (fadd (fadd x, x), (fadd x, x)) -> (fmul x, 4.0)
16702 if (N0.getOpcode() == ISD::FADD && N1.getOpcode() == ISD::FADD &&
16703 N0.getOperand(i: 0) == N0.getOperand(i: 1) &&
16704 N1.getOperand(i: 0) == N1.getOperand(i: 1) &&
16705 N0.getOperand(i: 0) == N1.getOperand(i: 0)) {
16706 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0.getOperand(i: 0),
16707 N2: DAG.getConstantFP(Val: 4.0, DL, VT));
16708 }
16709 }
16710
16711 // Fold fadd(vecreduce(x), vecreduce(y)) -> vecreduce(fadd(x, y))
16712 if (SDValue SD = reassociateReduction(RedOpc: ISD::VECREDUCE_FADD, Opc: ISD::FADD, DL,
16713 VT, N0, N1, Flags))
16714 return SD;
16715 } // enable-unsafe-fp-math
16716
16717 // FADD -> FMA combines:
16718 if (SDValue Fused = visitFADDForFMACombine<EmptyMatchContext>(N)) {
16719 if (Fused.getOpcode() != ISD::DELETED_NODE)
16720 AddToWorklist(N: Fused.getNode());
16721 return Fused;
16722 }
16723 return SDValue();
16724}
16725
16726SDValue DAGCombiner::visitSTRICT_FADD(SDNode *N) {
16727 SDValue Chain = N->getOperand(Num: 0);
16728 SDValue N0 = N->getOperand(Num: 1);
16729 SDValue N1 = N->getOperand(Num: 2);
16730 EVT VT = N->getValueType(ResNo: 0);
16731 EVT ChainVT = N->getValueType(ResNo: 1);
16732 SDLoc DL(N);
16733 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16734
16735 // fold (strict_fadd A, (fneg B)) -> (strict_fsub A, B)
16736 if (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::STRICT_FSUB, VT))
16737 if (SDValue NegN1 = TLI.getCheaperNegatedExpression(
16738 Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize)) {
16739 return DAG.getNode(Opcode: ISD::STRICT_FSUB, DL, VTList: DAG.getVTList(VT1: VT, VT2: ChainVT),
16740 Ops: {Chain, N0, NegN1});
16741 }
16742
16743 // fold (strict_fadd (fneg A), B) -> (strict_fsub B, A)
16744 if (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::STRICT_FSUB, VT))
16745 if (SDValue NegN0 = TLI.getCheaperNegatedExpression(
16746 Op: N0, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize)) {
16747 return DAG.getNode(Opcode: ISD::STRICT_FSUB, DL, VTList: DAG.getVTList(VT1: VT, VT2: ChainVT),
16748 Ops: {Chain, N1, NegN0});
16749 }
16750 return SDValue();
16751}
16752
16753SDValue DAGCombiner::visitFSUB(SDNode *N) {
16754 SDValue N0 = N->getOperand(Num: 0);
16755 SDValue N1 = N->getOperand(Num: 1);
16756 ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N: N0, AllowUndefs: true);
16757 ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N: N1, AllowUndefs: true);
16758 EVT VT = N->getValueType(ResNo: 0);
16759 SDLoc DL(N);
16760 const TargetOptions &Options = DAG.getTarget().Options;
16761 const SDNodeFlags Flags = N->getFlags();
16762 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16763
16764 if (SDValue R = DAG.simplifyFPBinop(Opcode: N->getOpcode(), X: N0, Y: N1, Flags))
16765 return R;
16766
16767 // fold (fsub c1, c2) -> c1-c2
16768 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FSUB, DL, VT, Ops: {N0, N1}))
16769 return C;
16770
16771 // fold vector ops
16772 if (VT.isVector())
16773 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
16774 return FoldedVOp;
16775
16776 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
16777 return NewSel;
16778
16779 // (fsub A, 0) -> A
16780 if (N1CFP && N1CFP->isZero()) {
16781 if (!N1CFP->isNegative() || Options.NoSignedZerosFPMath ||
16782 Flags.hasNoSignedZeros()) {
16783 return N0;
16784 }
16785 }
16786
16787 if (N0 == N1) {
16788 // (fsub x, x) -> 0.0
16789 if (Options.NoNaNsFPMath || Flags.hasNoNaNs())
16790 return DAG.getConstantFP(Val: 0.0f, DL, VT);
16791 }
16792
16793 // (fsub -0.0, N1) -> -N1
16794 if (N0CFP && N0CFP->isZero()) {
16795 if (N0CFP->isNegative() ||
16796 (Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())) {
16797 // We cannot replace an FSUB(+-0.0,X) with FNEG(X) when denormals are
16798 // flushed to zero, unless all users treat denorms as zero (DAZ).
16799 // FIXME: This transform will change the sign of a NaN and the behavior
16800 // of a signaling NaN. It is only valid when a NoNaN flag is present.
16801 DenormalMode DenormMode = DAG.getDenormalMode(VT);
16802 if (DenormMode == DenormalMode::getIEEE()) {
16803 if (SDValue NegN1 =
16804 TLI.getNegatedExpression(Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize))
16805 return NegN1;
16806 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::FNEG, VT))
16807 return DAG.getNode(Opcode: ISD::FNEG, DL, VT, Operand: N1);
16808 }
16809 }
16810 }
16811
16812 if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
16813 (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
16814 N1.getOpcode() == ISD::FADD) {
16815 // X - (X + Y) -> -Y
16816 if (N0 == N1->getOperand(Num: 0))
16817 return DAG.getNode(Opcode: ISD::FNEG, DL, VT, Operand: N1->getOperand(Num: 1));
16818 // X - (Y + X) -> -Y
16819 if (N0 == N1->getOperand(Num: 1))
16820 return DAG.getNode(Opcode: ISD::FNEG, DL, VT, Operand: N1->getOperand(Num: 0));
16821 }
16822
16823 // fold (fsub A, (fneg B)) -> (fadd A, B)
16824 if (SDValue NegN1 =
16825 TLI.getNegatedExpression(Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize))
16826 return DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N0, N2: NegN1);
16827
16828 // FSUB -> FMA combines:
16829 if (SDValue Fused = visitFSUBForFMACombine<EmptyMatchContext>(N)) {
16830 AddToWorklist(N: Fused.getNode());
16831 return Fused;
16832 }
16833
16834 return SDValue();
16835}
16836
16837// Transform IEEE Floats:
16838// (fmul C, (uitofp Pow2))
16839// -> (bitcast_to_FP (add (bitcast_to_INT C), Log2(Pow2) << mantissa))
16840// (fdiv C, (uitofp Pow2))
16841// -> (bitcast_to_FP (sub (bitcast_to_INT C), Log2(Pow2) << mantissa))
16842//
16843// The rationale is fmul/fdiv by a power of 2 is just change the exponent, so
16844// there is no need for more than an add/sub.
16845//
16846// This is valid under the following circumstances:
16847// 1) We are dealing with IEEE floats
16848// 2) C is normal
16849// 3) The fmul/fdiv add/sub will not go outside of min/max exponent bounds.
16850// TODO: Much of this could also be used for generating `ldexp` on targets the
16851// prefer it.
16852SDValue DAGCombiner::combineFMulOrFDivWithIntPow2(SDNode *N) {
16853 EVT VT = N->getValueType(ResNo: 0);
16854 SDValue ConstOp, Pow2Op;
16855
16856 std::optional<int> Mantissa;
16857 auto GetConstAndPow2Ops = [&](unsigned ConstOpIdx) {
16858 if (ConstOpIdx == 1 && N->getOpcode() == ISD::FDIV)
16859 return false;
16860
16861 ConstOp = peekThroughBitcasts(V: N->getOperand(Num: ConstOpIdx));
16862 Pow2Op = N->getOperand(Num: 1 - ConstOpIdx);
16863 if (Pow2Op.getOpcode() != ISD::UINT_TO_FP &&
16864 (Pow2Op.getOpcode() != ISD::SINT_TO_FP ||
16865 !DAG.computeKnownBits(Op: Pow2Op).isNonNegative()))
16866 return false;
16867
16868 Pow2Op = Pow2Op.getOperand(i: 0);
16869
16870 // `Log2(Pow2Op) < Pow2Op.getScalarSizeInBits()`.
16871 // TODO: We could use knownbits to make this bound more precise.
16872 int MaxExpChange = Pow2Op.getValueType().getScalarSizeInBits();
16873
16874 auto IsFPConstValid = [N, MaxExpChange, &Mantissa](ConstantFPSDNode *CFP) {
16875 if (CFP == nullptr)
16876 return false;
16877
16878 const APFloat &APF = CFP->getValueAPF();
16879
16880 // Make sure we have normal/ieee constant.
16881 if (!APF.isNormal() || !APF.isIEEE())
16882 return false;
16883
16884 // Make sure the floats exponent is within the bounds that this transform
16885 // produces bitwise equals value.
16886 int CurExp = ilogb(Arg: APF);
16887 // FMul by pow2 will only increase exponent.
16888 int MinExp =
16889 N->getOpcode() == ISD::FMUL ? CurExp : (CurExp - MaxExpChange);
16890 // FDiv by pow2 will only decrease exponent.
16891 int MaxExp =
16892 N->getOpcode() == ISD::FDIV ? CurExp : (CurExp + MaxExpChange);
16893 if (MinExp <= APFloat::semanticsMinExponent(APF.getSemantics()) ||
16894 MaxExp >= APFloat::semanticsMaxExponent(APF.getSemantics()))
16895 return false;
16896
16897 // Finally make sure we actually know the mantissa for the float type.
16898 int ThisMantissa = APFloat::semanticsPrecision(APF.getSemantics()) - 1;
16899 if (!Mantissa)
16900 Mantissa = ThisMantissa;
16901
16902 return *Mantissa == ThisMantissa && ThisMantissa > 0;
16903 };
16904
16905 // TODO: We may be able to include undefs.
16906 return ISD::matchUnaryFpPredicate(Op: ConstOp, Match: IsFPConstValid);
16907 };
16908
16909 if (!GetConstAndPow2Ops(0) && !GetConstAndPow2Ops(1))
16910 return SDValue();
16911
16912 if (!TLI.optimizeFMulOrFDivAsShiftAddBitcast(N, FPConst: ConstOp, IntPow2: Pow2Op))
16913 return SDValue();
16914
16915 // Get log2 after all other checks have taken place. This is because
16916 // BuildLogBase2 may create a new node.
16917 SDLoc DL(N);
16918 // Get Log2 type with same bitwidth as the float type (VT).
16919 EVT NewIntVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: VT.getScalarSizeInBits());
16920 if (VT.isVector())
16921 NewIntVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: NewIntVT,
16922 EC: VT.getVectorElementCount());
16923
16924 SDValue Log2 = BuildLogBase2(V: Pow2Op, DL, KnownNeverZero: DAG.isKnownNeverZero(Op: Pow2Op),
16925 /*InexpensiveOnly*/ true, OutVT: NewIntVT);
16926 if (!Log2)
16927 return SDValue();
16928
16929 // Perform actual transform.
16930 SDValue MantissaShiftCnt =
16931 DAG.getShiftAmountConstant(Val: *Mantissa, VT: NewIntVT, DL);
16932 // TODO: Sometimes Log2 is of form `(X + C)`. `(X + C) << C1` should fold to
16933 // `(X << C1) + (C << C1)`, but that isn't always the case because of the
16934 // cast. We could implement that by handle here to handle the casts.
16935 SDValue Shift = DAG.getNode(Opcode: ISD::SHL, DL, VT: NewIntVT, N1: Log2, N2: MantissaShiftCnt);
16936 SDValue ResAsInt =
16937 DAG.getNode(Opcode: N->getOpcode() == ISD::FMUL ? ISD::ADD : ISD::SUB, DL,
16938 VT: NewIntVT, N1: DAG.getBitcast(VT: NewIntVT, V: ConstOp), N2: Shift);
16939 SDValue ResAsFP = DAG.getBitcast(VT, V: ResAsInt);
16940 return ResAsFP;
16941}
16942
16943SDValue DAGCombiner::visitFMUL(SDNode *N) {
16944 SDValue N0 = N->getOperand(Num: 0);
16945 SDValue N1 = N->getOperand(Num: 1);
16946 ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N: N1, AllowUndefs: true);
16947 EVT VT = N->getValueType(ResNo: 0);
16948 SDLoc DL(N);
16949 const TargetOptions &Options = DAG.getTarget().Options;
16950 const SDNodeFlags Flags = N->getFlags();
16951 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16952
16953 if (SDValue R = DAG.simplifyFPBinop(Opcode: N->getOpcode(), X: N0, Y: N1, Flags))
16954 return R;
16955
16956 // fold (fmul c1, c2) -> c1*c2
16957 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FMUL, DL, VT, Ops: {N0, N1}))
16958 return C;
16959
16960 // canonicalize constant to RHS
16961 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0) &&
16962 !DAG.isConstantFPBuildVectorOrConstantFP(N: N1))
16963 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1, N2: N0);
16964
16965 // fold vector ops
16966 if (VT.isVector())
16967 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
16968 return FoldedVOp;
16969
16970 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
16971 return NewSel;
16972
16973 if (Options.UnsafeFPMath || Flags.hasAllowReassociation()) {
16974 // fmul (fmul X, C1), C2 -> fmul X, C1 * C2
16975 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N1) &&
16976 N0.getOpcode() == ISD::FMUL) {
16977 SDValue N00 = N0.getOperand(i: 0);
16978 SDValue N01 = N0.getOperand(i: 1);
16979 // Avoid an infinite loop by making sure that N00 is not a constant
16980 // (the inner multiply has not been constant folded yet).
16981 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N01) &&
16982 !DAG.isConstantFPBuildVectorOrConstantFP(N: N00)) {
16983 SDValue MulConsts = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N01, N2: N1);
16984 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N00, N2: MulConsts);
16985 }
16986 }
16987
16988 // Match a special-case: we convert X * 2.0 into fadd.
16989 // fmul (fadd X, X), C -> fmul X, 2.0 * C
16990 if (N0.getOpcode() == ISD::FADD && N0.hasOneUse() &&
16991 N0.getOperand(i: 0) == N0.getOperand(i: 1)) {
16992 const SDValue Two = DAG.getConstantFP(Val: 2.0, DL, VT);
16993 SDValue MulConsts = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Two, N2: N1);
16994 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0.getOperand(i: 0), N2: MulConsts);
16995 }
16996
16997 // Fold fmul(vecreduce(x), vecreduce(y)) -> vecreduce(fmul(x, y))
16998 if (SDValue SD = reassociateReduction(RedOpc: ISD::VECREDUCE_FMUL, Opc: ISD::FMUL, DL,
16999 VT, N0, N1, Flags))
17000 return SD;
17001 }
17002
17003 // fold (fmul X, 2.0) -> (fadd X, X)
17004 if (N1CFP && N1CFP->isExactlyValue(V: +2.0))
17005 return DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N0, N2: N0);
17006
17007 // fold (fmul X, -1.0) -> (fsub -0.0, X)
17008 if (N1CFP && N1CFP->isExactlyValue(V: -1.0)) {
17009 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::FSUB, VT)) {
17010 return DAG.getNode(Opcode: ISD::FSUB, DL, VT,
17011 N1: DAG.getConstantFP(Val: -0.0, DL, VT), N2: N0, Flags);
17012 }
17013 }
17014
17015 // -N0 * -N1 --> N0 * N1
17016 TargetLowering::NegatibleCost CostN0 =
17017 TargetLowering::NegatibleCost::Expensive;
17018 TargetLowering::NegatibleCost CostN1 =
17019 TargetLowering::NegatibleCost::Expensive;
17020 SDValue NegN0 =
17021 TLI.getNegatedExpression(Op: N0, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize, Cost&: CostN0);
17022 if (NegN0) {
17023 HandleSDNode NegN0Handle(NegN0);
17024 SDValue NegN1 =
17025 TLI.getNegatedExpression(Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize, Cost&: CostN1);
17026 if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
17027 CostN1 == TargetLowering::NegatibleCost::Cheaper))
17028 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: NegN0, N2: NegN1);
17029 }
17030
17031 // fold (fmul X, (select (fcmp X > 0.0), -1.0, 1.0)) -> (fneg (fabs X))
17032 // fold (fmul X, (select (fcmp X > 0.0), 1.0, -1.0)) -> (fabs X)
17033 if (Flags.hasNoNaNs() && Flags.hasNoSignedZeros() &&
17034 (N0.getOpcode() == ISD::SELECT || N1.getOpcode() == ISD::SELECT) &&
17035 TLI.isOperationLegal(Op: ISD::FABS, VT)) {
17036 SDValue Select = N0, X = N1;
17037 if (Select.getOpcode() != ISD::SELECT)
17038 std::swap(a&: Select, b&: X);
17039
17040 SDValue Cond = Select.getOperand(i: 0);
17041 auto TrueOpnd = dyn_cast<ConstantFPSDNode>(Val: Select.getOperand(i: 1));
17042 auto FalseOpnd = dyn_cast<ConstantFPSDNode>(Val: Select.getOperand(i: 2));
17043
17044 if (TrueOpnd && FalseOpnd &&
17045 Cond.getOpcode() == ISD::SETCC && Cond.getOperand(i: 0) == X &&
17046 isa<ConstantFPSDNode>(Val: Cond.getOperand(i: 1)) &&
17047 cast<ConstantFPSDNode>(Val: Cond.getOperand(i: 1))->isExactlyValue(V: 0.0)) {
17048 ISD::CondCode CC = cast<CondCodeSDNode>(Val: Cond.getOperand(i: 2))->get();
17049 switch (CC) {
17050 default: break;
17051 case ISD::SETOLT:
17052 case ISD::SETULT:
17053 case ISD::SETOLE:
17054 case ISD::SETULE:
17055 case ISD::SETLT:
17056 case ISD::SETLE:
17057 std::swap(a&: TrueOpnd, b&: FalseOpnd);
17058 [[fallthrough]];
17059 case ISD::SETOGT:
17060 case ISD::SETUGT:
17061 case ISD::SETOGE:
17062 case ISD::SETUGE:
17063 case ISD::SETGT:
17064 case ISD::SETGE:
17065 if (TrueOpnd->isExactlyValue(V: -1.0) && FalseOpnd->isExactlyValue(V: 1.0) &&
17066 TLI.isOperationLegal(Op: ISD::FNEG, VT))
17067 return DAG.getNode(Opcode: ISD::FNEG, DL, VT,
17068 Operand: DAG.getNode(Opcode: ISD::FABS, DL, VT, Operand: X));
17069 if (TrueOpnd->isExactlyValue(V: 1.0) && FalseOpnd->isExactlyValue(V: -1.0))
17070 return DAG.getNode(Opcode: ISD::FABS, DL, VT, Operand: X);
17071
17072 break;
17073 }
17074 }
17075 }
17076
17077 // FMUL -> FMA combines:
17078 if (SDValue Fused = visitFMULForFMADistributiveCombine(N)) {
17079 AddToWorklist(N: Fused.getNode());
17080 return Fused;
17081 }
17082
17083 // Don't do `combineFMulOrFDivWithIntPow2` until after FMUL -> FMA has been
17084 // able to run.
17085 if (SDValue R = combineFMulOrFDivWithIntPow2(N))
17086 return R;
17087
17088 return SDValue();
17089}
17090
17091template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
17092 SDValue N0 = N->getOperand(Num: 0);
17093 SDValue N1 = N->getOperand(Num: 1);
17094 SDValue N2 = N->getOperand(Num: 2);
17095 ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(Val&: N0);
17096 ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(Val&: N1);
17097 EVT VT = N->getValueType(ResNo: 0);
17098 SDLoc DL(N);
17099 const TargetOptions &Options = DAG.getTarget().Options;
17100 // FMA nodes have flags that propagate to the created nodes.
17101 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17102 MatchContextClass matcher(DAG, TLI, N);
17103
17104 // Constant fold FMA.
17105 if (isa<ConstantFPSDNode>(Val: N0) &&
17106 isa<ConstantFPSDNode>(Val: N1) &&
17107 isa<ConstantFPSDNode>(Val: N2)) {
17108 return matcher.getNode(ISD::FMA, DL, VT, N0, N1, N2);
17109 }
17110
17111 // (-N0 * -N1) + N2 --> (N0 * N1) + N2
17112 TargetLowering::NegatibleCost CostN0 =
17113 TargetLowering::NegatibleCost::Expensive;
17114 TargetLowering::NegatibleCost CostN1 =
17115 TargetLowering::NegatibleCost::Expensive;
17116 SDValue NegN0 =
17117 TLI.getNegatedExpression(Op: N0, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize, Cost&: CostN0);
17118 if (NegN0) {
17119 HandleSDNode NegN0Handle(NegN0);
17120 SDValue NegN1 =
17121 TLI.getNegatedExpression(Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize, Cost&: CostN1);
17122 if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
17123 CostN1 == TargetLowering::NegatibleCost::Cheaper))
17124 return matcher.getNode(ISD::FMA, DL, VT, NegN0, NegN1, N2);
17125 }
17126
17127 // FIXME: use fast math flags instead of Options.UnsafeFPMath
17128 if (Options.UnsafeFPMath) {
17129 if (N0CFP && N0CFP->isZero())
17130 return N2;
17131 if (N1CFP && N1CFP->isZero())
17132 return N2;
17133 }
17134
17135 // FIXME: Support splat of constant.
17136 if (N0CFP && N0CFP->isExactlyValue(V: 1.0))
17137 return matcher.getNode(ISD::FADD, SDLoc(N), VT, N1, N2);
17138 if (N1CFP && N1CFP->isExactlyValue(V: 1.0))
17139 return matcher.getNode(ISD::FADD, SDLoc(N), VT, N0, N2);
17140
17141 // Canonicalize (fma c, x, y) -> (fma x, c, y)
17142 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0) &&
17143 !DAG.isConstantFPBuildVectorOrConstantFP(N: N1))
17144 return matcher.getNode(ISD::FMA, SDLoc(N), VT, N1, N0, N2);
17145
17146 bool CanReassociate =
17147 Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
17148 if (CanReassociate) {
17149 // (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2)
17150 if (matcher.match(N2, ISD::FMUL) && N0 == N2.getOperand(i: 0) &&
17151 DAG.isConstantFPBuildVectorOrConstantFP(N: N1) &&
17152 DAG.isConstantFPBuildVectorOrConstantFP(N: N2.getOperand(i: 1))) {
17153 return matcher.getNode(
17154 ISD::FMUL, DL, VT, N0,
17155 matcher.getNode(ISD::FADD, DL, VT, N1, N2.getOperand(i: 1)));
17156 }
17157
17158 // (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y)
17159 if (matcher.match(N0, ISD::FMUL) &&
17160 DAG.isConstantFPBuildVectorOrConstantFP(N: N1) &&
17161 DAG.isConstantFPBuildVectorOrConstantFP(N: N0.getOperand(i: 1))) {
17162 return matcher.getNode(
17163 ISD::FMA, DL, VT, N0.getOperand(i: 0),
17164 matcher.getNode(ISD::FMUL, DL, VT, N1, N0.getOperand(i: 1)), N2);
17165 }
17166 }
17167
17168 // (fma x, -1, y) -> (fadd (fneg x), y)
17169 // FIXME: Support splat of constant.
17170 if (N1CFP) {
17171 if (N1CFP->isExactlyValue(V: 1.0))
17172 return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
17173
17174 if (N1CFP->isExactlyValue(V: -1.0) &&
17175 (!LegalOperations || TLI.isOperationLegal(Op: ISD::FNEG, VT))) {
17176 SDValue RHSNeg = matcher.getNode(ISD::FNEG, DL, VT, N0);
17177 AddToWorklist(N: RHSNeg.getNode());
17178 return matcher.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
17179 }
17180
17181 // fma (fneg x), K, y -> fma x -K, y
17182 if (matcher.match(N0, ISD::FNEG) &&
17183 (TLI.isOperationLegal(Op: ISD::ConstantFP, VT) ||
17184 (N1.hasOneUse() &&
17185 !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
17186 return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(i: 0),
17187 matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
17188 }
17189 }
17190
17191 // FIXME: Support splat of constant.
17192 if (CanReassociate) {
17193 // (fma x, c, x) -> (fmul x, (c+1))
17194 if (N1CFP && N0 == N2) {
17195 return matcher.getNode(ISD::FMUL, DL, VT, N0,
17196 matcher.getNode(ISD::FADD, DL, VT, N1,
17197 DAG.getConstantFP(Val: 1.0, DL, VT)));
17198 }
17199
17200 // (fma x, c, (fneg x)) -> (fmul x, (c-1))
17201 if (N1CFP && matcher.match(N2, ISD::FNEG) && N2.getOperand(i: 0) == N0) {
17202 return matcher.getNode(ISD::FMUL, DL, VT, N0,
17203 matcher.getNode(ISD::FADD, DL, VT, N1,
17204 DAG.getConstantFP(Val: -1.0, DL, VT)));
17205 }
17206 }
17207
17208 // fold ((fma (fneg X), Y, (fneg Z)) -> fneg (fma X, Y, Z))
17209 // fold ((fma X, (fneg Y), (fneg Z)) -> fneg (fma X, Y, Z))
17210 if (!TLI.isFNegFree(VT))
17211 if (SDValue Neg = TLI.getCheaperNegatedExpression(
17212 Op: SDValue(N, 0), DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize))
17213 return matcher.getNode(ISD::FNEG, DL, VT, Neg);
17214 return SDValue();
17215}
17216
17217SDValue DAGCombiner::visitFMAD(SDNode *N) {
17218 SDValue N0 = N->getOperand(Num: 0);
17219 SDValue N1 = N->getOperand(Num: 1);
17220 SDValue N2 = N->getOperand(Num: 2);
17221 EVT VT = N->getValueType(ResNo: 0);
17222 SDLoc DL(N);
17223
17224 // Constant fold FMAD.
17225 if (isa<ConstantFPSDNode>(Val: N0) && isa<ConstantFPSDNode>(Val: N1) &&
17226 isa<ConstantFPSDNode>(Val: N2))
17227 return DAG.getNode(Opcode: ISD::FMAD, DL, VT, N1: N0, N2: N1, N3: N2);
17228
17229 return SDValue();
17230}
17231
17232// Combine multiple FDIVs with the same divisor into multiple FMULs by the
17233// reciprocal.
17234// E.g., (a / D; b / D;) -> (recip = 1.0 / D; a * recip; b * recip)
17235// Notice that this is not always beneficial. One reason is different targets
17236// may have different costs for FDIV and FMUL, so sometimes the cost of two
17237// FDIVs may be lower than the cost of one FDIV and two FMULs. Another reason
17238// is the critical path is increased from "one FDIV" to "one FDIV + one FMUL".
17239SDValue DAGCombiner::combineRepeatedFPDivisors(SDNode *N) {
17240 // TODO: Limit this transform based on optsize/minsize - it always creates at
17241 // least 1 extra instruction. But the perf win may be substantial enough
17242 // that only minsize should restrict this.
17243 bool UnsafeMath = DAG.getTarget().Options.UnsafeFPMath;
17244 const SDNodeFlags Flags = N->getFlags();
17245 if (LegalDAG || (!UnsafeMath && !Flags.hasAllowReciprocal()))
17246 return SDValue();
17247
17248 // Skip if current node is a reciprocal/fneg-reciprocal.
17249 SDValue N0 = N->getOperand(Num: 0), N1 = N->getOperand(Num: 1);
17250 ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N: N0, /* AllowUndefs */ true);
17251 if (N0CFP && (N0CFP->isExactlyValue(V: 1.0) || N0CFP->isExactlyValue(V: -1.0)))
17252 return SDValue();
17253
17254 // Exit early if the target does not want this transform or if there can't
17255 // possibly be enough uses of the divisor to make the transform worthwhile.
17256 unsigned MinUses = TLI.combineRepeatedFPDivisors();
17257
17258 // For splat vectors, scale the number of uses by the splat factor. If we can
17259 // convert the division into a scalar op, that will likely be much faster.
17260 unsigned NumElts = 1;
17261 EVT VT = N->getValueType(ResNo: 0);
17262 if (VT.isVector() && DAG.isSplatValue(V: N1))
17263 NumElts = VT.getVectorMinNumElements();
17264
17265 if (!MinUses || (N1->use_size() * NumElts) < MinUses)
17266 return SDValue();
17267
17268 // Find all FDIV users of the same divisor.
17269 // Use a set because duplicates may be present in the user list.
17270 SetVector<SDNode *> Users;
17271 for (auto *U : N1->uses()) {
17272 if (U->getOpcode() == ISD::FDIV && U->getOperand(Num: 1) == N1) {
17273 // Skip X/sqrt(X) that has not been simplified to sqrt(X) yet.
17274 if (U->getOperand(Num: 1).getOpcode() == ISD::FSQRT &&
17275 U->getOperand(Num: 0) == U->getOperand(Num: 1).getOperand(i: 0) &&
17276 U->getFlags().hasAllowReassociation() &&
17277 U->getFlags().hasNoSignedZeros())
17278 continue;
17279
17280 // This division is eligible for optimization only if global unsafe math
17281 // is enabled or if this division allows reciprocal formation.
17282 if (UnsafeMath || U->getFlags().hasAllowReciprocal())
17283 Users.insert(X: U);
17284 }
17285 }
17286
17287 // Now that we have the actual number of divisor uses, make sure it meets
17288 // the minimum threshold specified by the target.
17289 if ((Users.size() * NumElts) < MinUses)
17290 return SDValue();
17291
17292 SDLoc DL(N);
17293 SDValue FPOne = DAG.getConstantFP(Val: 1.0, DL, VT);
17294 SDValue Reciprocal = DAG.getNode(Opcode: ISD::FDIV, DL, VT, N1: FPOne, N2: N1, Flags);
17295
17296 // Dividend / Divisor -> Dividend * Reciprocal
17297 for (auto *U : Users) {
17298 SDValue Dividend = U->getOperand(Num: 0);
17299 if (Dividend != FPOne) {
17300 SDValue NewNode = DAG.getNode(Opcode: ISD::FMUL, DL: SDLoc(U), VT, N1: Dividend,
17301 N2: Reciprocal, Flags);
17302 CombineTo(N: U, Res: NewNode);
17303 } else if (U != Reciprocal.getNode()) {
17304 // In the absence of fast-math-flags, this user node is always the
17305 // same node as Reciprocal, but with FMF they may be different nodes.
17306 CombineTo(N: U, Res: Reciprocal);
17307 }
17308 }
17309 return SDValue(N, 0); // N was replaced.
17310}
17311
17312SDValue DAGCombiner::visitFDIV(SDNode *N) {
17313 SDValue N0 = N->getOperand(Num: 0);
17314 SDValue N1 = N->getOperand(Num: 1);
17315 EVT VT = N->getValueType(ResNo: 0);
17316 SDLoc DL(N);
17317 const TargetOptions &Options = DAG.getTarget().Options;
17318 SDNodeFlags Flags = N->getFlags();
17319 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17320
17321 if (SDValue R = DAG.simplifyFPBinop(Opcode: N->getOpcode(), X: N0, Y: N1, Flags))
17322 return R;
17323
17324 // fold (fdiv c1, c2) -> c1/c2
17325 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FDIV, DL, VT, Ops: {N0, N1}))
17326 return C;
17327
17328 // fold vector ops
17329 if (VT.isVector())
17330 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
17331 return FoldedVOp;
17332
17333 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
17334 return NewSel;
17335
17336 if (SDValue V = combineRepeatedFPDivisors(N))
17337 return V;
17338
17339 // fold (fdiv X, c2) -> (fmul X, 1/c2) if there is no loss in precision, or
17340 // the loss is acceptable with AllowReciprocal.
17341 if (auto *N1CFP = isConstOrConstSplatFP(N: N1, AllowUndefs: true)) {
17342 // Compute the reciprocal 1.0 / c2.
17343 const APFloat &N1APF = N1CFP->getValueAPF();
17344 APFloat Recip = APFloat::getOne(Sem: N1APF.getSemantics());
17345 APFloat::opStatus st = Recip.divide(RHS: N1APF, RM: APFloat::rmNearestTiesToEven);
17346 // Only do the transform if the reciprocal is a legal fp immediate that
17347 // isn't too nasty (eg NaN, denormal, ...).
17348 if (((st == APFloat::opOK && !Recip.isDenormal()) ||
17349 (st == APFloat::opInexact &&
17350 (Options.UnsafeFPMath || Flags.hasAllowReciprocal()))) &&
17351 (!LegalOperations ||
17352 // FIXME: custom lowering of ConstantFP might fail (see e.g. ARM
17353 // backend)... we should handle this gracefully after Legalize.
17354 // TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT) ||
17355 TLI.isOperationLegal(Op: ISD::ConstantFP, VT) ||
17356 TLI.isFPImmLegal(Recip, VT, ForCodeSize)))
17357 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0,
17358 N2: DAG.getConstantFP(Val: Recip, DL, VT));
17359 }
17360
17361 if (Options.UnsafeFPMath || Flags.hasAllowReciprocal()) {
17362 // If this FDIV is part of a reciprocal square root, it may be folded
17363 // into a target-specific square root estimate instruction.
17364 if (N1.getOpcode() == ISD::FSQRT) {
17365 if (SDValue RV = buildRsqrtEstimate(Op: N1.getOperand(i: 0), Flags))
17366 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0, N2: RV);
17367 } else if (N1.getOpcode() == ISD::FP_EXTEND &&
17368 N1.getOperand(i: 0).getOpcode() == ISD::FSQRT) {
17369 if (SDValue RV =
17370 buildRsqrtEstimate(Op: N1.getOperand(i: 0).getOperand(i: 0), Flags)) {
17371 RV = DAG.getNode(Opcode: ISD::FP_EXTEND, DL: SDLoc(N1), VT, Operand: RV);
17372 AddToWorklist(N: RV.getNode());
17373 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0, N2: RV);
17374 }
17375 } else if (N1.getOpcode() == ISD::FP_ROUND &&
17376 N1.getOperand(i: 0).getOpcode() == ISD::FSQRT) {
17377 if (SDValue RV =
17378 buildRsqrtEstimate(Op: N1.getOperand(i: 0).getOperand(i: 0), Flags)) {
17379 RV = DAG.getNode(Opcode: ISD::FP_ROUND, DL: SDLoc(N1), VT, N1: RV, N2: N1.getOperand(i: 1));
17380 AddToWorklist(N: RV.getNode());
17381 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0, N2: RV);
17382 }
17383 } else if (N1.getOpcode() == ISD::FMUL) {
17384 // Look through an FMUL. Even though this won't remove the FDIV directly,
17385 // it's still worthwhile to get rid of the FSQRT if possible.
17386 SDValue Sqrt, Y;
17387 if (N1.getOperand(i: 0).getOpcode() == ISD::FSQRT) {
17388 Sqrt = N1.getOperand(i: 0);
17389 Y = N1.getOperand(i: 1);
17390 } else if (N1.getOperand(i: 1).getOpcode() == ISD::FSQRT) {
17391 Sqrt = N1.getOperand(i: 1);
17392 Y = N1.getOperand(i: 0);
17393 }
17394 if (Sqrt.getNode()) {
17395 // If the other multiply operand is known positive, pull it into the
17396 // sqrt. That will eliminate the division if we convert to an estimate.
17397 if (Flags.hasAllowReassociation() && N1.hasOneUse() &&
17398 N1->getFlags().hasAllowReassociation() && Sqrt.hasOneUse()) {
17399 SDValue A;
17400 if (Y.getOpcode() == ISD::FABS && Y.hasOneUse())
17401 A = Y.getOperand(i: 0);
17402 else if (Y == Sqrt.getOperand(i: 0))
17403 A = Y;
17404 if (A) {
17405 // X / (fabs(A) * sqrt(Z)) --> X / sqrt(A*A*Z) --> X * rsqrt(A*A*Z)
17406 // X / (A * sqrt(A)) --> X / sqrt(A*A*A) --> X * rsqrt(A*A*A)
17407 SDValue AA = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: A, N2: A);
17408 SDValue AAZ =
17409 DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: AA, N2: Sqrt.getOperand(i: 0));
17410 if (SDValue Rsqrt = buildRsqrtEstimate(Op: AAZ, Flags))
17411 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0, N2: Rsqrt);
17412
17413 // Estimate creation failed. Clean up speculatively created nodes.
17414 recursivelyDeleteUnusedNodes(N: AAZ.getNode());
17415 }
17416 }
17417
17418 // We found a FSQRT, so try to make this fold:
17419 // X / (Y * sqrt(Z)) -> X * (rsqrt(Z) / Y)
17420 if (SDValue Rsqrt = buildRsqrtEstimate(Op: Sqrt.getOperand(i: 0), Flags)) {
17421 SDValue Div = DAG.getNode(Opcode: ISD::FDIV, DL: SDLoc(N1), VT, N1: Rsqrt, N2: Y);
17422 AddToWorklist(N: Div.getNode());
17423 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0, N2: Div);
17424 }
17425 }
17426 }
17427
17428 // Fold into a reciprocal estimate and multiply instead of a real divide.
17429 if (Options.NoInfsFPMath || Flags.hasNoInfs())
17430 if (SDValue RV = BuildDivEstimate(N: N0, Op: N1, Flags))
17431 return RV;
17432 }
17433
17434 // Fold X/Sqrt(X) -> Sqrt(X)
17435 if ((Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros()) &&
17436 (Options.UnsafeFPMath || Flags.hasAllowReassociation()))
17437 if (N1.getOpcode() == ISD::FSQRT && N0 == N1.getOperand(i: 0))
17438 return N1;
17439
17440 // (fdiv (fneg X), (fneg Y)) -> (fdiv X, Y)
17441 TargetLowering::NegatibleCost CostN0 =
17442 TargetLowering::NegatibleCost::Expensive;
17443 TargetLowering::NegatibleCost CostN1 =
17444 TargetLowering::NegatibleCost::Expensive;
17445 SDValue NegN0 =
17446 TLI.getNegatedExpression(Op: N0, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize, Cost&: CostN0);
17447 if (NegN0) {
17448 HandleSDNode NegN0Handle(NegN0);
17449 SDValue NegN1 =
17450 TLI.getNegatedExpression(Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize, Cost&: CostN1);
17451 if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
17452 CostN1 == TargetLowering::NegatibleCost::Cheaper))
17453 return DAG.getNode(Opcode: ISD::FDIV, DL: SDLoc(N), VT, N1: NegN0, N2: NegN1);
17454 }
17455
17456 if (SDValue R = combineFMulOrFDivWithIntPow2(N))
17457 return R;
17458
17459 return SDValue();
17460}
17461
17462SDValue DAGCombiner::visitFREM(SDNode *N) {
17463 SDValue N0 = N->getOperand(Num: 0);
17464 SDValue N1 = N->getOperand(Num: 1);
17465 EVT VT = N->getValueType(ResNo: 0);
17466 SDNodeFlags Flags = N->getFlags();
17467 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17468 SDLoc DL(N);
17469
17470 if (SDValue R = DAG.simplifyFPBinop(Opcode: N->getOpcode(), X: N0, Y: N1, Flags))
17471 return R;
17472
17473 // fold (frem c1, c2) -> fmod(c1,c2)
17474 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FREM, DL, VT, Ops: {N0, N1}))
17475 return C;
17476
17477 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
17478 return NewSel;
17479
17480 // Lower frem N0, N1 => x - trunc(N0 / N1) * N1, providing N1 is an integer
17481 // power of 2.
17482 if (!TLI.isOperationLegal(Op: ISD::FREM, VT) &&
17483 TLI.isOperationLegalOrCustom(Op: ISD::FMUL, VT) &&
17484 TLI.isOperationLegalOrCustom(Op: ISD::FDIV, VT) &&
17485 TLI.isOperationLegalOrCustom(Op: ISD::FTRUNC, VT) &&
17486 DAG.isKnownToBeAPowerOfTwoFP(Val: N1)) {
17487 bool NeedsCopySign =
17488 !Flags.hasNoSignedZeros() && !DAG.cannotBeOrderedNegativeFP(Op: N0);
17489 SDValue Div = DAG.getNode(Opcode: ISD::FDIV, DL, VT, N1: N0, N2: N1);
17490 SDValue Rnd = DAG.getNode(Opcode: ISD::FTRUNC, DL, VT, Operand: Div);
17491 SDValue MLA;
17492 if (TLI.isFMAFasterThanFMulAndFAdd(MF: DAG.getMachineFunction(), VT)) {
17493 MLA = DAG.getNode(Opcode: ISD::FMA, DL, VT, N1: DAG.getNode(Opcode: ISD::FNEG, DL, VT, Operand: Rnd),
17494 N2: N1, N3: N0);
17495 } else {
17496 SDValue Mul = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Rnd, N2: N1);
17497 MLA = DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1: N0, N2: Mul);
17498 }
17499 return NeedsCopySign ? DAG.getNode(Opcode: ISD::FCOPYSIGN, DL, VT, N1: MLA, N2: N0) : MLA;
17500 }
17501
17502 return SDValue();
17503}
17504
17505SDValue DAGCombiner::visitFSQRT(SDNode *N) {
17506 SDNodeFlags Flags = N->getFlags();
17507 const TargetOptions &Options = DAG.getTarget().Options;
17508
17509 // Require 'ninf' flag since sqrt(+Inf) = +Inf, but the estimation goes as:
17510 // sqrt(+Inf) == rsqrt(+Inf) * +Inf = 0 * +Inf = NaN
17511 if (!Flags.hasApproximateFuncs() ||
17512 (!Options.NoInfsFPMath && !Flags.hasNoInfs()))
17513 return SDValue();
17514
17515 SDValue N0 = N->getOperand(Num: 0);
17516 if (TLI.isFsqrtCheap(X: N0, DAG))
17517 return SDValue();
17518
17519 // FSQRT nodes have flags that propagate to the created nodes.
17520 // TODO: If this is N0/sqrt(N0), and we reach this node before trying to
17521 // transform the fdiv, we may produce a sub-optimal estimate sequence
17522 // because the reciprocal calculation may not have to filter out a
17523 // 0.0 input.
17524 return buildSqrtEstimate(Op: N0, Flags);
17525}
17526
17527/// copysign(x, fp_extend(y)) -> copysign(x, y)
17528/// copysign(x, fp_round(y)) -> copysign(x, y)
17529/// Operands to the functions are the type of X and Y respectively.
17530static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(EVT XTy, EVT YTy) {
17531 // Always fold no-op FP casts.
17532 if (XTy == YTy)
17533 return true;
17534
17535 // Do not optimize out type conversion of f128 type yet.
17536 // For some targets like x86_64, configuration is changed to keep one f128
17537 // value in one SSE register, but instruction selection cannot handle
17538 // FCOPYSIGN on SSE registers yet.
17539 if (YTy == MVT::f128)
17540 return false;
17541
17542 return !YTy.isVector() || EnableVectorFCopySignExtendRound;
17543}
17544
17545static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode *N) {
17546 SDValue N1 = N->getOperand(Num: 1);
17547 if (N1.getOpcode() != ISD::FP_EXTEND &&
17548 N1.getOpcode() != ISD::FP_ROUND)
17549 return false;
17550 EVT N1VT = N1->getValueType(ResNo: 0);
17551 EVT N1Op0VT = N1->getOperand(Num: 0).getValueType();
17552 return CanCombineFCOPYSIGN_EXTEND_ROUND(XTy: N1VT, YTy: N1Op0VT);
17553}
17554
17555SDValue DAGCombiner::visitFCOPYSIGN(SDNode *N) {
17556 SDValue N0 = N->getOperand(Num: 0);
17557 SDValue N1 = N->getOperand(Num: 1);
17558 EVT VT = N->getValueType(ResNo: 0);
17559 SDLoc DL(N);
17560
17561 // fold (fcopysign c1, c2) -> fcopysign(c1,c2)
17562 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FCOPYSIGN, DL, VT, Ops: {N0, N1}))
17563 return C;
17564
17565 if (ConstantFPSDNode *N1C = isConstOrConstSplatFP(N: N->getOperand(Num: 1))) {
17566 const APFloat &V = N1C->getValueAPF();
17567 // copysign(x, c1) -> fabs(x) iff ispos(c1)
17568 // copysign(x, c1) -> fneg(fabs(x)) iff isneg(c1)
17569 if (!V.isNegative()) {
17570 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::FABS, VT))
17571 return DAG.getNode(Opcode: ISD::FABS, DL, VT, Operand: N0);
17572 } else {
17573 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::FNEG, VT))
17574 return DAG.getNode(Opcode: ISD::FNEG, DL, VT,
17575 Operand: DAG.getNode(Opcode: ISD::FABS, DL: SDLoc(N0), VT, Operand: N0));
17576 }
17577 }
17578
17579 // copysign(fabs(x), y) -> copysign(x, y)
17580 // copysign(fneg(x), y) -> copysign(x, y)
17581 // copysign(copysign(x,z), y) -> copysign(x, y)
17582 if (N0.getOpcode() == ISD::FABS || N0.getOpcode() == ISD::FNEG ||
17583 N0.getOpcode() == ISD::FCOPYSIGN)
17584 return DAG.getNode(Opcode: ISD::FCOPYSIGN, DL, VT, N1: N0.getOperand(i: 0), N2: N1);
17585
17586 // copysign(x, abs(y)) -> abs(x)
17587 if (N1.getOpcode() == ISD::FABS)
17588 return DAG.getNode(Opcode: ISD::FABS, DL, VT, Operand: N0);
17589
17590 // copysign(x, copysign(y,z)) -> copysign(x, z)
17591 if (N1.getOpcode() == ISD::FCOPYSIGN)
17592 return DAG.getNode(Opcode: ISD::FCOPYSIGN, DL, VT, N1: N0, N2: N1.getOperand(i: 1));
17593
17594 // copysign(x, fp_extend(y)) -> copysign(x, y)
17595 // copysign(x, fp_round(y)) -> copysign(x, y)
17596 if (CanCombineFCOPYSIGN_EXTEND_ROUND(N))
17597 return DAG.getNode(Opcode: ISD::FCOPYSIGN, DL, VT, N1: N0, N2: N1.getOperand(i: 0));
17598
17599 // We only take the sign bit from the sign operand.
17600 EVT SignVT = N1.getValueType();
17601 if (SimplifyDemandedBits(Op: N1,
17602 DemandedBits: APInt::getSignMask(BitWidth: SignVT.getScalarSizeInBits())))
17603 return SDValue(N, 0);
17604
17605 // We only take the non-sign bits from the value operand
17606 if (SimplifyDemandedBits(Op: N0,
17607 DemandedBits: APInt::getSignedMaxValue(numBits: VT.getScalarSizeInBits())))
17608 return SDValue(N, 0);
17609
17610 return SDValue();
17611}
17612
17613SDValue DAGCombiner::visitFPOW(SDNode *N) {
17614 ConstantFPSDNode *ExponentC = isConstOrConstSplatFP(N: N->getOperand(Num: 1));
17615 if (!ExponentC)
17616 return SDValue();
17617 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17618
17619 // Try to convert x ** (1/3) into cube root.
17620 // TODO: Handle the various flavors of long double.
17621 // TODO: Since we're approximating, we don't need an exact 1/3 exponent.
17622 // Some range near 1/3 should be fine.
17623 EVT VT = N->getValueType(ResNo: 0);
17624 if ((VT == MVT::f32 && ExponentC->getValueAPF().isExactlyValue(V: 1.0f/3.0f)) ||
17625 (VT == MVT::f64 && ExponentC->getValueAPF().isExactlyValue(V: 1.0/3.0))) {
17626 // pow(-0.0, 1/3) = +0.0; cbrt(-0.0) = -0.0.
17627 // pow(-inf, 1/3) = +inf; cbrt(-inf) = -inf.
17628 // pow(-val, 1/3) = nan; cbrt(-val) = -num.
17629 // For regular numbers, rounding may cause the results to differ.
17630 // Therefore, we require { nsz ninf nnan afn } for this transform.
17631 // TODO: We could select out the special cases if we don't have nsz/ninf.
17632 SDNodeFlags Flags = N->getFlags();
17633 if (!Flags.hasNoSignedZeros() || !Flags.hasNoInfs() || !Flags.hasNoNaNs() ||
17634 !Flags.hasApproximateFuncs())
17635 return SDValue();
17636
17637 // Do not create a cbrt() libcall if the target does not have it, and do not
17638 // turn a pow that has lowering support into a cbrt() libcall.
17639 if (!DAG.getLibInfo().has(F: LibFunc_cbrt) ||
17640 (!DAG.getTargetLoweringInfo().isOperationExpand(Op: ISD::FPOW, VT) &&
17641 DAG.getTargetLoweringInfo().isOperationExpand(Op: ISD::FCBRT, VT)))
17642 return SDValue();
17643
17644 return DAG.getNode(Opcode: ISD::FCBRT, DL: SDLoc(N), VT, Operand: N->getOperand(Num: 0));
17645 }
17646
17647 // Try to convert x ** (1/4) and x ** (3/4) into square roots.
17648 // x ** (1/2) is canonicalized to sqrt, so we do not bother with that case.
17649 // TODO: This could be extended (using a target hook) to handle smaller
17650 // power-of-2 fractional exponents.
17651 bool ExponentIs025 = ExponentC->getValueAPF().isExactlyValue(V: 0.25);
17652 bool ExponentIs075 = ExponentC->getValueAPF().isExactlyValue(V: 0.75);
17653 if (ExponentIs025 || ExponentIs075) {
17654 // pow(-0.0, 0.25) = +0.0; sqrt(sqrt(-0.0)) = -0.0.
17655 // pow(-inf, 0.25) = +inf; sqrt(sqrt(-inf)) = NaN.
17656 // pow(-0.0, 0.75) = +0.0; sqrt(-0.0) * sqrt(sqrt(-0.0)) = +0.0.
17657 // pow(-inf, 0.75) = +inf; sqrt(-inf) * sqrt(sqrt(-inf)) = NaN.
17658 // For regular numbers, rounding may cause the results to differ.
17659 // Therefore, we require { nsz ninf afn } for this transform.
17660 // TODO: We could select out the special cases if we don't have nsz/ninf.
17661 SDNodeFlags Flags = N->getFlags();
17662
17663 // We only need no signed zeros for the 0.25 case.
17664 if ((!Flags.hasNoSignedZeros() && ExponentIs025) || !Flags.hasNoInfs() ||
17665 !Flags.hasApproximateFuncs())
17666 return SDValue();
17667
17668 // Don't double the number of libcalls. We are trying to inline fast code.
17669 if (!DAG.getTargetLoweringInfo().isOperationLegalOrCustom(Op: ISD::FSQRT, VT))
17670 return SDValue();
17671
17672 // Assume that libcalls are the smallest code.
17673 // TODO: This restriction should probably be lifted for vectors.
17674 if (ForCodeSize)
17675 return SDValue();
17676
17677 // pow(X, 0.25) --> sqrt(sqrt(X))
17678 SDLoc DL(N);
17679 SDValue Sqrt = DAG.getNode(Opcode: ISD::FSQRT, DL, VT, Operand: N->getOperand(Num: 0));
17680 SDValue SqrtSqrt = DAG.getNode(Opcode: ISD::FSQRT, DL, VT, Operand: Sqrt);
17681 if (ExponentIs025)
17682 return SqrtSqrt;
17683 // pow(X, 0.75) --> sqrt(X) * sqrt(sqrt(X))
17684 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Sqrt, N2: SqrtSqrt);
17685 }
17686
17687 return SDValue();
17688}
17689
17690static SDValue foldFPToIntToFP(SDNode *N, SelectionDAG &DAG,
17691 const TargetLowering &TLI) {
17692 // We only do this if the target has legal ftrunc. Otherwise, we'd likely be
17693 // replacing casts with a libcall. We also must be allowed to ignore -0.0
17694 // because FTRUNC will return -0.0 for (-1.0, -0.0), but using integer
17695 // conversions would return +0.0.
17696 // FIXME: We should be able to use node-level FMF here.
17697 // TODO: If strict math, should we use FABS (+ range check for signed cast)?
17698 EVT VT = N->getValueType(ResNo: 0);
17699 if (!TLI.isOperationLegal(Op: ISD::FTRUNC, VT) ||
17700 !DAG.getTarget().Options.NoSignedZerosFPMath)
17701 return SDValue();
17702
17703 // fptosi/fptoui round towards zero, so converting from FP to integer and
17704 // back is the same as an 'ftrunc': [us]itofp (fpto[us]i X) --> ftrunc X
17705 SDValue N0 = N->getOperand(Num: 0);
17706 if (N->getOpcode() == ISD::SINT_TO_FP && N0.getOpcode() == ISD::FP_TO_SINT &&
17707 N0.getOperand(i: 0).getValueType() == VT)
17708 return DAG.getNode(Opcode: ISD::FTRUNC, DL: SDLoc(N), VT, Operand: N0.getOperand(i: 0));
17709
17710 if (N->getOpcode() == ISD::UINT_TO_FP && N0.getOpcode() == ISD::FP_TO_UINT &&
17711 N0.getOperand(i: 0).getValueType() == VT)
17712 return DAG.getNode(Opcode: ISD::FTRUNC, DL: SDLoc(N), VT, Operand: N0.getOperand(i: 0));
17713
17714 return SDValue();
17715}
17716
17717SDValue DAGCombiner::visitSINT_TO_FP(SDNode *N) {
17718 SDValue N0 = N->getOperand(Num: 0);
17719 EVT VT = N->getValueType(ResNo: 0);
17720 EVT OpVT = N0.getValueType();
17721
17722 // [us]itofp(undef) = 0, because the result value is bounded.
17723 if (N0.isUndef())
17724 return DAG.getConstantFP(Val: 0.0, DL: SDLoc(N), VT);
17725
17726 // fold (sint_to_fp c1) -> c1fp
17727 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
17728 // ...but only if the target supports immediate floating-point values
17729 (!LegalOperations ||
17730 TLI.isOperationLegalOrCustom(Op: ISD::ConstantFP, VT)))
17731 return DAG.getNode(Opcode: ISD::SINT_TO_FP, DL: SDLoc(N), VT, Operand: N0);
17732
17733 // If the input is a legal type, and SINT_TO_FP is not legal on this target,
17734 // but UINT_TO_FP is legal on this target, try to convert.
17735 if (!hasOperation(Opcode: ISD::SINT_TO_FP, VT: OpVT) &&
17736 hasOperation(Opcode: ISD::UINT_TO_FP, VT: OpVT)) {
17737 // If the sign bit is known to be zero, we can change this to UINT_TO_FP.
17738 if (DAG.SignBitIsZero(Op: N0))
17739 return DAG.getNode(Opcode: ISD::UINT_TO_FP, DL: SDLoc(N), VT, Operand: N0);
17740 }
17741
17742 // The next optimizations are desirable only if SELECT_CC can be lowered.
17743 // fold (sint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), -1.0, 0.0)
17744 if (N0.getOpcode() == ISD::SETCC && N0.getValueType() == MVT::i1 &&
17745 !VT.isVector() &&
17746 (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::ConstantFP, VT))) {
17747 SDLoc DL(N);
17748 return DAG.getSelect(DL, VT, Cond: N0, LHS: DAG.getConstantFP(Val: -1.0, DL, VT),
17749 RHS: DAG.getConstantFP(Val: 0.0, DL, VT));
17750 }
17751
17752 // fold (sint_to_fp (zext (setcc x, y, cc))) ->
17753 // (select (setcc x, y, cc), 1.0, 0.0)
17754 if (N0.getOpcode() == ISD::ZERO_EXTEND &&
17755 N0.getOperand(i: 0).getOpcode() == ISD::SETCC && !VT.isVector() &&
17756 (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::ConstantFP, VT))) {
17757 SDLoc DL(N);
17758 return DAG.getSelect(DL, VT, Cond: N0.getOperand(i: 0),
17759 LHS: DAG.getConstantFP(Val: 1.0, DL, VT),
17760 RHS: DAG.getConstantFP(Val: 0.0, DL, VT));
17761 }
17762
17763 if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI))
17764 return FTrunc;
17765
17766 return SDValue();
17767}
17768
17769SDValue DAGCombiner::visitUINT_TO_FP(SDNode *N) {
17770 SDValue N0 = N->getOperand(Num: 0);
17771 EVT VT = N->getValueType(ResNo: 0);
17772 EVT OpVT = N0.getValueType();
17773
17774 // [us]itofp(undef) = 0, because the result value is bounded.
17775 if (N0.isUndef())
17776 return DAG.getConstantFP(Val: 0.0, DL: SDLoc(N), VT);
17777
17778 // fold (uint_to_fp c1) -> c1fp
17779 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
17780 // ...but only if the target supports immediate floating-point values
17781 (!LegalOperations ||
17782 TLI.isOperationLegalOrCustom(Op: ISD::ConstantFP, VT)))
17783 return DAG.getNode(Opcode: ISD::UINT_TO_FP, DL: SDLoc(N), VT, Operand: N0);
17784
17785 // If the input is a legal type, and UINT_TO_FP is not legal on this target,
17786 // but SINT_TO_FP is legal on this target, try to convert.
17787 if (!hasOperation(Opcode: ISD::UINT_TO_FP, VT: OpVT) &&
17788 hasOperation(Opcode: ISD::SINT_TO_FP, VT: OpVT)) {
17789 // If the sign bit is known to be zero, we can change this to SINT_TO_FP.
17790 if (DAG.SignBitIsZero(Op: N0))
17791 return DAG.getNode(Opcode: ISD::SINT_TO_FP, DL: SDLoc(N), VT, Operand: N0);
17792 }
17793
17794 // fold (uint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), 1.0, 0.0)
17795 if (N0.getOpcode() == ISD::SETCC && !VT.isVector() &&
17796 (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::ConstantFP, VT))) {
17797 SDLoc DL(N);
17798 return DAG.getSelect(DL, VT, Cond: N0, LHS: DAG.getConstantFP(Val: 1.0, DL, VT),
17799 RHS: DAG.getConstantFP(Val: 0.0, DL, VT));
17800 }
17801
17802 if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI))
17803 return FTrunc;
17804
17805 return SDValue();
17806}
17807
17808// Fold (fp_to_{s/u}int ({s/u}int_to_fpx)) -> zext x, sext x, trunc x, or x
17809static SDValue FoldIntToFPToInt(SDNode *N, SelectionDAG &DAG) {
17810 SDValue N0 = N->getOperand(Num: 0);
17811 EVT VT = N->getValueType(ResNo: 0);
17812
17813 if (N0.getOpcode() != ISD::UINT_TO_FP && N0.getOpcode() != ISD::SINT_TO_FP)
17814 return SDValue();
17815
17816 SDValue Src = N0.getOperand(i: 0);
17817 EVT SrcVT = Src.getValueType();
17818 bool IsInputSigned = N0.getOpcode() == ISD::SINT_TO_FP;
17819 bool IsOutputSigned = N->getOpcode() == ISD::FP_TO_SINT;
17820
17821 // We can safely assume the conversion won't overflow the output range,
17822 // because (for example) (uint8_t)18293.f is undefined behavior.
17823
17824 // Since we can assume the conversion won't overflow, our decision as to
17825 // whether the input will fit in the float should depend on the minimum
17826 // of the input range and output range.
17827
17828 // This means this is also safe for a signed input and unsigned output, since
17829 // a negative input would lead to undefined behavior.
17830 unsigned InputSize = (int)SrcVT.getScalarSizeInBits() - IsInputSigned;
17831 unsigned OutputSize = (int)VT.getScalarSizeInBits();
17832 unsigned ActualSize = std::min(a: InputSize, b: OutputSize);
17833 const fltSemantics &sem = DAG.EVTToAPFloatSemantics(VT: N0.getValueType());
17834
17835 // We can only fold away the float conversion if the input range can be
17836 // represented exactly in the float range.
17837 if (APFloat::semanticsPrecision(sem) >= ActualSize) {
17838 if (VT.getScalarSizeInBits() > SrcVT.getScalarSizeInBits()) {
17839 unsigned ExtOp = IsInputSigned && IsOutputSigned ? ISD::SIGN_EXTEND
17840 : ISD::ZERO_EXTEND;
17841 return DAG.getNode(Opcode: ExtOp, DL: SDLoc(N), VT, Operand: Src);
17842 }
17843 if (VT.getScalarSizeInBits() < SrcVT.getScalarSizeInBits())
17844 return DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(N), VT, Operand: Src);
17845 return DAG.getBitcast(VT, V: Src);
17846 }
17847 return SDValue();
17848}
17849
17850SDValue DAGCombiner::visitFP_TO_SINT(SDNode *N) {
17851 SDValue N0 = N->getOperand(Num: 0);
17852 EVT VT = N->getValueType(ResNo: 0);
17853
17854 // fold (fp_to_sint undef) -> undef
17855 if (N0.isUndef())
17856 return DAG.getUNDEF(VT);
17857
17858 // fold (fp_to_sint c1fp) -> c1
17859 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0))
17860 return DAG.getNode(Opcode: ISD::FP_TO_SINT, DL: SDLoc(N), VT, Operand: N0);
17861
17862 return FoldIntToFPToInt(N, DAG);
17863}
17864
17865SDValue DAGCombiner::visitFP_TO_UINT(SDNode *N) {
17866 SDValue N0 = N->getOperand(Num: 0);
17867 EVT VT = N->getValueType(ResNo: 0);
17868
17869 // fold (fp_to_uint undef) -> undef
17870 if (N0.isUndef())
17871 return DAG.getUNDEF(VT);
17872
17873 // fold (fp_to_uint c1fp) -> c1
17874 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0))
17875 return DAG.getNode(Opcode: ISD::FP_TO_UINT, DL: SDLoc(N), VT, Operand: N0);
17876
17877 return FoldIntToFPToInt(N, DAG);
17878}
17879
17880SDValue DAGCombiner::visitXRINT(SDNode *N) {
17881 SDValue N0 = N->getOperand(Num: 0);
17882 EVT VT = N->getValueType(ResNo: 0);
17883
17884 // fold (lrint|llrint undef) -> undef
17885 if (N0.isUndef())
17886 return DAG.getUNDEF(VT);
17887
17888 // fold (lrint|llrint c1fp) -> c1
17889 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0))
17890 return DAG.getNode(Opcode: N->getOpcode(), DL: SDLoc(N), VT, Operand: N0);
17891
17892 return SDValue();
17893}
17894
17895SDValue DAGCombiner::visitFP_ROUND(SDNode *N) {
17896 SDValue N0 = N->getOperand(Num: 0);
17897 SDValue N1 = N->getOperand(Num: 1);
17898 EVT VT = N->getValueType(ResNo: 0);
17899
17900 // fold (fp_round c1fp) -> c1fp
17901 if (SDValue C =
17902 DAG.FoldConstantArithmetic(Opcode: ISD::FP_ROUND, DL: SDLoc(N), VT, Ops: {N0, N1}))
17903 return C;
17904
17905 // fold (fp_round (fp_extend x)) -> x
17906 if (N0.getOpcode() == ISD::FP_EXTEND && VT == N0.getOperand(i: 0).getValueType())
17907 return N0.getOperand(i: 0);
17908
17909 // fold (fp_round (fp_round x)) -> (fp_round x)
17910 if (N0.getOpcode() == ISD::FP_ROUND) {
17911 const bool NIsTrunc = N->getConstantOperandVal(Num: 1) == 1;
17912 const bool N0IsTrunc = N0.getConstantOperandVal(i: 1) == 1;
17913
17914 // Avoid folding legal fp_rounds into non-legal ones.
17915 if (!hasOperation(Opcode: ISD::FP_ROUND, VT))
17916 return SDValue();
17917
17918 // Skip this folding if it results in an fp_round from f80 to f16.
17919 //
17920 // f80 to f16 always generates an expensive (and as yet, unimplemented)
17921 // libcall to __truncxfhf2 instead of selecting native f16 conversion
17922 // instructions from f32 or f64. Moreover, the first (value-preserving)
17923 // fp_round from f80 to either f32 or f64 may become a NOP in platforms like
17924 // x86.
17925 if (N0.getOperand(i: 0).getValueType() == MVT::f80 && VT == MVT::f16)
17926 return SDValue();
17927
17928 // If the first fp_round isn't a value preserving truncation, it might
17929 // introduce a tie in the second fp_round, that wouldn't occur in the
17930 // single-step fp_round we want to fold to.
17931 // In other words, double rounding isn't the same as rounding.
17932 // Also, this is a value preserving truncation iff both fp_round's are.
17933 if (DAG.getTarget().Options.UnsafeFPMath || N0IsTrunc) {
17934 SDLoc DL(N);
17935 return DAG.getNode(
17936 Opcode: ISD::FP_ROUND, DL, VT, N1: N0.getOperand(i: 0),
17937 N2: DAG.getIntPtrConstant(Val: NIsTrunc && N0IsTrunc, DL, /*isTarget=*/true));
17938 }
17939 }
17940
17941 // fold (fp_round (copysign X, Y)) -> (copysign (fp_round X), Y)
17942 // Note: From a legality perspective, this is a two step transform. First,
17943 // we duplicate the fp_round to the arguments of the copysign, then we
17944 // eliminate the fp_round on Y. The second step requires an additional
17945 // predicate to match the implementation above.
17946 if (N0.getOpcode() == ISD::FCOPYSIGN && N0->hasOneUse() &&
17947 CanCombineFCOPYSIGN_EXTEND_ROUND(XTy: VT,
17948 YTy: N0.getValueType())) {
17949 SDValue Tmp = DAG.getNode(Opcode: ISD::FP_ROUND, DL: SDLoc(N0), VT,
17950 N1: N0.getOperand(i: 0), N2: N1);
17951 AddToWorklist(N: Tmp.getNode());
17952 return DAG.getNode(Opcode: ISD::FCOPYSIGN, DL: SDLoc(N), VT,
17953 N1: Tmp, N2: N0.getOperand(i: 1));
17954 }
17955
17956 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(Cast: N))
17957 return NewVSel;
17958
17959 return SDValue();
17960}
17961
17962SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) {
17963 SDValue N0 = N->getOperand(Num: 0);
17964 EVT VT = N->getValueType(ResNo: 0);
17965
17966 if (VT.isVector())
17967 if (SDValue FoldedVOp = SimplifyVCastOp(N, DL: SDLoc(N)))
17968 return FoldedVOp;
17969
17970 // If this is fp_round(fpextend), don't fold it, allow ourselves to be folded.
17971 if (N->hasOneUse() &&
17972 N->use_begin()->getOpcode() == ISD::FP_ROUND)
17973 return SDValue();
17974
17975 // fold (fp_extend c1fp) -> c1fp
17976 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0))
17977 return DAG.getNode(Opcode: ISD::FP_EXTEND, DL: SDLoc(N), VT, Operand: N0);
17978
17979 // fold (fp_extend (fp16_to_fp op)) -> (fp16_to_fp op)
17980 if (N0.getOpcode() == ISD::FP16_TO_FP &&
17981 TLI.getOperationAction(Op: ISD::FP16_TO_FP, VT) == TargetLowering::Legal)
17982 return DAG.getNode(Opcode: ISD::FP16_TO_FP, DL: SDLoc(N), VT, Operand: N0.getOperand(i: 0));
17983
17984 // Turn fp_extend(fp_round(X, 1)) -> x since the fp_round doesn't affect the
17985 // value of X.
17986 if (N0.getOpcode() == ISD::FP_ROUND
17987 && N0.getConstantOperandVal(i: 1) == 1) {
17988 SDValue In = N0.getOperand(i: 0);
17989 if (In.getValueType() == VT) return In;
17990 if (VT.bitsLT(VT: In.getValueType()))
17991 return DAG.getNode(Opcode: ISD::FP_ROUND, DL: SDLoc(N), VT,
17992 N1: In, N2: N0.getOperand(i: 1));
17993 return DAG.getNode(Opcode: ISD::FP_EXTEND, DL: SDLoc(N), VT, Operand: In);
17994 }
17995
17996 // fold (fpext (load x)) -> (fpext (fptrunc (extload x)))
17997 if (ISD::isNormalLoad(N: N0.getNode()) && N0.hasOneUse() &&
17998 TLI.isLoadExtLegalOrCustom(ExtType: ISD::EXTLOAD, ValVT: VT, MemVT: N0.getValueType())) {
17999 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
18000 SDValue ExtLoad = DAG.getExtLoad(ExtType: ISD::EXTLOAD, dl: SDLoc(N), VT,
18001 Chain: LN0->getChain(),
18002 Ptr: LN0->getBasePtr(), MemVT: N0.getValueType(),
18003 MMO: LN0->getMemOperand());
18004 CombineTo(N, Res: ExtLoad);
18005 CombineTo(
18006 N: N0.getNode(),
18007 Res0: DAG.getNode(Opcode: ISD::FP_ROUND, DL: SDLoc(N0), VT: N0.getValueType(), N1: ExtLoad,
18008 N2: DAG.getIntPtrConstant(Val: 1, DL: SDLoc(N0), /*isTarget=*/true)),
18009 Res1: ExtLoad.getValue(R: 1));
18010 return SDValue(N, 0); // Return N so it doesn't get rechecked!
18011 }
18012
18013 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(Cast: N))
18014 return NewVSel;
18015
18016 return SDValue();
18017}
18018
18019SDValue DAGCombiner::visitFCEIL(SDNode *N) {
18020 SDValue N0 = N->getOperand(Num: 0);
18021 EVT VT = N->getValueType(ResNo: 0);
18022
18023 // fold (fceil c1) -> fceil(c1)
18024 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0))
18025 return DAG.getNode(Opcode: ISD::FCEIL, DL: SDLoc(N), VT, Operand: N0);
18026
18027 return SDValue();
18028}
18029
18030SDValue DAGCombiner::visitFTRUNC(SDNode *N) {
18031 SDValue N0 = N->getOperand(Num: 0);
18032 EVT VT = N->getValueType(ResNo: 0);
18033
18034 // fold (ftrunc c1) -> ftrunc(c1)
18035 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0))
18036 return DAG.getNode(Opcode: ISD::FTRUNC, DL: SDLoc(N), VT, Operand: N0);
18037
18038 // fold ftrunc (known rounded int x) -> x
18039 // ftrunc is a part of fptosi/fptoui expansion on some targets, so this is
18040 // likely to be generated to extract integer from a rounded floating value.
18041 switch (N0.getOpcode()) {
18042 default: break;
18043 case ISD::FRINT:
18044 case ISD::FTRUNC:
18045 case ISD::FNEARBYINT:
18046 case ISD::FROUNDEVEN:
18047 case ISD::FFLOOR:
18048 case ISD::FCEIL:
18049 return N0;
18050 }
18051
18052 return SDValue();
18053}
18054
18055SDValue DAGCombiner::visitFFREXP(SDNode *N) {
18056 SDValue N0 = N->getOperand(Num: 0);
18057
18058 // fold (ffrexp c1) -> ffrexp(c1)
18059 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0))
18060 return DAG.getNode(Opcode: ISD::FFREXP, DL: SDLoc(N), VTList: N->getVTList(), N: N0);
18061 return SDValue();
18062}
18063
18064SDValue DAGCombiner::visitFFLOOR(SDNode *N) {
18065 SDValue N0 = N->getOperand(Num: 0);
18066 EVT VT = N->getValueType(ResNo: 0);
18067
18068 // fold (ffloor c1) -> ffloor(c1)
18069 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0))
18070 return DAG.getNode(Opcode: ISD::FFLOOR, DL: SDLoc(N), VT, Operand: N0);
18071
18072 return SDValue();
18073}
18074
18075SDValue DAGCombiner::visitFNEG(SDNode *N) {
18076 SDValue N0 = N->getOperand(Num: 0);
18077 EVT VT = N->getValueType(ResNo: 0);
18078 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
18079
18080 // Constant fold FNEG.
18081 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0))
18082 return DAG.getNode(Opcode: ISD::FNEG, DL: SDLoc(N), VT, Operand: N0);
18083
18084 if (SDValue NegN0 =
18085 TLI.getNegatedExpression(Op: N0, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize))
18086 return NegN0;
18087
18088 // -(X-Y) -> (Y-X) is unsafe because when X==Y, -0.0 != +0.0
18089 // FIXME: This is duplicated in getNegatibleCost, but getNegatibleCost doesn't
18090 // know it was called from a context with a nsz flag if the input fsub does
18091 // not.
18092 if (N0.getOpcode() == ISD::FSUB &&
18093 (DAG.getTarget().Options.NoSignedZerosFPMath ||
18094 N->getFlags().hasNoSignedZeros()) && N0.hasOneUse()) {
18095 return DAG.getNode(Opcode: ISD::FSUB, DL: SDLoc(N), VT, N1: N0.getOperand(i: 1),
18096 N2: N0.getOperand(i: 0));
18097 }
18098
18099 if (SDValue Cast = foldSignChangeInBitcast(N))
18100 return Cast;
18101
18102 return SDValue();
18103}
18104
18105SDValue DAGCombiner::visitFMinMax(SDNode *N) {
18106 SDValue N0 = N->getOperand(Num: 0);
18107 SDValue N1 = N->getOperand(Num: 1);
18108 EVT VT = N->getValueType(ResNo: 0);
18109 const SDNodeFlags Flags = N->getFlags();
18110 unsigned Opc = N->getOpcode();
18111 bool PropagatesNaN = Opc == ISD::FMINIMUM || Opc == ISD::FMAXIMUM;
18112 bool IsMin = Opc == ISD::FMINNUM || Opc == ISD::FMINIMUM;
18113 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
18114
18115 // Constant fold.
18116 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: Opc, DL: SDLoc(N), VT, Ops: {N0, N1}))
18117 return C;
18118
18119 // Canonicalize to constant on RHS.
18120 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0) &&
18121 !DAG.isConstantFPBuildVectorOrConstantFP(N: N1))
18122 return DAG.getNode(Opcode: N->getOpcode(), DL: SDLoc(N), VT, N1, N2: N0);
18123
18124 if (const ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N: N1)) {
18125 const APFloat &AF = N1CFP->getValueAPF();
18126
18127 // minnum(X, nan) -> X
18128 // maxnum(X, nan) -> X
18129 // minimum(X, nan) -> nan
18130 // maximum(X, nan) -> nan
18131 if (AF.isNaN())
18132 return PropagatesNaN ? N->getOperand(Num: 1) : N->getOperand(Num: 0);
18133
18134 // In the following folds, inf can be replaced with the largest finite
18135 // float, if the ninf flag is set.
18136 if (AF.isInfinity() || (Flags.hasNoInfs() && AF.isLargest())) {
18137 // minnum(X, -inf) -> -inf
18138 // maxnum(X, +inf) -> +inf
18139 // minimum(X, -inf) -> -inf if nnan
18140 // maximum(X, +inf) -> +inf if nnan
18141 if (IsMin == AF.isNegative() && (!PropagatesNaN || Flags.hasNoNaNs()))
18142 return N->getOperand(Num: 1);
18143
18144 // minnum(X, +inf) -> X if nnan
18145 // maxnum(X, -inf) -> X if nnan
18146 // minimum(X, +inf) -> X
18147 // maximum(X, -inf) -> X
18148 if (IsMin != AF.isNegative() && (PropagatesNaN || Flags.hasNoNaNs()))
18149 return N->getOperand(Num: 0);
18150 }
18151 }
18152
18153 if (SDValue SD = reassociateReduction(
18154 RedOpc: PropagatesNaN
18155 ? (IsMin ? ISD::VECREDUCE_FMINIMUM : ISD::VECREDUCE_FMAXIMUM)
18156 : (IsMin ? ISD::VECREDUCE_FMIN : ISD::VECREDUCE_FMAX),
18157 Opc, DL: SDLoc(N), VT, N0, N1, Flags))
18158 return SD;
18159
18160 return SDValue();
18161}
18162
18163SDValue DAGCombiner::visitFABS(SDNode *N) {
18164 SDValue N0 = N->getOperand(Num: 0);
18165 EVT VT = N->getValueType(ResNo: 0);
18166
18167 // fold (fabs c1) -> fabs(c1)
18168 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0))
18169 return DAG.getNode(Opcode: ISD::FABS, DL: SDLoc(N), VT, Operand: N0);
18170
18171 // fold (fabs (fabs x)) -> (fabs x)
18172 if (N0.getOpcode() == ISD::FABS)
18173 return N->getOperand(Num: 0);
18174
18175 // fold (fabs (fneg x)) -> (fabs x)
18176 // fold (fabs (fcopysign x, y)) -> (fabs x)
18177 if (N0.getOpcode() == ISD::FNEG || N0.getOpcode() == ISD::FCOPYSIGN)
18178 return DAG.getNode(Opcode: ISD::FABS, DL: SDLoc(N), VT, Operand: N0.getOperand(i: 0));
18179
18180 if (SDValue Cast = foldSignChangeInBitcast(N))
18181 return Cast;
18182
18183 return SDValue();
18184}
18185
18186SDValue DAGCombiner::visitBRCOND(SDNode *N) {
18187 SDValue Chain = N->getOperand(Num: 0);
18188 SDValue N1 = N->getOperand(Num: 1);
18189 SDValue N2 = N->getOperand(Num: 2);
18190
18191 // BRCOND(FREEZE(cond)) is equivalent to BRCOND(cond) (both are
18192 // nondeterministic jumps).
18193 if (N1->getOpcode() == ISD::FREEZE && N1.hasOneUse()) {
18194 return DAG.getNode(Opcode: ISD::BRCOND, DL: SDLoc(N), VT: MVT::Other, N1: Chain,
18195 N2: N1->getOperand(Num: 0), N3: N2);
18196 }
18197
18198 // Variant of the previous fold where there is a SETCC in between:
18199 // BRCOND(SETCC(FREEZE(X), CONST, Cond))
18200 // =>
18201 // BRCOND(FREEZE(SETCC(X, CONST, Cond)))
18202 // =>
18203 // BRCOND(SETCC(X, CONST, Cond))
18204 // This is correct if FREEZE(X) has one use and SETCC(FREEZE(X), CONST, Cond)
18205 // isn't equivalent to true or false.
18206 // For example, SETCC(FREEZE(X), -128, SETULT) cannot be folded to
18207 // FREEZE(SETCC(X, -128, SETULT)) because X can be poison.
18208 if (N1->getOpcode() == ISD::SETCC && N1.hasOneUse()) {
18209 SDValue S0 = N1->getOperand(Num: 0), S1 = N1->getOperand(Num: 1);
18210 ISD::CondCode Cond = cast<CondCodeSDNode>(Val: N1->getOperand(Num: 2))->get();
18211 ConstantSDNode *S0C = dyn_cast<ConstantSDNode>(Val&: S0);
18212 ConstantSDNode *S1C = dyn_cast<ConstantSDNode>(Val&: S1);
18213 bool Updated = false;
18214
18215 // Is 'X Cond C' always true or false?
18216 auto IsAlwaysTrueOrFalse = [](ISD::CondCode Cond, ConstantSDNode *C) {
18217 bool False = (Cond == ISD::SETULT && C->isZero()) ||
18218 (Cond == ISD::SETLT && C->isMinSignedValue()) ||
18219 (Cond == ISD::SETUGT && C->isAllOnes()) ||
18220 (Cond == ISD::SETGT && C->isMaxSignedValue());
18221 bool True = (Cond == ISD::SETULE && C->isAllOnes()) ||
18222 (Cond == ISD::SETLE && C->isMaxSignedValue()) ||
18223 (Cond == ISD::SETUGE && C->isZero()) ||
18224 (Cond == ISD::SETGE && C->isMinSignedValue());
18225 return True || False;
18226 };
18227
18228 if (S0->getOpcode() == ISD::FREEZE && S0.hasOneUse() && S1C) {
18229 if (!IsAlwaysTrueOrFalse(Cond, S1C)) {
18230 S0 = S0->getOperand(Num: 0);
18231 Updated = true;
18232 }
18233 }
18234 if (S1->getOpcode() == ISD::FREEZE && S1.hasOneUse() && S0C) {
18235 if (!IsAlwaysTrueOrFalse(ISD::getSetCCSwappedOperands(Operation: Cond), S0C)) {
18236 S1 = S1->getOperand(Num: 0);
18237 Updated = true;
18238 }
18239 }
18240
18241 if (Updated)
18242 return DAG.getNode(
18243 Opcode: ISD::BRCOND, DL: SDLoc(N), VT: MVT::Other, N1: Chain,
18244 N2: DAG.getSetCC(DL: SDLoc(N1), VT: N1->getValueType(ResNo: 0), LHS: S0, RHS: S1, Cond), N3: N2);
18245 }
18246
18247 // If N is a constant we could fold this into a fallthrough or unconditional
18248 // branch. However that doesn't happen very often in normal code, because
18249 // Instcombine/SimplifyCFG should have handled the available opportunities.
18250 // If we did this folding here, it would be necessary to update the
18251 // MachineBasicBlock CFG, which is awkward.
18252
18253 // fold a brcond with a setcc condition into a BR_CC node if BR_CC is legal
18254 // on the target.
18255 if (N1.getOpcode() == ISD::SETCC &&
18256 TLI.isOperationLegalOrCustom(Op: ISD::BR_CC,
18257 VT: N1.getOperand(i: 0).getValueType())) {
18258 return DAG.getNode(Opcode: ISD::BR_CC, DL: SDLoc(N), VT: MVT::Other,
18259 N1: Chain, N2: N1.getOperand(i: 2),
18260 N3: N1.getOperand(i: 0), N4: N1.getOperand(i: 1), N5: N2);
18261 }
18262
18263 if (N1.hasOneUse()) {
18264 // rebuildSetCC calls visitXor which may change the Chain when there is a
18265 // STRICT_FSETCC/STRICT_FSETCCS involved. Use a handle to track changes.
18266 HandleSDNode ChainHandle(Chain);
18267 if (SDValue NewN1 = rebuildSetCC(N: N1))
18268 return DAG.getNode(Opcode: ISD::BRCOND, DL: SDLoc(N), VT: MVT::Other,
18269 N1: ChainHandle.getValue(), N2: NewN1, N3: N2);
18270 }
18271
18272 return SDValue();
18273}
18274
18275SDValue DAGCombiner::rebuildSetCC(SDValue N) {
18276 if (N.getOpcode() == ISD::SRL ||
18277 (N.getOpcode() == ISD::TRUNCATE &&
18278 (N.getOperand(i: 0).hasOneUse() &&
18279 N.getOperand(i: 0).getOpcode() == ISD::SRL))) {
18280 // Look pass the truncate.
18281 if (N.getOpcode() == ISD::TRUNCATE)
18282 N = N.getOperand(i: 0);
18283
18284 // Match this pattern so that we can generate simpler code:
18285 //
18286 // %a = ...
18287 // %b = and i32 %a, 2
18288 // %c = srl i32 %b, 1
18289 // brcond i32 %c ...
18290 //
18291 // into
18292 //
18293 // %a = ...
18294 // %b = and i32 %a, 2
18295 // %c = setcc eq %b, 0
18296 // brcond %c ...
18297 //
18298 // This applies only when the AND constant value has one bit set and the
18299 // SRL constant is equal to the log2 of the AND constant. The back-end is
18300 // smart enough to convert the result into a TEST/JMP sequence.
18301 SDValue Op0 = N.getOperand(i: 0);
18302 SDValue Op1 = N.getOperand(i: 1);
18303
18304 if (Op0.getOpcode() == ISD::AND && Op1.getOpcode() == ISD::Constant) {
18305 SDValue AndOp1 = Op0.getOperand(i: 1);
18306
18307 if (AndOp1.getOpcode() == ISD::Constant) {
18308 const APInt &AndConst = AndOp1->getAsAPIntVal();
18309
18310 if (AndConst.isPowerOf2() &&
18311 Op1->getAsAPIntVal() == AndConst.logBase2()) {
18312 SDLoc DL(N);
18313 return DAG.getSetCC(DL, VT: getSetCCResultType(VT: Op0.getValueType()),
18314 LHS: Op0, RHS: DAG.getConstant(Val: 0, DL, VT: Op0.getValueType()),
18315 Cond: ISD::SETNE);
18316 }
18317 }
18318 }
18319 }
18320
18321 // Transform (brcond (xor x, y)) -> (brcond (setcc, x, y, ne))
18322 // Transform (brcond (xor (xor x, y), -1)) -> (brcond (setcc, x, y, eq))
18323 if (N.getOpcode() == ISD::XOR) {
18324 // Because we may call this on a speculatively constructed
18325 // SimplifiedSetCC Node, we need to simplify this node first.
18326 // Ideally this should be folded into SimplifySetCC and not
18327 // here. For now, grab a handle to N so we don't lose it from
18328 // replacements interal to the visit.
18329 HandleSDNode XORHandle(N);
18330 while (N.getOpcode() == ISD::XOR) {
18331 SDValue Tmp = visitXOR(N: N.getNode());
18332 // No simplification done.
18333 if (!Tmp.getNode())
18334 break;
18335 // Returning N is form in-visit replacement that may invalidated
18336 // N. Grab value from Handle.
18337 if (Tmp.getNode() == N.getNode())
18338 N = XORHandle.getValue();
18339 else // Node simplified. Try simplifying again.
18340 N = Tmp;
18341 }
18342
18343 if (N.getOpcode() != ISD::XOR)
18344 return N;
18345
18346 SDValue Op0 = N->getOperand(Num: 0);
18347 SDValue Op1 = N->getOperand(Num: 1);
18348
18349 if (Op0.getOpcode() != ISD::SETCC && Op1.getOpcode() != ISD::SETCC) {
18350 bool Equal = false;
18351 // (brcond (xor (xor x, y), -1)) -> (brcond (setcc x, y, eq))
18352 if (isBitwiseNot(V: N) && Op0.hasOneUse() && Op0.getOpcode() == ISD::XOR &&
18353 Op0.getValueType() == MVT::i1) {
18354 N = Op0;
18355 Op0 = N->getOperand(Num: 0);
18356 Op1 = N->getOperand(Num: 1);
18357 Equal = true;
18358 }
18359
18360 EVT SetCCVT = N.getValueType();
18361 if (LegalTypes)
18362 SetCCVT = getSetCCResultType(VT: SetCCVT);
18363 // Replace the uses of XOR with SETCC
18364 return DAG.getSetCC(DL: SDLoc(N), VT: SetCCVT, LHS: Op0, RHS: Op1,
18365 Cond: Equal ? ISD::SETEQ : ISD::SETNE);
18366 }
18367 }
18368
18369 return SDValue();
18370}
18371
18372// Operand List for BR_CC: Chain, CondCC, CondLHS, CondRHS, DestBB.
18373//
18374SDValue DAGCombiner::visitBR_CC(SDNode *N) {
18375 CondCodeSDNode *CC = cast<CondCodeSDNode>(Val: N->getOperand(Num: 1));
18376 SDValue CondLHS = N->getOperand(Num: 2), CondRHS = N->getOperand(Num: 3);
18377
18378 // If N is a constant we could fold this into a fallthrough or unconditional
18379 // branch. However that doesn't happen very often in normal code, because
18380 // Instcombine/SimplifyCFG should have handled the available opportunities.
18381 // If we did this folding here, it would be necessary to update the
18382 // MachineBasicBlock CFG, which is awkward.
18383
18384 // Use SimplifySetCC to simplify SETCC's.
18385 SDValue Simp = SimplifySetCC(VT: getSetCCResultType(VT: CondLHS.getValueType()),
18386 N0: CondLHS, N1: CondRHS, Cond: CC->get(), DL: SDLoc(N),
18387 foldBooleans: false);
18388 if (Simp.getNode()) AddToWorklist(N: Simp.getNode());
18389
18390 // fold to a simpler setcc
18391 if (Simp.getNode() && Simp.getOpcode() == ISD::SETCC)
18392 return DAG.getNode(Opcode: ISD::BR_CC, DL: SDLoc(N), VT: MVT::Other,
18393 N1: N->getOperand(Num: 0), N2: Simp.getOperand(i: 2),
18394 N3: Simp.getOperand(i: 0), N4: Simp.getOperand(i: 1),
18395 N5: N->getOperand(Num: 4));
18396
18397 return SDValue();
18398}
18399
18400static bool getCombineLoadStoreParts(SDNode *N, unsigned Inc, unsigned Dec,
18401 bool &IsLoad, bool &IsMasked, SDValue &Ptr,
18402 const TargetLowering &TLI) {
18403 if (LoadSDNode *LD = dyn_cast<LoadSDNode>(Val: N)) {
18404 if (LD->isIndexed())
18405 return false;
18406 EVT VT = LD->getMemoryVT();
18407 if (!TLI.isIndexedLoadLegal(IdxMode: Inc, VT) && !TLI.isIndexedLoadLegal(IdxMode: Dec, VT))
18408 return false;
18409 Ptr = LD->getBasePtr();
18410 } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(Val: N)) {
18411 if (ST->isIndexed())
18412 return false;
18413 EVT VT = ST->getMemoryVT();
18414 if (!TLI.isIndexedStoreLegal(IdxMode: Inc, VT) && !TLI.isIndexedStoreLegal(IdxMode: Dec, VT))
18415 return false;
18416 Ptr = ST->getBasePtr();
18417 IsLoad = false;
18418 } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(Val: N)) {
18419 if (LD->isIndexed())
18420 return false;
18421 EVT VT = LD->getMemoryVT();
18422 if (!TLI.isIndexedMaskedLoadLegal(IdxMode: Inc, VT) &&
18423 !TLI.isIndexedMaskedLoadLegal(IdxMode: Dec, VT))
18424 return false;
18425 Ptr = LD->getBasePtr();
18426 IsMasked = true;
18427 } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(Val: N)) {
18428 if (ST->isIndexed())
18429 return false;
18430 EVT VT = ST->getMemoryVT();
18431 if (!TLI.isIndexedMaskedStoreLegal(IdxMode: Inc, VT) &&
18432 !TLI.isIndexedMaskedStoreLegal(IdxMode: Dec, VT))
18433 return false;
18434 Ptr = ST->getBasePtr();
18435 IsLoad = false;
18436 IsMasked = true;
18437 } else {
18438 return false;
18439 }
18440 return true;
18441}
18442
18443/// Try turning a load/store into a pre-indexed load/store when the base
18444/// pointer is an add or subtract and it has other uses besides the load/store.
18445/// After the transformation, the new indexed load/store has effectively folded
18446/// the add/subtract in and all of its other uses are redirected to the
18447/// new load/store.
18448bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) {
18449 if (Level < AfterLegalizeDAG)
18450 return false;
18451
18452 bool IsLoad = true;
18453 bool IsMasked = false;
18454 SDValue Ptr;
18455 if (!getCombineLoadStoreParts(N, Inc: ISD::PRE_INC, Dec: ISD::PRE_DEC, IsLoad, IsMasked,
18456 Ptr, TLI))
18457 return false;
18458
18459 // If the pointer is not an add/sub, or if it doesn't have multiple uses, bail
18460 // out. There is no reason to make this a preinc/predec.
18461 if ((Ptr.getOpcode() != ISD::ADD && Ptr.getOpcode() != ISD::SUB) ||
18462 Ptr->hasOneUse())
18463 return false;
18464
18465 // Ask the target to do addressing mode selection.
18466 SDValue BasePtr;
18467 SDValue Offset;
18468 ISD::MemIndexedMode AM = ISD::UNINDEXED;
18469 if (!TLI.getPreIndexedAddressParts(N, BasePtr, Offset, AM, DAG))
18470 return false;
18471
18472 // Backends without true r+i pre-indexed forms may need to pass a
18473 // constant base with a variable offset so that constant coercion
18474 // will work with the patterns in canonical form.
18475 bool Swapped = false;
18476 if (isa<ConstantSDNode>(Val: BasePtr)) {
18477 std::swap(a&: BasePtr, b&: Offset);
18478 Swapped = true;
18479 }
18480
18481 // Don't create a indexed load / store with zero offset.
18482 if (isNullConstant(V: Offset))
18483 return false;
18484
18485 // Try turning it into a pre-indexed load / store except when:
18486 // 1) The new base ptr is a frame index.
18487 // 2) If N is a store and the new base ptr is either the same as or is a
18488 // predecessor of the value being stored.
18489 // 3) Another use of old base ptr is a predecessor of N. If ptr is folded
18490 // that would create a cycle.
18491 // 4) All uses are load / store ops that use it as old base ptr.
18492
18493 // Check #1. Preinc'ing a frame index would require copying the stack pointer
18494 // (plus the implicit offset) to a register to preinc anyway.
18495 if (isa<FrameIndexSDNode>(Val: BasePtr) || isa<RegisterSDNode>(Val: BasePtr))
18496 return false;
18497
18498 // Check #2.
18499 if (!IsLoad) {
18500 SDValue Val = IsMasked ? cast<MaskedStoreSDNode>(Val: N)->getValue()
18501 : cast<StoreSDNode>(Val: N)->getValue();
18502
18503 // Would require a copy.
18504 if (Val == BasePtr)
18505 return false;
18506
18507 // Would create a cycle.
18508 if (Val == Ptr || Ptr->isPredecessorOf(N: Val.getNode()))
18509 return false;
18510 }
18511
18512 // Caches for hasPredecessorHelper.
18513 SmallPtrSet<const SDNode *, 32> Visited;
18514 SmallVector<const SDNode *, 16> Worklist;
18515 Worklist.push_back(Elt: N);
18516
18517 // If the offset is a constant, there may be other adds of constants that
18518 // can be folded with this one. We should do this to avoid having to keep
18519 // a copy of the original base pointer.
18520 SmallVector<SDNode *, 16> OtherUses;
18521 constexpr unsigned int MaxSteps = 8192;
18522 if (isa<ConstantSDNode>(Val: Offset))
18523 for (SDNode::use_iterator UI = BasePtr->use_begin(),
18524 UE = BasePtr->use_end();
18525 UI != UE; ++UI) {
18526 SDUse &Use = UI.getUse();
18527 // Skip the use that is Ptr and uses of other results from BasePtr's
18528 // node (important for nodes that return multiple results).
18529 if (Use.getUser() == Ptr.getNode() || Use != BasePtr)
18530 continue;
18531
18532 if (SDNode::hasPredecessorHelper(N: Use.getUser(), Visited, Worklist,
18533 MaxSteps))
18534 continue;
18535
18536 if (Use.getUser()->getOpcode() != ISD::ADD &&
18537 Use.getUser()->getOpcode() != ISD::SUB) {
18538 OtherUses.clear();
18539 break;
18540 }
18541
18542 SDValue Op1 = Use.getUser()->getOperand(Num: (UI.getOperandNo() + 1) & 1);
18543 if (!isa<ConstantSDNode>(Val: Op1)) {
18544 OtherUses.clear();
18545 break;
18546 }
18547
18548 // FIXME: In some cases, we can be smarter about this.
18549 if (Op1.getValueType() != Offset.getValueType()) {
18550 OtherUses.clear();
18551 break;
18552 }
18553
18554 OtherUses.push_back(Elt: Use.getUser());
18555 }
18556
18557 if (Swapped)
18558 std::swap(a&: BasePtr, b&: Offset);
18559
18560 // Now check for #3 and #4.
18561 bool RealUse = false;
18562
18563 for (SDNode *Use : Ptr->uses()) {
18564 if (Use == N)
18565 continue;
18566 if (SDNode::hasPredecessorHelper(N: Use, Visited, Worklist, MaxSteps))
18567 return false;
18568
18569 // If Ptr may be folded in addressing mode of other use, then it's
18570 // not profitable to do this transformation.
18571 if (!canFoldInAddressingMode(N: Ptr.getNode(), Use, DAG, TLI))
18572 RealUse = true;
18573 }
18574
18575 if (!RealUse)
18576 return false;
18577
18578 SDValue Result;
18579 if (!IsMasked) {
18580 if (IsLoad)
18581 Result = DAG.getIndexedLoad(OrigLoad: SDValue(N, 0), dl: SDLoc(N), Base: BasePtr, Offset, AM);
18582 else
18583 Result =
18584 DAG.getIndexedStore(OrigStore: SDValue(N, 0), dl: SDLoc(N), Base: BasePtr, Offset, AM);
18585 } else {
18586 if (IsLoad)
18587 Result = DAG.getIndexedMaskedLoad(OrigLoad: SDValue(N, 0), dl: SDLoc(N), Base: BasePtr,
18588 Offset, AM);
18589 else
18590 Result = DAG.getIndexedMaskedStore(OrigStore: SDValue(N, 0), dl: SDLoc(N), Base: BasePtr,
18591 Offset, AM);
18592 }
18593 ++PreIndexedNodes;
18594 ++NodesCombined;
18595 LLVM_DEBUG(dbgs() << "\nReplacing.4 "; N->dump(&DAG); dbgs() << "\nWith: ";
18596 Result.dump(&DAG); dbgs() << '\n');
18597 WorklistRemover DeadNodes(*this);
18598 if (IsLoad) {
18599 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Result.getValue(R: 0));
18600 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 1), To: Result.getValue(R: 2));
18601 } else {
18602 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Result.getValue(R: 1));
18603 }
18604
18605 // Finally, since the node is now dead, remove it from the graph.
18606 deleteAndRecombine(N);
18607
18608 if (Swapped)
18609 std::swap(a&: BasePtr, b&: Offset);
18610
18611 // Replace other uses of BasePtr that can be updated to use Ptr
18612 for (unsigned i = 0, e = OtherUses.size(); i != e; ++i) {
18613 unsigned OffsetIdx = 1;
18614 if (OtherUses[i]->getOperand(Num: OffsetIdx).getNode() == BasePtr.getNode())
18615 OffsetIdx = 0;
18616 assert(OtherUses[i]->getOperand(!OffsetIdx).getNode() ==
18617 BasePtr.getNode() && "Expected BasePtr operand");
18618
18619 // We need to replace ptr0 in the following expression:
18620 // x0 * offset0 + y0 * ptr0 = t0
18621 // knowing that
18622 // x1 * offset1 + y1 * ptr0 = t1 (the indexed load/store)
18623 //
18624 // where x0, x1, y0 and y1 in {-1, 1} are given by the types of the
18625 // indexed load/store and the expression that needs to be re-written.
18626 //
18627 // Therefore, we have:
18628 // t0 = (x0 * offset0 - x1 * y0 * y1 *offset1) + (y0 * y1) * t1
18629
18630 auto *CN = cast<ConstantSDNode>(Val: OtherUses[i]->getOperand(Num: OffsetIdx));
18631 const APInt &Offset0 = CN->getAPIntValue();
18632 const APInt &Offset1 = Offset->getAsAPIntVal();
18633 int X0 = (OtherUses[i]->getOpcode() == ISD::SUB && OffsetIdx == 1) ? -1 : 1;
18634 int Y0 = (OtherUses[i]->getOpcode() == ISD::SUB && OffsetIdx == 0) ? -1 : 1;
18635 int X1 = (AM == ISD::PRE_DEC && !Swapped) ? -1 : 1;
18636 int Y1 = (AM == ISD::PRE_DEC && Swapped) ? -1 : 1;
18637
18638 unsigned Opcode = (Y0 * Y1 < 0) ? ISD::SUB : ISD::ADD;
18639
18640 APInt CNV = Offset0;
18641 if (X0 < 0) CNV = -CNV;
18642 if (X1 * Y0 * Y1 < 0) CNV = CNV + Offset1;
18643 else CNV = CNV - Offset1;
18644
18645 SDLoc DL(OtherUses[i]);
18646
18647 // We can now generate the new expression.
18648 SDValue NewOp1 = DAG.getConstant(Val: CNV, DL, VT: CN->getValueType(ResNo: 0));
18649 SDValue NewOp2 = Result.getValue(R: IsLoad ? 1 : 0);
18650
18651 SDValue NewUse = DAG.getNode(Opcode,
18652 DL,
18653 VT: OtherUses[i]->getValueType(ResNo: 0), N1: NewOp1, N2: NewOp2);
18654 DAG.ReplaceAllUsesOfValueWith(From: SDValue(OtherUses[i], 0), To: NewUse);
18655 deleteAndRecombine(N: OtherUses[i]);
18656 }
18657
18658 // Replace the uses of Ptr with uses of the updated base value.
18659 DAG.ReplaceAllUsesOfValueWith(From: Ptr, To: Result.getValue(R: IsLoad ? 1 : 0));
18660 deleteAndRecombine(N: Ptr.getNode());
18661 AddToWorklist(N: Result.getNode());
18662
18663 return true;
18664}
18665
18666static bool shouldCombineToPostInc(SDNode *N, SDValue Ptr, SDNode *PtrUse,
18667 SDValue &BasePtr, SDValue &Offset,
18668 ISD::MemIndexedMode &AM,
18669 SelectionDAG &DAG,
18670 const TargetLowering &TLI) {
18671 if (PtrUse == N ||
18672 (PtrUse->getOpcode() != ISD::ADD && PtrUse->getOpcode() != ISD::SUB))
18673 return false;
18674
18675 if (!TLI.getPostIndexedAddressParts(N, PtrUse, BasePtr, Offset, AM, DAG))
18676 return false;
18677
18678 // Don't create a indexed load / store with zero offset.
18679 if (isNullConstant(V: Offset))
18680 return false;
18681
18682 if (isa<FrameIndexSDNode>(Val: BasePtr) || isa<RegisterSDNode>(Val: BasePtr))
18683 return false;
18684
18685 SmallPtrSet<const SDNode *, 32> Visited;
18686 for (SDNode *Use : BasePtr->uses()) {
18687 if (Use == Ptr.getNode())
18688 continue;
18689
18690 // No if there's a later user which could perform the index instead.
18691 if (isa<MemSDNode>(Val: Use)) {
18692 bool IsLoad = true;
18693 bool IsMasked = false;
18694 SDValue OtherPtr;
18695 if (getCombineLoadStoreParts(N: Use, Inc: ISD::POST_INC, Dec: ISD::POST_DEC, IsLoad,
18696 IsMasked, Ptr&: OtherPtr, TLI)) {
18697 SmallVector<const SDNode *, 2> Worklist;
18698 Worklist.push_back(Elt: Use);
18699 if (SDNode::hasPredecessorHelper(N, Visited, Worklist))
18700 return false;
18701 }
18702 }
18703
18704 // If all the uses are load / store addresses, then don't do the
18705 // transformation.
18706 if (Use->getOpcode() == ISD::ADD || Use->getOpcode() == ISD::SUB) {
18707 for (SDNode *UseUse : Use->uses())
18708 if (canFoldInAddressingMode(N: Use, Use: UseUse, DAG, TLI))
18709 return false;
18710 }
18711 }
18712 return true;
18713}
18714
18715static SDNode *getPostIndexedLoadStoreOp(SDNode *N, bool &IsLoad,
18716 bool &IsMasked, SDValue &Ptr,
18717 SDValue &BasePtr, SDValue &Offset,
18718 ISD::MemIndexedMode &AM,
18719 SelectionDAG &DAG,
18720 const TargetLowering &TLI) {
18721 if (!getCombineLoadStoreParts(N, Inc: ISD::POST_INC, Dec: ISD::POST_DEC, IsLoad,
18722 IsMasked, Ptr, TLI) ||
18723 Ptr->hasOneUse())
18724 return nullptr;
18725
18726 // Try turning it into a post-indexed load / store except when
18727 // 1) All uses are load / store ops that use it as base ptr (and
18728 // it may be folded as addressing mmode).
18729 // 2) Op must be independent of N, i.e. Op is neither a predecessor
18730 // nor a successor of N. Otherwise, if Op is folded that would
18731 // create a cycle.
18732 for (SDNode *Op : Ptr->uses()) {
18733 // Check for #1.
18734 if (!shouldCombineToPostInc(N, Ptr, PtrUse: Op, BasePtr, Offset, AM, DAG, TLI))
18735 continue;
18736
18737 // Check for #2.
18738 SmallPtrSet<const SDNode *, 32> Visited;
18739 SmallVector<const SDNode *, 8> Worklist;
18740 constexpr unsigned int MaxSteps = 8192;
18741 // Ptr is predecessor to both N and Op.
18742 Visited.insert(Ptr: Ptr.getNode());
18743 Worklist.push_back(Elt: N);
18744 Worklist.push_back(Elt: Op);
18745 if (!SDNode::hasPredecessorHelper(N, Visited, Worklist, MaxSteps) &&
18746 !SDNode::hasPredecessorHelper(N: Op, Visited, Worklist, MaxSteps))
18747 return Op;
18748 }
18749 return nullptr;
18750}
18751
18752/// Try to combine a load/store with a add/sub of the base pointer node into a
18753/// post-indexed load/store. The transformation folded the add/subtract into the
18754/// new indexed load/store effectively and all of its uses are redirected to the
18755/// new load/store.
18756bool DAGCombiner::CombineToPostIndexedLoadStore(SDNode *N) {
18757 if (Level < AfterLegalizeDAG)
18758 return false;
18759
18760 bool IsLoad = true;
18761 bool IsMasked = false;
18762 SDValue Ptr;
18763 SDValue BasePtr;
18764 SDValue Offset;
18765 ISD::MemIndexedMode AM = ISD::UNINDEXED;
18766 SDNode *Op = getPostIndexedLoadStoreOp(N, IsLoad, IsMasked, Ptr, BasePtr,
18767 Offset, AM, DAG, TLI);
18768 if (!Op)
18769 return false;
18770
18771 SDValue Result;
18772 if (!IsMasked)
18773 Result = IsLoad ? DAG.getIndexedLoad(OrigLoad: SDValue(N, 0), dl: SDLoc(N), Base: BasePtr,
18774 Offset, AM)
18775 : DAG.getIndexedStore(OrigStore: SDValue(N, 0), dl: SDLoc(N),
18776 Base: BasePtr, Offset, AM);
18777 else
18778 Result = IsLoad ? DAG.getIndexedMaskedLoad(OrigLoad: SDValue(N, 0), dl: SDLoc(N),
18779 Base: BasePtr, Offset, AM)
18780 : DAG.getIndexedMaskedStore(OrigStore: SDValue(N, 0), dl: SDLoc(N),
18781 Base: BasePtr, Offset, AM);
18782 ++PostIndexedNodes;
18783 ++NodesCombined;
18784 LLVM_DEBUG(dbgs() << "\nReplacing.5 "; N->dump(&DAG); dbgs() << "\nWith: ";
18785 Result.dump(&DAG); dbgs() << '\n');
18786 WorklistRemover DeadNodes(*this);
18787 if (IsLoad) {
18788 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Result.getValue(R: 0));
18789 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 1), To: Result.getValue(R: 2));
18790 } else {
18791 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Result.getValue(R: 1));
18792 }
18793
18794 // Finally, since the node is now dead, remove it from the graph.
18795 deleteAndRecombine(N);
18796
18797 // Replace the uses of Use with uses of the updated base value.
18798 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Op, 0),
18799 To: Result.getValue(R: IsLoad ? 1 : 0));
18800 deleteAndRecombine(N: Op);
18801 return true;
18802}
18803
18804/// Return the base-pointer arithmetic from an indexed \p LD.
18805SDValue DAGCombiner::SplitIndexingFromLoad(LoadSDNode *LD) {
18806 ISD::MemIndexedMode AM = LD->getAddressingMode();
18807 assert(AM != ISD::UNINDEXED);
18808 SDValue BP = LD->getOperand(Num: 1);
18809 SDValue Inc = LD->getOperand(Num: 2);
18810
18811 // Some backends use TargetConstants for load offsets, but don't expect
18812 // TargetConstants in general ADD nodes. We can convert these constants into
18813 // regular Constants (if the constant is not opaque).
18814 assert((Inc.getOpcode() != ISD::TargetConstant ||
18815 !cast<ConstantSDNode>(Inc)->isOpaque()) &&
18816 "Cannot split out indexing using opaque target constants");
18817 if (Inc.getOpcode() == ISD::TargetConstant) {
18818 ConstantSDNode *ConstInc = cast<ConstantSDNode>(Val&: Inc);
18819 Inc = DAG.getConstant(Val: *ConstInc->getConstantIntValue(), DL: SDLoc(Inc),
18820 VT: ConstInc->getValueType(ResNo: 0));
18821 }
18822
18823 unsigned Opc =
18824 (AM == ISD::PRE_INC || AM == ISD::POST_INC ? ISD::ADD : ISD::SUB);
18825 return DAG.getNode(Opcode: Opc, DL: SDLoc(LD), VT: BP.getSimpleValueType(), N1: BP, N2: Inc);
18826}
18827
18828static inline ElementCount numVectorEltsOrZero(EVT T) {
18829 return T.isVector() ? T.getVectorElementCount() : ElementCount::getFixed(MinVal: 0);
18830}
18831
18832bool DAGCombiner::getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val) {
18833 EVT STType = Val.getValueType();
18834 EVT STMemType = ST->getMemoryVT();
18835 if (STType == STMemType)
18836 return true;
18837 if (isTypeLegal(VT: STMemType))
18838 return false; // fail.
18839 if (STType.isFloatingPoint() && STMemType.isFloatingPoint() &&
18840 TLI.isOperationLegal(Op: ISD::FTRUNC, VT: STMemType)) {
18841 Val = DAG.getNode(Opcode: ISD::FTRUNC, DL: SDLoc(ST), VT: STMemType, Operand: Val);
18842 return true;
18843 }
18844 if (numVectorEltsOrZero(T: STType) == numVectorEltsOrZero(T: STMemType) &&
18845 STType.isInteger() && STMemType.isInteger()) {
18846 Val = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(ST), VT: STMemType, Operand: Val);
18847 return true;
18848 }
18849 if (STType.getSizeInBits() == STMemType.getSizeInBits()) {
18850 Val = DAG.getBitcast(VT: STMemType, V: Val);
18851 return true;
18852 }
18853 return false; // fail.
18854}
18855
18856bool DAGCombiner::extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val) {
18857 EVT LDMemType = LD->getMemoryVT();
18858 EVT LDType = LD->getValueType(ResNo: 0);
18859 assert(Val.getValueType() == LDMemType &&
18860 "Attempting to extend value of non-matching type");
18861 if (LDType == LDMemType)
18862 return true;
18863 if (LDMemType.isInteger() && LDType.isInteger()) {
18864 switch (LD->getExtensionType()) {
18865 case ISD::NON_EXTLOAD:
18866 Val = DAG.getBitcast(VT: LDType, V: Val);
18867 return true;
18868 case ISD::EXTLOAD:
18869 Val = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL: SDLoc(LD), VT: LDType, Operand: Val);
18870 return true;
18871 case ISD::SEXTLOAD:
18872 Val = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL: SDLoc(LD), VT: LDType, Operand: Val);
18873 return true;
18874 case ISD::ZEXTLOAD:
18875 Val = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL: SDLoc(LD), VT: LDType, Operand: Val);
18876 return true;
18877 }
18878 }
18879 return false;
18880}
18881
18882StoreSDNode *DAGCombiner::getUniqueStoreFeeding(LoadSDNode *LD,
18883 int64_t &Offset) {
18884 SDValue Chain = LD->getOperand(Num: 0);
18885
18886 // Look through CALLSEQ_START.
18887 if (Chain.getOpcode() == ISD::CALLSEQ_START)
18888 Chain = Chain->getOperand(Num: 0);
18889
18890 StoreSDNode *ST = nullptr;
18891 SmallVector<SDValue, 8> Aliases;
18892 if (Chain.getOpcode() == ISD::TokenFactor) {
18893 // Look for unique store within the TokenFactor.
18894 for (SDValue Op : Chain->ops()) {
18895 StoreSDNode *Store = dyn_cast<StoreSDNode>(Val: Op.getNode());
18896 if (!Store)
18897 continue;
18898 BaseIndexOffset BasePtrLD = BaseIndexOffset::match(N: LD, DAG);
18899 BaseIndexOffset BasePtrST = BaseIndexOffset::match(N: Store, DAG);
18900 if (!BasePtrST.equalBaseIndex(Other: BasePtrLD, DAG, Off&: Offset))
18901 continue;
18902 // Make sure the store is not aliased with any nodes in TokenFactor.
18903 GatherAllAliases(N: Store, OriginalChain: Chain, Aliases);
18904 if (Aliases.empty() ||
18905 (Aliases.size() == 1 && Aliases.front().getNode() == Store))
18906 ST = Store;
18907 break;
18908 }
18909 } else {
18910 StoreSDNode *Store = dyn_cast<StoreSDNode>(Val: Chain.getNode());
18911 if (Store) {
18912 BaseIndexOffset BasePtrLD = BaseIndexOffset::match(N: LD, DAG);
18913 BaseIndexOffset BasePtrST = BaseIndexOffset::match(N: Store, DAG);
18914 if (BasePtrST.equalBaseIndex(Other: BasePtrLD, DAG, Off&: Offset))
18915 ST = Store;
18916 }
18917 }
18918
18919 return ST;
18920}
18921
18922SDValue DAGCombiner::ForwardStoreValueToDirectLoad(LoadSDNode *LD) {
18923 if (OptLevel == CodeGenOptLevel::None || !LD->isSimple())
18924 return SDValue();
18925 SDValue Chain = LD->getOperand(Num: 0);
18926 int64_t Offset;
18927
18928 StoreSDNode *ST = getUniqueStoreFeeding(LD, Offset);
18929 // TODO: Relax this restriction for unordered atomics (see D66309)
18930 if (!ST || !ST->isSimple() || ST->getAddressSpace() != LD->getAddressSpace())
18931 return SDValue();
18932
18933 EVT LDType = LD->getValueType(ResNo: 0);
18934 EVT LDMemType = LD->getMemoryVT();
18935 EVT STMemType = ST->getMemoryVT();
18936 EVT STType = ST->getValue().getValueType();
18937
18938 // There are two cases to consider here:
18939 // 1. The store is fixed width and the load is scalable. In this case we
18940 // don't know at compile time if the store completely envelops the load
18941 // so we abandon the optimisation.
18942 // 2. The store is scalable and the load is fixed width. We could
18943 // potentially support a limited number of cases here, but there has been
18944 // no cost-benefit analysis to prove it's worth it.
18945 bool LdStScalable = LDMemType.isScalableVT();
18946 if (LdStScalable != STMemType.isScalableVT())
18947 return SDValue();
18948
18949 // If we are dealing with scalable vectors on a big endian platform the
18950 // calculation of offsets below becomes trickier, since we do not know at
18951 // compile time the absolute size of the vector. Until we've done more
18952 // analysis on big-endian platforms it seems better to bail out for now.
18953 if (LdStScalable && DAG.getDataLayout().isBigEndian())
18954 return SDValue();
18955
18956 // Normalize for Endianness. After this Offset=0 will denote that the least
18957 // significant bit in the loaded value maps to the least significant bit in
18958 // the stored value). With Offset=n (for n > 0) the loaded value starts at the
18959 // n:th least significant byte of the stored value.
18960 int64_t OrigOffset = Offset;
18961 if (DAG.getDataLayout().isBigEndian())
18962 Offset = ((int64_t)STMemType.getStoreSizeInBits().getFixedValue() -
18963 (int64_t)LDMemType.getStoreSizeInBits().getFixedValue()) /
18964 8 -
18965 Offset;
18966
18967 // Check that the stored value cover all bits that are loaded.
18968 bool STCoversLD;
18969
18970 TypeSize LdMemSize = LDMemType.getSizeInBits();
18971 TypeSize StMemSize = STMemType.getSizeInBits();
18972 if (LdStScalable)
18973 STCoversLD = (Offset == 0) && LdMemSize == StMemSize;
18974 else
18975 STCoversLD = (Offset >= 0) && (Offset * 8 + LdMemSize.getFixedValue() <=
18976 StMemSize.getFixedValue());
18977
18978 auto ReplaceLd = [&](LoadSDNode *LD, SDValue Val, SDValue Chain) -> SDValue {
18979 if (LD->isIndexed()) {
18980 // Cannot handle opaque target constants and we must respect the user's
18981 // request not to split indexes from loads.
18982 if (!canSplitIdx(LD))
18983 return SDValue();
18984 SDValue Idx = SplitIndexingFromLoad(LD);
18985 SDValue Ops[] = {Val, Idx, Chain};
18986 return CombineTo(N: LD, To: Ops, NumTo: 3);
18987 }
18988 return CombineTo(N: LD, Res0: Val, Res1: Chain);
18989 };
18990
18991 if (!STCoversLD)
18992 return SDValue();
18993
18994 // Memory as copy space (potentially masked).
18995 if (Offset == 0 && LDType == STType && STMemType == LDMemType) {
18996 // Simple case: Direct non-truncating forwarding
18997 if (LDType.getSizeInBits() == LdMemSize)
18998 return ReplaceLd(LD, ST->getValue(), Chain);
18999 // Can we model the truncate and extension with an and mask?
19000 if (STType.isInteger() && LDMemType.isInteger() && !STType.isVector() &&
19001 !LDMemType.isVector() && LD->getExtensionType() != ISD::SEXTLOAD) {
19002 // Mask to size of LDMemType
19003 auto Mask =
19004 DAG.getConstant(Val: APInt::getLowBitsSet(numBits: STType.getFixedSizeInBits(),
19005 loBitsSet: StMemSize.getFixedValue()),
19006 DL: SDLoc(ST), VT: STType);
19007 auto Val = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(LD), VT: LDType, N1: ST->getValue(), N2: Mask);
19008 return ReplaceLd(LD, Val, Chain);
19009 }
19010 }
19011
19012 // Handle some cases for big-endian that would be Offset 0 and handled for
19013 // little-endian.
19014 SDValue Val = ST->getValue();
19015 if (DAG.getDataLayout().isBigEndian() && Offset > 0 && OrigOffset == 0) {
19016 if (STType.isInteger() && !STType.isVector() && LDType.isInteger() &&
19017 !LDType.isVector() && isTypeLegal(VT: STType) &&
19018 TLI.isOperationLegal(Op: ISD::SRL, VT: STType)) {
19019 Val = DAG.getNode(Opcode: ISD::SRL, DL: SDLoc(LD), VT: STType, N1: Val,
19020 N2: DAG.getConstant(Val: Offset * 8, DL: SDLoc(LD), VT: STType));
19021 Offset = 0;
19022 }
19023 }
19024
19025 // TODO: Deal with nonzero offset.
19026 if (LD->getBasePtr().isUndef() || Offset != 0)
19027 return SDValue();
19028 // Model necessary truncations / extenstions.
19029 // Truncate Value To Stored Memory Size.
19030 do {
19031 if (!getTruncatedStoreValue(ST, Val))
19032 break;
19033 if (!isTypeLegal(VT: LDMemType))
19034 break;
19035 if (STMemType != LDMemType) {
19036 // TODO: Support vectors? This requires extract_subvector/bitcast.
19037 if (!STMemType.isVector() && !LDMemType.isVector() &&
19038 STMemType.isInteger() && LDMemType.isInteger())
19039 Val = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(LD), VT: LDMemType, Operand: Val);
19040 else
19041 break;
19042 }
19043 if (!extendLoadedValueToExtension(LD, Val))
19044 break;
19045 return ReplaceLd(LD, Val, Chain);
19046 } while (false);
19047
19048 // On failure, cleanup dead nodes we may have created.
19049 if (Val->use_empty())
19050 deleteAndRecombine(N: Val.getNode());
19051 return SDValue();
19052}
19053
19054SDValue DAGCombiner::visitLOAD(SDNode *N) {
19055 LoadSDNode *LD = cast<LoadSDNode>(Val: N);
19056 SDValue Chain = LD->getChain();
19057 SDValue Ptr = LD->getBasePtr();
19058
19059 // If load is not volatile and there are no uses of the loaded value (and
19060 // the updated indexed value in case of indexed loads), change uses of the
19061 // chain value into uses of the chain input (i.e. delete the dead load).
19062 // TODO: Allow this for unordered atomics (see D66309)
19063 if (LD->isSimple()) {
19064 if (N->getValueType(ResNo: 1) == MVT::Other) {
19065 // Unindexed loads.
19066 if (!N->hasAnyUseOfValue(Value: 0)) {
19067 // It's not safe to use the two value CombineTo variant here. e.g.
19068 // v1, chain2 = load chain1, loc
19069 // v2, chain3 = load chain2, loc
19070 // v3 = add v2, c
19071 // Now we replace use of chain2 with chain1. This makes the second load
19072 // isomorphic to the one we are deleting, and thus makes this load live.
19073 LLVM_DEBUG(dbgs() << "\nReplacing.6 "; N->dump(&DAG);
19074 dbgs() << "\nWith chain: "; Chain.dump(&DAG);
19075 dbgs() << "\n");
19076 WorklistRemover DeadNodes(*this);
19077 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 1), To: Chain);
19078 AddUsersToWorklist(N: Chain.getNode());
19079 if (N->use_empty())
19080 deleteAndRecombine(N);
19081
19082 return SDValue(N, 0); // Return N so it doesn't get rechecked!
19083 }
19084 } else {
19085 // Indexed loads.
19086 assert(N->getValueType(2) == MVT::Other && "Malformed indexed loads?");
19087
19088 // If this load has an opaque TargetConstant offset, then we cannot split
19089 // the indexing into an add/sub directly (that TargetConstant may not be
19090 // valid for a different type of node, and we cannot convert an opaque
19091 // target constant into a regular constant).
19092 bool CanSplitIdx = canSplitIdx(LD);
19093
19094 if (!N->hasAnyUseOfValue(Value: 0) && (CanSplitIdx || !N->hasAnyUseOfValue(Value: 1))) {
19095 SDValue Undef = DAG.getUNDEF(VT: N->getValueType(ResNo: 0));
19096 SDValue Index;
19097 if (N->hasAnyUseOfValue(Value: 1) && CanSplitIdx) {
19098 Index = SplitIndexingFromLoad(LD);
19099 // Try to fold the base pointer arithmetic into subsequent loads and
19100 // stores.
19101 AddUsersToWorklist(N);
19102 } else
19103 Index = DAG.getUNDEF(VT: N->getValueType(ResNo: 1));
19104 LLVM_DEBUG(dbgs() << "\nReplacing.7 "; N->dump(&DAG);
19105 dbgs() << "\nWith: "; Undef.dump(&DAG);
19106 dbgs() << " and 2 other values\n");
19107 WorklistRemover DeadNodes(*this);
19108 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Undef);
19109 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 1), To: Index);
19110 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 2), To: Chain);
19111 deleteAndRecombine(N);
19112 return SDValue(N, 0); // Return N so it doesn't get rechecked!
19113 }
19114 }
19115 }
19116
19117 // If this load is directly stored, replace the load value with the stored
19118 // value.
19119 if (auto V = ForwardStoreValueToDirectLoad(LD))
19120 return V;
19121
19122 // Try to infer better alignment information than the load already has.
19123 if (OptLevel != CodeGenOptLevel::None && LD->isUnindexed() &&
19124 !LD->isAtomic()) {
19125 if (MaybeAlign Alignment = DAG.InferPtrAlign(Ptr)) {
19126 if (*Alignment > LD->getAlign() &&
19127 isAligned(Lhs: *Alignment, SizeInBytes: LD->getSrcValueOffset())) {
19128 SDValue NewLoad = DAG.getExtLoad(
19129 ExtType: LD->getExtensionType(), dl: SDLoc(N), VT: LD->getValueType(ResNo: 0), Chain, Ptr,
19130 PtrInfo: LD->getPointerInfo(), MemVT: LD->getMemoryVT(), Alignment: *Alignment,
19131 MMOFlags: LD->getMemOperand()->getFlags(), AAInfo: LD->getAAInfo());
19132 // NewLoad will always be N as we are only refining the alignment
19133 assert(NewLoad.getNode() == N);
19134 (void)NewLoad;
19135 }
19136 }
19137 }
19138
19139 if (LD->isUnindexed()) {
19140 // Walk up chain skipping non-aliasing memory nodes.
19141 SDValue BetterChain = FindBetterChain(N: LD, Chain);
19142
19143 // If there is a better chain.
19144 if (Chain != BetterChain) {
19145 SDValue ReplLoad;
19146
19147 // Replace the chain to void dependency.
19148 if (LD->getExtensionType() == ISD::NON_EXTLOAD) {
19149 ReplLoad = DAG.getLoad(VT: N->getValueType(ResNo: 0), dl: SDLoc(LD),
19150 Chain: BetterChain, Ptr, MMO: LD->getMemOperand());
19151 } else {
19152 ReplLoad = DAG.getExtLoad(ExtType: LD->getExtensionType(), dl: SDLoc(LD),
19153 VT: LD->getValueType(ResNo: 0),
19154 Chain: BetterChain, Ptr, MemVT: LD->getMemoryVT(),
19155 MMO: LD->getMemOperand());
19156 }
19157
19158 // Create token factor to keep old chain connected.
19159 SDValue Token = DAG.getNode(Opcode: ISD::TokenFactor, DL: SDLoc(N),
19160 VT: MVT::Other, N1: Chain, N2: ReplLoad.getValue(R: 1));
19161
19162 // Replace uses with load result and token factor
19163 return CombineTo(N, Res0: ReplLoad.getValue(R: 0), Res1: Token);
19164 }
19165 }
19166
19167 // Try transforming N to an indexed load.
19168 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
19169 return SDValue(N, 0);
19170
19171 // Try to slice up N to more direct loads if the slices are mapped to
19172 // different register banks or pairing can take place.
19173 if (SliceUpLoad(N))
19174 return SDValue(N, 0);
19175
19176 return SDValue();
19177}
19178
19179namespace {
19180
19181/// Helper structure used to slice a load in smaller loads.
19182/// Basically a slice is obtained from the following sequence:
19183/// Origin = load Ty1, Base
19184/// Shift = srl Ty1 Origin, CstTy Amount
19185/// Inst = trunc Shift to Ty2
19186///
19187/// Then, it will be rewritten into:
19188/// Slice = load SliceTy, Base + SliceOffset
19189/// [Inst = zext Slice to Ty2], only if SliceTy <> Ty2
19190///
19191/// SliceTy is deduced from the number of bits that are actually used to
19192/// build Inst.
19193struct LoadedSlice {
19194 /// Helper structure used to compute the cost of a slice.
19195 struct Cost {
19196 /// Are we optimizing for code size.
19197 bool ForCodeSize = false;
19198
19199 /// Various cost.
19200 unsigned Loads = 0;
19201 unsigned Truncates = 0;
19202 unsigned CrossRegisterBanksCopies = 0;
19203 unsigned ZExts = 0;
19204 unsigned Shift = 0;
19205
19206 explicit Cost(bool ForCodeSize) : ForCodeSize(ForCodeSize) {}
19207
19208 /// Get the cost of one isolated slice.
19209 Cost(const LoadedSlice &LS, bool ForCodeSize)
19210 : ForCodeSize(ForCodeSize), Loads(1) {
19211 EVT TruncType = LS.Inst->getValueType(ResNo: 0);
19212 EVT LoadedType = LS.getLoadedType();
19213 if (TruncType != LoadedType &&
19214 !LS.DAG->getTargetLoweringInfo().isZExtFree(FromTy: LoadedType, ToTy: TruncType))
19215 ZExts = 1;
19216 }
19217
19218 /// Account for slicing gain in the current cost.
19219 /// Slicing provide a few gains like removing a shift or a
19220 /// truncate. This method allows to grow the cost of the original
19221 /// load with the gain from this slice.
19222 void addSliceGain(const LoadedSlice &LS) {
19223 // Each slice saves a truncate.
19224 const TargetLowering &TLI = LS.DAG->getTargetLoweringInfo();
19225 if (!TLI.isTruncateFree(Val: LS.Inst->getOperand(Num: 0), VT2: LS.Inst->getValueType(ResNo: 0)))
19226 ++Truncates;
19227 // If there is a shift amount, this slice gets rid of it.
19228 if (LS.Shift)
19229 ++Shift;
19230 // If this slice can merge a cross register bank copy, account for it.
19231 if (LS.canMergeExpensiveCrossRegisterBankCopy())
19232 ++CrossRegisterBanksCopies;
19233 }
19234
19235 Cost &operator+=(const Cost &RHS) {
19236 Loads += RHS.Loads;
19237 Truncates += RHS.Truncates;
19238 CrossRegisterBanksCopies += RHS.CrossRegisterBanksCopies;
19239 ZExts += RHS.ZExts;
19240 Shift += RHS.Shift;
19241 return *this;
19242 }
19243
19244 bool operator==(const Cost &RHS) const {
19245 return Loads == RHS.Loads && Truncates == RHS.Truncates &&
19246 CrossRegisterBanksCopies == RHS.CrossRegisterBanksCopies &&
19247 ZExts == RHS.ZExts && Shift == RHS.Shift;
19248 }
19249
19250 bool operator!=(const Cost &RHS) const { return !(*this == RHS); }
19251
19252 bool operator<(const Cost &RHS) const {
19253 // Assume cross register banks copies are as expensive as loads.
19254 // FIXME: Do we want some more target hooks?
19255 unsigned ExpensiveOpsLHS = Loads + CrossRegisterBanksCopies;
19256 unsigned ExpensiveOpsRHS = RHS.Loads + RHS.CrossRegisterBanksCopies;
19257 // Unless we are optimizing for code size, consider the
19258 // expensive operation first.
19259 if (!ForCodeSize && ExpensiveOpsLHS != ExpensiveOpsRHS)
19260 return ExpensiveOpsLHS < ExpensiveOpsRHS;
19261 return (Truncates + ZExts + Shift + ExpensiveOpsLHS) <
19262 (RHS.Truncates + RHS.ZExts + RHS.Shift + ExpensiveOpsRHS);
19263 }
19264
19265 bool operator>(const Cost &RHS) const { return RHS < *this; }
19266
19267 bool operator<=(const Cost &RHS) const { return !(RHS < *this); }
19268
19269 bool operator>=(const Cost &RHS) const { return !(*this < RHS); }
19270 };
19271
19272 // The last instruction that represent the slice. This should be a
19273 // truncate instruction.
19274 SDNode *Inst;
19275
19276 // The original load instruction.
19277 LoadSDNode *Origin;
19278
19279 // The right shift amount in bits from the original load.
19280 unsigned Shift;
19281
19282 // The DAG from which Origin came from.
19283 // This is used to get some contextual information about legal types, etc.
19284 SelectionDAG *DAG;
19285
19286 LoadedSlice(SDNode *Inst = nullptr, LoadSDNode *Origin = nullptr,
19287 unsigned Shift = 0, SelectionDAG *DAG = nullptr)
19288 : Inst(Inst), Origin(Origin), Shift(Shift), DAG(DAG) {}
19289
19290 /// Get the bits used in a chunk of bits \p BitWidth large.
19291 /// \return Result is \p BitWidth and has used bits set to 1 and
19292 /// not used bits set to 0.
19293 APInt getUsedBits() const {
19294 // Reproduce the trunc(lshr) sequence:
19295 // - Start from the truncated value.
19296 // - Zero extend to the desired bit width.
19297 // - Shift left.
19298 assert(Origin && "No original load to compare against.");
19299 unsigned BitWidth = Origin->getValueSizeInBits(ResNo: 0);
19300 assert(Inst && "This slice is not bound to an instruction");
19301 assert(Inst->getValueSizeInBits(0) <= BitWidth &&
19302 "Extracted slice is bigger than the whole type!");
19303 APInt UsedBits(Inst->getValueSizeInBits(ResNo: 0), 0);
19304 UsedBits.setAllBits();
19305 UsedBits = UsedBits.zext(width: BitWidth);
19306 UsedBits <<= Shift;
19307 return UsedBits;
19308 }
19309
19310 /// Get the size of the slice to be loaded in bytes.
19311 unsigned getLoadedSize() const {
19312 unsigned SliceSize = getUsedBits().popcount();
19313 assert(!(SliceSize & 0x7) && "Size is not a multiple of a byte.");
19314 return SliceSize / 8;
19315 }
19316
19317 /// Get the type that will be loaded for this slice.
19318 /// Note: This may not be the final type for the slice.
19319 EVT getLoadedType() const {
19320 assert(DAG && "Missing context");
19321 LLVMContext &Ctxt = *DAG->getContext();
19322 return EVT::getIntegerVT(Context&: Ctxt, BitWidth: getLoadedSize() * 8);
19323 }
19324
19325 /// Get the alignment of the load used for this slice.
19326 Align getAlign() const {
19327 Align Alignment = Origin->getAlign();
19328 uint64_t Offset = getOffsetFromBase();
19329 if (Offset != 0)
19330 Alignment = commonAlignment(A: Alignment, Offset: Alignment.value() + Offset);
19331 return Alignment;
19332 }
19333
19334 /// Check if this slice can be rewritten with legal operations.
19335 bool isLegal() const {
19336 // An invalid slice is not legal.
19337 if (!Origin || !Inst || !DAG)
19338 return false;
19339
19340 // Offsets are for indexed load only, we do not handle that.
19341 if (!Origin->getOffset().isUndef())
19342 return false;
19343
19344 const TargetLowering &TLI = DAG->getTargetLoweringInfo();
19345
19346 // Check that the type is legal.
19347 EVT SliceType = getLoadedType();
19348 if (!TLI.isTypeLegal(VT: SliceType))
19349 return false;
19350
19351 // Check that the load is legal for this type.
19352 if (!TLI.isOperationLegal(Op: ISD::LOAD, VT: SliceType))
19353 return false;
19354
19355 // Check that the offset can be computed.
19356 // 1. Check its type.
19357 EVT PtrType = Origin->getBasePtr().getValueType();
19358 if (PtrType == MVT::Untyped || PtrType.isExtended())
19359 return false;
19360
19361 // 2. Check that it fits in the immediate.
19362 if (!TLI.isLegalAddImmediate(getOffsetFromBase()))
19363 return false;
19364
19365 // 3. Check that the computation is legal.
19366 if (!TLI.isOperationLegal(Op: ISD::ADD, VT: PtrType))
19367 return false;
19368
19369 // Check that the zext is legal if it needs one.
19370 EVT TruncateType = Inst->getValueType(ResNo: 0);
19371 if (TruncateType != SliceType &&
19372 !TLI.isOperationLegal(Op: ISD::ZERO_EXTEND, VT: TruncateType))
19373 return false;
19374
19375 return true;
19376 }
19377
19378 /// Get the offset in bytes of this slice in the original chunk of
19379 /// bits.
19380 /// \pre DAG != nullptr.
19381 uint64_t getOffsetFromBase() const {
19382 assert(DAG && "Missing context.");
19383 bool IsBigEndian = DAG->getDataLayout().isBigEndian();
19384 assert(!(Shift & 0x7) && "Shifts not aligned on Bytes are not supported.");
19385 uint64_t Offset = Shift / 8;
19386 unsigned TySizeInBytes = Origin->getValueSizeInBits(ResNo: 0) / 8;
19387 assert(!(Origin->getValueSizeInBits(0) & 0x7) &&
19388 "The size of the original loaded type is not a multiple of a"
19389 " byte.");
19390 // If Offset is bigger than TySizeInBytes, it means we are loading all
19391 // zeros. This should have been optimized before in the process.
19392 assert(TySizeInBytes > Offset &&
19393 "Invalid shift amount for given loaded size");
19394 if (IsBigEndian)
19395 Offset = TySizeInBytes - Offset - getLoadedSize();
19396 return Offset;
19397 }
19398
19399 /// Generate the sequence of instructions to load the slice
19400 /// represented by this object and redirect the uses of this slice to
19401 /// this new sequence of instructions.
19402 /// \pre this->Inst && this->Origin are valid Instructions and this
19403 /// object passed the legal check: LoadedSlice::isLegal returned true.
19404 /// \return The last instruction of the sequence used to load the slice.
19405 SDValue loadSlice() const {
19406 assert(Inst && Origin && "Unable to replace a non-existing slice.");
19407 const SDValue &OldBaseAddr = Origin->getBasePtr();
19408 SDValue BaseAddr = OldBaseAddr;
19409 // Get the offset in that chunk of bytes w.r.t. the endianness.
19410 int64_t Offset = static_cast<int64_t>(getOffsetFromBase());
19411 assert(Offset >= 0 && "Offset too big to fit in int64_t!");
19412 if (Offset) {
19413 // BaseAddr = BaseAddr + Offset.
19414 EVT ArithType = BaseAddr.getValueType();
19415 SDLoc DL(Origin);
19416 BaseAddr = DAG->getNode(Opcode: ISD::ADD, DL, VT: ArithType, N1: BaseAddr,
19417 N2: DAG->getConstant(Val: Offset, DL, VT: ArithType));
19418 }
19419
19420 // Create the type of the loaded slice according to its size.
19421 EVT SliceType = getLoadedType();
19422
19423 // Create the load for the slice.
19424 SDValue LastInst =
19425 DAG->getLoad(VT: SliceType, dl: SDLoc(Origin), Chain: Origin->getChain(), Ptr: BaseAddr,
19426 PtrInfo: Origin->getPointerInfo().getWithOffset(O: Offset), Alignment: getAlign(),
19427 MMOFlags: Origin->getMemOperand()->getFlags());
19428 // If the final type is not the same as the loaded type, this means that
19429 // we have to pad with zero. Create a zero extend for that.
19430 EVT FinalType = Inst->getValueType(ResNo: 0);
19431 if (SliceType != FinalType)
19432 LastInst =
19433 DAG->getNode(Opcode: ISD::ZERO_EXTEND, DL: SDLoc(LastInst), VT: FinalType, Operand: LastInst);
19434 return LastInst;
19435 }
19436
19437 /// Check if this slice can be merged with an expensive cross register
19438 /// bank copy. E.g.,
19439 /// i = load i32
19440 /// f = bitcast i32 i to float
19441 bool canMergeExpensiveCrossRegisterBankCopy() const {
19442 if (!Inst || !Inst->hasOneUse())
19443 return false;
19444 SDNode *Use = *Inst->use_begin();
19445 if (Use->getOpcode() != ISD::BITCAST)
19446 return false;
19447 assert(DAG && "Missing context");
19448 const TargetLowering &TLI = DAG->getTargetLoweringInfo();
19449 EVT ResVT = Use->getValueType(ResNo: 0);
19450 const TargetRegisterClass *ResRC =
19451 TLI.getRegClassFor(VT: ResVT.getSimpleVT(), isDivergent: Use->isDivergent());
19452 const TargetRegisterClass *ArgRC =
19453 TLI.getRegClassFor(VT: Use->getOperand(Num: 0).getValueType().getSimpleVT(),
19454 isDivergent: Use->getOperand(Num: 0)->isDivergent());
19455 if (ArgRC == ResRC || !TLI.isOperationLegal(Op: ISD::LOAD, VT: ResVT))
19456 return false;
19457
19458 // At this point, we know that we perform a cross-register-bank copy.
19459 // Check if it is expensive.
19460 const TargetRegisterInfo *TRI = DAG->getSubtarget().getRegisterInfo();
19461 // Assume bitcasts are cheap, unless both register classes do not
19462 // explicitly share a common sub class.
19463 if (!TRI || TRI->getCommonSubClass(A: ArgRC, B: ResRC))
19464 return false;
19465
19466 // Check if it will be merged with the load.
19467 // 1. Check the alignment / fast memory access constraint.
19468 unsigned IsFast = 0;
19469 if (!TLI.allowsMemoryAccess(Context&: *DAG->getContext(), DL: DAG->getDataLayout(), VT: ResVT,
19470 AddrSpace: Origin->getAddressSpace(), Alignment: getAlign(),
19471 Flags: Origin->getMemOperand()->getFlags(), Fast: &IsFast) ||
19472 !IsFast)
19473 return false;
19474
19475 // 2. Check that the load is a legal operation for that type.
19476 if (!TLI.isOperationLegal(Op: ISD::LOAD, VT: ResVT))
19477 return false;
19478
19479 // 3. Check that we do not have a zext in the way.
19480 if (Inst->getValueType(ResNo: 0) != getLoadedType())
19481 return false;
19482
19483 return true;
19484 }
19485};
19486
19487} // end anonymous namespace
19488
19489/// Check that all bits set in \p UsedBits form a dense region, i.e.,
19490/// \p UsedBits looks like 0..0 1..1 0..0.
19491static bool areUsedBitsDense(const APInt &UsedBits) {
19492 // If all the bits are one, this is dense!
19493 if (UsedBits.isAllOnes())
19494 return true;
19495
19496 // Get rid of the unused bits on the right.
19497 APInt NarrowedUsedBits = UsedBits.lshr(shiftAmt: UsedBits.countr_zero());
19498 // Get rid of the unused bits on the left.
19499 if (NarrowedUsedBits.countl_zero())
19500 NarrowedUsedBits = NarrowedUsedBits.trunc(width: NarrowedUsedBits.getActiveBits());
19501 // Check that the chunk of bits is completely used.
19502 return NarrowedUsedBits.isAllOnes();
19503}
19504
19505/// Check whether or not \p First and \p Second are next to each other
19506/// in memory. This means that there is no hole between the bits loaded
19507/// by \p First and the bits loaded by \p Second.
19508static bool areSlicesNextToEachOther(const LoadedSlice &First,
19509 const LoadedSlice &Second) {
19510 assert(First.Origin == Second.Origin && First.Origin &&
19511 "Unable to match different memory origins.");
19512 APInt UsedBits = First.getUsedBits();
19513 assert((UsedBits & Second.getUsedBits()) == 0 &&
19514 "Slices are not supposed to overlap.");
19515 UsedBits |= Second.getUsedBits();
19516 return areUsedBitsDense(UsedBits);
19517}
19518
19519/// Adjust the \p GlobalLSCost according to the target
19520/// paring capabilities and the layout of the slices.
19521/// \pre \p GlobalLSCost should account for at least as many loads as
19522/// there is in the slices in \p LoadedSlices.
19523static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices,
19524 LoadedSlice::Cost &GlobalLSCost) {
19525 unsigned NumberOfSlices = LoadedSlices.size();
19526 // If there is less than 2 elements, no pairing is possible.
19527 if (NumberOfSlices < 2)
19528 return;
19529
19530 // Sort the slices so that elements that are likely to be next to each
19531 // other in memory are next to each other in the list.
19532 llvm::sort(C&: LoadedSlices, Comp: [](const LoadedSlice &LHS, const LoadedSlice &RHS) {
19533 assert(LHS.Origin == RHS.Origin && "Different bases not implemented.");
19534 return LHS.getOffsetFromBase() < RHS.getOffsetFromBase();
19535 });
19536 const TargetLowering &TLI = LoadedSlices[0].DAG->getTargetLoweringInfo();
19537 // First (resp. Second) is the first (resp. Second) potentially candidate
19538 // to be placed in a paired load.
19539 const LoadedSlice *First = nullptr;
19540 const LoadedSlice *Second = nullptr;
19541 for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice,
19542 // Set the beginning of the pair.
19543 First = Second) {
19544 Second = &LoadedSlices[CurrSlice];
19545
19546 // If First is NULL, it means we start a new pair.
19547 // Get to the next slice.
19548 if (!First)
19549 continue;
19550
19551 EVT LoadedType = First->getLoadedType();
19552
19553 // If the types of the slices are different, we cannot pair them.
19554 if (LoadedType != Second->getLoadedType())
19555 continue;
19556
19557 // Check if the target supplies paired loads for this type.
19558 Align RequiredAlignment;
19559 if (!TLI.hasPairedLoad(LoadedType, RequiredAlignment)) {
19560 // move to the next pair, this type is hopeless.
19561 Second = nullptr;
19562 continue;
19563 }
19564 // Check if we meet the alignment requirement.
19565 if (First->getAlign() < RequiredAlignment)
19566 continue;
19567
19568 // Check that both loads are next to each other in memory.
19569 if (!areSlicesNextToEachOther(First: *First, Second: *Second))
19570 continue;
19571
19572 assert(GlobalLSCost.Loads > 0 && "We save more loads than we created!");
19573 --GlobalLSCost.Loads;
19574 // Move to the next pair.
19575 Second = nullptr;
19576 }
19577}
19578
19579/// Check the profitability of all involved LoadedSlice.
19580/// Currently, it is considered profitable if there is exactly two
19581/// involved slices (1) which are (2) next to each other in memory, and
19582/// whose cost (\see LoadedSlice::Cost) is smaller than the original load (3).
19583///
19584/// Note: The order of the elements in \p LoadedSlices may be modified, but not
19585/// the elements themselves.
19586///
19587/// FIXME: When the cost model will be mature enough, we can relax
19588/// constraints (1) and (2).
19589static bool isSlicingProfitable(SmallVectorImpl<LoadedSlice> &LoadedSlices,
19590 const APInt &UsedBits, bool ForCodeSize) {
19591 unsigned NumberOfSlices = LoadedSlices.size();
19592 if (StressLoadSlicing)
19593 return NumberOfSlices > 1;
19594
19595 // Check (1).
19596 if (NumberOfSlices != 2)
19597 return false;
19598
19599 // Check (2).
19600 if (!areUsedBitsDense(UsedBits))
19601 return false;
19602
19603 // Check (3).
19604 LoadedSlice::Cost OrigCost(ForCodeSize), GlobalSlicingCost(ForCodeSize);
19605 // The original code has one big load.
19606 OrigCost.Loads = 1;
19607 for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice) {
19608 const LoadedSlice &LS = LoadedSlices[CurrSlice];
19609 // Accumulate the cost of all the slices.
19610 LoadedSlice::Cost SliceCost(LS, ForCodeSize);
19611 GlobalSlicingCost += SliceCost;
19612
19613 // Account as cost in the original configuration the gain obtained
19614 // with the current slices.
19615 OrigCost.addSliceGain(LS);
19616 }
19617
19618 // If the target supports paired load, adjust the cost accordingly.
19619 adjustCostForPairing(LoadedSlices, GlobalLSCost&: GlobalSlicingCost);
19620 return OrigCost > GlobalSlicingCost;
19621}
19622
19623/// If the given load, \p LI, is used only by trunc or trunc(lshr)
19624/// operations, split it in the various pieces being extracted.
19625///
19626/// This sort of thing is introduced by SROA.
19627/// This slicing takes care not to insert overlapping loads.
19628/// \pre LI is a simple load (i.e., not an atomic or volatile load).
19629bool DAGCombiner::SliceUpLoad(SDNode *N) {
19630 if (Level < AfterLegalizeDAG)
19631 return false;
19632
19633 LoadSDNode *LD = cast<LoadSDNode>(Val: N);
19634 if (!LD->isSimple() || !ISD::isNormalLoad(N: LD) ||
19635 !LD->getValueType(ResNo: 0).isInteger())
19636 return false;
19637
19638 // The algorithm to split up a load of a scalable vector into individual
19639 // elements currently requires knowing the length of the loaded type,
19640 // so will need adjusting to work on scalable vectors.
19641 if (LD->getValueType(ResNo: 0).isScalableVector())
19642 return false;
19643
19644 // Keep track of already used bits to detect overlapping values.
19645 // In that case, we will just abort the transformation.
19646 APInt UsedBits(LD->getValueSizeInBits(ResNo: 0), 0);
19647
19648 SmallVector<LoadedSlice, 4> LoadedSlices;
19649
19650 // Check if this load is used as several smaller chunks of bits.
19651 // Basically, look for uses in trunc or trunc(lshr) and record a new chain
19652 // of computation for each trunc.
19653 for (SDNode::use_iterator UI = LD->use_begin(), UIEnd = LD->use_end();
19654 UI != UIEnd; ++UI) {
19655 // Skip the uses of the chain.
19656 if (UI.getUse().getResNo() != 0)
19657 continue;
19658
19659 SDNode *User = *UI;
19660 unsigned Shift = 0;
19661
19662 // Check if this is a trunc(lshr).
19663 if (User->getOpcode() == ISD::SRL && User->hasOneUse() &&
19664 isa<ConstantSDNode>(Val: User->getOperand(Num: 1))) {
19665 Shift = User->getConstantOperandVal(Num: 1);
19666 User = *User->use_begin();
19667 }
19668
19669 // At this point, User is a Truncate, iff we encountered, trunc or
19670 // trunc(lshr).
19671 if (User->getOpcode() != ISD::TRUNCATE)
19672 return false;
19673
19674 // The width of the type must be a power of 2 and greater than 8-bits.
19675 // Otherwise the load cannot be represented in LLVM IR.
19676 // Moreover, if we shifted with a non-8-bits multiple, the slice
19677 // will be across several bytes. We do not support that.
19678 unsigned Width = User->getValueSizeInBits(ResNo: 0);
19679 if (Width < 8 || !isPowerOf2_32(Value: Width) || (Shift & 0x7))
19680 return false;
19681
19682 // Build the slice for this chain of computations.
19683 LoadedSlice LS(User, LD, Shift, &DAG);
19684 APInt CurrentUsedBits = LS.getUsedBits();
19685
19686 // Check if this slice overlaps with another.
19687 if ((CurrentUsedBits & UsedBits) != 0)
19688 return false;
19689 // Update the bits used globally.
19690 UsedBits |= CurrentUsedBits;
19691
19692 // Check if the new slice would be legal.
19693 if (!LS.isLegal())
19694 return false;
19695
19696 // Record the slice.
19697 LoadedSlices.push_back(Elt: LS);
19698 }
19699
19700 // Abort slicing if it does not seem to be profitable.
19701 if (!isSlicingProfitable(LoadedSlices, UsedBits, ForCodeSize))
19702 return false;
19703
19704 ++SlicedLoads;
19705
19706 // Rewrite each chain to use an independent load.
19707 // By construction, each chain can be represented by a unique load.
19708
19709 // Prepare the argument for the new token factor for all the slices.
19710 SmallVector<SDValue, 8> ArgChains;
19711 for (const LoadedSlice &LS : LoadedSlices) {
19712 SDValue SliceInst = LS.loadSlice();
19713 CombineTo(N: LS.Inst, Res: SliceInst, AddTo: true);
19714 if (SliceInst.getOpcode() != ISD::LOAD)
19715 SliceInst = SliceInst.getOperand(i: 0);
19716 assert(SliceInst->getOpcode() == ISD::LOAD &&
19717 "It takes more than a zext to get to the loaded slice!!");
19718 ArgChains.push_back(Elt: SliceInst.getValue(R: 1));
19719 }
19720
19721 SDValue Chain = DAG.getNode(Opcode: ISD::TokenFactor, DL: SDLoc(LD), VT: MVT::Other,
19722 Ops: ArgChains);
19723 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 1), To: Chain);
19724 AddToWorklist(N: Chain.getNode());
19725 return true;
19726}
19727
19728/// Check to see if V is (and load (ptr), imm), where the load is having
19729/// specific bytes cleared out. If so, return the byte size being masked out
19730/// and the shift amount.
19731static std::pair<unsigned, unsigned>
19732CheckForMaskedLoad(SDValue V, SDValue Ptr, SDValue Chain) {
19733 std::pair<unsigned, unsigned> Result(0, 0);
19734
19735 // Check for the structure we're looking for.
19736 if (V->getOpcode() != ISD::AND ||
19737 !isa<ConstantSDNode>(Val: V->getOperand(Num: 1)) ||
19738 !ISD::isNormalLoad(N: V->getOperand(Num: 0).getNode()))
19739 return Result;
19740
19741 // Check the chain and pointer.
19742 LoadSDNode *LD = cast<LoadSDNode>(Val: V->getOperand(Num: 0));
19743 if (LD->getBasePtr() != Ptr) return Result; // Not from same pointer.
19744
19745 // This only handles simple types.
19746 if (V.getValueType() != MVT::i16 &&
19747 V.getValueType() != MVT::i32 &&
19748 V.getValueType() != MVT::i64)
19749 return Result;
19750
19751 // Check the constant mask. Invert it so that the bits being masked out are
19752 // 0 and the bits being kept are 1. Use getSExtValue so that leading bits
19753 // follow the sign bit for uniformity.
19754 uint64_t NotMask = ~cast<ConstantSDNode>(Val: V->getOperand(Num: 1))->getSExtValue();
19755 unsigned NotMaskLZ = llvm::countl_zero(Val: NotMask);
19756 if (NotMaskLZ & 7) return Result; // Must be multiple of a byte.
19757 unsigned NotMaskTZ = llvm::countr_zero(Val: NotMask);
19758 if (NotMaskTZ & 7) return Result; // Must be multiple of a byte.
19759 if (NotMaskLZ == 64) return Result; // All zero mask.
19760
19761 // See if we have a continuous run of bits. If so, we have 0*1+0*
19762 if (llvm::countr_one(Value: NotMask >> NotMaskTZ) + NotMaskTZ + NotMaskLZ != 64)
19763 return Result;
19764
19765 // Adjust NotMaskLZ down to be from the actual size of the int instead of i64.
19766 if (V.getValueType() != MVT::i64 && NotMaskLZ)
19767 NotMaskLZ -= 64-V.getValueSizeInBits();
19768
19769 unsigned MaskedBytes = (V.getValueSizeInBits()-NotMaskLZ-NotMaskTZ)/8;
19770 switch (MaskedBytes) {
19771 case 1:
19772 case 2:
19773 case 4: break;
19774 default: return Result; // All one mask, or 5-byte mask.
19775 }
19776
19777 // Verify that the first bit starts at a multiple of mask so that the access
19778 // is aligned the same as the access width.
19779 if (NotMaskTZ && NotMaskTZ/8 % MaskedBytes) return Result;
19780
19781 // For narrowing to be valid, it must be the case that the load the
19782 // immediately preceding memory operation before the store.
19783 if (LD == Chain.getNode())
19784 ; // ok.
19785 else if (Chain->getOpcode() == ISD::TokenFactor &&
19786 SDValue(LD, 1).hasOneUse()) {
19787 // LD has only 1 chain use so they are no indirect dependencies.
19788 if (!LD->isOperandOf(N: Chain.getNode()))
19789 return Result;
19790 } else
19791 return Result; // Fail.
19792
19793 Result.first = MaskedBytes;
19794 Result.second = NotMaskTZ/8;
19795 return Result;
19796}
19797
19798/// Check to see if IVal is something that provides a value as specified by
19799/// MaskInfo. If so, replace the specified store with a narrower store of
19800/// truncated IVal.
19801static SDValue
19802ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned, unsigned> &MaskInfo,
19803 SDValue IVal, StoreSDNode *St,
19804 DAGCombiner *DC) {
19805 unsigned NumBytes = MaskInfo.first;
19806 unsigned ByteShift = MaskInfo.second;
19807 SelectionDAG &DAG = DC->getDAG();
19808
19809 // Check to see if IVal is all zeros in the part being masked in by the 'or'
19810 // that uses this. If not, this is not a replacement.
19811 APInt Mask = ~APInt::getBitsSet(numBits: IVal.getValueSizeInBits(),
19812 loBit: ByteShift*8, hiBit: (ByteShift+NumBytes)*8);
19813 if (!DAG.MaskedValueIsZero(Op: IVal, Mask)) return SDValue();
19814
19815 // Check that it is legal on the target to do this. It is legal if the new
19816 // VT we're shrinking to (i8/i16/i32) is legal or we're still before type
19817 // legalization. If the source type is legal, but the store type isn't, see
19818 // if we can use a truncating store.
19819 MVT VT = MVT::getIntegerVT(BitWidth: NumBytes * 8);
19820 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
19821 bool UseTruncStore;
19822 if (DC->isTypeLegal(VT))
19823 UseTruncStore = false;
19824 else if (TLI.isTypeLegal(VT: IVal.getValueType()) &&
19825 TLI.isTruncStoreLegal(ValVT: IVal.getValueType(), MemVT: VT))
19826 UseTruncStore = true;
19827 else
19828 return SDValue();
19829
19830 // Can't do this for indexed stores.
19831 if (St->isIndexed())
19832 return SDValue();
19833
19834 // Check that the target doesn't think this is a bad idea.
19835 if (St->getMemOperand() &&
19836 !TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT,
19837 MMO: *St->getMemOperand()))
19838 return SDValue();
19839
19840 // Okay, we can do this! Replace the 'St' store with a store of IVal that is
19841 // shifted by ByteShift and truncated down to NumBytes.
19842 if (ByteShift) {
19843 SDLoc DL(IVal);
19844 IVal = DAG.getNode(
19845 Opcode: ISD::SRL, DL, VT: IVal.getValueType(), N1: IVal,
19846 N2: DAG.getShiftAmountConstant(Val: ByteShift * 8, VT: IVal.getValueType(), DL));
19847 }
19848
19849 // Figure out the offset for the store and the alignment of the access.
19850 unsigned StOffset;
19851 if (DAG.getDataLayout().isLittleEndian())
19852 StOffset = ByteShift;
19853 else
19854 StOffset = IVal.getValueType().getStoreSize() - ByteShift - NumBytes;
19855
19856 SDValue Ptr = St->getBasePtr();
19857 if (StOffset) {
19858 SDLoc DL(IVal);
19859 Ptr = DAG.getMemBasePlusOffset(Base: Ptr, Offset: TypeSize::getFixed(ExactSize: StOffset), DL);
19860 }
19861
19862 ++OpsNarrowed;
19863 if (UseTruncStore)
19864 return DAG.getTruncStore(Chain: St->getChain(), dl: SDLoc(St), Val: IVal, Ptr,
19865 PtrInfo: St->getPointerInfo().getWithOffset(O: StOffset),
19866 SVT: VT, Alignment: St->getOriginalAlign());
19867
19868 // Truncate down to the new size.
19869 IVal = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(IVal), VT, Operand: IVal);
19870
19871 return DAG
19872 .getStore(Chain: St->getChain(), dl: SDLoc(St), Val: IVal, Ptr,
19873 PtrInfo: St->getPointerInfo().getWithOffset(O: StOffset),
19874 Alignment: St->getOriginalAlign());
19875}
19876
19877/// Look for sequence of load / op / store where op is one of 'or', 'xor', and
19878/// 'and' of immediates. If 'op' is only touching some of the loaded bits, try
19879/// narrowing the load and store if it would end up being a win for performance
19880/// or code size.
19881SDValue DAGCombiner::ReduceLoadOpStoreWidth(SDNode *N) {
19882 StoreSDNode *ST = cast<StoreSDNode>(Val: N);
19883 if (!ST->isSimple())
19884 return SDValue();
19885
19886 SDValue Chain = ST->getChain();
19887 SDValue Value = ST->getValue();
19888 SDValue Ptr = ST->getBasePtr();
19889 EVT VT = Value.getValueType();
19890
19891 if (ST->isTruncatingStore() || VT.isVector())
19892 return SDValue();
19893
19894 unsigned Opc = Value.getOpcode();
19895
19896 if ((Opc != ISD::OR && Opc != ISD::XOR && Opc != ISD::AND) ||
19897 !Value.hasOneUse())
19898 return SDValue();
19899
19900 // If this is "store (or X, Y), P" and X is "(and (load P), cst)", where cst
19901 // is a byte mask indicating a consecutive number of bytes, check to see if
19902 // Y is known to provide just those bytes. If so, we try to replace the
19903 // load + replace + store sequence with a single (narrower) store, which makes
19904 // the load dead.
19905 if (Opc == ISD::OR && EnableShrinkLoadReplaceStoreWithStore) {
19906 std::pair<unsigned, unsigned> MaskedLoad;
19907 MaskedLoad = CheckForMaskedLoad(V: Value.getOperand(i: 0), Ptr, Chain);
19908 if (MaskedLoad.first)
19909 if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskInfo: MaskedLoad,
19910 IVal: Value.getOperand(i: 1), St: ST,DC: this))
19911 return NewST;
19912
19913 // Or is commutative, so try swapping X and Y.
19914 MaskedLoad = CheckForMaskedLoad(V: Value.getOperand(i: 1), Ptr, Chain);
19915 if (MaskedLoad.first)
19916 if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskInfo: MaskedLoad,
19917 IVal: Value.getOperand(i: 0), St: ST,DC: this))
19918 return NewST;
19919 }
19920
19921 if (!EnableReduceLoadOpStoreWidth)
19922 return SDValue();
19923
19924 if (Value.getOperand(i: 1).getOpcode() != ISD::Constant)
19925 return SDValue();
19926
19927 SDValue N0 = Value.getOperand(i: 0);
19928 if (ISD::isNormalLoad(N: N0.getNode()) && N0.hasOneUse() &&
19929 Chain == SDValue(N0.getNode(), 1)) {
19930 LoadSDNode *LD = cast<LoadSDNode>(Val&: N0);
19931 if (LD->getBasePtr() != Ptr ||
19932 LD->getPointerInfo().getAddrSpace() !=
19933 ST->getPointerInfo().getAddrSpace())
19934 return SDValue();
19935
19936 // Find the type to narrow it the load / op / store to.
19937 SDValue N1 = Value.getOperand(i: 1);
19938 unsigned BitWidth = N1.getValueSizeInBits();
19939 APInt Imm = N1->getAsAPIntVal();
19940 if (Opc == ISD::AND)
19941 Imm ^= APInt::getAllOnes(numBits: BitWidth);
19942 if (Imm == 0 || Imm.isAllOnes())
19943 return SDValue();
19944 unsigned ShAmt = Imm.countr_zero();
19945 unsigned MSB = BitWidth - Imm.countl_zero() - 1;
19946 unsigned NewBW = NextPowerOf2(A: MSB - ShAmt);
19947 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: NewBW);
19948 // The narrowing should be profitable, the load/store operation should be
19949 // legal (or custom) and the store size should be equal to the NewVT width.
19950 while (NewBW < BitWidth &&
19951 (NewVT.getStoreSizeInBits() != NewBW ||
19952 !TLI.isOperationLegalOrCustom(Op: Opc, VT: NewVT) ||
19953 !TLI.isNarrowingProfitable(SrcVT: VT, DestVT: NewVT))) {
19954 NewBW = NextPowerOf2(A: NewBW);
19955 NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: NewBW);
19956 }
19957 if (NewBW >= BitWidth)
19958 return SDValue();
19959
19960 // If the lsb changed does not start at the type bitwidth boundary,
19961 // start at the previous one.
19962 if (ShAmt % NewBW)
19963 ShAmt = (((ShAmt + NewBW - 1) / NewBW) * NewBW) - NewBW;
19964 APInt Mask = APInt::getBitsSet(numBits: BitWidth, loBit: ShAmt,
19965 hiBit: std::min(a: BitWidth, b: ShAmt + NewBW));
19966 if ((Imm & Mask) == Imm) {
19967 APInt NewImm = (Imm & Mask).lshr(shiftAmt: ShAmt).trunc(width: NewBW);
19968 if (Opc == ISD::AND)
19969 NewImm ^= APInt::getAllOnes(numBits: NewBW);
19970 uint64_t PtrOff = ShAmt / 8;
19971 // For big endian targets, we need to adjust the offset to the pointer to
19972 // load the correct bytes.
19973 if (DAG.getDataLayout().isBigEndian())
19974 PtrOff = (BitWidth + 7 - NewBW) / 8 - PtrOff;
19975
19976 unsigned IsFast = 0;
19977 Align NewAlign = commonAlignment(A: LD->getAlign(), Offset: PtrOff);
19978 if (!TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT: NewVT,
19979 AddrSpace: LD->getAddressSpace(), Alignment: NewAlign,
19980 Flags: LD->getMemOperand()->getFlags(), Fast: &IsFast) ||
19981 !IsFast)
19982 return SDValue();
19983
19984 SDValue NewPtr =
19985 DAG.getMemBasePlusOffset(Base: Ptr, Offset: TypeSize::getFixed(ExactSize: PtrOff), DL: SDLoc(LD));
19986 SDValue NewLD =
19987 DAG.getLoad(VT: NewVT, dl: SDLoc(N0), Chain: LD->getChain(), Ptr: NewPtr,
19988 PtrInfo: LD->getPointerInfo().getWithOffset(O: PtrOff), Alignment: NewAlign,
19989 MMOFlags: LD->getMemOperand()->getFlags(), AAInfo: LD->getAAInfo());
19990 SDValue NewVal = DAG.getNode(Opcode: Opc, DL: SDLoc(Value), VT: NewVT, N1: NewLD,
19991 N2: DAG.getConstant(Val: NewImm, DL: SDLoc(Value),
19992 VT: NewVT));
19993 SDValue NewST =
19994 DAG.getStore(Chain, dl: SDLoc(N), Val: NewVal, Ptr: NewPtr,
19995 PtrInfo: ST->getPointerInfo().getWithOffset(O: PtrOff), Alignment: NewAlign);
19996
19997 AddToWorklist(N: NewPtr.getNode());
19998 AddToWorklist(N: NewLD.getNode());
19999 AddToWorklist(N: NewVal.getNode());
20000 WorklistRemover DeadNodes(*this);
20001 DAG.ReplaceAllUsesOfValueWith(From: N0.getValue(R: 1), To: NewLD.getValue(R: 1));
20002 ++OpsNarrowed;
20003 return NewST;
20004 }
20005 }
20006
20007 return SDValue();
20008}
20009
20010/// For a given floating point load / store pair, if the load value isn't used
20011/// by any other operations, then consider transforming the pair to integer
20012/// load / store operations if the target deems the transformation profitable.
20013SDValue DAGCombiner::TransformFPLoadStorePair(SDNode *N) {
20014 StoreSDNode *ST = cast<StoreSDNode>(Val: N);
20015 SDValue Value = ST->getValue();
20016 if (ISD::isNormalStore(N: ST) && ISD::isNormalLoad(N: Value.getNode()) &&
20017 Value.hasOneUse()) {
20018 LoadSDNode *LD = cast<LoadSDNode>(Val&: Value);
20019 EVT VT = LD->getMemoryVT();
20020 if (!VT.isFloatingPoint() ||
20021 VT != ST->getMemoryVT() ||
20022 LD->isNonTemporal() ||
20023 ST->isNonTemporal() ||
20024 LD->getPointerInfo().getAddrSpace() != 0 ||
20025 ST->getPointerInfo().getAddrSpace() != 0)
20026 return SDValue();
20027
20028 TypeSize VTSize = VT.getSizeInBits();
20029
20030 // We don't know the size of scalable types at compile time so we cannot
20031 // create an integer of the equivalent size.
20032 if (VTSize.isScalable())
20033 return SDValue();
20034
20035 unsigned FastLD = 0, FastST = 0;
20036 EVT IntVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: VTSize.getFixedValue());
20037 if (!TLI.isOperationLegal(Op: ISD::LOAD, VT: IntVT) ||
20038 !TLI.isOperationLegal(Op: ISD::STORE, VT: IntVT) ||
20039 !TLI.isDesirableToTransformToIntegerOp(ISD::LOAD, VT) ||
20040 !TLI.isDesirableToTransformToIntegerOp(ISD::STORE, VT) ||
20041 !TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT: IntVT,
20042 MMO: *LD->getMemOperand(), Fast: &FastLD) ||
20043 !TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT: IntVT,
20044 MMO: *ST->getMemOperand(), Fast: &FastST) ||
20045 !FastLD || !FastST)
20046 return SDValue();
20047
20048 SDValue NewLD =
20049 DAG.getLoad(VT: IntVT, dl: SDLoc(Value), Chain: LD->getChain(), Ptr: LD->getBasePtr(),
20050 PtrInfo: LD->getPointerInfo(), Alignment: LD->getAlign());
20051
20052 SDValue NewST =
20053 DAG.getStore(Chain: ST->getChain(), dl: SDLoc(N), Val: NewLD, Ptr: ST->getBasePtr(),
20054 PtrInfo: ST->getPointerInfo(), Alignment: ST->getAlign());
20055
20056 AddToWorklist(N: NewLD.getNode());
20057 AddToWorklist(N: NewST.getNode());
20058 WorklistRemover DeadNodes(*this);
20059 DAG.ReplaceAllUsesOfValueWith(From: Value.getValue(R: 1), To: NewLD.getValue(R: 1));
20060 ++LdStFP2Int;
20061 return NewST;
20062 }
20063
20064 return SDValue();
20065}
20066
20067// This is a helper function for visitMUL to check the profitability
20068// of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
20069// MulNode is the original multiply, AddNode is (add x, c1),
20070// and ConstNode is c2.
20071//
20072// If the (add x, c1) has multiple uses, we could increase
20073// the number of adds if we make this transformation.
20074// It would only be worth doing this if we can remove a
20075// multiply in the process. Check for that here.
20076// To illustrate:
20077// (A + c1) * c3
20078// (A + c2) * c3
20079// We're checking for cases where we have common "c3 * A" expressions.
20080bool DAGCombiner::isMulAddWithConstProfitable(SDNode *MulNode, SDValue AddNode,
20081 SDValue ConstNode) {
20082 APInt Val;
20083
20084 // If the add only has one use, and the target thinks the folding is
20085 // profitable or does not lead to worse code, this would be OK to do.
20086 if (AddNode->hasOneUse() &&
20087 TLI.isMulAddWithConstProfitable(AddNode, ConstNode))
20088 return true;
20089
20090 // Walk all the users of the constant with which we're multiplying.
20091 for (SDNode *Use : ConstNode->uses()) {
20092 if (Use == MulNode) // This use is the one we're on right now. Skip it.
20093 continue;
20094
20095 if (Use->getOpcode() == ISD::MUL) { // We have another multiply use.
20096 SDNode *OtherOp;
20097 SDNode *MulVar = AddNode.getOperand(i: 0).getNode();
20098
20099 // OtherOp is what we're multiplying against the constant.
20100 if (Use->getOperand(Num: 0) == ConstNode)
20101 OtherOp = Use->getOperand(Num: 1).getNode();
20102 else
20103 OtherOp = Use->getOperand(Num: 0).getNode();
20104
20105 // Check to see if multiply is with the same operand of our "add".
20106 //
20107 // ConstNode = CONST
20108 // Use = ConstNode * A <-- visiting Use. OtherOp is A.
20109 // ...
20110 // AddNode = (A + c1) <-- MulVar is A.
20111 // = AddNode * ConstNode <-- current visiting instruction.
20112 //
20113 // If we make this transformation, we will have a common
20114 // multiply (ConstNode * A) that we can save.
20115 if (OtherOp == MulVar)
20116 return true;
20117
20118 // Now check to see if a future expansion will give us a common
20119 // multiply.
20120 //
20121 // ConstNode = CONST
20122 // AddNode = (A + c1)
20123 // ... = AddNode * ConstNode <-- current visiting instruction.
20124 // ...
20125 // OtherOp = (A + c2)
20126 // Use = OtherOp * ConstNode <-- visiting Use.
20127 //
20128 // If we make this transformation, we will have a common
20129 // multiply (CONST * A) after we also do the same transformation
20130 // to the "t2" instruction.
20131 if (OtherOp->getOpcode() == ISD::ADD &&
20132 DAG.isConstantIntBuildVectorOrConstantInt(N: OtherOp->getOperand(Num: 1)) &&
20133 OtherOp->getOperand(Num: 0).getNode() == MulVar)
20134 return true;
20135 }
20136 }
20137
20138 // Didn't find a case where this would be profitable.
20139 return false;
20140}
20141
20142SDValue DAGCombiner::getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
20143 unsigned NumStores) {
20144 SmallVector<SDValue, 8> Chains;
20145 SmallPtrSet<const SDNode *, 8> Visited;
20146 SDLoc StoreDL(StoreNodes[0].MemNode);
20147
20148 for (unsigned i = 0; i < NumStores; ++i) {
20149 Visited.insert(Ptr: StoreNodes[i].MemNode);
20150 }
20151
20152 // don't include nodes that are children or repeated nodes.
20153 for (unsigned i = 0; i < NumStores; ++i) {
20154 if (Visited.insert(Ptr: StoreNodes[i].MemNode->getChain().getNode()).second)
20155 Chains.push_back(Elt: StoreNodes[i].MemNode->getChain());
20156 }
20157
20158 assert(!Chains.empty() && "Chain should have generated a chain");
20159 return DAG.getTokenFactor(DL: StoreDL, Vals&: Chains);
20160}
20161
20162bool DAGCombiner::hasSameUnderlyingObj(ArrayRef<MemOpLink> StoreNodes) {
20163 const Value *UnderlyingObj = nullptr;
20164 for (const auto &MemOp : StoreNodes) {
20165 const MachineMemOperand *MMO = MemOp.MemNode->getMemOperand();
20166 // Pseudo value like stack frame has its own frame index and size, should
20167 // not use the first store's frame index for other frames.
20168 if (MMO->getPseudoValue())
20169 return false;
20170
20171 if (!MMO->getValue())
20172 return false;
20173
20174 const Value *Obj = getUnderlyingObject(V: MMO->getValue());
20175
20176 if (UnderlyingObj && UnderlyingObj != Obj)
20177 return false;
20178
20179 if (!UnderlyingObj)
20180 UnderlyingObj = Obj;
20181 }
20182
20183 return true;
20184}
20185
20186bool DAGCombiner::mergeStoresOfConstantsOrVecElts(
20187 SmallVectorImpl<MemOpLink> &StoreNodes, EVT MemVT, unsigned NumStores,
20188 bool IsConstantSrc, bool UseVector, bool UseTrunc) {
20189 // Make sure we have something to merge.
20190 if (NumStores < 2)
20191 return false;
20192
20193 assert((!UseTrunc || !UseVector) &&
20194 "This optimization cannot emit a vector truncating store");
20195
20196 // The latest Node in the DAG.
20197 SDLoc DL(StoreNodes[0].MemNode);
20198
20199 TypeSize ElementSizeBits = MemVT.getStoreSizeInBits();
20200 unsigned SizeInBits = NumStores * ElementSizeBits;
20201 unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
20202
20203 std::optional<MachineMemOperand::Flags> Flags;
20204 AAMDNodes AAInfo;
20205 for (unsigned I = 0; I != NumStores; ++I) {
20206 StoreSDNode *St = cast<StoreSDNode>(Val: StoreNodes[I].MemNode);
20207 if (!Flags) {
20208 Flags = St->getMemOperand()->getFlags();
20209 AAInfo = St->getAAInfo();
20210 continue;
20211 }
20212 // Skip merging if there's an inconsistent flag.
20213 if (Flags != St->getMemOperand()->getFlags())
20214 return false;
20215 // Concatenate AA metadata.
20216 AAInfo = AAInfo.concat(Other: St->getAAInfo());
20217 }
20218
20219 EVT StoreTy;
20220 if (UseVector) {
20221 unsigned Elts = NumStores * NumMemElts;
20222 // Get the type for the merged vector store.
20223 StoreTy = EVT::getVectorVT(Context&: *DAG.getContext(), VT: MemVT.getScalarType(), NumElements: Elts);
20224 } else
20225 StoreTy = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: SizeInBits);
20226
20227 SDValue StoredVal;
20228 if (UseVector) {
20229 if (IsConstantSrc) {
20230 SmallVector<SDValue, 8> BuildVector;
20231 for (unsigned I = 0; I != NumStores; ++I) {
20232 StoreSDNode *St = cast<StoreSDNode>(Val: StoreNodes[I].MemNode);
20233 SDValue Val = St->getValue();
20234 // If constant is of the wrong type, convert it now. This comes up
20235 // when one of our stores was truncating.
20236 if (MemVT != Val.getValueType()) {
20237 Val = peekThroughBitcasts(V: Val);
20238 // Deal with constants of wrong size.
20239 if (ElementSizeBits != Val.getValueSizeInBits()) {
20240 auto *C = dyn_cast<ConstantSDNode>(Val);
20241 if (!C)
20242 // Not clear how to truncate FP values.
20243 // TODO: Handle truncation of build_vector constants
20244 return false;
20245
20246 EVT IntMemVT =
20247 EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: MemVT.getSizeInBits());
20248 Val = DAG.getConstant(Val: C->getAPIntValue()
20249 .zextOrTrunc(width: Val.getValueSizeInBits())
20250 .zextOrTrunc(width: ElementSizeBits),
20251 DL: SDLoc(C), VT: IntMemVT);
20252 }
20253 // Make sure correctly size type is the correct type.
20254 Val = DAG.getBitcast(VT: MemVT, V: Val);
20255 }
20256 BuildVector.push_back(Elt: Val);
20257 }
20258 StoredVal = DAG.getNode(Opcode: MemVT.isVector() ? ISD::CONCAT_VECTORS
20259 : ISD::BUILD_VECTOR,
20260 DL, VT: StoreTy, Ops: BuildVector);
20261 } else {
20262 SmallVector<SDValue, 8> Ops;
20263 for (unsigned i = 0; i < NumStores; ++i) {
20264 StoreSDNode *St = cast<StoreSDNode>(Val: StoreNodes[i].MemNode);
20265 SDValue Val = peekThroughBitcasts(V: St->getValue());
20266 // All operands of BUILD_VECTOR / CONCAT_VECTOR must be of
20267 // type MemVT. If the underlying value is not the correct
20268 // type, but it is an extraction of an appropriate vector we
20269 // can recast Val to be of the correct type. This may require
20270 // converting between EXTRACT_VECTOR_ELT and
20271 // EXTRACT_SUBVECTOR.
20272 if ((MemVT != Val.getValueType()) &&
20273 (Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT ||
20274 Val.getOpcode() == ISD::EXTRACT_SUBVECTOR)) {
20275 EVT MemVTScalarTy = MemVT.getScalarType();
20276 // We may need to add a bitcast here to get types to line up.
20277 if (MemVTScalarTy != Val.getValueType().getScalarType()) {
20278 Val = DAG.getBitcast(VT: MemVT, V: Val);
20279 } else if (MemVT.isVector() &&
20280 Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
20281 Val = DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: MemVT, Operand: Val);
20282 } else {
20283 unsigned OpC = MemVT.isVector() ? ISD::EXTRACT_SUBVECTOR
20284 : ISD::EXTRACT_VECTOR_ELT;
20285 SDValue Vec = Val.getOperand(i: 0);
20286 SDValue Idx = Val.getOperand(i: 1);
20287 Val = DAG.getNode(Opcode: OpC, DL: SDLoc(Val), VT: MemVT, N1: Vec, N2: Idx);
20288 }
20289 }
20290 Ops.push_back(Elt: Val);
20291 }
20292
20293 // Build the extracted vector elements back into a vector.
20294 StoredVal = DAG.getNode(Opcode: MemVT.isVector() ? ISD::CONCAT_VECTORS
20295 : ISD::BUILD_VECTOR,
20296 DL, VT: StoreTy, Ops);
20297 }
20298 } else {
20299 // We should always use a vector store when merging extracted vector
20300 // elements, so this path implies a store of constants.
20301 assert(IsConstantSrc && "Merged vector elements should use vector store");
20302
20303 APInt StoreInt(SizeInBits, 0);
20304
20305 // Construct a single integer constant which is made of the smaller
20306 // constant inputs.
20307 bool IsLE = DAG.getDataLayout().isLittleEndian();
20308 for (unsigned i = 0; i < NumStores; ++i) {
20309 unsigned Idx = IsLE ? (NumStores - 1 - i) : i;
20310 StoreSDNode *St = cast<StoreSDNode>(Val: StoreNodes[Idx].MemNode);
20311
20312 SDValue Val = St->getValue();
20313 Val = peekThroughBitcasts(V: Val);
20314 StoreInt <<= ElementSizeBits;
20315 if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val)) {
20316 StoreInt |= C->getAPIntValue()
20317 .zextOrTrunc(width: ElementSizeBits)
20318 .zextOrTrunc(width: SizeInBits);
20319 } else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Val)) {
20320 StoreInt |= C->getValueAPF()
20321 .bitcastToAPInt()
20322 .zextOrTrunc(width: ElementSizeBits)
20323 .zextOrTrunc(width: SizeInBits);
20324 // If fp truncation is necessary give up for now.
20325 if (MemVT.getSizeInBits() != ElementSizeBits)
20326 return false;
20327 } else if (ISD::isBuildVectorOfConstantSDNodes(N: Val.getNode()) ||
20328 ISD::isBuildVectorOfConstantFPSDNodes(N: Val.getNode())) {
20329 // Not yet handled
20330 return false;
20331 } else {
20332 llvm_unreachable("Invalid constant element type");
20333 }
20334 }
20335
20336 // Create the new Load and Store operations.
20337 StoredVal = DAG.getConstant(Val: StoreInt, DL, VT: StoreTy);
20338 }
20339
20340 LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
20341 SDValue NewChain = getMergeStoreChains(StoreNodes, NumStores);
20342 bool CanReusePtrInfo = hasSameUnderlyingObj(StoreNodes);
20343
20344 // make sure we use trunc store if it's necessary to be legal.
20345 // When generate the new widen store, if the first store's pointer info can
20346 // not be reused, discard the pointer info except the address space because
20347 // now the widen store can not be represented by the original pointer info
20348 // which is for the narrow memory object.
20349 SDValue NewStore;
20350 if (!UseTrunc) {
20351 NewStore = DAG.getStore(
20352 Chain: NewChain, dl: DL, Val: StoredVal, Ptr: FirstInChain->getBasePtr(),
20353 PtrInfo: CanReusePtrInfo
20354 ? FirstInChain->getPointerInfo()
20355 : MachinePointerInfo(FirstInChain->getPointerInfo().getAddrSpace()),
20356 Alignment: FirstInChain->getAlign(), MMOFlags: *Flags, AAInfo);
20357 } else { // Must be realized as a trunc store
20358 EVT LegalizedStoredValTy =
20359 TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT: StoredVal.getValueType());
20360 unsigned LegalizedStoreSize = LegalizedStoredValTy.getSizeInBits();
20361 ConstantSDNode *C = cast<ConstantSDNode>(Val&: StoredVal);
20362 SDValue ExtendedStoreVal =
20363 DAG.getConstant(Val: C->getAPIntValue().zextOrTrunc(width: LegalizedStoreSize), DL,
20364 VT: LegalizedStoredValTy);
20365 NewStore = DAG.getTruncStore(
20366 Chain: NewChain, dl: DL, Val: ExtendedStoreVal, Ptr: FirstInChain->getBasePtr(),
20367 PtrInfo: CanReusePtrInfo
20368 ? FirstInChain->getPointerInfo()
20369 : MachinePointerInfo(FirstInChain->getPointerInfo().getAddrSpace()),
20370 SVT: StoredVal.getValueType() /*TVT*/, Alignment: FirstInChain->getAlign(), MMOFlags: *Flags,
20371 AAInfo);
20372 }
20373
20374 // Replace all merged stores with the new store.
20375 for (unsigned i = 0; i < NumStores; ++i)
20376 CombineTo(N: StoreNodes[i].MemNode, Res: NewStore);
20377
20378 AddToWorklist(N: NewChain.getNode());
20379 return true;
20380}
20381
20382SDNode *
20383DAGCombiner::getStoreMergeCandidates(StoreSDNode *St,
20384 SmallVectorImpl<MemOpLink> &StoreNodes) {
20385 // This holds the base pointer, index, and the offset in bytes from the base
20386 // pointer. We must have a base and an offset. Do not handle stores to undef
20387 // base pointers.
20388 BaseIndexOffset BasePtr = BaseIndexOffset::match(N: St, DAG);
20389 if (!BasePtr.getBase().getNode() || BasePtr.getBase().isUndef())
20390 return nullptr;
20391
20392 SDValue Val = peekThroughBitcasts(V: St->getValue());
20393 StoreSource StoreSrc = getStoreSource(StoreVal: Val);
20394 assert(StoreSrc != StoreSource::Unknown && "Expected known source for store");
20395
20396 // Match on loadbaseptr if relevant.
20397 EVT MemVT = St->getMemoryVT();
20398 BaseIndexOffset LBasePtr;
20399 EVT LoadVT;
20400 if (StoreSrc == StoreSource::Load) {
20401 auto *Ld = cast<LoadSDNode>(Val);
20402 LBasePtr = BaseIndexOffset::match(N: Ld, DAG);
20403 LoadVT = Ld->getMemoryVT();
20404 // Load and store should be the same type.
20405 if (MemVT != LoadVT)
20406 return nullptr;
20407 // Loads must only have one use.
20408 if (!Ld->hasNUsesOfValue(NUses: 1, Value: 0))
20409 return nullptr;
20410 // The memory operands must not be volatile/indexed/atomic.
20411 // TODO: May be able to relax for unordered atomics (see D66309)
20412 if (!Ld->isSimple() || Ld->isIndexed())
20413 return nullptr;
20414 }
20415 auto CandidateMatch = [&](StoreSDNode *Other, BaseIndexOffset &Ptr,
20416 int64_t &Offset) -> bool {
20417 // The memory operands must not be volatile/indexed/atomic.
20418 // TODO: May be able to relax for unordered atomics (see D66309)
20419 if (!Other->isSimple() || Other->isIndexed())
20420 return false;
20421 // Don't mix temporal stores with non-temporal stores.
20422 if (St->isNonTemporal() != Other->isNonTemporal())
20423 return false;
20424 if (!TLI.areTwoSDNodeTargetMMOFlagsMergeable(NodeX: *St, NodeY: *Other))
20425 return false;
20426 SDValue OtherBC = peekThroughBitcasts(V: Other->getValue());
20427 // Allow merging constants of different types as integers.
20428 bool NoTypeMatch = (MemVT.isInteger()) ? !MemVT.bitsEq(VT: Other->getMemoryVT())
20429 : Other->getMemoryVT() != MemVT;
20430 switch (StoreSrc) {
20431 case StoreSource::Load: {
20432 if (NoTypeMatch)
20433 return false;
20434 // The Load's Base Ptr must also match.
20435 auto *OtherLd = dyn_cast<LoadSDNode>(Val&: OtherBC);
20436 if (!OtherLd)
20437 return false;
20438 BaseIndexOffset LPtr = BaseIndexOffset::match(N: OtherLd, DAG);
20439 if (LoadVT != OtherLd->getMemoryVT())
20440 return false;
20441 // Loads must only have one use.
20442 if (!OtherLd->hasNUsesOfValue(NUses: 1, Value: 0))
20443 return false;
20444 // The memory operands must not be volatile/indexed/atomic.
20445 // TODO: May be able to relax for unordered atomics (see D66309)
20446 if (!OtherLd->isSimple() || OtherLd->isIndexed())
20447 return false;
20448 // Don't mix temporal loads with non-temporal loads.
20449 if (cast<LoadSDNode>(Val)->isNonTemporal() != OtherLd->isNonTemporal())
20450 return false;
20451 if (!TLI.areTwoSDNodeTargetMMOFlagsMergeable(NodeX: *cast<LoadSDNode>(Val),
20452 NodeY: *OtherLd))
20453 return false;
20454 if (!(LBasePtr.equalBaseIndex(Other: LPtr, DAG)))
20455 return false;
20456 break;
20457 }
20458 case StoreSource::Constant:
20459 if (NoTypeMatch)
20460 return false;
20461 if (getStoreSource(StoreVal: OtherBC) != StoreSource::Constant)
20462 return false;
20463 break;
20464 case StoreSource::Extract:
20465 // Do not merge truncated stores here.
20466 if (Other->isTruncatingStore())
20467 return false;
20468 if (!MemVT.bitsEq(VT: OtherBC.getValueType()))
20469 return false;
20470 if (OtherBC.getOpcode() != ISD::EXTRACT_VECTOR_ELT &&
20471 OtherBC.getOpcode() != ISD::EXTRACT_SUBVECTOR)
20472 return false;
20473 break;
20474 default:
20475 llvm_unreachable("Unhandled store source for merging");
20476 }
20477 Ptr = BaseIndexOffset::match(N: Other, DAG);
20478 return (BasePtr.equalBaseIndex(Other: Ptr, DAG, Off&: Offset));
20479 };
20480
20481 // We are looking for a root node which is an ancestor to all mergable
20482 // stores. We search up through a load, to our root and then down
20483 // through all children. For instance we will find Store{1,2,3} if
20484 // St is Store1, Store2. or Store3 where the root is not a load
20485 // which always true for nonvolatile ops. TODO: Expand
20486 // the search to find all valid candidates through multiple layers of loads.
20487 //
20488 // Root
20489 // |-------|-------|
20490 // Load Load Store3
20491 // | |
20492 // Store1 Store2
20493 //
20494 // FIXME: We should be able to climb and
20495 // descend TokenFactors to find candidates as well.
20496
20497 SDNode *RootNode = St->getChain().getNode();
20498 // Bail out if we already analyzed this root node and found nothing.
20499 if (ChainsWithoutMergeableStores.contains(Ptr: RootNode))
20500 return nullptr;
20501
20502 // Check if the pair of StoreNode and the RootNode already bail out many
20503 // times which is over the limit in dependence check.
20504 auto OverLimitInDependenceCheck = [&](SDNode *StoreNode,
20505 SDNode *RootNode) -> bool {
20506 auto RootCount = StoreRootCountMap.find(Val: StoreNode);
20507 return RootCount != StoreRootCountMap.end() &&
20508 RootCount->second.first == RootNode &&
20509 RootCount->second.second > StoreMergeDependenceLimit;
20510 };
20511
20512 auto TryToAddCandidate = [&](SDNode::use_iterator UseIter) {
20513 // This must be a chain use.
20514 if (UseIter.getOperandNo() != 0)
20515 return;
20516 if (auto *OtherStore = dyn_cast<StoreSDNode>(Val: *UseIter)) {
20517 BaseIndexOffset Ptr;
20518 int64_t PtrDiff;
20519 if (CandidateMatch(OtherStore, Ptr, PtrDiff) &&
20520 !OverLimitInDependenceCheck(OtherStore, RootNode))
20521 StoreNodes.push_back(Elt: MemOpLink(OtherStore, PtrDiff));
20522 }
20523 };
20524
20525 unsigned NumNodesExplored = 0;
20526 const unsigned MaxSearchNodes = 1024;
20527 if (auto *Ldn = dyn_cast<LoadSDNode>(Val: RootNode)) {
20528 RootNode = Ldn->getChain().getNode();
20529 // Bail out if we already analyzed this root node and found nothing.
20530 if (ChainsWithoutMergeableStores.contains(Ptr: RootNode))
20531 return nullptr;
20532 for (auto I = RootNode->use_begin(), E = RootNode->use_end();
20533 I != E && NumNodesExplored < MaxSearchNodes; ++I, ++NumNodesExplored) {
20534 if (I.getOperandNo() == 0 && isa<LoadSDNode>(Val: *I)) { // walk down chain
20535 for (auto I2 = (*I)->use_begin(), E2 = (*I)->use_end(); I2 != E2; ++I2)
20536 TryToAddCandidate(I2);
20537 }
20538 // Check stores that depend on the root (e.g. Store 3 in the chart above).
20539 if (I.getOperandNo() == 0 && isa<StoreSDNode>(Val: *I)) {
20540 TryToAddCandidate(I);
20541 }
20542 }
20543 } else {
20544 for (auto I = RootNode->use_begin(), E = RootNode->use_end();
20545 I != E && NumNodesExplored < MaxSearchNodes; ++I, ++NumNodesExplored)
20546 TryToAddCandidate(I);
20547 }
20548
20549 return RootNode;
20550}
20551
20552// We need to check that merging these stores does not cause a loop in the
20553// DAG. Any store candidate may depend on another candidate indirectly through
20554// its operands. Check in parallel by searching up from operands of candidates.
20555bool DAGCombiner::checkMergeStoreCandidatesForDependencies(
20556 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
20557 SDNode *RootNode) {
20558 // FIXME: We should be able to truncate a full search of
20559 // predecessors by doing a BFS and keeping tabs the originating
20560 // stores from which worklist nodes come from in a similar way to
20561 // TokenFactor simplfication.
20562
20563 SmallPtrSet<const SDNode *, 32> Visited;
20564 SmallVector<const SDNode *, 8> Worklist;
20565
20566 // RootNode is a predecessor to all candidates so we need not search
20567 // past it. Add RootNode (peeking through TokenFactors). Do not count
20568 // these towards size check.
20569
20570 Worklist.push_back(Elt: RootNode);
20571 while (!Worklist.empty()) {
20572 auto N = Worklist.pop_back_val();
20573 if (!Visited.insert(Ptr: N).second)
20574 continue; // Already present in Visited.
20575 if (N->getOpcode() == ISD::TokenFactor) {
20576 for (SDValue Op : N->ops())
20577 Worklist.push_back(Elt: Op.getNode());
20578 }
20579 }
20580
20581 // Don't count pruning nodes towards max.
20582 unsigned int Max = 1024 + Visited.size();
20583 // Search Ops of store candidates.
20584 for (unsigned i = 0; i < NumStores; ++i) {
20585 SDNode *N = StoreNodes[i].MemNode;
20586 // Of the 4 Store Operands:
20587 // * Chain (Op 0) -> We have already considered these
20588 // in candidate selection, but only by following the
20589 // chain dependencies. We could still have a chain
20590 // dependency to a load, that has a non-chain dep to
20591 // another load, that depends on a store, etc. So it is
20592 // possible to have dependencies that consist of a mix
20593 // of chain and non-chain deps, and we need to include
20594 // chain operands in the analysis here..
20595 // * Value (Op 1) -> Cycles may happen (e.g. through load chains)
20596 // * Address (Op 2) -> Merged addresses may only vary by a fixed constant,
20597 // but aren't necessarily fromt the same base node, so
20598 // cycles possible (e.g. via indexed store).
20599 // * (Op 3) -> Represents the pre or post-indexing offset (or undef for
20600 // non-indexed stores). Not constant on all targets (e.g. ARM)
20601 // and so can participate in a cycle.
20602 for (const SDValue &Op : N->op_values())
20603 Worklist.push_back(Elt: Op.getNode());
20604 }
20605 // Search through DAG. We can stop early if we find a store node.
20606 for (unsigned i = 0; i < NumStores; ++i)
20607 if (SDNode::hasPredecessorHelper(N: StoreNodes[i].MemNode, Visited, Worklist,
20608 MaxSteps: Max)) {
20609 // If the searching bail out, record the StoreNode and RootNode in the
20610 // StoreRootCountMap. If we have seen the pair many times over a limit,
20611 // we won't add the StoreNode into StoreNodes set again.
20612 if (Visited.size() >= Max) {
20613 auto &RootCount = StoreRootCountMap[StoreNodes[i].MemNode];
20614 if (RootCount.first == RootNode)
20615 RootCount.second++;
20616 else
20617 RootCount = {RootNode, 1};
20618 }
20619 return false;
20620 }
20621 return true;
20622}
20623
20624unsigned
20625DAGCombiner::getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
20626 int64_t ElementSizeBytes) const {
20627 while (true) {
20628 // Find a store past the width of the first store.
20629 size_t StartIdx = 0;
20630 while ((StartIdx + 1 < StoreNodes.size()) &&
20631 StoreNodes[StartIdx].OffsetFromBase + ElementSizeBytes !=
20632 StoreNodes[StartIdx + 1].OffsetFromBase)
20633 ++StartIdx;
20634
20635 // Bail if we don't have enough candidates to merge.
20636 if (StartIdx + 1 >= StoreNodes.size())
20637 return 0;
20638
20639 // Trim stores that overlapped with the first store.
20640 if (StartIdx)
20641 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + StartIdx);
20642
20643 // Scan the memory operations on the chain and find the first
20644 // non-consecutive store memory address.
20645 unsigned NumConsecutiveStores = 1;
20646 int64_t StartAddress = StoreNodes[0].OffsetFromBase;
20647 // Check that the addresses are consecutive starting from the second
20648 // element in the list of stores.
20649 for (unsigned i = 1, e = StoreNodes.size(); i < e; ++i) {
20650 int64_t CurrAddress = StoreNodes[i].OffsetFromBase;
20651 if (CurrAddress - StartAddress != (ElementSizeBytes * i))
20652 break;
20653 NumConsecutiveStores = i + 1;
20654 }
20655 if (NumConsecutiveStores > 1)
20656 return NumConsecutiveStores;
20657
20658 // There are no consecutive stores at the start of the list.
20659 // Remove the first store and try again.
20660 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + 1);
20661 }
20662}
20663
20664bool DAGCombiner::tryStoreMergeOfConstants(
20665 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores,
20666 EVT MemVT, SDNode *RootNode, bool AllowVectors) {
20667 LLVMContext &Context = *DAG.getContext();
20668 const DataLayout &DL = DAG.getDataLayout();
20669 int64_t ElementSizeBytes = MemVT.getStoreSize();
20670 unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
20671 bool MadeChange = false;
20672
20673 // Store the constants into memory as one consecutive store.
20674 while (NumConsecutiveStores >= 2) {
20675 LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
20676 unsigned FirstStoreAS = FirstInChain->getAddressSpace();
20677 Align FirstStoreAlign = FirstInChain->getAlign();
20678 unsigned LastLegalType = 1;
20679 unsigned LastLegalVectorType = 1;
20680 bool LastIntegerTrunc = false;
20681 bool NonZero = false;
20682 unsigned FirstZeroAfterNonZero = NumConsecutiveStores;
20683 for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
20684 StoreSDNode *ST = cast<StoreSDNode>(Val: StoreNodes[i].MemNode);
20685 SDValue StoredVal = ST->getValue();
20686 bool IsElementZero = false;
20687 if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val&: StoredVal))
20688 IsElementZero = C->isZero();
20689 else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Val&: StoredVal))
20690 IsElementZero = C->getConstantFPValue()->isNullValue();
20691 else if (ISD::isBuildVectorAllZeros(N: StoredVal.getNode()))
20692 IsElementZero = true;
20693 if (IsElementZero) {
20694 if (NonZero && FirstZeroAfterNonZero == NumConsecutiveStores)
20695 FirstZeroAfterNonZero = i;
20696 }
20697 NonZero |= !IsElementZero;
20698
20699 // Find a legal type for the constant store.
20700 unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
20701 EVT StoreTy = EVT::getIntegerVT(Context, BitWidth: SizeInBits);
20702 unsigned IsFast = 0;
20703
20704 // Break early when size is too large to be legal.
20705 if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
20706 break;
20707
20708 if (TLI.isTypeLegal(VT: StoreTy) &&
20709 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: StoreTy,
20710 MF: DAG.getMachineFunction()) &&
20711 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
20712 MMO: *FirstInChain->getMemOperand(), Fast: &IsFast) &&
20713 IsFast) {
20714 LastIntegerTrunc = false;
20715 LastLegalType = i + 1;
20716 // Or check whether a truncstore is legal.
20717 } else if (TLI.getTypeAction(Context, VT: StoreTy) ==
20718 TargetLowering::TypePromoteInteger) {
20719 EVT LegalizedStoredValTy =
20720 TLI.getTypeToTransformTo(Context, VT: StoredVal.getValueType());
20721 if (TLI.isTruncStoreLegal(ValVT: LegalizedStoredValTy, MemVT: StoreTy) &&
20722 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: LegalizedStoredValTy,
20723 MF: DAG.getMachineFunction()) &&
20724 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
20725 MMO: *FirstInChain->getMemOperand(), Fast: &IsFast) &&
20726 IsFast) {
20727 LastIntegerTrunc = true;
20728 LastLegalType = i + 1;
20729 }
20730 }
20731
20732 // We only use vectors if the target allows it and the function is not
20733 // marked with the noimplicitfloat attribute.
20734 if (TLI.storeOfVectorConstantIsCheap(IsZero: !NonZero, MemVT, NumElem: i + 1, AddrSpace: FirstStoreAS) &&
20735 AllowVectors) {
20736 // Find a legal type for the vector store.
20737 unsigned Elts = (i + 1) * NumMemElts;
20738 EVT Ty = EVT::getVectorVT(Context, VT: MemVT.getScalarType(), NumElements: Elts);
20739 if (TLI.isTypeLegal(VT: Ty) && TLI.isTypeLegal(VT: MemVT) &&
20740 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: Ty, MF: DAG.getMachineFunction()) &&
20741 TLI.allowsMemoryAccess(Context, DL, VT: Ty,
20742 MMO: *FirstInChain->getMemOperand(), Fast: &IsFast) &&
20743 IsFast)
20744 LastLegalVectorType = i + 1;
20745 }
20746 }
20747
20748 bool UseVector = (LastLegalVectorType > LastLegalType) && AllowVectors;
20749 unsigned NumElem = (UseVector) ? LastLegalVectorType : LastLegalType;
20750 bool UseTrunc = LastIntegerTrunc && !UseVector;
20751
20752 // Check if we found a legal integer type that creates a meaningful
20753 // merge.
20754 if (NumElem < 2) {
20755 // We know that candidate stores are in order and of correct
20756 // shape. While there is no mergeable sequence from the
20757 // beginning one may start later in the sequence. The only
20758 // reason a merge of size N could have failed where another of
20759 // the same size would not have, is if the alignment has
20760 // improved or we've dropped a non-zero value. Drop as many
20761 // candidates as we can here.
20762 unsigned NumSkip = 1;
20763 while ((NumSkip < NumConsecutiveStores) &&
20764 (NumSkip < FirstZeroAfterNonZero) &&
20765 (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
20766 NumSkip++;
20767
20768 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumSkip);
20769 NumConsecutiveStores -= NumSkip;
20770 continue;
20771 }
20772
20773 // Check that we can merge these candidates without causing a cycle.
20774 if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumStores: NumElem,
20775 RootNode)) {
20776 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumElem);
20777 NumConsecutiveStores -= NumElem;
20778 continue;
20779 }
20780
20781 MadeChange |= mergeStoresOfConstantsOrVecElts(StoreNodes, MemVT, NumStores: NumElem,
20782 /*IsConstantSrc*/ true,
20783 UseVector, UseTrunc);
20784
20785 // Remove merged stores for next iteration.
20786 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumElem);
20787 NumConsecutiveStores -= NumElem;
20788 }
20789 return MadeChange;
20790}
20791
20792bool DAGCombiner::tryStoreMergeOfExtracts(
20793 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores,
20794 EVT MemVT, SDNode *RootNode) {
20795 LLVMContext &Context = *DAG.getContext();
20796 const DataLayout &DL = DAG.getDataLayout();
20797 unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
20798 bool MadeChange = false;
20799
20800 // Loop on Consecutive Stores on success.
20801 while (NumConsecutiveStores >= 2) {
20802 LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
20803 unsigned FirstStoreAS = FirstInChain->getAddressSpace();
20804 Align FirstStoreAlign = FirstInChain->getAlign();
20805 unsigned NumStoresToMerge = 1;
20806 for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
20807 // Find a legal type for the vector store.
20808 unsigned Elts = (i + 1) * NumMemElts;
20809 EVT Ty = EVT::getVectorVT(Context&: *DAG.getContext(), VT: MemVT.getScalarType(), NumElements: Elts);
20810 unsigned IsFast = 0;
20811
20812 // Break early when size is too large to be legal.
20813 if (Ty.getSizeInBits() > MaximumLegalStoreInBits)
20814 break;
20815
20816 if (TLI.isTypeLegal(VT: Ty) &&
20817 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: Ty, MF: DAG.getMachineFunction()) &&
20818 TLI.allowsMemoryAccess(Context, DL, VT: Ty,
20819 MMO: *FirstInChain->getMemOperand(), Fast: &IsFast) &&
20820 IsFast)
20821 NumStoresToMerge = i + 1;
20822 }
20823
20824 // Check if we found a legal integer type creating a meaningful
20825 // merge.
20826 if (NumStoresToMerge < 2) {
20827 // We know that candidate stores are in order and of correct
20828 // shape. While there is no mergeable sequence from the
20829 // beginning one may start later in the sequence. The only
20830 // reason a merge of size N could have failed where another of
20831 // the same size would not have, is if the alignment has
20832 // improved. Drop as many candidates as we can here.
20833 unsigned NumSkip = 1;
20834 while ((NumSkip < NumConsecutiveStores) &&
20835 (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
20836 NumSkip++;
20837
20838 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumSkip);
20839 NumConsecutiveStores -= NumSkip;
20840 continue;
20841 }
20842
20843 // Check that we can merge these candidates without causing a cycle.
20844 if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumStores: NumStoresToMerge,
20845 RootNode)) {
20846 StoreNodes.erase(CS: StoreNodes.begin(),
20847 CE: StoreNodes.begin() + NumStoresToMerge);
20848 NumConsecutiveStores -= NumStoresToMerge;
20849 continue;
20850 }
20851
20852 MadeChange |= mergeStoresOfConstantsOrVecElts(
20853 StoreNodes, MemVT, NumStores: NumStoresToMerge, /*IsConstantSrc*/ false,
20854 /*UseVector*/ true, /*UseTrunc*/ false);
20855
20856 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumStoresToMerge);
20857 NumConsecutiveStores -= NumStoresToMerge;
20858 }
20859 return MadeChange;
20860}
20861
20862bool DAGCombiner::tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
20863 unsigned NumConsecutiveStores, EVT MemVT,
20864 SDNode *RootNode, bool AllowVectors,
20865 bool IsNonTemporalStore,
20866 bool IsNonTemporalLoad) {
20867 LLVMContext &Context = *DAG.getContext();
20868 const DataLayout &DL = DAG.getDataLayout();
20869 int64_t ElementSizeBytes = MemVT.getStoreSize();
20870 unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
20871 bool MadeChange = false;
20872
20873 // Look for load nodes which are used by the stored values.
20874 SmallVector<MemOpLink, 8> LoadNodes;
20875
20876 // Find acceptable loads. Loads need to have the same chain (token factor),
20877 // must not be zext, volatile, indexed, and they must be consecutive.
20878 BaseIndexOffset LdBasePtr;
20879
20880 for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
20881 StoreSDNode *St = cast<StoreSDNode>(Val: StoreNodes[i].MemNode);
20882 SDValue Val = peekThroughBitcasts(V: St->getValue());
20883 LoadSDNode *Ld = cast<LoadSDNode>(Val);
20884
20885 BaseIndexOffset LdPtr = BaseIndexOffset::match(N: Ld, DAG);
20886 // If this is not the first ptr that we check.
20887 int64_t LdOffset = 0;
20888 if (LdBasePtr.getBase().getNode()) {
20889 // The base ptr must be the same.
20890 if (!LdBasePtr.equalBaseIndex(Other: LdPtr, DAG, Off&: LdOffset))
20891 break;
20892 } else {
20893 // Check that all other base pointers are the same as this one.
20894 LdBasePtr = LdPtr;
20895 }
20896
20897 // We found a potential memory operand to merge.
20898 LoadNodes.push_back(Elt: MemOpLink(Ld, LdOffset));
20899 }
20900
20901 while (NumConsecutiveStores >= 2 && LoadNodes.size() >= 2) {
20902 Align RequiredAlignment;
20903 bool NeedRotate = false;
20904 if (LoadNodes.size() == 2) {
20905 // If we have load/store pair instructions and we only have two values,
20906 // don't bother merging.
20907 if (TLI.hasPairedLoad(MemVT, RequiredAlignment) &&
20908 StoreNodes[0].MemNode->getAlign() >= RequiredAlignment) {
20909 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + 2);
20910 LoadNodes.erase(CS: LoadNodes.begin(), CE: LoadNodes.begin() + 2);
20911 break;
20912 }
20913 // If the loads are reversed, see if we can rotate the halves into place.
20914 int64_t Offset0 = LoadNodes[0].OffsetFromBase;
20915 int64_t Offset1 = LoadNodes[1].OffsetFromBase;
20916 EVT PairVT = EVT::getIntegerVT(Context, BitWidth: ElementSizeBytes * 8 * 2);
20917 if (Offset0 - Offset1 == ElementSizeBytes &&
20918 (hasOperation(Opcode: ISD::ROTL, VT: PairVT) ||
20919 hasOperation(Opcode: ISD::ROTR, VT: PairVT))) {
20920 std::swap(a&: LoadNodes[0], b&: LoadNodes[1]);
20921 NeedRotate = true;
20922 }
20923 }
20924 LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
20925 unsigned FirstStoreAS = FirstInChain->getAddressSpace();
20926 Align FirstStoreAlign = FirstInChain->getAlign();
20927 LoadSDNode *FirstLoad = cast<LoadSDNode>(Val: LoadNodes[0].MemNode);
20928
20929 // Scan the memory operations on the chain and find the first
20930 // non-consecutive load memory address. These variables hold the index in
20931 // the store node array.
20932
20933 unsigned LastConsecutiveLoad = 1;
20934
20935 // This variable refers to the size and not index in the array.
20936 unsigned LastLegalVectorType = 1;
20937 unsigned LastLegalIntegerType = 1;
20938 bool isDereferenceable = true;
20939 bool DoIntegerTruncate = false;
20940 int64_t StartAddress = LoadNodes[0].OffsetFromBase;
20941 SDValue LoadChain = FirstLoad->getChain();
20942 for (unsigned i = 1; i < LoadNodes.size(); ++i) {
20943 // All loads must share the same chain.
20944 if (LoadNodes[i].MemNode->getChain() != LoadChain)
20945 break;
20946
20947 int64_t CurrAddress = LoadNodes[i].OffsetFromBase;
20948 if (CurrAddress - StartAddress != (ElementSizeBytes * i))
20949 break;
20950 LastConsecutiveLoad = i;
20951
20952 if (isDereferenceable && !LoadNodes[i].MemNode->isDereferenceable())
20953 isDereferenceable = false;
20954
20955 // Find a legal type for the vector store.
20956 unsigned Elts = (i + 1) * NumMemElts;
20957 EVT StoreTy = EVT::getVectorVT(Context, VT: MemVT.getScalarType(), NumElements: Elts);
20958
20959 // Break early when size is too large to be legal.
20960 if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
20961 break;
20962
20963 unsigned IsFastSt = 0;
20964 unsigned IsFastLd = 0;
20965 // Don't try vector types if we need a rotate. We may still fail the
20966 // legality checks for the integer type, but we can't handle the rotate
20967 // case with vectors.
20968 // FIXME: We could use a shuffle in place of the rotate.
20969 if (!NeedRotate && TLI.isTypeLegal(VT: StoreTy) &&
20970 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: StoreTy,
20971 MF: DAG.getMachineFunction()) &&
20972 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
20973 MMO: *FirstInChain->getMemOperand(), Fast: &IsFastSt) &&
20974 IsFastSt &&
20975 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
20976 MMO: *FirstLoad->getMemOperand(), Fast: &IsFastLd) &&
20977 IsFastLd) {
20978 LastLegalVectorType = i + 1;
20979 }
20980
20981 // Find a legal type for the integer store.
20982 unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
20983 StoreTy = EVT::getIntegerVT(Context, BitWidth: SizeInBits);
20984 if (TLI.isTypeLegal(VT: StoreTy) &&
20985 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: StoreTy,
20986 MF: DAG.getMachineFunction()) &&
20987 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
20988 MMO: *FirstInChain->getMemOperand(), Fast: &IsFastSt) &&
20989 IsFastSt &&
20990 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
20991 MMO: *FirstLoad->getMemOperand(), Fast: &IsFastLd) &&
20992 IsFastLd) {
20993 LastLegalIntegerType = i + 1;
20994 DoIntegerTruncate = false;
20995 // Or check whether a truncstore and extload is legal.
20996 } else if (TLI.getTypeAction(Context, VT: StoreTy) ==
20997 TargetLowering::TypePromoteInteger) {
20998 EVT LegalizedStoredValTy = TLI.getTypeToTransformTo(Context, VT: StoreTy);
20999 if (TLI.isTruncStoreLegal(ValVT: LegalizedStoredValTy, MemVT: StoreTy) &&
21000 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: LegalizedStoredValTy,
21001 MF: DAG.getMachineFunction()) &&
21002 TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: LegalizedStoredValTy, MemVT: StoreTy) &&
21003 TLI.isLoadExtLegal(ExtType: ISD::SEXTLOAD, ValVT: LegalizedStoredValTy, MemVT: StoreTy) &&
21004 TLI.isLoadExtLegal(ExtType: ISD::EXTLOAD, ValVT: LegalizedStoredValTy, MemVT: StoreTy) &&
21005 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
21006 MMO: *FirstInChain->getMemOperand(), Fast: &IsFastSt) &&
21007 IsFastSt &&
21008 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
21009 MMO: *FirstLoad->getMemOperand(), Fast: &IsFastLd) &&
21010 IsFastLd) {
21011 LastLegalIntegerType = i + 1;
21012 DoIntegerTruncate = true;
21013 }
21014 }
21015 }
21016
21017 // Only use vector types if the vector type is larger than the integer
21018 // type. If they are the same, use integers.
21019 bool UseVectorTy =
21020 LastLegalVectorType > LastLegalIntegerType && AllowVectors;
21021 unsigned LastLegalType =
21022 std::max(a: LastLegalVectorType, b: LastLegalIntegerType);
21023
21024 // We add +1 here because the LastXXX variables refer to location while
21025 // the NumElem refers to array/index size.
21026 unsigned NumElem = std::min(a: NumConsecutiveStores, b: LastConsecutiveLoad + 1);
21027 NumElem = std::min(a: LastLegalType, b: NumElem);
21028 Align FirstLoadAlign = FirstLoad->getAlign();
21029
21030 if (NumElem < 2) {
21031 // We know that candidate stores are in order and of correct
21032 // shape. While there is no mergeable sequence from the
21033 // beginning one may start later in the sequence. The only
21034 // reason a merge of size N could have failed where another of
21035 // the same size would not have is if the alignment or either
21036 // the load or store has improved. Drop as many candidates as we
21037 // can here.
21038 unsigned NumSkip = 1;
21039 while ((NumSkip < LoadNodes.size()) &&
21040 (LoadNodes[NumSkip].MemNode->getAlign() <= FirstLoadAlign) &&
21041 (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
21042 NumSkip++;
21043 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumSkip);
21044 LoadNodes.erase(CS: LoadNodes.begin(), CE: LoadNodes.begin() + NumSkip);
21045 NumConsecutiveStores -= NumSkip;
21046 continue;
21047 }
21048
21049 // Check that we can merge these candidates without causing a cycle.
21050 if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumStores: NumElem,
21051 RootNode)) {
21052 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumElem);
21053 LoadNodes.erase(CS: LoadNodes.begin(), CE: LoadNodes.begin() + NumElem);
21054 NumConsecutiveStores -= NumElem;
21055 continue;
21056 }
21057
21058 // Find if it is better to use vectors or integers to load and store
21059 // to memory.
21060 EVT JointMemOpVT;
21061 if (UseVectorTy) {
21062 // Find a legal type for the vector store.
21063 unsigned Elts = NumElem * NumMemElts;
21064 JointMemOpVT = EVT::getVectorVT(Context, VT: MemVT.getScalarType(), NumElements: Elts);
21065 } else {
21066 unsigned SizeInBits = NumElem * ElementSizeBytes * 8;
21067 JointMemOpVT = EVT::getIntegerVT(Context, BitWidth: SizeInBits);
21068 }
21069
21070 SDLoc LoadDL(LoadNodes[0].MemNode);
21071 SDLoc StoreDL(StoreNodes[0].MemNode);
21072
21073 // The merged loads are required to have the same incoming chain, so
21074 // using the first's chain is acceptable.
21075
21076 SDValue NewStoreChain = getMergeStoreChains(StoreNodes, NumStores: NumElem);
21077 bool CanReusePtrInfo = hasSameUnderlyingObj(StoreNodes);
21078 AddToWorklist(N: NewStoreChain.getNode());
21079
21080 MachineMemOperand::Flags LdMMOFlags =
21081 isDereferenceable ? MachineMemOperand::MODereferenceable
21082 : MachineMemOperand::MONone;
21083 if (IsNonTemporalLoad)
21084 LdMMOFlags |= MachineMemOperand::MONonTemporal;
21085
21086 LdMMOFlags |= TLI.getTargetMMOFlags(Node: *FirstLoad);
21087
21088 MachineMemOperand::Flags StMMOFlags = IsNonTemporalStore
21089 ? MachineMemOperand::MONonTemporal
21090 : MachineMemOperand::MONone;
21091
21092 StMMOFlags |= TLI.getTargetMMOFlags(Node: *StoreNodes[0].MemNode);
21093
21094 SDValue NewLoad, NewStore;
21095 if (UseVectorTy || !DoIntegerTruncate) {
21096 NewLoad = DAG.getLoad(
21097 VT: JointMemOpVT, dl: LoadDL, Chain: FirstLoad->getChain(), Ptr: FirstLoad->getBasePtr(),
21098 PtrInfo: FirstLoad->getPointerInfo(), Alignment: FirstLoadAlign, MMOFlags: LdMMOFlags);
21099 SDValue StoreOp = NewLoad;
21100 if (NeedRotate) {
21101 unsigned LoadWidth = ElementSizeBytes * 8 * 2;
21102 assert(JointMemOpVT == EVT::getIntegerVT(Context, LoadWidth) &&
21103 "Unexpected type for rotate-able load pair");
21104 SDValue RotAmt =
21105 DAG.getShiftAmountConstant(Val: LoadWidth / 2, VT: JointMemOpVT, DL: LoadDL);
21106 // Target can convert to the identical ROTR if it does not have ROTL.
21107 StoreOp = DAG.getNode(Opcode: ISD::ROTL, DL: LoadDL, VT: JointMemOpVT, N1: NewLoad, N2: RotAmt);
21108 }
21109 NewStore = DAG.getStore(
21110 Chain: NewStoreChain, dl: StoreDL, Val: StoreOp, Ptr: FirstInChain->getBasePtr(),
21111 PtrInfo: CanReusePtrInfo ? FirstInChain->getPointerInfo()
21112 : MachinePointerInfo(FirstStoreAS),
21113 Alignment: FirstStoreAlign, MMOFlags: StMMOFlags);
21114 } else { // This must be the truncstore/extload case
21115 EVT ExtendedTy =
21116 TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT: JointMemOpVT);
21117 NewLoad = DAG.getExtLoad(ExtType: ISD::EXTLOAD, dl: LoadDL, VT: ExtendedTy,
21118 Chain: FirstLoad->getChain(), Ptr: FirstLoad->getBasePtr(),
21119 PtrInfo: FirstLoad->getPointerInfo(), MemVT: JointMemOpVT,
21120 Alignment: FirstLoadAlign, MMOFlags: LdMMOFlags);
21121 NewStore = DAG.getTruncStore(
21122 Chain: NewStoreChain, dl: StoreDL, Val: NewLoad, Ptr: FirstInChain->getBasePtr(),
21123 PtrInfo: CanReusePtrInfo ? FirstInChain->getPointerInfo()
21124 : MachinePointerInfo(FirstStoreAS),
21125 SVT: JointMemOpVT, Alignment: FirstInChain->getAlign(),
21126 MMOFlags: FirstInChain->getMemOperand()->getFlags());
21127 }
21128
21129 // Transfer chain users from old loads to the new load.
21130 for (unsigned i = 0; i < NumElem; ++i) {
21131 LoadSDNode *Ld = cast<LoadSDNode>(Val: LoadNodes[i].MemNode);
21132 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Ld, 1),
21133 To: SDValue(NewLoad.getNode(), 1));
21134 }
21135
21136 // Replace all stores with the new store. Recursively remove corresponding
21137 // values if they are no longer used.
21138 for (unsigned i = 0; i < NumElem; ++i) {
21139 SDValue Val = StoreNodes[i].MemNode->getOperand(Num: 1);
21140 CombineTo(N: StoreNodes[i].MemNode, Res: NewStore);
21141 if (Val->use_empty())
21142 recursivelyDeleteUnusedNodes(N: Val.getNode());
21143 }
21144
21145 MadeChange = true;
21146 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumElem);
21147 LoadNodes.erase(CS: LoadNodes.begin(), CE: LoadNodes.begin() + NumElem);
21148 NumConsecutiveStores -= NumElem;
21149 }
21150 return MadeChange;
21151}
21152
21153bool DAGCombiner::mergeConsecutiveStores(StoreSDNode *St) {
21154 if (OptLevel == CodeGenOptLevel::None || !EnableStoreMerging)
21155 return false;
21156
21157 // TODO: Extend this function to merge stores of scalable vectors.
21158 // (i.e. two <vscale x 8 x i8> stores can be merged to one <vscale x 16 x i8>
21159 // store since we know <vscale x 16 x i8> is exactly twice as large as
21160 // <vscale x 8 x i8>). Until then, bail out for scalable vectors.
21161 EVT MemVT = St->getMemoryVT();
21162 if (MemVT.isScalableVT())
21163 return false;
21164 if (!MemVT.isSimple() || MemVT.getSizeInBits() * 2 > MaximumLegalStoreInBits)
21165 return false;
21166
21167 // This function cannot currently deal with non-byte-sized memory sizes.
21168 int64_t ElementSizeBytes = MemVT.getStoreSize();
21169 if (ElementSizeBytes * 8 != (int64_t)MemVT.getSizeInBits())
21170 return false;
21171
21172 // Do not bother looking at stored values that are not constants, loads, or
21173 // extracted vector elements.
21174 SDValue StoredVal = peekThroughBitcasts(V: St->getValue());
21175 const StoreSource StoreSrc = getStoreSource(StoreVal: StoredVal);
21176 if (StoreSrc == StoreSource::Unknown)
21177 return false;
21178
21179 SmallVector<MemOpLink, 8> StoreNodes;
21180 // Find potential store merge candidates by searching through chain sub-DAG
21181 SDNode *RootNode = getStoreMergeCandidates(St, StoreNodes);
21182
21183 // Check if there is anything to merge.
21184 if (StoreNodes.size() < 2)
21185 return false;
21186
21187 // Sort the memory operands according to their distance from the
21188 // base pointer.
21189 llvm::sort(C&: StoreNodes, Comp: [](MemOpLink LHS, MemOpLink RHS) {
21190 return LHS.OffsetFromBase < RHS.OffsetFromBase;
21191 });
21192
21193 bool AllowVectors = !DAG.getMachineFunction().getFunction().hasFnAttribute(
21194 Kind: Attribute::NoImplicitFloat);
21195 bool IsNonTemporalStore = St->isNonTemporal();
21196 bool IsNonTemporalLoad = StoreSrc == StoreSource::Load &&
21197 cast<LoadSDNode>(Val&: StoredVal)->isNonTemporal();
21198
21199 // Store Merge attempts to merge the lowest stores. This generally
21200 // works out as if successful, as the remaining stores are checked
21201 // after the first collection of stores is merged. However, in the
21202 // case that a non-mergeable store is found first, e.g., {p[-2],
21203 // p[0], p[1], p[2], p[3]}, we would fail and miss the subsequent
21204 // mergeable cases. To prevent this, we prune such stores from the
21205 // front of StoreNodes here.
21206 bool MadeChange = false;
21207 while (StoreNodes.size() > 1) {
21208 unsigned NumConsecutiveStores =
21209 getConsecutiveStores(StoreNodes, ElementSizeBytes);
21210 // There are no more stores in the list to examine.
21211 if (NumConsecutiveStores == 0)
21212 return MadeChange;
21213
21214 // We have at least 2 consecutive stores. Try to merge them.
21215 assert(NumConsecutiveStores >= 2 && "Expected at least 2 stores");
21216 switch (StoreSrc) {
21217 case StoreSource::Constant:
21218 MadeChange |= tryStoreMergeOfConstants(StoreNodes, NumConsecutiveStores,
21219 MemVT, RootNode, AllowVectors);
21220 break;
21221
21222 case StoreSource::Extract:
21223 MadeChange |= tryStoreMergeOfExtracts(StoreNodes, NumConsecutiveStores,
21224 MemVT, RootNode);
21225 break;
21226
21227 case StoreSource::Load:
21228 MadeChange |= tryStoreMergeOfLoads(StoreNodes, NumConsecutiveStores,
21229 MemVT, RootNode, AllowVectors,
21230 IsNonTemporalStore, IsNonTemporalLoad);
21231 break;
21232
21233 default:
21234 llvm_unreachable("Unhandled store source type");
21235 }
21236 }
21237
21238 // Remember if we failed to optimize, to save compile time.
21239 if (!MadeChange)
21240 ChainsWithoutMergeableStores.insert(Ptr: RootNode);
21241
21242 return MadeChange;
21243}
21244
21245SDValue DAGCombiner::replaceStoreChain(StoreSDNode *ST, SDValue BetterChain) {
21246 SDLoc SL(ST);
21247 SDValue ReplStore;
21248
21249 // Replace the chain to avoid dependency.
21250 if (ST->isTruncatingStore()) {
21251 ReplStore = DAG.getTruncStore(Chain: BetterChain, dl: SL, Val: ST->getValue(),
21252 Ptr: ST->getBasePtr(), SVT: ST->getMemoryVT(),
21253 MMO: ST->getMemOperand());
21254 } else {
21255 ReplStore = DAG.getStore(Chain: BetterChain, dl: SL, Val: ST->getValue(), Ptr: ST->getBasePtr(),
21256 MMO: ST->getMemOperand());
21257 }
21258
21259 // Create token to keep both nodes around.
21260 SDValue Token = DAG.getNode(Opcode: ISD::TokenFactor, DL: SL,
21261 VT: MVT::Other, N1: ST->getChain(), N2: ReplStore);
21262
21263 // Make sure the new and old chains are cleaned up.
21264 AddToWorklist(N: Token.getNode());
21265
21266 // Don't add users to work list.
21267 return CombineTo(N: ST, Res: Token, AddTo: false);
21268}
21269
21270SDValue DAGCombiner::replaceStoreOfFPConstant(StoreSDNode *ST) {
21271 SDValue Value = ST->getValue();
21272 if (Value.getOpcode() == ISD::TargetConstantFP)
21273 return SDValue();
21274
21275 if (!ISD::isNormalStore(N: ST))
21276 return SDValue();
21277
21278 SDLoc DL(ST);
21279
21280 SDValue Chain = ST->getChain();
21281 SDValue Ptr = ST->getBasePtr();
21282
21283 const ConstantFPSDNode *CFP = cast<ConstantFPSDNode>(Val&: Value);
21284
21285 // NOTE: If the original store is volatile, this transform must not increase
21286 // the number of stores. For example, on x86-32 an f64 can be stored in one
21287 // processor operation but an i64 (which is not legal) requires two. So the
21288 // transform should not be done in this case.
21289
21290 SDValue Tmp;
21291 switch (CFP->getSimpleValueType(ResNo: 0).SimpleTy) {
21292 default:
21293 llvm_unreachable("Unknown FP type");
21294 case MVT::f16: // We don't do this for these yet.
21295 case MVT::bf16:
21296 case MVT::f80:
21297 case MVT::f128:
21298 case MVT::ppcf128:
21299 return SDValue();
21300 case MVT::f32:
21301 if ((isTypeLegal(VT: MVT::i32) && !LegalOperations && ST->isSimple()) ||
21302 TLI.isOperationLegalOrCustom(Op: ISD::STORE, VT: MVT::i32)) {
21303 Tmp = DAG.getConstant(Val: (uint32_t)CFP->getValueAPF().
21304 bitcastToAPInt().getZExtValue(), DL: SDLoc(CFP),
21305 VT: MVT::i32);
21306 return DAG.getStore(Chain, dl: DL, Val: Tmp, Ptr, MMO: ST->getMemOperand());
21307 }
21308
21309 return SDValue();
21310 case MVT::f64:
21311 if ((TLI.isTypeLegal(VT: MVT::i64) && !LegalOperations &&
21312 ST->isSimple()) ||
21313 TLI.isOperationLegalOrCustom(Op: ISD::STORE, VT: MVT::i64)) {
21314 Tmp = DAG.getConstant(Val: CFP->getValueAPF().bitcastToAPInt().
21315 getZExtValue(), DL: SDLoc(CFP), VT: MVT::i64);
21316 return DAG.getStore(Chain, dl: DL, Val: Tmp,
21317 Ptr, MMO: ST->getMemOperand());
21318 }
21319
21320 if (ST->isSimple() && TLI.isOperationLegalOrCustom(Op: ISD::STORE, VT: MVT::i32) &&
21321 !TLI.isFPImmLegal(CFP->getValueAPF(), MVT::f64)) {
21322 // Many FP stores are not made apparent until after legalize, e.g. for
21323 // argument passing. Since this is so common, custom legalize the
21324 // 64-bit integer store into two 32-bit stores.
21325 uint64_t Val = CFP->getValueAPF().bitcastToAPInt().getZExtValue();
21326 SDValue Lo = DAG.getConstant(Val: Val & 0xFFFFFFFF, DL: SDLoc(CFP), VT: MVT::i32);
21327 SDValue Hi = DAG.getConstant(Val: Val >> 32, DL: SDLoc(CFP), VT: MVT::i32);
21328 if (DAG.getDataLayout().isBigEndian())
21329 std::swap(a&: Lo, b&: Hi);
21330
21331 MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
21332 AAMDNodes AAInfo = ST->getAAInfo();
21333
21334 SDValue St0 = DAG.getStore(Chain, dl: DL, Val: Lo, Ptr, PtrInfo: ST->getPointerInfo(),
21335 Alignment: ST->getOriginalAlign(), MMOFlags, AAInfo);
21336 Ptr = DAG.getMemBasePlusOffset(Base: Ptr, Offset: TypeSize::getFixed(ExactSize: 4), DL);
21337 SDValue St1 = DAG.getStore(Chain, dl: DL, Val: Hi, Ptr,
21338 PtrInfo: ST->getPointerInfo().getWithOffset(O: 4),
21339 Alignment: ST->getOriginalAlign(), MMOFlags, AAInfo);
21340 return DAG.getNode(Opcode: ISD::TokenFactor, DL, VT: MVT::Other,
21341 N1: St0, N2: St1);
21342 }
21343
21344 return SDValue();
21345 }
21346}
21347
21348// (store (insert_vector_elt (load p), x, i), p) -> (store x, p+offset)
21349//
21350// If a store of a load with an element inserted into it has no other
21351// uses in between the chain, then we can consider the vector store
21352// dead and replace it with just the single scalar element store.
21353SDValue DAGCombiner::replaceStoreOfInsertLoad(StoreSDNode *ST) {
21354 SDLoc DL(ST);
21355 SDValue Value = ST->getValue();
21356 SDValue Ptr = ST->getBasePtr();
21357 SDValue Chain = ST->getChain();
21358 if (Value.getOpcode() != ISD::INSERT_VECTOR_ELT || !Value.hasOneUse())
21359 return SDValue();
21360
21361 SDValue Elt = Value.getOperand(i: 1);
21362 SDValue Idx = Value.getOperand(i: 2);
21363
21364 // If the element isn't byte sized or is implicitly truncated then we can't
21365 // compute an offset.
21366 EVT EltVT = Elt.getValueType();
21367 if (!EltVT.isByteSized() ||
21368 EltVT != Value.getOperand(i: 0).getValueType().getVectorElementType())
21369 return SDValue();
21370
21371 auto *Ld = dyn_cast<LoadSDNode>(Val: Value.getOperand(i: 0));
21372 if (!Ld || Ld->getBasePtr() != Ptr ||
21373 ST->getMemoryVT() != Ld->getMemoryVT() || !ST->isSimple() ||
21374 !ISD::isNormalStore(N: ST) ||
21375 Ld->getAddressSpace() != ST->getAddressSpace() ||
21376 !Chain.reachesChainWithoutSideEffects(Dest: SDValue(Ld, 1)))
21377 return SDValue();
21378
21379 unsigned IsFast;
21380 if (!TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(),
21381 VT: Elt.getValueType(), AddrSpace: ST->getAddressSpace(),
21382 Alignment: ST->getAlign(), Flags: ST->getMemOperand()->getFlags(),
21383 Fast: &IsFast) ||
21384 !IsFast)
21385 return SDValue();
21386
21387 MachinePointerInfo PointerInfo(ST->getAddressSpace());
21388
21389 // If the offset is a known constant then try to recover the pointer
21390 // info
21391 SDValue NewPtr;
21392 if (auto *CIdx = dyn_cast<ConstantSDNode>(Val&: Idx)) {
21393 unsigned COffset = CIdx->getSExtValue() * EltVT.getSizeInBits() / 8;
21394 NewPtr = DAG.getMemBasePlusOffset(Base: Ptr, Offset: TypeSize::getFixed(ExactSize: COffset), DL);
21395 PointerInfo = ST->getPointerInfo().getWithOffset(O: COffset);
21396 } else {
21397 NewPtr = TLI.getVectorElementPointer(DAG, VecPtr: Ptr, VecVT: Value.getValueType(), Index: Idx);
21398 }
21399
21400 return DAG.getStore(Chain, dl: DL, Val: Elt, Ptr: NewPtr, PtrInfo: PointerInfo, Alignment: ST->getAlign(),
21401 MMOFlags: ST->getMemOperand()->getFlags());
21402}
21403
21404SDValue DAGCombiner::visitATOMIC_STORE(SDNode *N) {
21405 AtomicSDNode *ST = cast<AtomicSDNode>(Val: N);
21406 SDValue Val = ST->getVal();
21407 EVT VT = Val.getValueType();
21408 EVT MemVT = ST->getMemoryVT();
21409
21410 if (MemVT.bitsLT(VT)) { // Is truncating store
21411 APInt TruncDemandedBits = APInt::getLowBitsSet(numBits: VT.getScalarSizeInBits(),
21412 loBitsSet: MemVT.getScalarSizeInBits());
21413 // See if we can simplify the operation with SimplifyDemandedBits, which
21414 // only works if the value has a single use.
21415 if (SimplifyDemandedBits(Op: Val, DemandedBits: TruncDemandedBits))
21416 return SDValue(N, 0);
21417 }
21418
21419 return SDValue();
21420}
21421
21422SDValue DAGCombiner::visitSTORE(SDNode *N) {
21423 StoreSDNode *ST = cast<StoreSDNode>(Val: N);
21424 SDValue Chain = ST->getChain();
21425 SDValue Value = ST->getValue();
21426 SDValue Ptr = ST->getBasePtr();
21427
21428 // If this is a store of a bit convert, store the input value if the
21429 // resultant store does not need a higher alignment than the original.
21430 if (Value.getOpcode() == ISD::BITCAST && !ST->isTruncatingStore() &&
21431 ST->isUnindexed()) {
21432 EVT SVT = Value.getOperand(i: 0).getValueType();
21433 // If the store is volatile, we only want to change the store type if the
21434 // resulting store is legal. Otherwise we might increase the number of
21435 // memory accesses. We don't care if the original type was legal or not
21436 // as we assume software couldn't rely on the number of accesses of an
21437 // illegal type.
21438 // TODO: May be able to relax for unordered atomics (see D66309)
21439 if (((!LegalOperations && ST->isSimple()) ||
21440 TLI.isOperationLegal(Op: ISD::STORE, VT: SVT)) &&
21441 TLI.isStoreBitCastBeneficial(StoreVT: Value.getValueType(), BitcastVT: SVT,
21442 DAG, MMO: *ST->getMemOperand())) {
21443 return DAG.getStore(Chain, dl: SDLoc(N), Val: Value.getOperand(i: 0), Ptr,
21444 MMO: ST->getMemOperand());
21445 }
21446 }
21447
21448 // Turn 'store undef, Ptr' -> nothing.
21449 if (Value.isUndef() && ST->isUnindexed())
21450 return Chain;
21451
21452 // Try to infer better alignment information than the store already has.
21453 if (OptLevel != CodeGenOptLevel::None && ST->isUnindexed() &&
21454 !ST->isAtomic()) {
21455 if (MaybeAlign Alignment = DAG.InferPtrAlign(Ptr)) {
21456 if (*Alignment > ST->getAlign() &&
21457 isAligned(Lhs: *Alignment, SizeInBytes: ST->getSrcValueOffset())) {
21458 SDValue NewStore =
21459 DAG.getTruncStore(Chain, dl: SDLoc(N), Val: Value, Ptr, PtrInfo: ST->getPointerInfo(),
21460 SVT: ST->getMemoryVT(), Alignment: *Alignment,
21461 MMOFlags: ST->getMemOperand()->getFlags(), AAInfo: ST->getAAInfo());
21462 // NewStore will always be N as we are only refining the alignment
21463 assert(NewStore.getNode() == N);
21464 (void)NewStore;
21465 }
21466 }
21467 }
21468
21469 // Try transforming a pair floating point load / store ops to integer
21470 // load / store ops.
21471 if (SDValue NewST = TransformFPLoadStorePair(N))
21472 return NewST;
21473
21474 // Try transforming several stores into STORE (BSWAP).
21475 if (SDValue Store = mergeTruncStores(N: ST))
21476 return Store;
21477
21478 if (ST->isUnindexed()) {
21479 // Walk up chain skipping non-aliasing memory nodes, on this store and any
21480 // adjacent stores.
21481 if (findBetterNeighborChains(St: ST)) {
21482 // replaceStoreChain uses CombineTo, which handled all of the worklist
21483 // manipulation. Return the original node to not do anything else.
21484 return SDValue(ST, 0);
21485 }
21486 Chain = ST->getChain();
21487 }
21488
21489 // FIXME: is there such a thing as a truncating indexed store?
21490 if (ST->isTruncatingStore() && ST->isUnindexed() &&
21491 Value.getValueType().isInteger() &&
21492 (!isa<ConstantSDNode>(Val: Value) ||
21493 !cast<ConstantSDNode>(Val&: Value)->isOpaque())) {
21494 // Convert a truncating store of a extension into a standard store.
21495 if ((Value.getOpcode() == ISD::ZERO_EXTEND ||
21496 Value.getOpcode() == ISD::SIGN_EXTEND ||
21497 Value.getOpcode() == ISD::ANY_EXTEND) &&
21498 Value.getOperand(i: 0).getValueType() == ST->getMemoryVT() &&
21499 TLI.isOperationLegalOrCustom(Op: ISD::STORE, VT: ST->getMemoryVT()))
21500 return DAG.getStore(Chain, dl: SDLoc(N), Val: Value.getOperand(i: 0), Ptr,
21501 MMO: ST->getMemOperand());
21502
21503 APInt TruncDemandedBits =
21504 APInt::getLowBitsSet(numBits: Value.getScalarValueSizeInBits(),
21505 loBitsSet: ST->getMemoryVT().getScalarSizeInBits());
21506
21507 // See if we can simplify the operation with SimplifyDemandedBits, which
21508 // only works if the value has a single use.
21509 AddToWorklist(N: Value.getNode());
21510 if (SimplifyDemandedBits(Op: Value, DemandedBits: TruncDemandedBits)) {
21511 // Re-visit the store if anything changed and the store hasn't been merged
21512 // with another node (N is deleted) SimplifyDemandedBits will add Value's
21513 // node back to the worklist if necessary, but we also need to re-visit
21514 // the Store node itself.
21515 if (N->getOpcode() != ISD::DELETED_NODE)
21516 AddToWorklist(N);
21517 return SDValue(N, 0);
21518 }
21519
21520 // Otherwise, see if we can simplify the input to this truncstore with
21521 // knowledge that only the low bits are being used. For example:
21522 // "truncstore (or (shl x, 8), y), i8" -> "truncstore y, i8"
21523 if (SDValue Shorter =
21524 TLI.SimplifyMultipleUseDemandedBits(Op: Value, DemandedBits: TruncDemandedBits, DAG))
21525 return DAG.getTruncStore(Chain, dl: SDLoc(N), Val: Shorter, Ptr, SVT: ST->getMemoryVT(),
21526 MMO: ST->getMemOperand());
21527
21528 // If we're storing a truncated constant, see if we can simplify it.
21529 // TODO: Move this to targetShrinkDemandedConstant?
21530 if (auto *Cst = dyn_cast<ConstantSDNode>(Val&: Value))
21531 if (!Cst->isOpaque()) {
21532 const APInt &CValue = Cst->getAPIntValue();
21533 APInt NewVal = CValue & TruncDemandedBits;
21534 if (NewVal != CValue) {
21535 SDValue Shorter =
21536 DAG.getConstant(Val: NewVal, DL: SDLoc(N), VT: Value.getValueType());
21537 return DAG.getTruncStore(Chain, dl: SDLoc(N), Val: Shorter, Ptr,
21538 SVT: ST->getMemoryVT(), MMO: ST->getMemOperand());
21539 }
21540 }
21541 }
21542
21543 // If this is a load followed by a store to the same location, then the store
21544 // is dead/noop. Peek through any truncates if canCombineTruncStore failed.
21545 // TODO: Add big-endian truncate support with test coverage.
21546 // TODO: Can relax for unordered atomics (see D66309)
21547 SDValue TruncVal = DAG.getDataLayout().isLittleEndian()
21548 ? peekThroughTruncates(V: Value)
21549 : Value;
21550 if (auto *Ld = dyn_cast<LoadSDNode>(Val&: TruncVal)) {
21551 if (Ld->getBasePtr() == Ptr && ST->getMemoryVT() == Ld->getMemoryVT() &&
21552 ST->isUnindexed() && ST->isSimple() &&
21553 Ld->getAddressSpace() == ST->getAddressSpace() &&
21554 // There can't be any side effects between the load and store, such as
21555 // a call or store.
21556 Chain.reachesChainWithoutSideEffects(Dest: SDValue(Ld, 1))) {
21557 // The store is dead, remove it.
21558 return Chain;
21559 }
21560 }
21561
21562 // Try scalarizing vector stores of loads where we only change one element
21563 if (SDValue NewST = replaceStoreOfInsertLoad(ST))
21564 return NewST;
21565
21566 // TODO: Can relax for unordered atomics (see D66309)
21567 if (StoreSDNode *ST1 = dyn_cast<StoreSDNode>(Val&: Chain)) {
21568 if (ST->isUnindexed() && ST->isSimple() &&
21569 ST1->isUnindexed() && ST1->isSimple()) {
21570 if (OptLevel != CodeGenOptLevel::None && ST1->getBasePtr() == Ptr &&
21571 ST1->getValue() == Value && ST->getMemoryVT() == ST1->getMemoryVT() &&
21572 ST->getAddressSpace() == ST1->getAddressSpace()) {
21573 // If this is a store followed by a store with the same value to the
21574 // same location, then the store is dead/noop.
21575 return Chain;
21576 }
21577
21578 if (OptLevel != CodeGenOptLevel::None && ST1->hasOneUse() &&
21579 !ST1->getBasePtr().isUndef() &&
21580 ST->getAddressSpace() == ST1->getAddressSpace()) {
21581 // If we consider two stores and one smaller in size is a scalable
21582 // vector type and another one a bigger size store with a fixed type,
21583 // then we could not allow the scalable store removal because we don't
21584 // know its final size in the end.
21585 if (ST->getMemoryVT().isScalableVector() ||
21586 ST1->getMemoryVT().isScalableVector()) {
21587 if (ST1->getBasePtr() == Ptr &&
21588 TypeSize::isKnownLE(LHS: ST1->getMemoryVT().getStoreSize(),
21589 RHS: ST->getMemoryVT().getStoreSize())) {
21590 CombineTo(N: ST1, Res: ST1->getChain());
21591 return SDValue(N, 0);
21592 }
21593 } else {
21594 const BaseIndexOffset STBase = BaseIndexOffset::match(N: ST, DAG);
21595 const BaseIndexOffset ChainBase = BaseIndexOffset::match(N: ST1, DAG);
21596 // If this is a store who's preceding store to a subset of the current
21597 // location and no one other node is chained to that store we can
21598 // effectively drop the store. Do not remove stores to undef as they
21599 // may be used as data sinks.
21600 if (STBase.contains(DAG, BitSize: ST->getMemoryVT().getFixedSizeInBits(),
21601 Other: ChainBase,
21602 OtherBitSize: ST1->getMemoryVT().getFixedSizeInBits())) {
21603 CombineTo(N: ST1, Res: ST1->getChain());
21604 return SDValue(N, 0);
21605 }
21606 }
21607 }
21608 }
21609 }
21610
21611 // If this is an FP_ROUND or TRUNC followed by a store, fold this into a
21612 // truncating store. We can do this even if this is already a truncstore.
21613 if ((Value.getOpcode() == ISD::FP_ROUND ||
21614 Value.getOpcode() == ISD::TRUNCATE) &&
21615 Value->hasOneUse() && ST->isUnindexed() &&
21616 TLI.canCombineTruncStore(ValVT: Value.getOperand(i: 0).getValueType(),
21617 MemVT: ST->getMemoryVT(), LegalOnly: LegalOperations)) {
21618 return DAG.getTruncStore(Chain, dl: SDLoc(N), Val: Value.getOperand(i: 0),
21619 Ptr, SVT: ST->getMemoryVT(), MMO: ST->getMemOperand());
21620 }
21621
21622 // Always perform this optimization before types are legal. If the target
21623 // prefers, also try this after legalization to catch stores that were created
21624 // by intrinsics or other nodes.
21625 if (!LegalTypes || (TLI.mergeStoresAfterLegalization(MemVT: ST->getMemoryVT()))) {
21626 while (true) {
21627 // There can be multiple store sequences on the same chain.
21628 // Keep trying to merge store sequences until we are unable to do so
21629 // or until we merge the last store on the chain.
21630 bool Changed = mergeConsecutiveStores(St: ST);
21631 if (!Changed) break;
21632 // Return N as merge only uses CombineTo and no worklist clean
21633 // up is necessary.
21634 if (N->getOpcode() == ISD::DELETED_NODE || !isa<StoreSDNode>(Val: N))
21635 return SDValue(N, 0);
21636 }
21637 }
21638
21639 // Try transforming N to an indexed store.
21640 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
21641 return SDValue(N, 0);
21642
21643 // Turn 'store float 1.0, Ptr' -> 'store int 0x12345678, Ptr'
21644 //
21645 // Make sure to do this only after attempting to merge stores in order to
21646 // avoid changing the types of some subset of stores due to visit order,
21647 // preventing their merging.
21648 if (isa<ConstantFPSDNode>(Val: ST->getValue())) {
21649 if (SDValue NewSt = replaceStoreOfFPConstant(ST))
21650 return NewSt;
21651 }
21652
21653 if (SDValue NewSt = splitMergedValStore(ST))
21654 return NewSt;
21655
21656 return ReduceLoadOpStoreWidth(N);
21657}
21658
21659SDValue DAGCombiner::visitLIFETIME_END(SDNode *N) {
21660 const auto *LifetimeEnd = cast<LifetimeSDNode>(Val: N);
21661 if (!LifetimeEnd->hasOffset())
21662 return SDValue();
21663
21664 const BaseIndexOffset LifetimeEndBase(N->getOperand(Num: 1), SDValue(),
21665 LifetimeEnd->getOffset(), false);
21666
21667 // We walk up the chains to find stores.
21668 SmallVector<SDValue, 8> Chains = {N->getOperand(Num: 0)};
21669 while (!Chains.empty()) {
21670 SDValue Chain = Chains.pop_back_val();
21671 if (!Chain.hasOneUse())
21672 continue;
21673 switch (Chain.getOpcode()) {
21674 case ISD::TokenFactor:
21675 for (unsigned Nops = Chain.getNumOperands(); Nops;)
21676 Chains.push_back(Elt: Chain.getOperand(i: --Nops));
21677 break;
21678 case ISD::LIFETIME_START:
21679 case ISD::LIFETIME_END:
21680 // We can forward past any lifetime start/end that can be proven not to
21681 // alias the node.
21682 if (!mayAlias(Op0: Chain.getNode(), Op1: N))
21683 Chains.push_back(Elt: Chain.getOperand(i: 0));
21684 break;
21685 case ISD::STORE: {
21686 StoreSDNode *ST = dyn_cast<StoreSDNode>(Val&: Chain);
21687 // TODO: Can relax for unordered atomics (see D66309)
21688 if (!ST->isSimple() || ST->isIndexed())
21689 continue;
21690 const TypeSize StoreSize = ST->getMemoryVT().getStoreSize();
21691 // The bounds of a scalable store are not known until runtime, so this
21692 // store cannot be elided.
21693 if (StoreSize.isScalable())
21694 continue;
21695 const BaseIndexOffset StoreBase = BaseIndexOffset::match(N: ST, DAG);
21696 // If we store purely within object bounds just before its lifetime ends,
21697 // we can remove the store.
21698 if (LifetimeEndBase.contains(DAG, BitSize: LifetimeEnd->getSize() * 8, Other: StoreBase,
21699 OtherBitSize: StoreSize.getFixedValue() * 8)) {
21700 LLVM_DEBUG(dbgs() << "\nRemoving store:"; StoreBase.dump();
21701 dbgs() << "\nwithin LIFETIME_END of : ";
21702 LifetimeEndBase.dump(); dbgs() << "\n");
21703 CombineTo(N: ST, Res: ST->getChain());
21704 return SDValue(N, 0);
21705 }
21706 }
21707 }
21708 }
21709 return SDValue();
21710}
21711
21712/// For the instruction sequence of store below, F and I values
21713/// are bundled together as an i64 value before being stored into memory.
21714/// Sometimes it is more efficent to generate separate stores for F and I,
21715/// which can remove the bitwise instructions or sink them to colder places.
21716///
21717/// (store (or (zext (bitcast F to i32) to i64),
21718/// (shl (zext I to i64), 32)), addr) -->
21719/// (store F, addr) and (store I, addr+4)
21720///
21721/// Similarly, splitting for other merged store can also be beneficial, like:
21722/// For pair of {i32, i32}, i64 store --> two i32 stores.
21723/// For pair of {i32, i16}, i64 store --> two i32 stores.
21724/// For pair of {i16, i16}, i32 store --> two i16 stores.
21725/// For pair of {i16, i8}, i32 store --> two i16 stores.
21726/// For pair of {i8, i8}, i16 store --> two i8 stores.
21727///
21728/// We allow each target to determine specifically which kind of splitting is
21729/// supported.
21730///
21731/// The store patterns are commonly seen from the simple code snippet below
21732/// if only std::make_pair(...) is sroa transformed before inlined into hoo.
21733/// void goo(const std::pair<int, float> &);
21734/// hoo() {
21735/// ...
21736/// goo(std::make_pair(tmp, ftmp));
21737/// ...
21738/// }
21739///
21740SDValue DAGCombiner::splitMergedValStore(StoreSDNode *ST) {
21741 if (OptLevel == CodeGenOptLevel::None)
21742 return SDValue();
21743
21744 // Can't change the number of memory accesses for a volatile store or break
21745 // atomicity for an atomic one.
21746 if (!ST->isSimple())
21747 return SDValue();
21748
21749 SDValue Val = ST->getValue();
21750 SDLoc DL(ST);
21751
21752 // Match OR operand.
21753 if (!Val.getValueType().isScalarInteger() || Val.getOpcode() != ISD::OR)
21754 return SDValue();
21755
21756 // Match SHL operand and get Lower and Higher parts of Val.
21757 SDValue Op1 = Val.getOperand(i: 0);
21758 SDValue Op2 = Val.getOperand(i: 1);
21759 SDValue Lo, Hi;
21760 if (Op1.getOpcode() != ISD::SHL) {
21761 std::swap(a&: Op1, b&: Op2);
21762 if (Op1.getOpcode() != ISD::SHL)
21763 return SDValue();
21764 }
21765 Lo = Op2;
21766 Hi = Op1.getOperand(i: 0);
21767 if (!Op1.hasOneUse())
21768 return SDValue();
21769
21770 // Match shift amount to HalfValBitSize.
21771 unsigned HalfValBitSize = Val.getValueSizeInBits() / 2;
21772 ConstantSDNode *ShAmt = dyn_cast<ConstantSDNode>(Val: Op1.getOperand(i: 1));
21773 if (!ShAmt || ShAmt->getAPIntValue() != HalfValBitSize)
21774 return SDValue();
21775
21776 // Lo and Hi are zero-extended from int with size less equal than 32
21777 // to i64.
21778 if (Lo.getOpcode() != ISD::ZERO_EXTEND || !Lo.hasOneUse() ||
21779 !Lo.getOperand(i: 0).getValueType().isScalarInteger() ||
21780 Lo.getOperand(i: 0).getValueSizeInBits() > HalfValBitSize ||
21781 Hi.getOpcode() != ISD::ZERO_EXTEND || !Hi.hasOneUse() ||
21782 !Hi.getOperand(i: 0).getValueType().isScalarInteger() ||
21783 Hi.getOperand(i: 0).getValueSizeInBits() > HalfValBitSize)
21784 return SDValue();
21785
21786 // Use the EVT of low and high parts before bitcast as the input
21787 // of target query.
21788 EVT LowTy = (Lo.getOperand(i: 0).getOpcode() == ISD::BITCAST)
21789 ? Lo.getOperand(i: 0).getValueType()
21790 : Lo.getValueType();
21791 EVT HighTy = (Hi.getOperand(i: 0).getOpcode() == ISD::BITCAST)
21792 ? Hi.getOperand(i: 0).getValueType()
21793 : Hi.getValueType();
21794 if (!TLI.isMultiStoresCheaperThanBitsMerge(LTy: LowTy, HTy: HighTy))
21795 return SDValue();
21796
21797 // Start to split store.
21798 MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
21799 AAMDNodes AAInfo = ST->getAAInfo();
21800
21801 // Change the sizes of Lo and Hi's value types to HalfValBitSize.
21802 EVT VT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: HalfValBitSize);
21803 Lo = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: Lo.getOperand(i: 0));
21804 Hi = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: Hi.getOperand(i: 0));
21805
21806 SDValue Chain = ST->getChain();
21807 SDValue Ptr = ST->getBasePtr();
21808 // Lower value store.
21809 SDValue St0 = DAG.getStore(Chain, dl: DL, Val: Lo, Ptr, PtrInfo: ST->getPointerInfo(),
21810 Alignment: ST->getOriginalAlign(), MMOFlags, AAInfo);
21811 Ptr =
21812 DAG.getMemBasePlusOffset(Base: Ptr, Offset: TypeSize::getFixed(ExactSize: HalfValBitSize / 8), DL);
21813 // Higher value store.
21814 SDValue St1 = DAG.getStore(
21815 Chain: St0, dl: DL, Val: Hi, Ptr, PtrInfo: ST->getPointerInfo().getWithOffset(O: HalfValBitSize / 8),
21816 Alignment: ST->getOriginalAlign(), MMOFlags, AAInfo);
21817 return St1;
21818}
21819
21820// Merge an insertion into an existing shuffle:
21821// (insert_vector_elt (vector_shuffle X, Y, Mask),
21822// .(extract_vector_elt X, N), InsIndex)
21823// --> (vector_shuffle X, Y, NewMask)
21824// and variations where shuffle operands may be CONCAT_VECTORS.
21825static bool mergeEltWithShuffle(SDValue &X, SDValue &Y, ArrayRef<int> Mask,
21826 SmallVectorImpl<int> &NewMask, SDValue Elt,
21827 unsigned InsIndex) {
21828 if (Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
21829 !isa<ConstantSDNode>(Val: Elt.getOperand(i: 1)))
21830 return false;
21831
21832 // Vec's operand 0 is using indices from 0 to N-1 and
21833 // operand 1 from N to 2N - 1, where N is the number of
21834 // elements in the vectors.
21835 SDValue InsertVal0 = Elt.getOperand(i: 0);
21836 int ElementOffset = -1;
21837
21838 // We explore the inputs of the shuffle in order to see if we find the
21839 // source of the extract_vector_elt. If so, we can use it to modify the
21840 // shuffle rather than perform an insert_vector_elt.
21841 SmallVector<std::pair<int, SDValue>, 8> ArgWorkList;
21842 ArgWorkList.emplace_back(Args: Mask.size(), Args&: Y);
21843 ArgWorkList.emplace_back(Args: 0, Args&: X);
21844
21845 while (!ArgWorkList.empty()) {
21846 int ArgOffset;
21847 SDValue ArgVal;
21848 std::tie(args&: ArgOffset, args&: ArgVal) = ArgWorkList.pop_back_val();
21849
21850 if (ArgVal == InsertVal0) {
21851 ElementOffset = ArgOffset;
21852 break;
21853 }
21854
21855 // Peek through concat_vector.
21856 if (ArgVal.getOpcode() == ISD::CONCAT_VECTORS) {
21857 int CurrentArgOffset =
21858 ArgOffset + ArgVal.getValueType().getVectorNumElements();
21859 int Step = ArgVal.getOperand(i: 0).getValueType().getVectorNumElements();
21860 for (SDValue Op : reverse(C: ArgVal->ops())) {
21861 CurrentArgOffset -= Step;
21862 ArgWorkList.emplace_back(Args&: CurrentArgOffset, Args&: Op);
21863 }
21864
21865 // Make sure we went through all the elements and did not screw up index
21866 // computation.
21867 assert(CurrentArgOffset == ArgOffset);
21868 }
21869 }
21870
21871 // If we failed to find a match, see if we can replace an UNDEF shuffle
21872 // operand.
21873 if (ElementOffset == -1) {
21874 if (!Y.isUndef() || InsertVal0.getValueType() != Y.getValueType())
21875 return false;
21876 ElementOffset = Mask.size();
21877 Y = InsertVal0;
21878 }
21879
21880 NewMask.assign(in_start: Mask.begin(), in_end: Mask.end());
21881 NewMask[InsIndex] = ElementOffset + Elt.getConstantOperandVal(i: 1);
21882 assert(NewMask[InsIndex] < (int)(2 * Mask.size()) && NewMask[InsIndex] >= 0 &&
21883 "NewMask[InsIndex] is out of bound");
21884 return true;
21885}
21886
21887// Merge an insertion into an existing shuffle:
21888// (insert_vector_elt (vector_shuffle X, Y), (extract_vector_elt X, N),
21889// InsIndex)
21890// --> (vector_shuffle X, Y) and variations where shuffle operands may be
21891// CONCAT_VECTORS.
21892SDValue DAGCombiner::mergeInsertEltWithShuffle(SDNode *N, unsigned InsIndex) {
21893 assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT &&
21894 "Expected extract_vector_elt");
21895 SDValue InsertVal = N->getOperand(Num: 1);
21896 SDValue Vec = N->getOperand(Num: 0);
21897
21898 auto *SVN = dyn_cast<ShuffleVectorSDNode>(Val&: Vec);
21899 if (!SVN || !Vec.hasOneUse())
21900 return SDValue();
21901
21902 ArrayRef<int> Mask = SVN->getMask();
21903 SDValue X = Vec.getOperand(i: 0);
21904 SDValue Y = Vec.getOperand(i: 1);
21905
21906 SmallVector<int, 16> NewMask(Mask);
21907 if (mergeEltWithShuffle(X, Y, Mask, NewMask, Elt: InsertVal, InsIndex)) {
21908 SDValue LegalShuffle = TLI.buildLegalVectorShuffle(
21909 VT: Vec.getValueType(), DL: SDLoc(N), N0: X, N1: Y, Mask: NewMask, DAG);
21910 if (LegalShuffle)
21911 return LegalShuffle;
21912 }
21913
21914 return SDValue();
21915}
21916
21917// Convert a disguised subvector insertion into a shuffle:
21918// insert_vector_elt V, (bitcast X from vector type), IdxC -->
21919// bitcast(shuffle (bitcast V), (extended X), Mask)
21920// Note: We do not use an insert_subvector node because that requires a
21921// legal subvector type.
21922SDValue DAGCombiner::combineInsertEltToShuffle(SDNode *N, unsigned InsIndex) {
21923 assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT &&
21924 "Expected extract_vector_elt");
21925 SDValue InsertVal = N->getOperand(Num: 1);
21926
21927 if (InsertVal.getOpcode() != ISD::BITCAST || !InsertVal.hasOneUse() ||
21928 !InsertVal.getOperand(i: 0).getValueType().isVector())
21929 return SDValue();
21930
21931 SDValue SubVec = InsertVal.getOperand(i: 0);
21932 SDValue DestVec = N->getOperand(Num: 0);
21933 EVT SubVecVT = SubVec.getValueType();
21934 EVT VT = DestVec.getValueType();
21935 unsigned NumSrcElts = SubVecVT.getVectorNumElements();
21936 // If the source only has a single vector element, the cost of creating adding
21937 // it to a vector is likely to exceed the cost of a insert_vector_elt.
21938 if (NumSrcElts == 1)
21939 return SDValue();
21940 unsigned ExtendRatio = VT.getSizeInBits() / SubVecVT.getSizeInBits();
21941 unsigned NumMaskVals = ExtendRatio * NumSrcElts;
21942
21943 // Step 1: Create a shuffle mask that implements this insert operation. The
21944 // vector that we are inserting into will be operand 0 of the shuffle, so
21945 // those elements are just 'i'. The inserted subvector is in the first
21946 // positions of operand 1 of the shuffle. Example:
21947 // insert v4i32 V, (v2i16 X), 2 --> shuffle v8i16 V', X', {0,1,2,3,8,9,6,7}
21948 SmallVector<int, 16> Mask(NumMaskVals);
21949 for (unsigned i = 0; i != NumMaskVals; ++i) {
21950 if (i / NumSrcElts == InsIndex)
21951 Mask[i] = (i % NumSrcElts) + NumMaskVals;
21952 else
21953 Mask[i] = i;
21954 }
21955
21956 // Bail out if the target can not handle the shuffle we want to create.
21957 EVT SubVecEltVT = SubVecVT.getVectorElementType();
21958 EVT ShufVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: SubVecEltVT, NumElements: NumMaskVals);
21959 if (!TLI.isShuffleMaskLegal(Mask, ShufVT))
21960 return SDValue();
21961
21962 // Step 2: Create a wide vector from the inserted source vector by appending
21963 // undefined elements. This is the same size as our destination vector.
21964 SDLoc DL(N);
21965 SmallVector<SDValue, 8> ConcatOps(ExtendRatio, DAG.getUNDEF(VT: SubVecVT));
21966 ConcatOps[0] = SubVec;
21967 SDValue PaddedSubV = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT: ShufVT, Ops: ConcatOps);
21968
21969 // Step 3: Shuffle in the padded subvector.
21970 SDValue DestVecBC = DAG.getBitcast(VT: ShufVT, V: DestVec);
21971 SDValue Shuf = DAG.getVectorShuffle(VT: ShufVT, dl: DL, N1: DestVecBC, N2: PaddedSubV, Mask);
21972 AddToWorklist(N: PaddedSubV.getNode());
21973 AddToWorklist(N: DestVecBC.getNode());
21974 AddToWorklist(N: Shuf.getNode());
21975 return DAG.getBitcast(VT, V: Shuf);
21976}
21977
21978// Combine insert(shuffle(load, <u,0,1,2>), load, 0) into a single load if
21979// possible and the new load will be quick. We use more loads but less shuffles
21980// and inserts.
21981SDValue DAGCombiner::combineInsertEltToLoad(SDNode *N, unsigned InsIndex) {
21982 EVT VT = N->getValueType(ResNo: 0);
21983
21984 // InsIndex is expected to be the first of last lane.
21985 if (!VT.isFixedLengthVector() ||
21986 (InsIndex != 0 && InsIndex != VT.getVectorNumElements() - 1))
21987 return SDValue();
21988
21989 // Look for a shuffle with the mask u,0,1,2,3,4,5,6 or 1,2,3,4,5,6,7,u
21990 // depending on the InsIndex.
21991 auto *Shuffle = dyn_cast<ShuffleVectorSDNode>(Val: N->getOperand(Num: 0));
21992 SDValue Scalar = N->getOperand(Num: 1);
21993 if (!Shuffle || !all_of(Range: enumerate(First: Shuffle->getMask()), P: [&](auto P) {
21994 return InsIndex == P.index() || P.value() < 0 ||
21995 (InsIndex == 0 && P.value() == (int)P.index() - 1) ||
21996 (InsIndex == VT.getVectorNumElements() - 1 &&
21997 P.value() == (int)P.index() + 1);
21998 }))
21999 return SDValue();
22000
22001 // We optionally skip over an extend so long as both loads are extended in the
22002 // same way from the same type.
22003 unsigned Extend = 0;
22004 if (Scalar.getOpcode() == ISD::ZERO_EXTEND ||
22005 Scalar.getOpcode() == ISD::SIGN_EXTEND ||
22006 Scalar.getOpcode() == ISD::ANY_EXTEND) {
22007 Extend = Scalar.getOpcode();
22008 Scalar = Scalar.getOperand(i: 0);
22009 }
22010
22011 auto *ScalarLoad = dyn_cast<LoadSDNode>(Val&: Scalar);
22012 if (!ScalarLoad)
22013 return SDValue();
22014
22015 SDValue Vec = Shuffle->getOperand(Num: 0);
22016 if (Extend) {
22017 if (Vec.getOpcode() != Extend)
22018 return SDValue();
22019 Vec = Vec.getOperand(i: 0);
22020 }
22021 auto *VecLoad = dyn_cast<LoadSDNode>(Val&: Vec);
22022 if (!VecLoad || Vec.getValueType().getScalarType() != Scalar.getValueType())
22023 return SDValue();
22024
22025 int EltSize = ScalarLoad->getValueType(ResNo: 0).getScalarSizeInBits();
22026 if (EltSize == 0 || EltSize % 8 != 0 || !ScalarLoad->isSimple() ||
22027 !VecLoad->isSimple() || VecLoad->getExtensionType() != ISD::NON_EXTLOAD ||
22028 ScalarLoad->getExtensionType() != ISD::NON_EXTLOAD ||
22029 ScalarLoad->getAddressSpace() != VecLoad->getAddressSpace())
22030 return SDValue();
22031
22032 // Check that the offset between the pointers to produce a single continuous
22033 // load.
22034 if (InsIndex == 0) {
22035 if (!DAG.areNonVolatileConsecutiveLoads(LD: ScalarLoad, Base: VecLoad, Bytes: EltSize / 8,
22036 Dist: -1))
22037 return SDValue();
22038 } else {
22039 if (!DAG.areNonVolatileConsecutiveLoads(
22040 LD: VecLoad, Base: ScalarLoad, Bytes: VT.getVectorNumElements() * EltSize / 8, Dist: -1))
22041 return SDValue();
22042 }
22043
22044 // And that the new unaligned load will be fast.
22045 unsigned IsFast = 0;
22046 Align NewAlign = commonAlignment(A: VecLoad->getAlign(), Offset: EltSize / 8);
22047 if (!TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(),
22048 VT: Vec.getValueType(), AddrSpace: VecLoad->getAddressSpace(),
22049 Alignment: NewAlign, Flags: VecLoad->getMemOperand()->getFlags(),
22050 Fast: &IsFast) ||
22051 !IsFast)
22052 return SDValue();
22053
22054 // Calculate the new Ptr and create the new load.
22055 SDLoc DL(N);
22056 SDValue Ptr = ScalarLoad->getBasePtr();
22057 if (InsIndex != 0)
22058 Ptr = DAG.getNode(Opcode: ISD::ADD, DL, VT: Ptr.getValueType(), N1: VecLoad->getBasePtr(),
22059 N2: DAG.getConstant(Val: EltSize / 8, DL, VT: Ptr.getValueType()));
22060 MachinePointerInfo PtrInfo =
22061 InsIndex == 0 ? ScalarLoad->getPointerInfo()
22062 : VecLoad->getPointerInfo().getWithOffset(O: EltSize / 8);
22063
22064 SDValue Load = DAG.getLoad(VT: VecLoad->getValueType(ResNo: 0), dl: DL,
22065 Chain: ScalarLoad->getChain(), Ptr, PtrInfo, Alignment: NewAlign);
22066 DAG.makeEquivalentMemoryOrdering(OldLoad: ScalarLoad, NewMemOp: Load.getValue(R: 1));
22067 DAG.makeEquivalentMemoryOrdering(OldLoad: VecLoad, NewMemOp: Load.getValue(R: 1));
22068 return Extend ? DAG.getNode(Opcode: Extend, DL, VT, Operand: Load) : Load;
22069}
22070
22071SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
22072 SDValue InVec = N->getOperand(Num: 0);
22073 SDValue InVal = N->getOperand(Num: 1);
22074 SDValue EltNo = N->getOperand(Num: 2);
22075 SDLoc DL(N);
22076
22077 EVT VT = InVec.getValueType();
22078 auto *IndexC = dyn_cast<ConstantSDNode>(Val&: EltNo);
22079
22080 // Insert into out-of-bounds element is undefined.
22081 if (IndexC && VT.isFixedLengthVector() &&
22082 IndexC->getZExtValue() >= VT.getVectorNumElements())
22083 return DAG.getUNDEF(VT);
22084
22085 // Remove redundant insertions:
22086 // (insert_vector_elt x (extract_vector_elt x idx) idx) -> x
22087 if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
22088 InVec == InVal.getOperand(i: 0) && EltNo == InVal.getOperand(i: 1))
22089 return InVec;
22090
22091 if (!IndexC) {
22092 // If this is variable insert to undef vector, it might be better to splat:
22093 // inselt undef, InVal, EltNo --> build_vector < InVal, InVal, ... >
22094 if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT))
22095 return DAG.getSplat(VT, DL, Op: InVal);
22096 return SDValue();
22097 }
22098
22099 if (VT.isScalableVector())
22100 return SDValue();
22101
22102 unsigned NumElts = VT.getVectorNumElements();
22103
22104 // We must know which element is being inserted for folds below here.
22105 unsigned Elt = IndexC->getZExtValue();
22106
22107 // Handle <1 x ???> vector insertion special cases.
22108 if (NumElts == 1) {
22109 // insert_vector_elt(x, extract_vector_elt(y, 0), 0) -> y
22110 if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
22111 InVal.getOperand(i: 0).getValueType() == VT &&
22112 isNullConstant(V: InVal.getOperand(i: 1)))
22113 return InVal.getOperand(i: 0);
22114 }
22115
22116 // Canonicalize insert_vector_elt dag nodes.
22117 // Example:
22118 // (insert_vector_elt (insert_vector_elt A, Idx0), Idx1)
22119 // -> (insert_vector_elt (insert_vector_elt A, Idx1), Idx0)
22120 //
22121 // Do this only if the child insert_vector node has one use; also
22122 // do this only if indices are both constants and Idx1 < Idx0.
22123 if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT && InVec.hasOneUse()
22124 && isa<ConstantSDNode>(Val: InVec.getOperand(i: 2))) {
22125 unsigned OtherElt = InVec.getConstantOperandVal(i: 2);
22126 if (Elt < OtherElt) {
22127 // Swap nodes.
22128 SDValue NewOp = DAG.getNode(Opcode: ISD::INSERT_VECTOR_ELT, DL, VT,
22129 N1: InVec.getOperand(i: 0), N2: InVal, N3: EltNo);
22130 AddToWorklist(N: NewOp.getNode());
22131 return DAG.getNode(Opcode: ISD::INSERT_VECTOR_ELT, DL: SDLoc(InVec.getNode()),
22132 VT, N1: NewOp, N2: InVec.getOperand(i: 1), N3: InVec.getOperand(i: 2));
22133 }
22134 }
22135
22136 if (SDValue Shuf = mergeInsertEltWithShuffle(N, InsIndex: Elt))
22137 return Shuf;
22138
22139 if (SDValue Shuf = combineInsertEltToShuffle(N, InsIndex: Elt))
22140 return Shuf;
22141
22142 if (SDValue Shuf = combineInsertEltToLoad(N, InsIndex: Elt))
22143 return Shuf;
22144
22145 // Attempt to convert an insert_vector_elt chain into a legal build_vector.
22146 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::BUILD_VECTOR, VT)) {
22147 // vXi1 vector - we don't need to recurse.
22148 if (NumElts == 1)
22149 return DAG.getBuildVector(VT, DL, Ops: {InVal});
22150
22151 // If we haven't already collected the element, insert into the op list.
22152 EVT MaxEltVT = InVal.getValueType();
22153 auto AddBuildVectorOp = [&](SmallVectorImpl<SDValue> &Ops, SDValue Elt,
22154 unsigned Idx) {
22155 if (!Ops[Idx]) {
22156 Ops[Idx] = Elt;
22157 if (VT.isInteger()) {
22158 EVT EltVT = Elt.getValueType();
22159 MaxEltVT = MaxEltVT.bitsGE(VT: EltVT) ? MaxEltVT : EltVT;
22160 }
22161 }
22162 };
22163
22164 // Ensure all the operands are the same value type, fill any missing
22165 // operands with UNDEF and create the BUILD_VECTOR.
22166 auto CanonicalizeBuildVector = [&](SmallVectorImpl<SDValue> &Ops) {
22167 assert(Ops.size() == NumElts && "Unexpected vector size");
22168 for (SDValue &Op : Ops) {
22169 if (Op)
22170 Op = VT.isInteger() ? DAG.getAnyExtOrTrunc(Op, DL, VT: MaxEltVT) : Op;
22171 else
22172 Op = DAG.getUNDEF(VT: MaxEltVT);
22173 }
22174 return DAG.getBuildVector(VT, DL, Ops);
22175 };
22176
22177 SmallVector<SDValue, 8> Ops(NumElts, SDValue());
22178 Ops[Elt] = InVal;
22179
22180 // Recurse up a INSERT_VECTOR_ELT chain to build a BUILD_VECTOR.
22181 for (SDValue CurVec = InVec; CurVec;) {
22182 // UNDEF - build new BUILD_VECTOR from already inserted operands.
22183 if (CurVec.isUndef())
22184 return CanonicalizeBuildVector(Ops);
22185
22186 // BUILD_VECTOR - insert unused operands and build new BUILD_VECTOR.
22187 if (CurVec.getOpcode() == ISD::BUILD_VECTOR && CurVec.hasOneUse()) {
22188 for (unsigned I = 0; I != NumElts; ++I)
22189 AddBuildVectorOp(Ops, CurVec.getOperand(i: I), I);
22190 return CanonicalizeBuildVector(Ops);
22191 }
22192
22193 // SCALAR_TO_VECTOR - insert unused scalar and build new BUILD_VECTOR.
22194 if (CurVec.getOpcode() == ISD::SCALAR_TO_VECTOR && CurVec.hasOneUse()) {
22195 AddBuildVectorOp(Ops, CurVec.getOperand(i: 0), 0);
22196 return CanonicalizeBuildVector(Ops);
22197 }
22198
22199 // INSERT_VECTOR_ELT - insert operand and continue up the chain.
22200 if (CurVec.getOpcode() == ISD::INSERT_VECTOR_ELT && CurVec.hasOneUse())
22201 if (auto *CurIdx = dyn_cast<ConstantSDNode>(Val: CurVec.getOperand(i: 2)))
22202 if (CurIdx->getAPIntValue().ult(RHS: NumElts)) {
22203 unsigned Idx = CurIdx->getZExtValue();
22204 AddBuildVectorOp(Ops, CurVec.getOperand(i: 1), Idx);
22205
22206 // Found entire BUILD_VECTOR.
22207 if (all_of(Range&: Ops, P: [](SDValue Op) { return !!Op; }))
22208 return CanonicalizeBuildVector(Ops);
22209
22210 CurVec = CurVec->getOperand(Num: 0);
22211 continue;
22212 }
22213
22214 // VECTOR_SHUFFLE - if all the operands match the shuffle's sources,
22215 // update the shuffle mask (and second operand if we started with unary
22216 // shuffle) and create a new legal shuffle.
22217 if (CurVec.getOpcode() == ISD::VECTOR_SHUFFLE && CurVec.hasOneUse()) {
22218 auto *SVN = cast<ShuffleVectorSDNode>(Val&: CurVec);
22219 SDValue LHS = SVN->getOperand(Num: 0);
22220 SDValue RHS = SVN->getOperand(Num: 1);
22221 SmallVector<int, 16> Mask(SVN->getMask());
22222 bool Merged = true;
22223 for (auto I : enumerate(First&: Ops)) {
22224 SDValue &Op = I.value();
22225 if (Op) {
22226 SmallVector<int, 16> NewMask;
22227 if (!mergeEltWithShuffle(X&: LHS, Y&: RHS, Mask, NewMask, Elt: Op, InsIndex: I.index())) {
22228 Merged = false;
22229 break;
22230 }
22231 Mask = std::move(NewMask);
22232 }
22233 }
22234 if (Merged)
22235 if (SDValue NewShuffle =
22236 TLI.buildLegalVectorShuffle(VT, DL, N0: LHS, N1: RHS, Mask, DAG))
22237 return NewShuffle;
22238 }
22239
22240 // If all insertions are zero value, try to convert to AND mask.
22241 // TODO: Do this for -1 with OR mask?
22242 if (!LegalOperations && llvm::isNullConstant(V: InVal) &&
22243 all_of(Range&: Ops, P: [InVal](SDValue Op) { return !Op || Op == InVal; }) &&
22244 count_if(Range&: Ops, P: [InVal](SDValue Op) { return Op == InVal; }) >= 2) {
22245 SDValue Zero = DAG.getConstant(Val: 0, DL, VT: MaxEltVT);
22246 SDValue AllOnes = DAG.getAllOnesConstant(DL, VT: MaxEltVT);
22247 SmallVector<SDValue, 8> Mask(NumElts);
22248 for (unsigned I = 0; I != NumElts; ++I)
22249 Mask[I] = Ops[I] ? Zero : AllOnes;
22250 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: CurVec,
22251 N2: DAG.getBuildVector(VT, DL, Ops: Mask));
22252 }
22253
22254 // Failed to find a match in the chain - bail.
22255 break;
22256 }
22257
22258 // See if we can fill in the missing constant elements as zeros.
22259 // TODO: Should we do this for any constant?
22260 APInt DemandedZeroElts = APInt::getZero(numBits: NumElts);
22261 for (unsigned I = 0; I != NumElts; ++I)
22262 if (!Ops[I])
22263 DemandedZeroElts.setBit(I);
22264
22265 if (DAG.MaskedVectorIsZero(Op: InVec, DemandedElts: DemandedZeroElts)) {
22266 SDValue Zero = VT.isInteger() ? DAG.getConstant(Val: 0, DL, VT: MaxEltVT)
22267 : DAG.getConstantFP(Val: 0, DL, VT: MaxEltVT);
22268 for (unsigned I = 0; I != NumElts; ++I)
22269 if (!Ops[I])
22270 Ops[I] = Zero;
22271
22272 return CanonicalizeBuildVector(Ops);
22273 }
22274 }
22275
22276 return SDValue();
22277}
22278
22279SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
22280 SDValue EltNo,
22281 LoadSDNode *OriginalLoad) {
22282 assert(OriginalLoad->isSimple());
22283
22284 EVT ResultVT = EVE->getValueType(ResNo: 0);
22285 EVT VecEltVT = InVecVT.getVectorElementType();
22286
22287 // If the vector element type is not a multiple of a byte then we are unable
22288 // to correctly compute an address to load only the extracted element as a
22289 // scalar.
22290 if (!VecEltVT.isByteSized())
22291 return SDValue();
22292
22293 ISD::LoadExtType ExtTy =
22294 ResultVT.bitsGT(VT: VecEltVT) ? ISD::NON_EXTLOAD : ISD::EXTLOAD;
22295 if (!TLI.isOperationLegalOrCustom(Op: ISD::LOAD, VT: VecEltVT) ||
22296 !TLI.shouldReduceLoadWidth(Load: OriginalLoad, ExtTy, NewVT: VecEltVT))
22297 return SDValue();
22298
22299 Align Alignment = OriginalLoad->getAlign();
22300 MachinePointerInfo MPI;
22301 SDLoc DL(EVE);
22302 if (auto *ConstEltNo = dyn_cast<ConstantSDNode>(Val&: EltNo)) {
22303 int Elt = ConstEltNo->getZExtValue();
22304 unsigned PtrOff = VecEltVT.getSizeInBits() * Elt / 8;
22305 MPI = OriginalLoad->getPointerInfo().getWithOffset(O: PtrOff);
22306 Alignment = commonAlignment(A: Alignment, Offset: PtrOff);
22307 } else {
22308 // Discard the pointer info except the address space because the memory
22309 // operand can't represent this new access since the offset is variable.
22310 MPI = MachinePointerInfo(OriginalLoad->getPointerInfo().getAddrSpace());
22311 Alignment = commonAlignment(A: Alignment, Offset: VecEltVT.getSizeInBits() / 8);
22312 }
22313
22314 unsigned IsFast = 0;
22315 if (!TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT: VecEltVT,
22316 AddrSpace: OriginalLoad->getAddressSpace(), Alignment,
22317 Flags: OriginalLoad->getMemOperand()->getFlags(),
22318 Fast: &IsFast) ||
22319 !IsFast)
22320 return SDValue();
22321
22322 SDValue NewPtr = TLI.getVectorElementPointer(DAG, VecPtr: OriginalLoad->getBasePtr(),
22323 VecVT: InVecVT, Index: EltNo);
22324
22325 // We are replacing a vector load with a scalar load. The new load must have
22326 // identical memory op ordering to the original.
22327 SDValue Load;
22328 if (ResultVT.bitsGT(VT: VecEltVT)) {
22329 // If the result type of vextract is wider than the load, then issue an
22330 // extending load instead.
22331 ISD::LoadExtType ExtType =
22332 TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: ResultVT, MemVT: VecEltVT) ? ISD::ZEXTLOAD
22333 : ISD::EXTLOAD;
22334 Load = DAG.getExtLoad(ExtType, dl: DL, VT: ResultVT, Chain: OriginalLoad->getChain(),
22335 Ptr: NewPtr, PtrInfo: MPI, MemVT: VecEltVT, Alignment,
22336 MMOFlags: OriginalLoad->getMemOperand()->getFlags(),
22337 AAInfo: OriginalLoad->getAAInfo());
22338 DAG.makeEquivalentMemoryOrdering(OldLoad: OriginalLoad, NewMemOp: Load);
22339 } else {
22340 // The result type is narrower or the same width as the vector element
22341 Load = DAG.getLoad(VT: VecEltVT, dl: DL, Chain: OriginalLoad->getChain(), Ptr: NewPtr, PtrInfo: MPI,
22342 Alignment, MMOFlags: OriginalLoad->getMemOperand()->getFlags(),
22343 AAInfo: OriginalLoad->getAAInfo());
22344 DAG.makeEquivalentMemoryOrdering(OldLoad: OriginalLoad, NewMemOp: Load);
22345 if (ResultVT.bitsLT(VT: VecEltVT))
22346 Load = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: ResultVT, Operand: Load);
22347 else
22348 Load = DAG.getBitcast(VT: ResultVT, V: Load);
22349 }
22350 ++OpsNarrowed;
22351 return Load;
22352}
22353
22354/// Transform a vector binary operation into a scalar binary operation by moving
22355/// the math/logic after an extract element of a vector.
22356static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG,
22357 const SDLoc &DL, bool LegalOperations) {
22358 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
22359 SDValue Vec = ExtElt->getOperand(Num: 0);
22360 SDValue Index = ExtElt->getOperand(Num: 1);
22361 auto *IndexC = dyn_cast<ConstantSDNode>(Val&: Index);
22362 if (!IndexC || !TLI.isBinOp(Opcode: Vec.getOpcode()) || !Vec.hasOneUse() ||
22363 Vec->getNumValues() != 1)
22364 return SDValue();
22365
22366 // Targets may want to avoid this to prevent an expensive register transfer.
22367 if (!TLI.shouldScalarizeBinop(VecOp: Vec))
22368 return SDValue();
22369
22370 // Extracting an element of a vector constant is constant-folded, so this
22371 // transform is just replacing a vector op with a scalar op while moving the
22372 // extract.
22373 SDValue Op0 = Vec.getOperand(i: 0);
22374 SDValue Op1 = Vec.getOperand(i: 1);
22375 APInt SplatVal;
22376 if (isAnyConstantBuildVector(V: Op0, NoOpaques: true) ||
22377 ISD::isConstantSplatVector(N: Op0.getNode(), SplatValue&: SplatVal) ||
22378 isAnyConstantBuildVector(V: Op1, NoOpaques: true) ||
22379 ISD::isConstantSplatVector(N: Op1.getNode(), SplatValue&: SplatVal)) {
22380 // extractelt (binop X, C), IndexC --> binop (extractelt X, IndexC), C'
22381 // extractelt (binop C, X), IndexC --> binop C', (extractelt X, IndexC)
22382 EVT VT = ExtElt->getValueType(ResNo: 0);
22383 SDValue Ext0 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT, N1: Op0, N2: Index);
22384 SDValue Ext1 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT, N1: Op1, N2: Index);
22385 return DAG.getNode(Opcode: Vec.getOpcode(), DL, VT, N1: Ext0, N2: Ext1);
22386 }
22387
22388 return SDValue();
22389}
22390
22391// Given a ISD::EXTRACT_VECTOR_ELT, which is a glorified bit sequence extract,
22392// recursively analyse all of it's users. and try to model themselves as
22393// bit sequence extractions. If all of them agree on the new, narrower element
22394// type, and all of them can be modelled as ISD::EXTRACT_VECTOR_ELT's of that
22395// new element type, do so now.
22396// This is mainly useful to recover from legalization that scalarized
22397// the vector as wide elements, but tries to rebuild it with narrower elements.
22398//
22399// Some more nodes could be modelled if that helps cover interesting patterns.
22400bool DAGCombiner::refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(
22401 SDNode *N) {
22402 // We perform this optimization post type-legalization because
22403 // the type-legalizer often scalarizes integer-promoted vectors.
22404 // Performing this optimization before may cause legalizaton cycles.
22405 if (Level != AfterLegalizeVectorOps && Level != AfterLegalizeTypes)
22406 return false;
22407
22408 // TODO: Add support for big-endian.
22409 if (DAG.getDataLayout().isBigEndian())
22410 return false;
22411
22412 SDValue VecOp = N->getOperand(Num: 0);
22413 EVT VecVT = VecOp.getValueType();
22414 assert(!VecVT.isScalableVector() && "Only for fixed vectors.");
22415
22416 // We must start with a constant extraction index.
22417 auto *IndexC = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
22418 if (!IndexC)
22419 return false;
22420
22421 assert(IndexC->getZExtValue() < VecVT.getVectorNumElements() &&
22422 "Original ISD::EXTRACT_VECTOR_ELT is undefinend?");
22423
22424 // TODO: deal with the case of implicit anyext of the extraction.
22425 unsigned VecEltBitWidth = VecVT.getScalarSizeInBits();
22426 EVT ScalarVT = N->getValueType(ResNo: 0);
22427 if (VecVT.getScalarType() != ScalarVT)
22428 return false;
22429
22430 // TODO: deal with the cases other than everything being integer-typed.
22431 if (!ScalarVT.isScalarInteger())
22432 return false;
22433
22434 struct Entry {
22435 SDNode *Producer;
22436
22437 // Which bits of VecOp does it contain?
22438 unsigned BitPos;
22439 int NumBits;
22440 // NOTE: the actual width of \p Producer may be wider than NumBits!
22441
22442 Entry(Entry &&) = default;
22443 Entry(SDNode *Producer_, unsigned BitPos_, int NumBits_)
22444 : Producer(Producer_), BitPos(BitPos_), NumBits(NumBits_) {}
22445
22446 Entry() = delete;
22447 Entry(const Entry &) = delete;
22448 Entry &operator=(const Entry &) = delete;
22449 Entry &operator=(Entry &&) = delete;
22450 };
22451 SmallVector<Entry, 32> Worklist;
22452 SmallVector<Entry, 32> Leafs;
22453
22454 // We start at the "root" ISD::EXTRACT_VECTOR_ELT.
22455 Worklist.emplace_back(Args&: N, /*BitPos=*/Args: VecEltBitWidth * IndexC->getZExtValue(),
22456 /*NumBits=*/Args&: VecEltBitWidth);
22457
22458 while (!Worklist.empty()) {
22459 Entry E = Worklist.pop_back_val();
22460 // Does the node not even use any of the VecOp bits?
22461 if (!(E.NumBits > 0 && E.BitPos < VecVT.getSizeInBits() &&
22462 E.BitPos + E.NumBits <= VecVT.getSizeInBits()))
22463 return false; // Let's allow the other combines clean this up first.
22464 // Did we fail to model any of the users of the Producer?
22465 bool ProducerIsLeaf = false;
22466 // Look at each user of this Producer.
22467 for (SDNode *User : E.Producer->uses()) {
22468 switch (User->getOpcode()) {
22469 // TODO: support ISD::BITCAST
22470 // TODO: support ISD::ANY_EXTEND
22471 // TODO: support ISD::ZERO_EXTEND
22472 // TODO: support ISD::SIGN_EXTEND
22473 case ISD::TRUNCATE:
22474 // Truncation simply means we keep position, but extract less bits.
22475 Worklist.emplace_back(Args&: User, Args&: E.BitPos,
22476 /*NumBits=*/Args: User->getValueSizeInBits(ResNo: 0));
22477 break;
22478 // TODO: support ISD::SRA
22479 // TODO: support ISD::SHL
22480 case ISD::SRL:
22481 // We should be shifting the Producer by a constant amount.
22482 if (auto *ShAmtC = dyn_cast<ConstantSDNode>(Val: User->getOperand(Num: 1));
22483 User->getOperand(Num: 0).getNode() == E.Producer && ShAmtC) {
22484 // Logical right-shift means that we start extraction later,
22485 // but stop it at the same position we did previously.
22486 unsigned ShAmt = ShAmtC->getZExtValue();
22487 Worklist.emplace_back(Args&: User, Args: E.BitPos + ShAmt, Args: E.NumBits - ShAmt);
22488 break;
22489 }
22490 [[fallthrough]];
22491 default:
22492 // We can not model this user of the Producer.
22493 // Which means the current Producer will be a ISD::EXTRACT_VECTOR_ELT.
22494 ProducerIsLeaf = true;
22495 // Profitability check: all users that we can not model
22496 // must be ISD::BUILD_VECTOR's.
22497 if (User->getOpcode() != ISD::BUILD_VECTOR)
22498 return false;
22499 break;
22500 }
22501 }
22502 if (ProducerIsLeaf)
22503 Leafs.emplace_back(Args: std::move(E));
22504 }
22505
22506 unsigned NewVecEltBitWidth = Leafs.front().NumBits;
22507
22508 // If we are still at the same element granularity, give up,
22509 if (NewVecEltBitWidth == VecEltBitWidth)
22510 return false;
22511
22512 // The vector width must be a multiple of the new element width.
22513 if (VecVT.getSizeInBits() % NewVecEltBitWidth != 0)
22514 return false;
22515
22516 // All leafs must agree on the new element width.
22517 // All leafs must not expect any "padding" bits ontop of that width.
22518 // All leafs must start extraction from multiple of that width.
22519 if (!all_of(Range&: Leafs, P: [NewVecEltBitWidth](const Entry &E) {
22520 return (unsigned)E.NumBits == NewVecEltBitWidth &&
22521 E.Producer->getValueSizeInBits(ResNo: 0) == NewVecEltBitWidth &&
22522 E.BitPos % NewVecEltBitWidth == 0;
22523 }))
22524 return false;
22525
22526 EVT NewScalarVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: NewVecEltBitWidth);
22527 EVT NewVecVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: NewScalarVT,
22528 NumElements: VecVT.getSizeInBits() / NewVecEltBitWidth);
22529
22530 if (LegalTypes &&
22531 !(TLI.isTypeLegal(VT: NewScalarVT) && TLI.isTypeLegal(VT: NewVecVT)))
22532 return false;
22533
22534 if (LegalOperations &&
22535 !(TLI.isOperationLegalOrCustom(Op: ISD::BITCAST, VT: NewVecVT) &&
22536 TLI.isOperationLegalOrCustom(Op: ISD::EXTRACT_VECTOR_ELT, VT: NewVecVT)))
22537 return false;
22538
22539 SDValue NewVecOp = DAG.getBitcast(VT: NewVecVT, V: VecOp);
22540 for (const Entry &E : Leafs) {
22541 SDLoc DL(E.Producer);
22542 unsigned NewIndex = E.BitPos / NewVecEltBitWidth;
22543 assert(NewIndex < NewVecVT.getVectorNumElements() &&
22544 "Creating out-of-bounds ISD::EXTRACT_VECTOR_ELT?");
22545 SDValue V = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: NewScalarVT, N1: NewVecOp,
22546 N2: DAG.getVectorIdxConstant(Val: NewIndex, DL));
22547 CombineTo(N: E.Producer, Res: V);
22548 }
22549
22550 return true;
22551}
22552
22553SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
22554 SDValue VecOp = N->getOperand(Num: 0);
22555 SDValue Index = N->getOperand(Num: 1);
22556 EVT ScalarVT = N->getValueType(ResNo: 0);
22557 EVT VecVT = VecOp.getValueType();
22558 if (VecOp.isUndef())
22559 return DAG.getUNDEF(VT: ScalarVT);
22560
22561 // extract_vector_elt (insert_vector_elt vec, val, idx), idx) -> val
22562 //
22563 // This only really matters if the index is non-constant since other combines
22564 // on the constant elements already work.
22565 SDLoc DL(N);
22566 if (VecOp.getOpcode() == ISD::INSERT_VECTOR_ELT &&
22567 Index == VecOp.getOperand(i: 2)) {
22568 SDValue Elt = VecOp.getOperand(i: 1);
22569 return VecVT.isInteger() ? DAG.getAnyExtOrTrunc(Op: Elt, DL, VT: ScalarVT) : Elt;
22570 }
22571
22572 // (vextract (scalar_to_vector val, 0) -> val
22573 if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR) {
22574 // Only 0'th element of SCALAR_TO_VECTOR is defined.
22575 if (DAG.isKnownNeverZero(Op: Index))
22576 return DAG.getUNDEF(VT: ScalarVT);
22577
22578 // Check if the result type doesn't match the inserted element type.
22579 // The inserted element and extracted element may have mismatched bitwidth.
22580 // As a result, EXTRACT_VECTOR_ELT may extend or truncate the extracted vector.
22581 SDValue InOp = VecOp.getOperand(i: 0);
22582 if (InOp.getValueType() != ScalarVT) {
22583 assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
22584 if (InOp.getValueType().bitsGT(VT: ScalarVT))
22585 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: ScalarVT, Operand: InOp);
22586 return DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT: ScalarVT, Operand: InOp);
22587 }
22588 return InOp;
22589 }
22590
22591 // extract_vector_elt of out-of-bounds element -> UNDEF
22592 auto *IndexC = dyn_cast<ConstantSDNode>(Val&: Index);
22593 if (IndexC && VecVT.isFixedLengthVector() &&
22594 IndexC->getAPIntValue().uge(RHS: VecVT.getVectorNumElements()))
22595 return DAG.getUNDEF(VT: ScalarVT);
22596
22597 // extract_vector_elt (build_vector x, y), 1 -> y
22598 if (((IndexC && VecOp.getOpcode() == ISD::BUILD_VECTOR) ||
22599 VecOp.getOpcode() == ISD::SPLAT_VECTOR) &&
22600 TLI.isTypeLegal(VT: VecVT)) {
22601 assert((VecOp.getOpcode() != ISD::BUILD_VECTOR ||
22602 VecVT.isFixedLengthVector()) &&
22603 "BUILD_VECTOR used for scalable vectors");
22604 unsigned IndexVal =
22605 VecOp.getOpcode() == ISD::BUILD_VECTOR ? IndexC->getZExtValue() : 0;
22606 SDValue Elt = VecOp.getOperand(i: IndexVal);
22607 EVT InEltVT = Elt.getValueType();
22608
22609 if (VecOp.hasOneUse() || TLI.aggressivelyPreferBuildVectorSources(VecVT) ||
22610 isNullConstant(V: Elt)) {
22611 // Sometimes build_vector's scalar input types do not match result type.
22612 if (ScalarVT == InEltVT)
22613 return Elt;
22614
22615 // TODO: It may be useful to truncate if free if the build_vector
22616 // implicitly converts.
22617 }
22618 }
22619
22620 if (SDValue BO = scalarizeExtractedBinop(ExtElt: N, DAG, DL, LegalOperations))
22621 return BO;
22622
22623 if (VecVT.isScalableVector())
22624 return SDValue();
22625
22626 // All the code from this point onwards assumes fixed width vectors, but it's
22627 // possible that some of the combinations could be made to work for scalable
22628 // vectors too.
22629 unsigned NumElts = VecVT.getVectorNumElements();
22630 unsigned VecEltBitWidth = VecVT.getScalarSizeInBits();
22631
22632 // See if the extracted element is constant, in which case fold it if its
22633 // a legal fp immediate.
22634 if (IndexC && ScalarVT.isFloatingPoint()) {
22635 APInt EltMask = APInt::getOneBitSet(numBits: NumElts, BitNo: IndexC->getZExtValue());
22636 KnownBits KnownElt = DAG.computeKnownBits(Op: VecOp, DemandedElts: EltMask);
22637 if (KnownElt.isConstant()) {
22638 APFloat CstFP =
22639 APFloat(DAG.EVTToAPFloatSemantics(VT: ScalarVT), KnownElt.getConstant());
22640 if (TLI.isFPImmLegal(CstFP, ScalarVT))
22641 return DAG.getConstantFP(Val: CstFP, DL, VT: ScalarVT);
22642 }
22643 }
22644
22645 // TODO: These transforms should not require the 'hasOneUse' restriction, but
22646 // there are regressions on multiple targets without it. We can end up with a
22647 // mess of scalar and vector code if we reduce only part of the DAG to scalar.
22648 if (IndexC && VecOp.getOpcode() == ISD::BITCAST && VecVT.isInteger() &&
22649 VecOp.hasOneUse()) {
22650 // The vector index of the LSBs of the source depend on the endian-ness.
22651 bool IsLE = DAG.getDataLayout().isLittleEndian();
22652 unsigned ExtractIndex = IndexC->getZExtValue();
22653 // extract_elt (v2i32 (bitcast i64:x)), BCTruncElt -> i32 (trunc i64:x)
22654 unsigned BCTruncElt = IsLE ? 0 : NumElts - 1;
22655 SDValue BCSrc = VecOp.getOperand(i: 0);
22656 if (ExtractIndex == BCTruncElt && BCSrc.getValueType().isScalarInteger())
22657 return DAG.getAnyExtOrTrunc(Op: BCSrc, DL, VT: ScalarVT);
22658
22659 if (LegalTypes && BCSrc.getValueType().isInteger() &&
22660 BCSrc.getOpcode() == ISD::SCALAR_TO_VECTOR) {
22661 // ext_elt (bitcast (scalar_to_vec i64 X to v2i64) to v4i32), TruncElt -->
22662 // trunc i64 X to i32
22663 SDValue X = BCSrc.getOperand(i: 0);
22664 assert(X.getValueType().isScalarInteger() && ScalarVT.isScalarInteger() &&
22665 "Extract element and scalar to vector can't change element type "
22666 "from FP to integer.");
22667 unsigned XBitWidth = X.getValueSizeInBits();
22668 BCTruncElt = IsLE ? 0 : XBitWidth / VecEltBitWidth - 1;
22669
22670 // An extract element return value type can be wider than its vector
22671 // operand element type. In that case, the high bits are undefined, so
22672 // it's possible that we may need to extend rather than truncate.
22673 if (ExtractIndex == BCTruncElt && XBitWidth > VecEltBitWidth) {
22674 assert(XBitWidth % VecEltBitWidth == 0 &&
22675 "Scalar bitwidth must be a multiple of vector element bitwidth");
22676 return DAG.getAnyExtOrTrunc(Op: X, DL, VT: ScalarVT);
22677 }
22678 }
22679 }
22680
22681 // Transform: (EXTRACT_VECTOR_ELT( VECTOR_SHUFFLE )) -> EXTRACT_VECTOR_ELT.
22682 // We only perform this optimization before the op legalization phase because
22683 // we may introduce new vector instructions which are not backed by TD
22684 // patterns. For example on AVX, extracting elements from a wide vector
22685 // without using extract_subvector. However, if we can find an underlying
22686 // scalar value, then we can always use that.
22687 if (IndexC && VecOp.getOpcode() == ISD::VECTOR_SHUFFLE) {
22688 auto *Shuf = cast<ShuffleVectorSDNode>(Val&: VecOp);
22689 // Find the new index to extract from.
22690 int OrigElt = Shuf->getMaskElt(Idx: IndexC->getZExtValue());
22691
22692 // Extracting an undef index is undef.
22693 if (OrigElt == -1)
22694 return DAG.getUNDEF(VT: ScalarVT);
22695
22696 // Select the right vector half to extract from.
22697 SDValue SVInVec;
22698 if (OrigElt < (int)NumElts) {
22699 SVInVec = VecOp.getOperand(i: 0);
22700 } else {
22701 SVInVec = VecOp.getOperand(i: 1);
22702 OrigElt -= NumElts;
22703 }
22704
22705 if (SVInVec.getOpcode() == ISD::BUILD_VECTOR) {
22706 SDValue InOp = SVInVec.getOperand(i: OrigElt);
22707 if (InOp.getValueType() != ScalarVT) {
22708 assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
22709 InOp = DAG.getSExtOrTrunc(Op: InOp, DL, VT: ScalarVT);
22710 }
22711
22712 return InOp;
22713 }
22714
22715 // FIXME: We should handle recursing on other vector shuffles and
22716 // scalar_to_vector here as well.
22717
22718 if (!LegalOperations ||
22719 // FIXME: Should really be just isOperationLegalOrCustom.
22720 TLI.isOperationLegal(Op: ISD::EXTRACT_VECTOR_ELT, VT: VecVT) ||
22721 TLI.isOperationExpand(Op: ISD::VECTOR_SHUFFLE, VT: VecVT)) {
22722 return DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: ScalarVT, N1: SVInVec,
22723 N2: DAG.getVectorIdxConstant(Val: OrigElt, DL));
22724 }
22725 }
22726
22727 // If only EXTRACT_VECTOR_ELT nodes use the source vector we can
22728 // simplify it based on the (valid) extraction indices.
22729 if (llvm::all_of(Range: VecOp->uses(), P: [&](SDNode *Use) {
22730 return Use->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
22731 Use->getOperand(Num: 0) == VecOp &&
22732 isa<ConstantSDNode>(Val: Use->getOperand(Num: 1));
22733 })) {
22734 APInt DemandedElts = APInt::getZero(numBits: NumElts);
22735 for (SDNode *Use : VecOp->uses()) {
22736 auto *CstElt = cast<ConstantSDNode>(Val: Use->getOperand(Num: 1));
22737 if (CstElt->getAPIntValue().ult(RHS: NumElts))
22738 DemandedElts.setBit(CstElt->getZExtValue());
22739 }
22740 if (SimplifyDemandedVectorElts(Op: VecOp, DemandedElts, AssumeSingleUse: true)) {
22741 // We simplified the vector operand of this extract element. If this
22742 // extract is not dead, visit it again so it is folded properly.
22743 if (N->getOpcode() != ISD::DELETED_NODE)
22744 AddToWorklist(N);
22745 return SDValue(N, 0);
22746 }
22747 APInt DemandedBits = APInt::getAllOnes(numBits: VecEltBitWidth);
22748 if (SimplifyDemandedBits(Op: VecOp, DemandedBits, DemandedElts, AssumeSingleUse: true)) {
22749 // We simplified the vector operand of this extract element. If this
22750 // extract is not dead, visit it again so it is folded properly.
22751 if (N->getOpcode() != ISD::DELETED_NODE)
22752 AddToWorklist(N);
22753 return SDValue(N, 0);
22754 }
22755 }
22756
22757 if (refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(N))
22758 return SDValue(N, 0);
22759
22760 // Everything under here is trying to match an extract of a loaded value.
22761 // If the result of load has to be truncated, then it's not necessarily
22762 // profitable.
22763 bool BCNumEltsChanged = false;
22764 EVT ExtVT = VecVT.getVectorElementType();
22765 EVT LVT = ExtVT;
22766 if (ScalarVT.bitsLT(VT: LVT) && !TLI.isTruncateFree(FromVT: LVT, ToVT: ScalarVT))
22767 return SDValue();
22768
22769 if (VecOp.getOpcode() == ISD::BITCAST) {
22770 // Don't duplicate a load with other uses.
22771 if (!VecOp.hasOneUse())
22772 return SDValue();
22773
22774 EVT BCVT = VecOp.getOperand(i: 0).getValueType();
22775 if (!BCVT.isVector() || ExtVT.bitsGT(VT: BCVT.getVectorElementType()))
22776 return SDValue();
22777 if (NumElts != BCVT.getVectorNumElements())
22778 BCNumEltsChanged = true;
22779 VecOp = VecOp.getOperand(i: 0);
22780 ExtVT = BCVT.getVectorElementType();
22781 }
22782
22783 // extract (vector load $addr), i --> load $addr + i * size
22784 if (!LegalOperations && !IndexC && VecOp.hasOneUse() &&
22785 ISD::isNormalLoad(N: VecOp.getNode()) &&
22786 !Index->hasPredecessor(N: VecOp.getNode())) {
22787 auto *VecLoad = dyn_cast<LoadSDNode>(Val&: VecOp);
22788 if (VecLoad && VecLoad->isSimple())
22789 return scalarizeExtractedVectorLoad(EVE: N, InVecVT: VecVT, EltNo: Index, OriginalLoad: VecLoad);
22790 }
22791
22792 // Perform only after legalization to ensure build_vector / vector_shuffle
22793 // optimizations have already been done.
22794 if (!LegalOperations || !IndexC)
22795 return SDValue();
22796
22797 // (vextract (v4f32 load $addr), c) -> (f32 load $addr+c*size)
22798 // (vextract (v4f32 s2v (f32 load $addr)), c) -> (f32 load $addr+c*size)
22799 // (vextract (v4f32 shuffle (load $addr), <1,u,u,u>), 0) -> (f32 load $addr)
22800 int Elt = IndexC->getZExtValue();
22801 LoadSDNode *LN0 = nullptr;
22802 if (ISD::isNormalLoad(N: VecOp.getNode())) {
22803 LN0 = cast<LoadSDNode>(Val&: VecOp);
22804 } else if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR &&
22805 VecOp.getOperand(i: 0).getValueType() == ExtVT &&
22806 ISD::isNormalLoad(N: VecOp.getOperand(i: 0).getNode())) {
22807 // Don't duplicate a load with other uses.
22808 if (!VecOp.hasOneUse())
22809 return SDValue();
22810
22811 LN0 = cast<LoadSDNode>(Val: VecOp.getOperand(i: 0));
22812 }
22813 if (auto *Shuf = dyn_cast<ShuffleVectorSDNode>(Val&: VecOp)) {
22814 // (vextract (vector_shuffle (load $addr), v2, <1, u, u, u>), 1)
22815 // =>
22816 // (load $addr+1*size)
22817
22818 // Don't duplicate a load with other uses.
22819 if (!VecOp.hasOneUse())
22820 return SDValue();
22821
22822 // If the bit convert changed the number of elements, it is unsafe
22823 // to examine the mask.
22824 if (BCNumEltsChanged)
22825 return SDValue();
22826
22827 // Select the input vector, guarding against out of range extract vector.
22828 int Idx = (Elt > (int)NumElts) ? -1 : Shuf->getMaskElt(Idx: Elt);
22829 VecOp = (Idx < (int)NumElts) ? VecOp.getOperand(i: 0) : VecOp.getOperand(i: 1);
22830
22831 if (VecOp.getOpcode() == ISD::BITCAST) {
22832 // Don't duplicate a load with other uses.
22833 if (!VecOp.hasOneUse())
22834 return SDValue();
22835
22836 VecOp = VecOp.getOperand(i: 0);
22837 }
22838 if (ISD::isNormalLoad(N: VecOp.getNode())) {
22839 LN0 = cast<LoadSDNode>(Val&: VecOp);
22840 Elt = (Idx < (int)NumElts) ? Idx : Idx - (int)NumElts;
22841 Index = DAG.getConstant(Val: Elt, DL, VT: Index.getValueType());
22842 }
22843 } else if (VecOp.getOpcode() == ISD::CONCAT_VECTORS && !BCNumEltsChanged &&
22844 VecVT.getVectorElementType() == ScalarVT &&
22845 (!LegalTypes ||
22846 TLI.isTypeLegal(
22847 VT: VecOp.getOperand(i: 0).getValueType().getVectorElementType()))) {
22848 // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 0
22849 // -> extract_vector_elt a, 0
22850 // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 1
22851 // -> extract_vector_elt a, 1
22852 // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 2
22853 // -> extract_vector_elt b, 0
22854 // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 3
22855 // -> extract_vector_elt b, 1
22856 EVT ConcatVT = VecOp.getOperand(i: 0).getValueType();
22857 unsigned ConcatNumElts = ConcatVT.getVectorNumElements();
22858 SDValue NewIdx = DAG.getConstant(Val: Elt % ConcatNumElts, DL,
22859 VT: Index.getValueType());
22860
22861 SDValue ConcatOp = VecOp.getOperand(i: Elt / ConcatNumElts);
22862 SDValue Elt = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL,
22863 VT: ConcatVT.getVectorElementType(),
22864 N1: ConcatOp, N2: NewIdx);
22865 return DAG.getNode(Opcode: ISD::BITCAST, DL, VT: ScalarVT, Operand: Elt);
22866 }
22867
22868 // Make sure we found a non-volatile load and the extractelement is
22869 // the only use.
22870 if (!LN0 || !LN0->hasNUsesOfValue(NUses: 1,Value: 0) || !LN0->isSimple())
22871 return SDValue();
22872
22873 // If Idx was -1 above, Elt is going to be -1, so just return undef.
22874 if (Elt == -1)
22875 return DAG.getUNDEF(VT: LVT);
22876
22877 return scalarizeExtractedVectorLoad(EVE: N, InVecVT: VecVT, EltNo: Index, OriginalLoad: LN0);
22878}
22879
22880// Simplify (build_vec (ext )) to (bitcast (build_vec ))
22881SDValue DAGCombiner::reduceBuildVecExtToExtBuildVec(SDNode *N) {
22882 // We perform this optimization post type-legalization because
22883 // the type-legalizer often scalarizes integer-promoted vectors.
22884 // Performing this optimization before may create bit-casts which
22885 // will be type-legalized to complex code sequences.
22886 // We perform this optimization only before the operation legalizer because we
22887 // may introduce illegal operations.
22888 if (Level != AfterLegalizeVectorOps && Level != AfterLegalizeTypes)
22889 return SDValue();
22890
22891 unsigned NumInScalars = N->getNumOperands();
22892 SDLoc DL(N);
22893 EVT VT = N->getValueType(ResNo: 0);
22894
22895 // Check to see if this is a BUILD_VECTOR of a bunch of values
22896 // which come from any_extend or zero_extend nodes. If so, we can create
22897 // a new BUILD_VECTOR using bit-casts which may enable other BUILD_VECTOR
22898 // optimizations. We do not handle sign-extend because we can't fill the sign
22899 // using shuffles.
22900 EVT SourceType = MVT::Other;
22901 bool AllAnyExt = true;
22902
22903 for (unsigned i = 0; i != NumInScalars; ++i) {
22904 SDValue In = N->getOperand(Num: i);
22905 // Ignore undef inputs.
22906 if (In.isUndef()) continue;
22907
22908 bool AnyExt = In.getOpcode() == ISD::ANY_EXTEND;
22909 bool ZeroExt = In.getOpcode() == ISD::ZERO_EXTEND;
22910
22911 // Abort if the element is not an extension.
22912 if (!ZeroExt && !AnyExt) {
22913 SourceType = MVT::Other;
22914 break;
22915 }
22916
22917 // The input is a ZeroExt or AnyExt. Check the original type.
22918 EVT InTy = In.getOperand(i: 0).getValueType();
22919
22920 // Check that all of the widened source types are the same.
22921 if (SourceType == MVT::Other)
22922 // First time.
22923 SourceType = InTy;
22924 else if (InTy != SourceType) {
22925 // Multiple income types. Abort.
22926 SourceType = MVT::Other;
22927 break;
22928 }
22929
22930 // Check if all of the extends are ANY_EXTENDs.
22931 AllAnyExt &= AnyExt;
22932 }
22933
22934 // In order to have valid types, all of the inputs must be extended from the
22935 // same source type and all of the inputs must be any or zero extend.
22936 // Scalar sizes must be a power of two.
22937 EVT OutScalarTy = VT.getScalarType();
22938 bool ValidTypes =
22939 SourceType != MVT::Other &&
22940 llvm::has_single_bit<uint32_t>(Value: OutScalarTy.getSizeInBits()) &&
22941 llvm::has_single_bit<uint32_t>(Value: SourceType.getSizeInBits());
22942
22943 // Create a new simpler BUILD_VECTOR sequence which other optimizations can
22944 // turn into a single shuffle instruction.
22945 if (!ValidTypes)
22946 return SDValue();
22947
22948 // If we already have a splat buildvector, then don't fold it if it means
22949 // introducing zeros.
22950 if (!AllAnyExt && DAG.isSplatValue(V: SDValue(N, 0), /*AllowUndefs*/ true))
22951 return SDValue();
22952
22953 bool isLE = DAG.getDataLayout().isLittleEndian();
22954 unsigned ElemRatio = OutScalarTy.getSizeInBits()/SourceType.getSizeInBits();
22955 assert(ElemRatio > 1 && "Invalid element size ratio");
22956 SDValue Filler = AllAnyExt ? DAG.getUNDEF(VT: SourceType):
22957 DAG.getConstant(Val: 0, DL, VT: SourceType);
22958
22959 unsigned NewBVElems = ElemRatio * VT.getVectorNumElements();
22960 SmallVector<SDValue, 8> Ops(NewBVElems, Filler);
22961
22962 // Populate the new build_vector
22963 for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
22964 SDValue Cast = N->getOperand(Num: i);
22965 assert((Cast.getOpcode() == ISD::ANY_EXTEND ||
22966 Cast.getOpcode() == ISD::ZERO_EXTEND ||
22967 Cast.isUndef()) && "Invalid cast opcode");
22968 SDValue In;
22969 if (Cast.isUndef())
22970 In = DAG.getUNDEF(VT: SourceType);
22971 else
22972 In = Cast->getOperand(Num: 0);
22973 unsigned Index = isLE ? (i * ElemRatio) :
22974 (i * ElemRatio + (ElemRatio - 1));
22975
22976 assert(Index < Ops.size() && "Invalid index");
22977 Ops[Index] = In;
22978 }
22979
22980 // The type of the new BUILD_VECTOR node.
22981 EVT VecVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: SourceType, NumElements: NewBVElems);
22982 assert(VecVT.getSizeInBits() == VT.getSizeInBits() &&
22983 "Invalid vector size");
22984 // Check if the new vector type is legal.
22985 if (!isTypeLegal(VT: VecVT) ||
22986 (!TLI.isOperationLegal(Op: ISD::BUILD_VECTOR, VT: VecVT) &&
22987 TLI.isOperationLegal(Op: ISD::BUILD_VECTOR, VT)))
22988 return SDValue();
22989
22990 // Make the new BUILD_VECTOR.
22991 SDValue BV = DAG.getBuildVector(VT: VecVT, DL, Ops);
22992
22993 // The new BUILD_VECTOR node has the potential to be further optimized.
22994 AddToWorklist(N: BV.getNode());
22995 // Bitcast to the desired type.
22996 return DAG.getBitcast(VT, V: BV);
22997}
22998
22999// Simplify (build_vec (trunc $1)
23000// (trunc (srl $1 half-width))
23001// (trunc (srl $1 (2 * half-width))))
23002// to (bitcast $1)
23003SDValue DAGCombiner::reduceBuildVecTruncToBitCast(SDNode *N) {
23004 assert(N->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
23005
23006 EVT VT = N->getValueType(ResNo: 0);
23007
23008 // Don't run this before LegalizeTypes if VT is legal.
23009 // Targets may have other preferences.
23010 if (Level < AfterLegalizeTypes && TLI.isTypeLegal(VT))
23011 return SDValue();
23012
23013 // Only for little endian
23014 if (!DAG.getDataLayout().isLittleEndian())
23015 return SDValue();
23016
23017 SDLoc DL(N);
23018 EVT OutScalarTy = VT.getScalarType();
23019 uint64_t ScalarTypeBitsize = OutScalarTy.getSizeInBits();
23020
23021 // Only for power of two types to be sure that bitcast works well
23022 if (!isPowerOf2_64(Value: ScalarTypeBitsize))
23023 return SDValue();
23024
23025 unsigned NumInScalars = N->getNumOperands();
23026
23027 // Look through bitcasts
23028 auto PeekThroughBitcast = [](SDValue Op) {
23029 if (Op.getOpcode() == ISD::BITCAST)
23030 return Op.getOperand(i: 0);
23031 return Op;
23032 };
23033
23034 // The source value where all the parts are extracted.
23035 SDValue Src;
23036 for (unsigned i = 0; i != NumInScalars; ++i) {
23037 SDValue In = PeekThroughBitcast(N->getOperand(Num: i));
23038 // Ignore undef inputs.
23039 if (In.isUndef()) continue;
23040
23041 if (In.getOpcode() != ISD::TRUNCATE)
23042 return SDValue();
23043
23044 In = PeekThroughBitcast(In.getOperand(i: 0));
23045
23046 if (In.getOpcode() != ISD::SRL) {
23047 // For now only build_vec without shuffling, handle shifts here in the
23048 // future.
23049 if (i != 0)
23050 return SDValue();
23051
23052 Src = In;
23053 } else {
23054 // In is SRL
23055 SDValue part = PeekThroughBitcast(In.getOperand(i: 0));
23056
23057 if (!Src) {
23058 Src = part;
23059 } else if (Src != part) {
23060 // Vector parts do not stem from the same variable
23061 return SDValue();
23062 }
23063
23064 SDValue ShiftAmtVal = In.getOperand(i: 1);
23065 if (!isa<ConstantSDNode>(Val: ShiftAmtVal))
23066 return SDValue();
23067
23068 uint64_t ShiftAmt = In.getConstantOperandVal(i: 1);
23069
23070 // The extracted value is not extracted at the right position
23071 if (ShiftAmt != i * ScalarTypeBitsize)
23072 return SDValue();
23073 }
23074 }
23075
23076 // Only cast if the size is the same
23077 if (!Src || Src.getValueType().getSizeInBits() != VT.getSizeInBits())
23078 return SDValue();
23079
23080 return DAG.getBitcast(VT, V: Src);
23081}
23082
23083SDValue DAGCombiner::createBuildVecShuffle(const SDLoc &DL, SDNode *N,
23084 ArrayRef<int> VectorMask,
23085 SDValue VecIn1, SDValue VecIn2,
23086 unsigned LeftIdx, bool DidSplitVec) {
23087 SDValue ZeroIdx = DAG.getVectorIdxConstant(Val: 0, DL);
23088
23089 EVT VT = N->getValueType(ResNo: 0);
23090 EVT InVT1 = VecIn1.getValueType();
23091 EVT InVT2 = VecIn2.getNode() ? VecIn2.getValueType() : InVT1;
23092
23093 unsigned NumElems = VT.getVectorNumElements();
23094 unsigned ShuffleNumElems = NumElems;
23095
23096 // If we artificially split a vector in two already, then the offsets in the
23097 // operands will all be based off of VecIn1, even those in VecIn2.
23098 unsigned Vec2Offset = DidSplitVec ? 0 : InVT1.getVectorNumElements();
23099
23100 uint64_t VTSize = VT.getFixedSizeInBits();
23101 uint64_t InVT1Size = InVT1.getFixedSizeInBits();
23102 uint64_t InVT2Size = InVT2.getFixedSizeInBits();
23103
23104 assert(InVT2Size <= InVT1Size &&
23105 "Inputs must be sorted to be in non-increasing vector size order.");
23106
23107 // We can't generate a shuffle node with mismatched input and output types.
23108 // Try to make the types match the type of the output.
23109 if (InVT1 != VT || InVT2 != VT) {
23110 if ((VTSize % InVT1Size == 0) && InVT1 == InVT2) {
23111 // If the output vector length is a multiple of both input lengths,
23112 // we can concatenate them and pad the rest with undefs.
23113 unsigned NumConcats = VTSize / InVT1Size;
23114 assert(NumConcats >= 2 && "Concat needs at least two inputs!");
23115 SmallVector<SDValue, 2> ConcatOps(NumConcats, DAG.getUNDEF(VT: InVT1));
23116 ConcatOps[0] = VecIn1;
23117 ConcatOps[1] = VecIn2 ? VecIn2 : DAG.getUNDEF(VT: InVT1);
23118 VecIn1 = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT, Ops: ConcatOps);
23119 VecIn2 = SDValue();
23120 } else if (InVT1Size == VTSize * 2) {
23121 if (!TLI.isExtractSubvectorCheap(ResVT: VT, SrcVT: InVT1, Index: NumElems))
23122 return SDValue();
23123
23124 if (!VecIn2.getNode()) {
23125 // If we only have one input vector, and it's twice the size of the
23126 // output, split it in two.
23127 VecIn2 = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT, N1: VecIn1,
23128 N2: DAG.getVectorIdxConstant(Val: NumElems, DL));
23129 VecIn1 = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT, N1: VecIn1, N2: ZeroIdx);
23130 // Since we now have shorter input vectors, adjust the offset of the
23131 // second vector's start.
23132 Vec2Offset = NumElems;
23133 } else {
23134 assert(InVT2Size <= InVT1Size &&
23135 "Second input is not going to be larger than the first one.");
23136
23137 // VecIn1 is wider than the output, and we have another, possibly
23138 // smaller input. Pad the smaller input with undefs, shuffle at the
23139 // input vector width, and extract the output.
23140 // The shuffle type is different than VT, so check legality again.
23141 if (LegalOperations &&
23142 !TLI.isOperationLegal(Op: ISD::VECTOR_SHUFFLE, VT: InVT1))
23143 return SDValue();
23144
23145 // Legalizing INSERT_SUBVECTOR is tricky - you basically have to
23146 // lower it back into a BUILD_VECTOR. So if the inserted type is
23147 // illegal, don't even try.
23148 if (InVT1 != InVT2) {
23149 if (!TLI.isTypeLegal(VT: InVT2))
23150 return SDValue();
23151 VecIn2 = DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL, VT: InVT1,
23152 N1: DAG.getUNDEF(VT: InVT1), N2: VecIn2, N3: ZeroIdx);
23153 }
23154 ShuffleNumElems = NumElems * 2;
23155 }
23156 } else if (InVT2Size * 2 == VTSize && InVT1Size == VTSize) {
23157 SmallVector<SDValue, 2> ConcatOps(2, DAG.getUNDEF(VT: InVT2));
23158 ConcatOps[0] = VecIn2;
23159 VecIn2 = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT, Ops: ConcatOps);
23160 } else if (InVT1Size / VTSize > 1 && InVT1Size % VTSize == 0) {
23161 if (!TLI.isExtractSubvectorCheap(ResVT: VT, SrcVT: InVT1, Index: NumElems) ||
23162 !TLI.isTypeLegal(VT: InVT1) || !TLI.isTypeLegal(VT: InVT2))
23163 return SDValue();
23164 // If dest vector has less than two elements, then use shuffle and extract
23165 // from larger regs will cost even more.
23166 if (VT.getVectorNumElements() <= 2 || !VecIn2.getNode())
23167 return SDValue();
23168 assert(InVT2Size <= InVT1Size &&
23169 "Second input is not going to be larger than the first one.");
23170
23171 // VecIn1 is wider than the output, and we have another, possibly
23172 // smaller input. Pad the smaller input with undefs, shuffle at the
23173 // input vector width, and extract the output.
23174 // The shuffle type is different than VT, so check legality again.
23175 if (LegalOperations && !TLI.isOperationLegal(Op: ISD::VECTOR_SHUFFLE, VT: InVT1))
23176 return SDValue();
23177
23178 if (InVT1 != InVT2) {
23179 VecIn2 = DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL, VT: InVT1,
23180 N1: DAG.getUNDEF(VT: InVT1), N2: VecIn2, N3: ZeroIdx);
23181 }
23182 ShuffleNumElems = InVT1Size / VTSize * NumElems;
23183 } else {
23184 // TODO: Support cases where the length mismatch isn't exactly by a
23185 // factor of 2.
23186 // TODO: Move this check upwards, so that if we have bad type
23187 // mismatches, we don't create any DAG nodes.
23188 return SDValue();
23189 }
23190 }
23191
23192 // Initialize mask to undef.
23193 SmallVector<int, 8> Mask(ShuffleNumElems, -1);
23194
23195 // Only need to run up to the number of elements actually used, not the
23196 // total number of elements in the shuffle - if we are shuffling a wider
23197 // vector, the high lanes should be set to undef.
23198 for (unsigned i = 0; i != NumElems; ++i) {
23199 if (VectorMask[i] <= 0)
23200 continue;
23201
23202 unsigned ExtIndex = N->getOperand(Num: i).getConstantOperandVal(i: 1);
23203 if (VectorMask[i] == (int)LeftIdx) {
23204 Mask[i] = ExtIndex;
23205 } else if (VectorMask[i] == (int)LeftIdx + 1) {
23206 Mask[i] = Vec2Offset + ExtIndex;
23207 }
23208 }
23209
23210 // The type the input vectors may have changed above.
23211 InVT1 = VecIn1.getValueType();
23212
23213 // If we already have a VecIn2, it should have the same type as VecIn1.
23214 // If we don't, get an undef/zero vector of the appropriate type.
23215 VecIn2 = VecIn2.getNode() ? VecIn2 : DAG.getUNDEF(VT: InVT1);
23216 assert(InVT1 == VecIn2.getValueType() && "Unexpected second input type.");
23217
23218 SDValue Shuffle = DAG.getVectorShuffle(VT: InVT1, dl: DL, N1: VecIn1, N2: VecIn2, Mask);
23219 if (ShuffleNumElems > NumElems)
23220 Shuffle = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT, N1: Shuffle, N2: ZeroIdx);
23221
23222 return Shuffle;
23223}
23224
23225static SDValue reduceBuildVecToShuffleWithZero(SDNode *BV, SelectionDAG &DAG) {
23226 assert(BV->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
23227
23228 // First, determine where the build vector is not undef.
23229 // TODO: We could extend this to handle zero elements as well as undefs.
23230 int NumBVOps = BV->getNumOperands();
23231 int ZextElt = -1;
23232 for (int i = 0; i != NumBVOps; ++i) {
23233 SDValue Op = BV->getOperand(Num: i);
23234 if (Op.isUndef())
23235 continue;
23236 if (ZextElt == -1)
23237 ZextElt = i;
23238 else
23239 return SDValue();
23240 }
23241 // Bail out if there's no non-undef element.
23242 if (ZextElt == -1)
23243 return SDValue();
23244
23245 // The build vector contains some number of undef elements and exactly
23246 // one other element. That other element must be a zero-extended scalar
23247 // extracted from a vector at a constant index to turn this into a shuffle.
23248 // Also, require that the build vector does not implicitly truncate/extend
23249 // its elements.
23250 // TODO: This could be enhanced to allow ANY_EXTEND as well as ZERO_EXTEND.
23251 EVT VT = BV->getValueType(ResNo: 0);
23252 SDValue Zext = BV->getOperand(Num: ZextElt);
23253 if (Zext.getOpcode() != ISD::ZERO_EXTEND || !Zext.hasOneUse() ||
23254 Zext.getOperand(i: 0).getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
23255 !isa<ConstantSDNode>(Val: Zext.getOperand(i: 0).getOperand(i: 1)) ||
23256 Zext.getValueSizeInBits() != VT.getScalarSizeInBits())
23257 return SDValue();
23258
23259 // The zero-extend must be a multiple of the source size, and we must be
23260 // building a vector of the same size as the source of the extract element.
23261 SDValue Extract = Zext.getOperand(i: 0);
23262 unsigned DestSize = Zext.getValueSizeInBits();
23263 unsigned SrcSize = Extract.getValueSizeInBits();
23264 if (DestSize % SrcSize != 0 ||
23265 Extract.getOperand(i: 0).getValueSizeInBits() != VT.getSizeInBits())
23266 return SDValue();
23267
23268 // Create a shuffle mask that will combine the extracted element with zeros
23269 // and undefs.
23270 int ZextRatio = DestSize / SrcSize;
23271 int NumMaskElts = NumBVOps * ZextRatio;
23272 SmallVector<int, 32> ShufMask(NumMaskElts, -1);
23273 for (int i = 0; i != NumMaskElts; ++i) {
23274 if (i / ZextRatio == ZextElt) {
23275 // The low bits of the (potentially translated) extracted element map to
23276 // the source vector. The high bits map to zero. We will use a zero vector
23277 // as the 2nd source operand of the shuffle, so use the 1st element of
23278 // that vector (mask value is number-of-elements) for the high bits.
23279 int Low = DAG.getDataLayout().isBigEndian() ? (ZextRatio - 1) : 0;
23280 ShufMask[i] = (i % ZextRatio == Low) ? Extract.getConstantOperandVal(i: 1)
23281 : NumMaskElts;
23282 }
23283
23284 // Undef elements of the build vector remain undef because we initialize
23285 // the shuffle mask with -1.
23286 }
23287
23288 // buildvec undef, ..., (zext (extractelt V, IndexC)), undef... -->
23289 // bitcast (shuffle V, ZeroVec, VectorMask)
23290 SDLoc DL(BV);
23291 EVT VecVT = Extract.getOperand(i: 0).getValueType();
23292 SDValue ZeroVec = DAG.getConstant(Val: 0, DL, VT: VecVT);
23293 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
23294 SDValue Shuf = TLI.buildLegalVectorShuffle(VT: VecVT, DL, N0: Extract.getOperand(i: 0),
23295 N1: ZeroVec, Mask: ShufMask, DAG);
23296 if (!Shuf)
23297 return SDValue();
23298 return DAG.getBitcast(VT, V: Shuf);
23299}
23300
23301// FIXME: promote to STLExtras.
23302template <typename R, typename T>
23303static auto getFirstIndexOf(R &&Range, const T &Val) {
23304 auto I = find(Range, Val);
23305 if (I == Range.end())
23306 return static_cast<decltype(std::distance(Range.begin(), I))>(-1);
23307 return std::distance(Range.begin(), I);
23308}
23309
23310// Check to see if this is a BUILD_VECTOR of a bunch of EXTRACT_VECTOR_ELT
23311// operations. If the types of the vectors we're extracting from allow it,
23312// turn this into a vector_shuffle node.
23313SDValue DAGCombiner::reduceBuildVecToShuffle(SDNode *N) {
23314 SDLoc DL(N);
23315 EVT VT = N->getValueType(ResNo: 0);
23316
23317 // Only type-legal BUILD_VECTOR nodes are converted to shuffle nodes.
23318 if (!isTypeLegal(VT))
23319 return SDValue();
23320
23321 if (SDValue V = reduceBuildVecToShuffleWithZero(BV: N, DAG))
23322 return V;
23323
23324 // May only combine to shuffle after legalize if shuffle is legal.
23325 if (LegalOperations && !TLI.isOperationLegal(Op: ISD::VECTOR_SHUFFLE, VT))
23326 return SDValue();
23327
23328 bool UsesZeroVector = false;
23329 unsigned NumElems = N->getNumOperands();
23330
23331 // Record, for each element of the newly built vector, which input vector
23332 // that element comes from. -1 stands for undef, 0 for the zero vector,
23333 // and positive values for the input vectors.
23334 // VectorMask maps each element to its vector number, and VecIn maps vector
23335 // numbers to their initial SDValues.
23336
23337 SmallVector<int, 8> VectorMask(NumElems, -1);
23338 SmallVector<SDValue, 8> VecIn;
23339 VecIn.push_back(Elt: SDValue());
23340
23341 for (unsigned i = 0; i != NumElems; ++i) {
23342 SDValue Op = N->getOperand(Num: i);
23343
23344 if (Op.isUndef())
23345 continue;
23346
23347 // See if we can use a blend with a zero vector.
23348 // TODO: Should we generalize this to a blend with an arbitrary constant
23349 // vector?
23350 if (isNullConstant(V: Op) || isNullFPConstant(V: Op)) {
23351 UsesZeroVector = true;
23352 VectorMask[i] = 0;
23353 continue;
23354 }
23355
23356 // Not an undef or zero. If the input is something other than an
23357 // EXTRACT_VECTOR_ELT with an in-range constant index, bail out.
23358 if (Op.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
23359 !isa<ConstantSDNode>(Val: Op.getOperand(i: 1)))
23360 return SDValue();
23361 SDValue ExtractedFromVec = Op.getOperand(i: 0);
23362
23363 if (ExtractedFromVec.getValueType().isScalableVector())
23364 return SDValue();
23365
23366 const APInt &ExtractIdx = Op.getConstantOperandAPInt(i: 1);
23367 if (ExtractIdx.uge(RHS: ExtractedFromVec.getValueType().getVectorNumElements()))
23368 return SDValue();
23369
23370 // All inputs must have the same element type as the output.
23371 if (VT.getVectorElementType() !=
23372 ExtractedFromVec.getValueType().getVectorElementType())
23373 return SDValue();
23374
23375 // Have we seen this input vector before?
23376 // The vectors are expected to be tiny (usually 1 or 2 elements), so using
23377 // a map back from SDValues to numbers isn't worth it.
23378 int Idx = getFirstIndexOf(Range&: VecIn, Val: ExtractedFromVec);
23379 if (Idx == -1) { // A new source vector?
23380 Idx = VecIn.size();
23381 VecIn.push_back(Elt: ExtractedFromVec);
23382 }
23383
23384 VectorMask[i] = Idx;
23385 }
23386
23387 // If we didn't find at least one input vector, bail out.
23388 if (VecIn.size() < 2)
23389 return SDValue();
23390
23391 // If all the Operands of BUILD_VECTOR extract from same
23392 // vector, then split the vector efficiently based on the maximum
23393 // vector access index and adjust the VectorMask and
23394 // VecIn accordingly.
23395 bool DidSplitVec = false;
23396 if (VecIn.size() == 2) {
23397 unsigned MaxIndex = 0;
23398 unsigned NearestPow2 = 0;
23399 SDValue Vec = VecIn.back();
23400 EVT InVT = Vec.getValueType();
23401 SmallVector<unsigned, 8> IndexVec(NumElems, 0);
23402
23403 for (unsigned i = 0; i < NumElems; i++) {
23404 if (VectorMask[i] <= 0)
23405 continue;
23406 unsigned Index = N->getOperand(Num: i).getConstantOperandVal(i: 1);
23407 IndexVec[i] = Index;
23408 MaxIndex = std::max(a: MaxIndex, b: Index);
23409 }
23410
23411 NearestPow2 = PowerOf2Ceil(A: MaxIndex);
23412 if (InVT.isSimple() && NearestPow2 > 2 && MaxIndex < NearestPow2 &&
23413 NumElems * 2 < NearestPow2) {
23414 unsigned SplitSize = NearestPow2 / 2;
23415 EVT SplitVT = EVT::getVectorVT(Context&: *DAG.getContext(),
23416 VT: InVT.getVectorElementType(), NumElements: SplitSize);
23417 if (TLI.isTypeLegal(VT: SplitVT) &&
23418 SplitSize + SplitVT.getVectorNumElements() <=
23419 InVT.getVectorNumElements()) {
23420 SDValue VecIn2 = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: SplitVT, N1: Vec,
23421 N2: DAG.getVectorIdxConstant(Val: SplitSize, DL));
23422 SDValue VecIn1 = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: SplitVT, N1: Vec,
23423 N2: DAG.getVectorIdxConstant(Val: 0, DL));
23424 VecIn.pop_back();
23425 VecIn.push_back(Elt: VecIn1);
23426 VecIn.push_back(Elt: VecIn2);
23427 DidSplitVec = true;
23428
23429 for (unsigned i = 0; i < NumElems; i++) {
23430 if (VectorMask[i] <= 0)
23431 continue;
23432 VectorMask[i] = (IndexVec[i] < SplitSize) ? 1 : 2;
23433 }
23434 }
23435 }
23436 }
23437
23438 // Sort input vectors by decreasing vector element count,
23439 // while preserving the relative order of equally-sized vectors.
23440 // Note that we keep the first "implicit zero vector as-is.
23441 SmallVector<SDValue, 8> SortedVecIn(VecIn);
23442 llvm::stable_sort(Range: MutableArrayRef<SDValue>(SortedVecIn).drop_front(),
23443 C: [](const SDValue &a, const SDValue &b) {
23444 return a.getValueType().getVectorNumElements() >
23445 b.getValueType().getVectorNumElements();
23446 });
23447
23448 // We now also need to rebuild the VectorMask, because it referenced element
23449 // order in VecIn, and we just sorted them.
23450 for (int &SourceVectorIndex : VectorMask) {
23451 if (SourceVectorIndex <= 0)
23452 continue;
23453 unsigned Idx = getFirstIndexOf(Range&: SortedVecIn, Val: VecIn[SourceVectorIndex]);
23454 assert(Idx > 0 && Idx < SortedVecIn.size() &&
23455 VecIn[SourceVectorIndex] == SortedVecIn[Idx] && "Remapping failure");
23456 SourceVectorIndex = Idx;
23457 }
23458
23459 VecIn = std::move(SortedVecIn);
23460
23461 // TODO: Should this fire if some of the input vectors has illegal type (like
23462 // it does now), or should we let legalization run its course first?
23463
23464 // Shuffle phase:
23465 // Take pairs of vectors, and shuffle them so that the result has elements
23466 // from these vectors in the correct places.
23467 // For example, given:
23468 // t10: i32 = extract_vector_elt t1, Constant:i64<0>
23469 // t11: i32 = extract_vector_elt t2, Constant:i64<0>
23470 // t12: i32 = extract_vector_elt t3, Constant:i64<0>
23471 // t13: i32 = extract_vector_elt t1, Constant:i64<1>
23472 // t14: v4i32 = BUILD_VECTOR t10, t11, t12, t13
23473 // We will generate:
23474 // t20: v4i32 = vector_shuffle<0,4,u,1> t1, t2
23475 // t21: v4i32 = vector_shuffle<u,u,0,u> t3, undef
23476 SmallVector<SDValue, 4> Shuffles;
23477 for (unsigned In = 0, Len = (VecIn.size() / 2); In < Len; ++In) {
23478 unsigned LeftIdx = 2 * In + 1;
23479 SDValue VecLeft = VecIn[LeftIdx];
23480 SDValue VecRight =
23481 (LeftIdx + 1) < VecIn.size() ? VecIn[LeftIdx + 1] : SDValue();
23482
23483 if (SDValue Shuffle = createBuildVecShuffle(DL, N, VectorMask, VecIn1: VecLeft,
23484 VecIn2: VecRight, LeftIdx, DidSplitVec))
23485 Shuffles.push_back(Elt: Shuffle);
23486 else
23487 return SDValue();
23488 }
23489
23490 // If we need the zero vector as an "ingredient" in the blend tree, add it
23491 // to the list of shuffles.
23492 if (UsesZeroVector)
23493 Shuffles.push_back(Elt: VT.isInteger() ? DAG.getConstant(Val: 0, DL, VT)
23494 : DAG.getConstantFP(Val: 0.0, DL, VT));
23495
23496 // If we only have one shuffle, we're done.
23497 if (Shuffles.size() == 1)
23498 return Shuffles[0];
23499
23500 // Update the vector mask to point to the post-shuffle vectors.
23501 for (int &Vec : VectorMask)
23502 if (Vec == 0)
23503 Vec = Shuffles.size() - 1;
23504 else
23505 Vec = (Vec - 1) / 2;
23506
23507 // More than one shuffle. Generate a binary tree of blends, e.g. if from
23508 // the previous step we got the set of shuffles t10, t11, t12, t13, we will
23509 // generate:
23510 // t10: v8i32 = vector_shuffle<0,8,u,u,u,u,u,u> t1, t2
23511 // t11: v8i32 = vector_shuffle<u,u,0,8,u,u,u,u> t3, t4
23512 // t12: v8i32 = vector_shuffle<u,u,u,u,0,8,u,u> t5, t6
23513 // t13: v8i32 = vector_shuffle<u,u,u,u,u,u,0,8> t7, t8
23514 // t20: v8i32 = vector_shuffle<0,1,10,11,u,u,u,u> t10, t11
23515 // t21: v8i32 = vector_shuffle<u,u,u,u,4,5,14,15> t12, t13
23516 // t30: v8i32 = vector_shuffle<0,1,2,3,12,13,14,15> t20, t21
23517
23518 // Make sure the initial size of the shuffle list is even.
23519 if (Shuffles.size() % 2)
23520 Shuffles.push_back(Elt: DAG.getUNDEF(VT));
23521
23522 for (unsigned CurSize = Shuffles.size(); CurSize > 1; CurSize /= 2) {
23523 if (CurSize % 2) {
23524 Shuffles[CurSize] = DAG.getUNDEF(VT);
23525 CurSize++;
23526 }
23527 for (unsigned In = 0, Len = CurSize / 2; In < Len; ++In) {
23528 int Left = 2 * In;
23529 int Right = 2 * In + 1;
23530 SmallVector<int, 8> Mask(NumElems, -1);
23531 SDValue L = Shuffles[Left];
23532 ArrayRef<int> LMask;
23533 bool IsLeftShuffle = L.getOpcode() == ISD::VECTOR_SHUFFLE &&
23534 L.use_empty() && L.getOperand(i: 1).isUndef() &&
23535 L.getOperand(i: 0).getValueType() == L.getValueType();
23536 if (IsLeftShuffle) {
23537 LMask = cast<ShuffleVectorSDNode>(Val: L.getNode())->getMask();
23538 L = L.getOperand(i: 0);
23539 }
23540 SDValue R = Shuffles[Right];
23541 ArrayRef<int> RMask;
23542 bool IsRightShuffle = R.getOpcode() == ISD::VECTOR_SHUFFLE &&
23543 R.use_empty() && R.getOperand(i: 1).isUndef() &&
23544 R.getOperand(i: 0).getValueType() == R.getValueType();
23545 if (IsRightShuffle) {
23546 RMask = cast<ShuffleVectorSDNode>(Val: R.getNode())->getMask();
23547 R = R.getOperand(i: 0);
23548 }
23549 for (unsigned I = 0; I != NumElems; ++I) {
23550 if (VectorMask[I] == Left) {
23551 Mask[I] = I;
23552 if (IsLeftShuffle)
23553 Mask[I] = LMask[I];
23554 VectorMask[I] = In;
23555 } else if (VectorMask[I] == Right) {
23556 Mask[I] = I + NumElems;
23557 if (IsRightShuffle)
23558 Mask[I] = RMask[I] + NumElems;
23559 VectorMask[I] = In;
23560 }
23561 }
23562
23563 Shuffles[In] = DAG.getVectorShuffle(VT, dl: DL, N1: L, N2: R, Mask);
23564 }
23565 }
23566 return Shuffles[0];
23567}
23568
23569// Try to turn a build vector of zero extends of extract vector elts into a
23570// a vector zero extend and possibly an extract subvector.
23571// TODO: Support sign extend?
23572// TODO: Allow undef elements?
23573SDValue DAGCombiner::convertBuildVecZextToZext(SDNode *N) {
23574 if (LegalOperations)
23575 return SDValue();
23576
23577 EVT VT = N->getValueType(ResNo: 0);
23578
23579 bool FoundZeroExtend = false;
23580 SDValue Op0 = N->getOperand(Num: 0);
23581 auto checkElem = [&](SDValue Op) -> int64_t {
23582 unsigned Opc = Op.getOpcode();
23583 FoundZeroExtend |= (Opc == ISD::ZERO_EXTEND);
23584 if ((Opc == ISD::ZERO_EXTEND || Opc == ISD::ANY_EXTEND) &&
23585 Op.getOperand(i: 0).getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
23586 Op0.getOperand(i: 0).getOperand(i: 0) == Op.getOperand(i: 0).getOperand(i: 0))
23587 if (auto *C = dyn_cast<ConstantSDNode>(Val: Op.getOperand(i: 0).getOperand(i: 1)))
23588 return C->getZExtValue();
23589 return -1;
23590 };
23591
23592 // Make sure the first element matches
23593 // (zext (extract_vector_elt X, C))
23594 // Offset must be a constant multiple of the
23595 // known-minimum vector length of the result type.
23596 int64_t Offset = checkElem(Op0);
23597 if (Offset < 0 || (Offset % VT.getVectorNumElements()) != 0)
23598 return SDValue();
23599
23600 unsigned NumElems = N->getNumOperands();
23601 SDValue In = Op0.getOperand(i: 0).getOperand(i: 0);
23602 EVT InSVT = In.getValueType().getScalarType();
23603 EVT InVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: InSVT, NumElements: NumElems);
23604
23605 // Don't create an illegal input type after type legalization.
23606 if (LegalTypes && !TLI.isTypeLegal(VT: InVT))
23607 return SDValue();
23608
23609 // Ensure all the elements come from the same vector and are adjacent.
23610 for (unsigned i = 1; i != NumElems; ++i) {
23611 if ((Offset + i) != checkElem(N->getOperand(Num: i)))
23612 return SDValue();
23613 }
23614
23615 SDLoc DL(N);
23616 In = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: InVT, N1: In,
23617 N2: Op0.getOperand(i: 0).getOperand(i: 1));
23618 return DAG.getNode(Opcode: FoundZeroExtend ? ISD::ZERO_EXTEND : ISD::ANY_EXTEND, DL,
23619 VT, Operand: In);
23620}
23621
23622// If this is a very simple BUILD_VECTOR with first element being a ZERO_EXTEND,
23623// and all other elements being constant zero's, granularize the BUILD_VECTOR's
23624// element width, absorbing the ZERO_EXTEND, turning it into a constant zero op.
23625// This patten can appear during legalization.
23626//
23627// NOTE: This can be generalized to allow more than a single
23628// non-constant-zero op, UNDEF's, and to be KnownBits-based,
23629SDValue DAGCombiner::convertBuildVecZextToBuildVecWithZeros(SDNode *N) {
23630 // Don't run this after legalization. Targets may have other preferences.
23631 if (Level >= AfterLegalizeDAG)
23632 return SDValue();
23633
23634 // FIXME: support big-endian.
23635 if (DAG.getDataLayout().isBigEndian())
23636 return SDValue();
23637
23638 EVT VT = N->getValueType(ResNo: 0);
23639 EVT OpVT = N->getOperand(Num: 0).getValueType();
23640 assert(!VT.isScalableVector() && "Encountered scalable BUILD_VECTOR?");
23641
23642 EVT OpIntVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: OpVT.getSizeInBits());
23643
23644 if (!TLI.isTypeLegal(VT: OpIntVT) ||
23645 (LegalOperations && !TLI.isOperationLegalOrCustom(Op: ISD::BITCAST, VT: OpIntVT)))
23646 return SDValue();
23647
23648 unsigned EltBitwidth = VT.getScalarSizeInBits();
23649 // NOTE: the actual width of operands may be wider than that!
23650
23651 // Analyze all operands of this BUILD_VECTOR. What is the largest number of
23652 // active bits they all have? We'll want to truncate them all to that width.
23653 unsigned ActiveBits = 0;
23654 APInt KnownZeroOps(VT.getVectorNumElements(), 0);
23655 for (auto I : enumerate(First: N->ops())) {
23656 SDValue Op = I.value();
23657 // FIXME: support UNDEF elements?
23658 if (auto *Cst = dyn_cast<ConstantSDNode>(Val&: Op)) {
23659 unsigned OpActiveBits =
23660 Cst->getAPIntValue().trunc(width: EltBitwidth).getActiveBits();
23661 if (OpActiveBits == 0) {
23662 KnownZeroOps.setBit(I.index());
23663 continue;
23664 }
23665 // Profitability check: don't allow non-zero constant operands.
23666 return SDValue();
23667 }
23668 // Profitability check: there must only be a single non-zero operand,
23669 // and it must be the first operand of the BUILD_VECTOR.
23670 if (I.index() != 0)
23671 return SDValue();
23672 // The operand must be a zero-extension itself.
23673 // FIXME: this could be generalized to known leading zeros check.
23674 if (Op.getOpcode() != ISD::ZERO_EXTEND)
23675 return SDValue();
23676 unsigned CurrActiveBits =
23677 Op.getOperand(i: 0).getValueSizeInBits().getFixedValue();
23678 assert(!ActiveBits && "Already encountered non-constant-zero operand?");
23679 ActiveBits = CurrActiveBits;
23680 // We want to at least halve the element size.
23681 if (2 * ActiveBits > EltBitwidth)
23682 return SDValue();
23683 }
23684
23685 // This BUILD_VECTOR must have at least one non-constant-zero operand.
23686 if (ActiveBits == 0)
23687 return SDValue();
23688
23689 // We have EltBitwidth bits, the *minimal* chunk size is ActiveBits,
23690 // into how many chunks can we split our element width?
23691 EVT NewScalarIntVT, NewIntVT;
23692 std::optional<unsigned> Factor;
23693 // We can split the element into at least two chunks, but not into more
23694 // than |_ EltBitwidth / ActiveBits _| chunks. Find a largest split factor
23695 // for which the element width is a multiple of it,
23696 // and the resulting types/operations on that chunk width are legal.
23697 assert(2 * ActiveBits <= EltBitwidth &&
23698 "We know that half or less bits of the element are active.");
23699 for (unsigned Scale = EltBitwidth / ActiveBits; Scale >= 2; --Scale) {
23700 if (EltBitwidth % Scale != 0)
23701 continue;
23702 unsigned ChunkBitwidth = EltBitwidth / Scale;
23703 assert(ChunkBitwidth >= ActiveBits && "As per starting point.");
23704 NewScalarIntVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: ChunkBitwidth);
23705 NewIntVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: NewScalarIntVT,
23706 NumElements: Scale * N->getNumOperands());
23707 if (!TLI.isTypeLegal(VT: NewScalarIntVT) || !TLI.isTypeLegal(VT: NewIntVT) ||
23708 (LegalOperations &&
23709 !(TLI.isOperationLegalOrCustom(Op: ISD::TRUNCATE, VT: NewScalarIntVT) &&
23710 TLI.isOperationLegalOrCustom(Op: ISD::BUILD_VECTOR, VT: NewIntVT))))
23711 continue;
23712 Factor = Scale;
23713 break;
23714 }
23715 if (!Factor)
23716 return SDValue();
23717
23718 SDLoc DL(N);
23719 SDValue ZeroOp = DAG.getConstant(Val: 0, DL, VT: NewScalarIntVT);
23720
23721 // Recreate the BUILD_VECTOR, with elements now being Factor times smaller.
23722 SmallVector<SDValue, 16> NewOps;
23723 NewOps.reserve(N: NewIntVT.getVectorNumElements());
23724 for (auto I : enumerate(First: N->ops())) {
23725 SDValue Op = I.value();
23726 assert(!Op.isUndef() && "FIXME: after allowing UNDEF's, handle them here.");
23727 unsigned SrcOpIdx = I.index();
23728 if (KnownZeroOps[SrcOpIdx]) {
23729 NewOps.append(NumInputs: *Factor, Elt: ZeroOp);
23730 continue;
23731 }
23732 Op = DAG.getBitcast(VT: OpIntVT, V: Op);
23733 Op = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: NewScalarIntVT, Operand: Op);
23734 NewOps.emplace_back(Args&: Op);
23735 NewOps.append(NumInputs: *Factor - 1, Elt: ZeroOp);
23736 }
23737 assert(NewOps.size() == NewIntVT.getVectorNumElements());
23738 SDValue NewBV = DAG.getBuildVector(VT: NewIntVT, DL, Ops: NewOps);
23739 NewBV = DAG.getBitcast(VT, V: NewBV);
23740 return NewBV;
23741}
23742
23743SDValue DAGCombiner::visitBUILD_VECTOR(SDNode *N) {
23744 EVT VT = N->getValueType(ResNo: 0);
23745
23746 // A vector built entirely of undefs is undef.
23747 if (ISD::allOperandsUndef(N))
23748 return DAG.getUNDEF(VT);
23749
23750 // If this is a splat of a bitcast from another vector, change to a
23751 // concat_vector.
23752 // For example:
23753 // (build_vector (i64 (bitcast (v2i32 X))), (i64 (bitcast (v2i32 X)))) ->
23754 // (v2i64 (bitcast (concat_vectors (v2i32 X), (v2i32 X))))
23755 //
23756 // If X is a build_vector itself, the concat can become a larger build_vector.
23757 // TODO: Maybe this is useful for non-splat too?
23758 if (!LegalOperations) {
23759 SDValue Splat = cast<BuildVectorSDNode>(Val: N)->getSplatValue();
23760 // Only change build_vector to a concat_vector if the splat value type is
23761 // same as the vector element type.
23762 if (Splat && Splat.getValueType() == VT.getVectorElementType()) {
23763 Splat = peekThroughBitcasts(V: Splat);
23764 EVT SrcVT = Splat.getValueType();
23765 if (SrcVT.isVector()) {
23766 unsigned NumElts = N->getNumOperands() * SrcVT.getVectorNumElements();
23767 EVT NewVT = EVT::getVectorVT(Context&: *DAG.getContext(),
23768 VT: SrcVT.getVectorElementType(), NumElements: NumElts);
23769 if (!LegalTypes || TLI.isTypeLegal(VT: NewVT)) {
23770 SmallVector<SDValue, 8> Ops(N->getNumOperands(), Splat);
23771 SDValue Concat =
23772 DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT: NewVT, Ops);
23773 return DAG.getBitcast(VT, V: Concat);
23774 }
23775 }
23776 }
23777 }
23778
23779 // Check if we can express BUILD VECTOR via subvector extract.
23780 if (!LegalTypes && (N->getNumOperands() > 1)) {
23781 SDValue Op0 = N->getOperand(Num: 0);
23782 auto checkElem = [&](SDValue Op) -> uint64_t {
23783 if ((Op.getOpcode() == ISD::EXTRACT_VECTOR_ELT) &&
23784 (Op0.getOperand(i: 0) == Op.getOperand(i: 0)))
23785 if (auto CNode = dyn_cast<ConstantSDNode>(Val: Op.getOperand(i: 1)))
23786 return CNode->getZExtValue();
23787 return -1;
23788 };
23789
23790 int Offset = checkElem(Op0);
23791 for (unsigned i = 0; i < N->getNumOperands(); ++i) {
23792 if (Offset + i != checkElem(N->getOperand(Num: i))) {
23793 Offset = -1;
23794 break;
23795 }
23796 }
23797
23798 if ((Offset == 0) &&
23799 (Op0.getOperand(i: 0).getValueType() == N->getValueType(ResNo: 0)))
23800 return Op0.getOperand(i: 0);
23801 if ((Offset != -1) &&
23802 ((Offset % N->getValueType(ResNo: 0).getVectorNumElements()) ==
23803 0)) // IDX must be multiple of output size.
23804 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL: SDLoc(N), VT: N->getValueType(ResNo: 0),
23805 N1: Op0.getOperand(i: 0), N2: Op0.getOperand(i: 1));
23806 }
23807
23808 if (SDValue V = convertBuildVecZextToZext(N))
23809 return V;
23810
23811 if (SDValue V = convertBuildVecZextToBuildVecWithZeros(N))
23812 return V;
23813
23814 if (SDValue V = reduceBuildVecExtToExtBuildVec(N))
23815 return V;
23816
23817 if (SDValue V = reduceBuildVecTruncToBitCast(N))
23818 return V;
23819
23820 if (SDValue V = reduceBuildVecToShuffle(N))
23821 return V;
23822
23823 // A splat of a single element is a SPLAT_VECTOR if supported on the target.
23824 // Do this late as some of the above may replace the splat.
23825 if (TLI.getOperationAction(Op: ISD::SPLAT_VECTOR, VT) != TargetLowering::Expand)
23826 if (SDValue V = cast<BuildVectorSDNode>(Val: N)->getSplatValue()) {
23827 assert(!V.isUndef() && "Splat of undef should have been handled earlier");
23828 return DAG.getNode(Opcode: ISD::SPLAT_VECTOR, DL: SDLoc(N), VT, Operand: V);
23829 }
23830
23831 return SDValue();
23832}
23833
23834static SDValue combineConcatVectorOfScalars(SDNode *N, SelectionDAG &DAG) {
23835 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
23836 EVT OpVT = N->getOperand(Num: 0).getValueType();
23837
23838 // If the operands are legal vectors, leave them alone.
23839 if (TLI.isTypeLegal(VT: OpVT) || OpVT.isScalableVector())
23840 return SDValue();
23841
23842 SDLoc DL(N);
23843 EVT VT = N->getValueType(ResNo: 0);
23844 SmallVector<SDValue, 8> Ops;
23845 EVT SVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: OpVT.getSizeInBits());
23846
23847 // Keep track of what we encounter.
23848 bool AnyInteger = false;
23849 bool AnyFP = false;
23850 for (const SDValue &Op : N->ops()) {
23851 if (ISD::BITCAST == Op.getOpcode() &&
23852 !Op.getOperand(i: 0).getValueType().isVector())
23853 Ops.push_back(Elt: Op.getOperand(i: 0));
23854 else if (ISD::UNDEF == Op.getOpcode())
23855 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::UNDEF, DL, VT: SVT));
23856 else
23857 return SDValue();
23858
23859 // Note whether we encounter an integer or floating point scalar.
23860 // If it's neither, bail out, it could be something weird like x86mmx.
23861 EVT LastOpVT = Ops.back().getValueType();
23862 if (LastOpVT.isFloatingPoint())
23863 AnyFP = true;
23864 else if (LastOpVT.isInteger())
23865 AnyInteger = true;
23866 else
23867 return SDValue();
23868 }
23869
23870 // If any of the operands is a floating point scalar bitcast to a vector,
23871 // use floating point types throughout, and bitcast everything.
23872 // Replace UNDEFs by another scalar UNDEF node, of the final desired type.
23873 if (AnyFP) {
23874 SVT = EVT::getFloatingPointVT(BitWidth: OpVT.getSizeInBits());
23875 if (AnyInteger) {
23876 for (SDValue &Op : Ops) {
23877 if (Op.getValueType() == SVT)
23878 continue;
23879 if (Op.isUndef())
23880 Op = DAG.getNode(Opcode: ISD::UNDEF, DL, VT: SVT);
23881 else
23882 Op = DAG.getBitcast(VT: SVT, V: Op);
23883 }
23884 }
23885 }
23886
23887 EVT VecVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: SVT,
23888 NumElements: VT.getSizeInBits() / SVT.getSizeInBits());
23889 return DAG.getBitcast(VT, V: DAG.getBuildVector(VT: VecVT, DL, Ops));
23890}
23891
23892// Attempt to merge nested concat_vectors/undefs.
23893// Fold concat_vectors(concat_vectors(x,y,z,w),u,u,concat_vectors(a,b,c,d))
23894// --> concat_vectors(x,y,z,w,u,u,u,u,u,u,u,u,a,b,c,d)
23895static SDValue combineConcatVectorOfConcatVectors(SDNode *N,
23896 SelectionDAG &DAG) {
23897 EVT VT = N->getValueType(ResNo: 0);
23898
23899 // Ensure we're concatenating UNDEF and CONCAT_VECTORS nodes of similar types.
23900 EVT SubVT;
23901 SDValue FirstConcat;
23902 for (const SDValue &Op : N->ops()) {
23903 if (Op.isUndef())
23904 continue;
23905 if (Op.getOpcode() != ISD::CONCAT_VECTORS)
23906 return SDValue();
23907 if (!FirstConcat) {
23908 SubVT = Op.getOperand(i: 0).getValueType();
23909 if (!DAG.getTargetLoweringInfo().isTypeLegal(VT: SubVT))
23910 return SDValue();
23911 FirstConcat = Op;
23912 continue;
23913 }
23914 if (SubVT != Op.getOperand(i: 0).getValueType())
23915 return SDValue();
23916 }
23917 assert(FirstConcat && "Concat of all-undefs found");
23918
23919 SmallVector<SDValue> ConcatOps;
23920 for (const SDValue &Op : N->ops()) {
23921 if (Op.isUndef()) {
23922 ConcatOps.append(NumInputs: FirstConcat->getNumOperands(), Elt: DAG.getUNDEF(VT: SubVT));
23923 continue;
23924 }
23925 ConcatOps.append(in_start: Op->op_begin(), in_end: Op->op_end());
23926 }
23927 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT, Ops: ConcatOps);
23928}
23929
23930// Check to see if this is a CONCAT_VECTORS of a bunch of EXTRACT_SUBVECTOR
23931// operations. If so, and if the EXTRACT_SUBVECTOR vector inputs come from at
23932// most two distinct vectors the same size as the result, attempt to turn this
23933// into a legal shuffle.
23934static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) {
23935 EVT VT = N->getValueType(ResNo: 0);
23936 EVT OpVT = N->getOperand(Num: 0).getValueType();
23937
23938 // We currently can't generate an appropriate shuffle for a scalable vector.
23939 if (VT.isScalableVector())
23940 return SDValue();
23941
23942 int NumElts = VT.getVectorNumElements();
23943 int NumOpElts = OpVT.getVectorNumElements();
23944
23945 SDValue SV0 = DAG.getUNDEF(VT), SV1 = DAG.getUNDEF(VT);
23946 SmallVector<int, 8> Mask;
23947
23948 for (SDValue Op : N->ops()) {
23949 Op = peekThroughBitcasts(V: Op);
23950
23951 // UNDEF nodes convert to UNDEF shuffle mask values.
23952 if (Op.isUndef()) {
23953 Mask.append(NumInputs: (unsigned)NumOpElts, Elt: -1);
23954 continue;
23955 }
23956
23957 if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
23958 return SDValue();
23959
23960 // What vector are we extracting the subvector from and at what index?
23961 SDValue ExtVec = Op.getOperand(i: 0);
23962 int ExtIdx = Op.getConstantOperandVal(i: 1);
23963
23964 // We want the EVT of the original extraction to correctly scale the
23965 // extraction index.
23966 EVT ExtVT = ExtVec.getValueType();
23967 ExtVec = peekThroughBitcasts(V: ExtVec);
23968
23969 // UNDEF nodes convert to UNDEF shuffle mask values.
23970 if (ExtVec.isUndef()) {
23971 Mask.append(NumInputs: (unsigned)NumOpElts, Elt: -1);
23972 continue;
23973 }
23974
23975 // Ensure that we are extracting a subvector from a vector the same
23976 // size as the result.
23977 if (ExtVT.getSizeInBits() != VT.getSizeInBits())
23978 return SDValue();
23979
23980 // Scale the subvector index to account for any bitcast.
23981 int NumExtElts = ExtVT.getVectorNumElements();
23982 if (0 == (NumExtElts % NumElts))
23983 ExtIdx /= (NumExtElts / NumElts);
23984 else if (0 == (NumElts % NumExtElts))
23985 ExtIdx *= (NumElts / NumExtElts);
23986 else
23987 return SDValue();
23988
23989 // At most we can reference 2 inputs in the final shuffle.
23990 if (SV0.isUndef() || SV0 == ExtVec) {
23991 SV0 = ExtVec;
23992 for (int i = 0; i != NumOpElts; ++i)
23993 Mask.push_back(Elt: i + ExtIdx);
23994 } else if (SV1.isUndef() || SV1 == ExtVec) {
23995 SV1 = ExtVec;
23996 for (int i = 0; i != NumOpElts; ++i)
23997 Mask.push_back(Elt: i + ExtIdx + NumElts);
23998 } else {
23999 return SDValue();
24000 }
24001 }
24002
24003 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24004 return TLI.buildLegalVectorShuffle(VT, DL: SDLoc(N), N0: DAG.getBitcast(VT, V: SV0),
24005 N1: DAG.getBitcast(VT, V: SV1), Mask, DAG);
24006}
24007
24008static SDValue combineConcatVectorOfCasts(SDNode *N, SelectionDAG &DAG) {
24009 unsigned CastOpcode = N->getOperand(Num: 0).getOpcode();
24010 switch (CastOpcode) {
24011 case ISD::SINT_TO_FP:
24012 case ISD::UINT_TO_FP:
24013 case ISD::FP_TO_SINT:
24014 case ISD::FP_TO_UINT:
24015 // TODO: Allow more opcodes?
24016 // case ISD::BITCAST:
24017 // case ISD::TRUNCATE:
24018 // case ISD::ZERO_EXTEND:
24019 // case ISD::SIGN_EXTEND:
24020 // case ISD::FP_EXTEND:
24021 break;
24022 default:
24023 return SDValue();
24024 }
24025
24026 EVT SrcVT = N->getOperand(Num: 0).getOperand(i: 0).getValueType();
24027 if (!SrcVT.isVector())
24028 return SDValue();
24029
24030 // All operands of the concat must be the same kind of cast from the same
24031 // source type.
24032 SmallVector<SDValue, 4> SrcOps;
24033 for (SDValue Op : N->ops()) {
24034 if (Op.getOpcode() != CastOpcode || !Op.hasOneUse() ||
24035 Op.getOperand(i: 0).getValueType() != SrcVT)
24036 return SDValue();
24037 SrcOps.push_back(Elt: Op.getOperand(i: 0));
24038 }
24039
24040 // The wider cast must be supported by the target. This is unusual because
24041 // the operation support type parameter depends on the opcode. In addition,
24042 // check the other type in the cast to make sure this is really legal.
24043 EVT VT = N->getValueType(ResNo: 0);
24044 EVT SrcEltVT = SrcVT.getVectorElementType();
24045 ElementCount NumElts = SrcVT.getVectorElementCount() * N->getNumOperands();
24046 EVT ConcatSrcVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: SrcEltVT, EC: NumElts);
24047 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24048 switch (CastOpcode) {
24049 case ISD::SINT_TO_FP:
24050 case ISD::UINT_TO_FP:
24051 if (!TLI.isOperationLegalOrCustom(Op: CastOpcode, VT: ConcatSrcVT) ||
24052 !TLI.isTypeLegal(VT))
24053 return SDValue();
24054 break;
24055 case ISD::FP_TO_SINT:
24056 case ISD::FP_TO_UINT:
24057 if (!TLI.isOperationLegalOrCustom(Op: CastOpcode, VT) ||
24058 !TLI.isTypeLegal(VT: ConcatSrcVT))
24059 return SDValue();
24060 break;
24061 default:
24062 llvm_unreachable("Unexpected cast opcode");
24063 }
24064
24065 // concat (cast X), (cast Y)... -> cast (concat X, Y...)
24066 SDLoc DL(N);
24067 SDValue NewConcat = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT: ConcatSrcVT, Ops: SrcOps);
24068 return DAG.getNode(Opcode: CastOpcode, DL, VT, Operand: NewConcat);
24069}
24070
24071// See if this is a simple CONCAT_VECTORS with no UNDEF operands, and if one of
24072// the operands is a SHUFFLE_VECTOR, and all other operands are also operands
24073// to that SHUFFLE_VECTOR, create wider SHUFFLE_VECTOR.
24074static SDValue combineConcatVectorOfShuffleAndItsOperands(
24075 SDNode *N, SelectionDAG &DAG, const TargetLowering &TLI, bool LegalTypes,
24076 bool LegalOperations) {
24077 EVT VT = N->getValueType(ResNo: 0);
24078 EVT OpVT = N->getOperand(Num: 0).getValueType();
24079 if (VT.isScalableVector())
24080 return SDValue();
24081
24082 // For now, only allow simple 2-operand concatenations.
24083 if (N->getNumOperands() != 2)
24084 return SDValue();
24085
24086 // Don't create illegal types/shuffles when not allowed to.
24087 if ((LegalTypes && !TLI.isTypeLegal(VT)) ||
24088 (LegalOperations &&
24089 !TLI.isOperationLegalOrCustom(Op: ISD::VECTOR_SHUFFLE, VT)))
24090 return SDValue();
24091
24092 // Analyze all of the operands of the CONCAT_VECTORS. Out of all of them,
24093 // we want to find one that is: (1) a SHUFFLE_VECTOR (2) only used by us,
24094 // and (3) all operands of CONCAT_VECTORS must be either that SHUFFLE_VECTOR,
24095 // or one of the operands of that SHUFFLE_VECTOR (but not UNDEF!).
24096 // (4) and for now, the SHUFFLE_VECTOR must be unary.
24097 ShuffleVectorSDNode *SVN = nullptr;
24098 for (SDValue Op : N->ops()) {
24099 if (auto *CurSVN = dyn_cast<ShuffleVectorSDNode>(Val&: Op);
24100 CurSVN && CurSVN->getOperand(Num: 1).isUndef() && N->isOnlyUserOf(N: CurSVN) &&
24101 all_of(Range: N->ops(), P: [CurSVN](SDValue Op) {
24102 // FIXME: can we allow UNDEF operands?
24103 return !Op.isUndef() &&
24104 (Op.getNode() == CurSVN || is_contained(Range: CurSVN->ops(), Element: Op));
24105 })) {
24106 SVN = CurSVN;
24107 break;
24108 }
24109 }
24110 if (!SVN)
24111 return SDValue();
24112
24113 // We are going to pad the shuffle operands, so any indice, that was picking
24114 // from the second operand, must be adjusted.
24115 SmallVector<int, 16> AdjustedMask;
24116 AdjustedMask.reserve(N: SVN->getMask().size());
24117 assert(SVN->getOperand(1).isUndef() && "Expected unary shuffle!");
24118 append_range(C&: AdjustedMask, R: SVN->getMask());
24119
24120 // Identity masks for the operands of the (padded) shuffle.
24121 SmallVector<int, 32> IdentityMask(2 * OpVT.getVectorNumElements());
24122 MutableArrayRef<int> FirstShufOpIdentityMask =
24123 MutableArrayRef<int>(IdentityMask)
24124 .take_front(N: OpVT.getVectorNumElements());
24125 MutableArrayRef<int> SecondShufOpIdentityMask =
24126 MutableArrayRef<int>(IdentityMask).take_back(N: OpVT.getVectorNumElements());
24127 std::iota(first: FirstShufOpIdentityMask.begin(), last: FirstShufOpIdentityMask.end(), value: 0);
24128 std::iota(first: SecondShufOpIdentityMask.begin(), last: SecondShufOpIdentityMask.end(),
24129 value: VT.getVectorNumElements());
24130
24131 // New combined shuffle mask.
24132 SmallVector<int, 32> Mask;
24133 Mask.reserve(N: VT.getVectorNumElements());
24134 for (SDValue Op : N->ops()) {
24135 assert(!Op.isUndef() && "Not expecting to concatenate UNDEF.");
24136 if (Op.getNode() == SVN) {
24137 append_range(C&: Mask, R&: AdjustedMask);
24138 continue;
24139 }
24140 if (Op == SVN->getOperand(Num: 0)) {
24141 append_range(C&: Mask, R&: FirstShufOpIdentityMask);
24142 continue;
24143 }
24144 if (Op == SVN->getOperand(Num: 1)) {
24145 append_range(C&: Mask, R&: SecondShufOpIdentityMask);
24146 continue;
24147 }
24148 llvm_unreachable("Unexpected operand!");
24149 }
24150
24151 // Don't create illegal shuffle masks.
24152 if (!TLI.isShuffleMaskLegal(Mask, VT))
24153 return SDValue();
24154
24155 // Pad the shuffle operands with UNDEF.
24156 SDLoc dl(N);
24157 std::array<SDValue, 2> ShufOps;
24158 for (auto I : zip(t: SVN->ops(), u&: ShufOps)) {
24159 SDValue ShufOp = std::get<0>(t&: I);
24160 SDValue &NewShufOp = std::get<1>(t&: I);
24161 if (ShufOp.isUndef())
24162 NewShufOp = DAG.getUNDEF(VT);
24163 else {
24164 SmallVector<SDValue, 2> ShufOpParts(N->getNumOperands(),
24165 DAG.getUNDEF(VT: OpVT));
24166 ShufOpParts[0] = ShufOp;
24167 NewShufOp = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: dl, VT, Ops: ShufOpParts);
24168 }
24169 }
24170 // Finally, create the new wide shuffle.
24171 return DAG.getVectorShuffle(VT, dl, N1: ShufOps[0], N2: ShufOps[1], Mask);
24172}
24173
24174SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
24175 // If we only have one input vector, we don't need to do any concatenation.
24176 if (N->getNumOperands() == 1)
24177 return N->getOperand(Num: 0);
24178
24179 // Check if all of the operands are undefs.
24180 EVT VT = N->getValueType(ResNo: 0);
24181 if (ISD::allOperandsUndef(N))
24182 return DAG.getUNDEF(VT);
24183
24184 // Optimize concat_vectors where all but the first of the vectors are undef.
24185 if (all_of(Range: drop_begin(RangeOrContainer: N->ops()),
24186 P: [](const SDValue &Op) { return Op.isUndef(); })) {
24187 SDValue In = N->getOperand(Num: 0);
24188 assert(In.getValueType().isVector() && "Must concat vectors");
24189
24190 // If the input is a concat_vectors, just make a larger concat by padding
24191 // with smaller undefs.
24192 //
24193 // Legalizing in AArch64TargetLowering::LowerCONCAT_VECTORS() and combining
24194 // here could cause an infinite loop. That legalizing happens when LegalDAG
24195 // is true and input of AArch64TargetLowering::LowerCONCAT_VECTORS() is
24196 // scalable.
24197 if (In.getOpcode() == ISD::CONCAT_VECTORS && In.hasOneUse() &&
24198 !(LegalDAG && In.getValueType().isScalableVector())) {
24199 unsigned NumOps = N->getNumOperands() * In.getNumOperands();
24200 SmallVector<SDValue, 4> Ops(In->op_begin(), In->op_end());
24201 Ops.resize(N: NumOps, NV: DAG.getUNDEF(VT: Ops[0].getValueType()));
24202 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT, Ops);
24203 }
24204
24205 SDValue Scalar = peekThroughOneUseBitcasts(V: In);
24206
24207 // concat_vectors(scalar_to_vector(scalar), undef) ->
24208 // scalar_to_vector(scalar)
24209 if (!LegalOperations && Scalar.getOpcode() == ISD::SCALAR_TO_VECTOR &&
24210 Scalar.hasOneUse()) {
24211 EVT SVT = Scalar.getValueType().getVectorElementType();
24212 if (SVT == Scalar.getOperand(i: 0).getValueType())
24213 Scalar = Scalar.getOperand(i: 0);
24214 }
24215
24216 // concat_vectors(scalar, undef) -> scalar_to_vector(scalar)
24217 if (!Scalar.getValueType().isVector() && In.hasOneUse()) {
24218 // If the bitcast type isn't legal, it might be a trunc of a legal type;
24219 // look through the trunc so we can still do the transform:
24220 // concat_vectors(trunc(scalar), undef) -> scalar_to_vector(scalar)
24221 if (Scalar->getOpcode() == ISD::TRUNCATE &&
24222 !TLI.isTypeLegal(VT: Scalar.getValueType()) &&
24223 TLI.isTypeLegal(VT: Scalar->getOperand(Num: 0).getValueType()))
24224 Scalar = Scalar->getOperand(Num: 0);
24225
24226 EVT SclTy = Scalar.getValueType();
24227
24228 if (!SclTy.isFloatingPoint() && !SclTy.isInteger())
24229 return SDValue();
24230
24231 // Bail out if the vector size is not a multiple of the scalar size.
24232 if (VT.getSizeInBits() % SclTy.getSizeInBits())
24233 return SDValue();
24234
24235 unsigned VNTNumElms = VT.getSizeInBits() / SclTy.getSizeInBits();
24236 if (VNTNumElms < 2)
24237 return SDValue();
24238
24239 EVT NVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: SclTy, NumElements: VNTNumElms);
24240 if (!TLI.isTypeLegal(VT: NVT) || !TLI.isTypeLegal(VT: Scalar.getValueType()))
24241 return SDValue();
24242
24243 SDValue Res = DAG.getNode(Opcode: ISD::SCALAR_TO_VECTOR, DL: SDLoc(N), VT: NVT, Operand: Scalar);
24244 return DAG.getBitcast(VT, V: Res);
24245 }
24246 }
24247
24248 // Fold any combination of BUILD_VECTOR or UNDEF nodes into one BUILD_VECTOR.
24249 // We have already tested above for an UNDEF only concatenation.
24250 // fold (concat_vectors (BUILD_VECTOR A, B, ...), (BUILD_VECTOR C, D, ...))
24251 // -> (BUILD_VECTOR A, B, ..., C, D, ...)
24252 auto IsBuildVectorOrUndef = [](const SDValue &Op) {
24253 return ISD::UNDEF == Op.getOpcode() || ISD::BUILD_VECTOR == Op.getOpcode();
24254 };
24255 if (llvm::all_of(Range: N->ops(), P: IsBuildVectorOrUndef)) {
24256 SmallVector<SDValue, 8> Opnds;
24257 EVT SVT = VT.getScalarType();
24258
24259 EVT MinVT = SVT;
24260 if (!SVT.isFloatingPoint()) {
24261 // If BUILD_VECTOR are from built from integer, they may have different
24262 // operand types. Get the smallest type and truncate all operands to it.
24263 bool FoundMinVT = false;
24264 for (const SDValue &Op : N->ops())
24265 if (ISD::BUILD_VECTOR == Op.getOpcode()) {
24266 EVT OpSVT = Op.getOperand(i: 0).getValueType();
24267 MinVT = (!FoundMinVT || OpSVT.bitsLE(VT: MinVT)) ? OpSVT : MinVT;
24268 FoundMinVT = true;
24269 }
24270 assert(FoundMinVT && "Concat vector type mismatch");
24271 }
24272
24273 for (const SDValue &Op : N->ops()) {
24274 EVT OpVT = Op.getValueType();
24275 unsigned NumElts = OpVT.getVectorNumElements();
24276
24277 if (ISD::UNDEF == Op.getOpcode())
24278 Opnds.append(NumInputs: NumElts, Elt: DAG.getUNDEF(VT: MinVT));
24279
24280 if (ISD::BUILD_VECTOR == Op.getOpcode()) {
24281 if (SVT.isFloatingPoint()) {
24282 assert(SVT == OpVT.getScalarType() && "Concat vector type mismatch");
24283 Opnds.append(in_start: Op->op_begin(), in_end: Op->op_begin() + NumElts);
24284 } else {
24285 for (unsigned i = 0; i != NumElts; ++i)
24286 Opnds.push_back(
24287 Elt: DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(N), VT: MinVT, Operand: Op.getOperand(i)));
24288 }
24289 }
24290 }
24291
24292 assert(VT.getVectorNumElements() == Opnds.size() &&
24293 "Concat vector type mismatch");
24294 return DAG.getBuildVector(VT, DL: SDLoc(N), Ops: Opnds);
24295 }
24296
24297 // Fold CONCAT_VECTORS of only bitcast scalars (or undef) to BUILD_VECTOR.
24298 // FIXME: Add support for concat_vectors(bitcast(vec0),bitcast(vec1),...).
24299 if (SDValue V = combineConcatVectorOfScalars(N, DAG))
24300 return V;
24301
24302 if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT)) {
24303 // Fold CONCAT_VECTORS of CONCAT_VECTORS (or undef) to VECTOR_SHUFFLE.
24304 if (SDValue V = combineConcatVectorOfConcatVectors(N, DAG))
24305 return V;
24306
24307 // Fold CONCAT_VECTORS of EXTRACT_SUBVECTOR (or undef) to VECTOR_SHUFFLE.
24308 if (SDValue V = combineConcatVectorOfExtracts(N, DAG))
24309 return V;
24310 }
24311
24312 if (SDValue V = combineConcatVectorOfCasts(N, DAG))
24313 return V;
24314
24315 if (SDValue V = combineConcatVectorOfShuffleAndItsOperands(
24316 N, DAG, TLI, LegalTypes, LegalOperations))
24317 return V;
24318
24319 // Type legalization of vectors and DAG canonicalization of SHUFFLE_VECTOR
24320 // nodes often generate nop CONCAT_VECTOR nodes. Scan the CONCAT_VECTOR
24321 // operands and look for a CONCAT operations that place the incoming vectors
24322 // at the exact same location.
24323 //
24324 // For scalable vectors, EXTRACT_SUBVECTOR indexes are implicitly scaled.
24325 SDValue SingleSource = SDValue();
24326 unsigned PartNumElem =
24327 N->getOperand(Num: 0).getValueType().getVectorMinNumElements();
24328
24329 for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
24330 SDValue Op = N->getOperand(Num: i);
24331
24332 if (Op.isUndef())
24333 continue;
24334
24335 // Check if this is the identity extract:
24336 if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
24337 return SDValue();
24338
24339 // Find the single incoming vector for the extract_subvector.
24340 if (SingleSource.getNode()) {
24341 if (Op.getOperand(i: 0) != SingleSource)
24342 return SDValue();
24343 } else {
24344 SingleSource = Op.getOperand(i: 0);
24345
24346 // Check the source type is the same as the type of the result.
24347 // If not, this concat may extend the vector, so we can not
24348 // optimize it away.
24349 if (SingleSource.getValueType() != N->getValueType(ResNo: 0))
24350 return SDValue();
24351 }
24352
24353 // Check that we are reading from the identity index.
24354 unsigned IdentityIndex = i * PartNumElem;
24355 if (Op.getConstantOperandAPInt(i: 1) != IdentityIndex)
24356 return SDValue();
24357 }
24358
24359 if (SingleSource.getNode())
24360 return SingleSource;
24361
24362 return SDValue();
24363}
24364
24365// Helper that peeks through INSERT_SUBVECTOR/CONCAT_VECTORS to find
24366// if the subvector can be sourced for free.
24367static SDValue getSubVectorSrc(SDValue V, SDValue Index, EVT SubVT) {
24368 if (V.getOpcode() == ISD::INSERT_SUBVECTOR &&
24369 V.getOperand(i: 1).getValueType() == SubVT && V.getOperand(i: 2) == Index) {
24370 return V.getOperand(i: 1);
24371 }
24372 auto *IndexC = dyn_cast<ConstantSDNode>(Val&: Index);
24373 if (IndexC && V.getOpcode() == ISD::CONCAT_VECTORS &&
24374 V.getOperand(i: 0).getValueType() == SubVT &&
24375 (IndexC->getZExtValue() % SubVT.getVectorMinNumElements()) == 0) {
24376 uint64_t SubIdx = IndexC->getZExtValue() / SubVT.getVectorMinNumElements();
24377 return V.getOperand(i: SubIdx);
24378 }
24379 return SDValue();
24380}
24381
24382static SDValue narrowInsertExtractVectorBinOp(SDNode *Extract,
24383 SelectionDAG &DAG,
24384 bool LegalOperations) {
24385 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24386 SDValue BinOp = Extract->getOperand(Num: 0);
24387 unsigned BinOpcode = BinOp.getOpcode();
24388 if (!TLI.isBinOp(Opcode: BinOpcode) || BinOp->getNumValues() != 1)
24389 return SDValue();
24390
24391 EVT VecVT = BinOp.getValueType();
24392 SDValue Bop0 = BinOp.getOperand(i: 0), Bop1 = BinOp.getOperand(i: 1);
24393 if (VecVT != Bop0.getValueType() || VecVT != Bop1.getValueType())
24394 return SDValue();
24395
24396 SDValue Index = Extract->getOperand(Num: 1);
24397 EVT SubVT = Extract->getValueType(ResNo: 0);
24398 if (!TLI.isOperationLegalOrCustom(Op: BinOpcode, VT: SubVT, LegalOnly: LegalOperations))
24399 return SDValue();
24400
24401 SDValue Sub0 = getSubVectorSrc(V: Bop0, Index, SubVT);
24402 SDValue Sub1 = getSubVectorSrc(V: Bop1, Index, SubVT);
24403
24404 // TODO: We could handle the case where only 1 operand is being inserted by
24405 // creating an extract of the other operand, but that requires checking
24406 // number of uses and/or costs.
24407 if (!Sub0 || !Sub1)
24408 return SDValue();
24409
24410 // We are inserting both operands of the wide binop only to extract back
24411 // to the narrow vector size. Eliminate all of the insert/extract:
24412 // ext (binop (ins ?, X, Index), (ins ?, Y, Index)), Index --> binop X, Y
24413 return DAG.getNode(Opcode: BinOpcode, DL: SDLoc(Extract), VT: SubVT, N1: Sub0, N2: Sub1,
24414 Flags: BinOp->getFlags());
24415}
24416
24417/// If we are extracting a subvector produced by a wide binary operator try
24418/// to use a narrow binary operator and/or avoid concatenation and extraction.
24419static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG,
24420 bool LegalOperations) {
24421 // TODO: Refactor with the caller (visitEXTRACT_SUBVECTOR), so we can share
24422 // some of these bailouts with other transforms.
24423
24424 if (SDValue V = narrowInsertExtractVectorBinOp(Extract, DAG, LegalOperations))
24425 return V;
24426
24427 // The extract index must be a constant, so we can map it to a concat operand.
24428 auto *ExtractIndexC = dyn_cast<ConstantSDNode>(Val: Extract->getOperand(Num: 1));
24429 if (!ExtractIndexC)
24430 return SDValue();
24431
24432 // We are looking for an optionally bitcasted wide vector binary operator
24433 // feeding an extract subvector.
24434 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24435 SDValue BinOp = peekThroughBitcasts(V: Extract->getOperand(Num: 0));
24436 unsigned BOpcode = BinOp.getOpcode();
24437 if (!TLI.isBinOp(Opcode: BOpcode) || BinOp->getNumValues() != 1)
24438 return SDValue();
24439
24440 // Exclude the fake form of fneg (fsub -0.0, x) because that is likely to be
24441 // reduced to the unary fneg when it is visited, and we probably want to deal
24442 // with fneg in a target-specific way.
24443 if (BOpcode == ISD::FSUB) {
24444 auto *C = isConstOrConstSplatFP(N: BinOp.getOperand(i: 0), /*AllowUndefs*/ true);
24445 if (C && C->getValueAPF().isNegZero())
24446 return SDValue();
24447 }
24448
24449 // The binop must be a vector type, so we can extract some fraction of it.
24450 EVT WideBVT = BinOp.getValueType();
24451 // The optimisations below currently assume we are dealing with fixed length
24452 // vectors. It is possible to add support for scalable vectors, but at the
24453 // moment we've done no analysis to prove whether they are profitable or not.
24454 if (!WideBVT.isFixedLengthVector())
24455 return SDValue();
24456
24457 EVT VT = Extract->getValueType(ResNo: 0);
24458 unsigned ExtractIndex = ExtractIndexC->getZExtValue();
24459 assert(ExtractIndex % VT.getVectorNumElements() == 0 &&
24460 "Extract index is not a multiple of the vector length.");
24461
24462 // Bail out if this is not a proper multiple width extraction.
24463 unsigned WideWidth = WideBVT.getSizeInBits();
24464 unsigned NarrowWidth = VT.getSizeInBits();
24465 if (WideWidth % NarrowWidth != 0)
24466 return SDValue();
24467
24468 // Bail out if we are extracting a fraction of a single operation. This can
24469 // occur because we potentially looked through a bitcast of the binop.
24470 unsigned NarrowingRatio = WideWidth / NarrowWidth;
24471 unsigned WideNumElts = WideBVT.getVectorNumElements();
24472 if (WideNumElts % NarrowingRatio != 0)
24473 return SDValue();
24474
24475 // Bail out if the target does not support a narrower version of the binop.
24476 EVT NarrowBVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: WideBVT.getScalarType(),
24477 NumElements: WideNumElts / NarrowingRatio);
24478 if (!TLI.isOperationLegalOrCustomOrPromote(Op: BOpcode, VT: NarrowBVT,
24479 LegalOnly: LegalOperations))
24480 return SDValue();
24481
24482 // If extraction is cheap, we don't need to look at the binop operands
24483 // for concat ops. The narrow binop alone makes this transform profitable.
24484 // We can't just reuse the original extract index operand because we may have
24485 // bitcasted.
24486 unsigned ConcatOpNum = ExtractIndex / VT.getVectorNumElements();
24487 unsigned ExtBOIdx = ConcatOpNum * NarrowBVT.getVectorNumElements();
24488 if (TLI.isExtractSubvectorCheap(ResVT: NarrowBVT, SrcVT: WideBVT, Index: ExtBOIdx) &&
24489 BinOp.hasOneUse() && Extract->getOperand(Num: 0)->hasOneUse()) {
24490 // extract (binop B0, B1), N --> binop (extract B0, N), (extract B1, N)
24491 SDLoc DL(Extract);
24492 SDValue NewExtIndex = DAG.getVectorIdxConstant(Val: ExtBOIdx, DL);
24493 SDValue X = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NarrowBVT,
24494 N1: BinOp.getOperand(i: 0), N2: NewExtIndex);
24495 SDValue Y = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NarrowBVT,
24496 N1: BinOp.getOperand(i: 1), N2: NewExtIndex);
24497 SDValue NarrowBinOp =
24498 DAG.getNode(Opcode: BOpcode, DL, VT: NarrowBVT, N1: X, N2: Y, Flags: BinOp->getFlags());
24499 return DAG.getBitcast(VT, V: NarrowBinOp);
24500 }
24501
24502 // Only handle the case where we are doubling and then halving. A larger ratio
24503 // may require more than two narrow binops to replace the wide binop.
24504 if (NarrowingRatio != 2)
24505 return SDValue();
24506
24507 // TODO: The motivating case for this transform is an x86 AVX1 target. That
24508 // target has temptingly almost legal versions of bitwise logic ops in 256-bit
24509 // flavors, but no other 256-bit integer support. This could be extended to
24510 // handle any binop, but that may require fixing/adding other folds to avoid
24511 // codegen regressions.
24512 if (BOpcode != ISD::AND && BOpcode != ISD::OR && BOpcode != ISD::XOR)
24513 return SDValue();
24514
24515 // We need at least one concatenation operation of a binop operand to make
24516 // this transform worthwhile. The concat must double the input vector sizes.
24517 auto GetSubVector = [ConcatOpNum](SDValue V) -> SDValue {
24518 if (V.getOpcode() == ISD::CONCAT_VECTORS && V.getNumOperands() == 2)
24519 return V.getOperand(i: ConcatOpNum);
24520 return SDValue();
24521 };
24522 SDValue SubVecL = GetSubVector(peekThroughBitcasts(V: BinOp.getOperand(i: 0)));
24523 SDValue SubVecR = GetSubVector(peekThroughBitcasts(V: BinOp.getOperand(i: 1)));
24524
24525 if (SubVecL || SubVecR) {
24526 // If a binop operand was not the result of a concat, we must extract a
24527 // half-sized operand for our new narrow binop:
24528 // extract (binop (concat X1, X2), (concat Y1, Y2)), N --> binop XN, YN
24529 // extract (binop (concat X1, X2), Y), N --> binop XN, (extract Y, IndexC)
24530 // extract (binop X, (concat Y1, Y2)), N --> binop (extract X, IndexC), YN
24531 SDLoc DL(Extract);
24532 SDValue IndexC = DAG.getVectorIdxConstant(Val: ExtBOIdx, DL);
24533 SDValue X = SubVecL ? DAG.getBitcast(VT: NarrowBVT, V: SubVecL)
24534 : DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NarrowBVT,
24535 N1: BinOp.getOperand(i: 0), N2: IndexC);
24536
24537 SDValue Y = SubVecR ? DAG.getBitcast(VT: NarrowBVT, V: SubVecR)
24538 : DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NarrowBVT,
24539 N1: BinOp.getOperand(i: 1), N2: IndexC);
24540
24541 SDValue NarrowBinOp = DAG.getNode(Opcode: BOpcode, DL, VT: NarrowBVT, N1: X, N2: Y);
24542 return DAG.getBitcast(VT, V: NarrowBinOp);
24543 }
24544
24545 return SDValue();
24546}
24547
24548/// If we are extracting a subvector from a wide vector load, convert to a
24549/// narrow load to eliminate the extraction:
24550/// (extract_subvector (load wide vector)) --> (load narrow vector)
24551static SDValue narrowExtractedVectorLoad(SDNode *Extract, SelectionDAG &DAG) {
24552 // TODO: Add support for big-endian. The offset calculation must be adjusted.
24553 if (DAG.getDataLayout().isBigEndian())
24554 return SDValue();
24555
24556 auto *Ld = dyn_cast<LoadSDNode>(Val: Extract->getOperand(Num: 0));
24557 if (!Ld || Ld->getExtensionType() || !Ld->isSimple())
24558 return SDValue();
24559
24560 // Allow targets to opt-out.
24561 EVT VT = Extract->getValueType(ResNo: 0);
24562
24563 // We can only create byte sized loads.
24564 if (!VT.isByteSized())
24565 return SDValue();
24566
24567 unsigned Index = Extract->getConstantOperandVal(Num: 1);
24568 unsigned NumElts = VT.getVectorMinNumElements();
24569 // A fixed length vector being extracted from a scalable vector
24570 // may not be any *smaller* than the scalable one.
24571 if (Index == 0 && NumElts >= Ld->getValueType(ResNo: 0).getVectorMinNumElements())
24572 return SDValue();
24573
24574 // The definition of EXTRACT_SUBVECTOR states that the index must be a
24575 // multiple of the minimum number of elements in the result type.
24576 assert(Index % NumElts == 0 && "The extract subvector index is not a "
24577 "multiple of the result's element count");
24578
24579 // It's fine to use TypeSize here as we know the offset will not be negative.
24580 TypeSize Offset = VT.getStoreSize() * (Index / NumElts);
24581
24582 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24583 if (!TLI.shouldReduceLoadWidth(Load: Ld, ExtTy: Ld->getExtensionType(), NewVT: VT))
24584 return SDValue();
24585
24586 // The narrow load will be offset from the base address of the old load if
24587 // we are extracting from something besides index 0 (little-endian).
24588 SDLoc DL(Extract);
24589
24590 // TODO: Use "BaseIndexOffset" to make this more effective.
24591 SDValue NewAddr = DAG.getMemBasePlusOffset(Base: Ld->getBasePtr(), Offset, DL);
24592
24593 LocationSize StoreSize = LocationSize::precise(Value: VT.getStoreSize());
24594 MachineFunction &MF = DAG.getMachineFunction();
24595 MachineMemOperand *MMO;
24596 if (Offset.isScalable()) {
24597 MachinePointerInfo MPI =
24598 MachinePointerInfo(Ld->getPointerInfo().getAddrSpace());
24599 MMO = MF.getMachineMemOperand(MMO: Ld->getMemOperand(), PtrInfo: MPI, Size: StoreSize);
24600 } else
24601 MMO = MF.getMachineMemOperand(MMO: Ld->getMemOperand(), Offset: Offset.getFixedValue(),
24602 Size: StoreSize);
24603
24604 SDValue NewLd = DAG.getLoad(VT, dl: DL, Chain: Ld->getChain(), Ptr: NewAddr, MMO);
24605 DAG.makeEquivalentMemoryOrdering(OldLoad: Ld, NewMemOp: NewLd);
24606 return NewLd;
24607}
24608
24609/// Given EXTRACT_SUBVECTOR(VECTOR_SHUFFLE(Op0, Op1, Mask)),
24610/// try to produce VECTOR_SHUFFLE(EXTRACT_SUBVECTOR(Op?, ?),
24611/// EXTRACT_SUBVECTOR(Op?, ?),
24612/// Mask'))
24613/// iff it is legal and profitable to do so. Notably, the trimmed mask
24614/// (containing only the elements that are extracted)
24615/// must reference at most two subvectors.
24616static SDValue foldExtractSubvectorFromShuffleVector(SDNode *N,
24617 SelectionDAG &DAG,
24618 const TargetLowering &TLI,
24619 bool LegalOperations) {
24620 assert(N->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
24621 "Must only be called on EXTRACT_SUBVECTOR's");
24622
24623 SDValue N0 = N->getOperand(Num: 0);
24624
24625 // Only deal with non-scalable vectors.
24626 EVT NarrowVT = N->getValueType(ResNo: 0);
24627 EVT WideVT = N0.getValueType();
24628 if (!NarrowVT.isFixedLengthVector() || !WideVT.isFixedLengthVector())
24629 return SDValue();
24630
24631 // The operand must be a shufflevector.
24632 auto *WideShuffleVector = dyn_cast<ShuffleVectorSDNode>(Val&: N0);
24633 if (!WideShuffleVector)
24634 return SDValue();
24635
24636 // The old shuffleneeds to go away.
24637 if (!WideShuffleVector->hasOneUse())
24638 return SDValue();
24639
24640 // And the narrow shufflevector that we'll form must be legal.
24641 if (LegalOperations &&
24642 !TLI.isOperationLegalOrCustom(Op: ISD::VECTOR_SHUFFLE, VT: NarrowVT))
24643 return SDValue();
24644
24645 uint64_t FirstExtractedEltIdx = N->getConstantOperandVal(Num: 1);
24646 int NumEltsExtracted = NarrowVT.getVectorNumElements();
24647 assert((FirstExtractedEltIdx % NumEltsExtracted) == 0 &&
24648 "Extract index is not a multiple of the output vector length.");
24649
24650 int WideNumElts = WideVT.getVectorNumElements();
24651
24652 SmallVector<int, 16> NewMask;
24653 NewMask.reserve(N: NumEltsExtracted);
24654 SmallSetVector<std::pair<SDValue /*Op*/, int /*SubvectorIndex*/>, 2>
24655 DemandedSubvectors;
24656
24657 // Try to decode the wide mask into narrow mask from at most two subvectors.
24658 for (int M : WideShuffleVector->getMask().slice(N: FirstExtractedEltIdx,
24659 M: NumEltsExtracted)) {
24660 assert((M >= -1) && (M < (2 * WideNumElts)) &&
24661 "Out-of-bounds shuffle mask?");
24662
24663 if (M < 0) {
24664 // Does not depend on operands, does not require adjustment.
24665 NewMask.emplace_back(Args&: M);
24666 continue;
24667 }
24668
24669 // From which operand of the shuffle does this shuffle mask element pick?
24670 int WideShufOpIdx = M / WideNumElts;
24671 // Which element of that operand is picked?
24672 int OpEltIdx = M % WideNumElts;
24673
24674 assert((OpEltIdx + WideShufOpIdx * WideNumElts) == M &&
24675 "Shuffle mask vector decomposition failure.");
24676
24677 // And which NumEltsExtracted-sized subvector of that operand is that?
24678 int OpSubvecIdx = OpEltIdx / NumEltsExtracted;
24679 // And which element within that subvector of that operand is that?
24680 int OpEltIdxInSubvec = OpEltIdx % NumEltsExtracted;
24681
24682 assert((OpEltIdxInSubvec + OpSubvecIdx * NumEltsExtracted) == OpEltIdx &&
24683 "Shuffle mask subvector decomposition failure.");
24684
24685 assert((OpEltIdxInSubvec + OpSubvecIdx * NumEltsExtracted +
24686 WideShufOpIdx * WideNumElts) == M &&
24687 "Shuffle mask full decomposition failure.");
24688
24689 SDValue Op = WideShuffleVector->getOperand(Num: WideShufOpIdx);
24690
24691 if (Op.isUndef()) {
24692 // Picking from an undef operand. Let's adjust mask instead.
24693 NewMask.emplace_back(Args: -1);
24694 continue;
24695 }
24696
24697 const std::pair<SDValue, int> DemandedSubvector =
24698 std::make_pair(x&: Op, y&: OpSubvecIdx);
24699
24700 if (DemandedSubvectors.insert(X: DemandedSubvector)) {
24701 if (DemandedSubvectors.size() > 2)
24702 return SDValue(); // We can't handle more than two subvectors.
24703 // How many elements into the WideVT does this subvector start?
24704 int Index = NumEltsExtracted * OpSubvecIdx;
24705 // Bail out if the extraction isn't going to be cheap.
24706 if (!TLI.isExtractSubvectorCheap(ResVT: NarrowVT, SrcVT: WideVT, Index))
24707 return SDValue();
24708 }
24709
24710 // Ok, but from which operand of the new shuffle will this element pick?
24711 int NewOpIdx =
24712 getFirstIndexOf(Range: DemandedSubvectors.getArrayRef(), Val: DemandedSubvector);
24713 assert((NewOpIdx == 0 || NewOpIdx == 1) && "Unexpected operand index.");
24714
24715 int AdjM = OpEltIdxInSubvec + NewOpIdx * NumEltsExtracted;
24716 NewMask.emplace_back(Args&: AdjM);
24717 }
24718 assert(NewMask.size() == (unsigned)NumEltsExtracted && "Produced bad mask.");
24719 assert(DemandedSubvectors.size() <= 2 &&
24720 "Should have ended up demanding at most two subvectors.");
24721
24722 // Did we discover that the shuffle does not actually depend on operands?
24723 if (DemandedSubvectors.empty())
24724 return DAG.getUNDEF(VT: NarrowVT);
24725
24726 // Profitability check: only deal with extractions from the first subvector
24727 // unless the mask becomes an identity mask.
24728 if (!ShuffleVectorInst::isIdentityMask(Mask: NewMask, NumSrcElts: NewMask.size()) ||
24729 any_of(Range&: NewMask, P: [](int M) { return M < 0; }))
24730 for (auto &DemandedSubvector : DemandedSubvectors)
24731 if (DemandedSubvector.second != 0)
24732 return SDValue();
24733
24734 // We still perform the exact same EXTRACT_SUBVECTOR, just on different
24735 // operand[s]/index[es], so there is no point in checking for it's legality.
24736
24737 // Do not turn a legal shuffle into an illegal one.
24738 if (TLI.isShuffleMaskLegal(WideShuffleVector->getMask(), WideVT) &&
24739 !TLI.isShuffleMaskLegal(NewMask, NarrowVT))
24740 return SDValue();
24741
24742 SDLoc DL(N);
24743
24744 SmallVector<SDValue, 2> NewOps;
24745 for (const std::pair<SDValue /*Op*/, int /*SubvectorIndex*/>
24746 &DemandedSubvector : DemandedSubvectors) {
24747 // How many elements into the WideVT does this subvector start?
24748 int Index = NumEltsExtracted * DemandedSubvector.second;
24749 SDValue IndexC = DAG.getVectorIdxConstant(Val: Index, DL);
24750 NewOps.emplace_back(Args: DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NarrowVT,
24751 N1: DemandedSubvector.first, N2: IndexC));
24752 }
24753 assert((NewOps.size() == 1 || NewOps.size() == 2) &&
24754 "Should end up with either one or two ops");
24755
24756 // If we ended up with only one operand, pad with an undef.
24757 if (NewOps.size() == 1)
24758 NewOps.emplace_back(Args: DAG.getUNDEF(VT: NarrowVT));
24759
24760 return DAG.getVectorShuffle(VT: NarrowVT, dl: DL, N1: NewOps[0], N2: NewOps[1], Mask: NewMask);
24761}
24762
24763SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) {
24764 EVT NVT = N->getValueType(ResNo: 0);
24765 SDValue V = N->getOperand(Num: 0);
24766 uint64_t ExtIdx = N->getConstantOperandVal(Num: 1);
24767 SDLoc DL(N);
24768
24769 // Extract from UNDEF is UNDEF.
24770 if (V.isUndef())
24771 return DAG.getUNDEF(VT: NVT);
24772
24773 if (TLI.isOperationLegalOrCustomOrPromote(Op: ISD::LOAD, VT: NVT))
24774 if (SDValue NarrowLoad = narrowExtractedVectorLoad(Extract: N, DAG))
24775 return NarrowLoad;
24776
24777 // Combine an extract of an extract into a single extract_subvector.
24778 // ext (ext X, C), 0 --> ext X, C
24779 if (ExtIdx == 0 && V.getOpcode() == ISD::EXTRACT_SUBVECTOR && V.hasOneUse()) {
24780 if (TLI.isExtractSubvectorCheap(ResVT: NVT, SrcVT: V.getOperand(i: 0).getValueType(),
24781 Index: V.getConstantOperandVal(i: 1)) &&
24782 TLI.isOperationLegalOrCustom(Op: ISD::EXTRACT_SUBVECTOR, VT: NVT)) {
24783 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NVT, N1: V.getOperand(i: 0),
24784 N2: V.getOperand(i: 1));
24785 }
24786 }
24787
24788 // ty1 extract_vector(ty2 splat(V))) -> ty1 splat(V)
24789 if (V.getOpcode() == ISD::SPLAT_VECTOR)
24790 if (DAG.isConstantValueOfAnyType(N: V.getOperand(i: 0)) || V.hasOneUse())
24791 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::SPLAT_VECTOR, VT: NVT))
24792 return DAG.getSplatVector(VT: NVT, DL, Op: V.getOperand(i: 0));
24793
24794 // extract_subvector(insert_subvector(x,y,c1),c2)
24795 // --> extract_subvector(y,c2-c1)
24796 // iff we're just extracting from the inserted subvector.
24797 if (V.getOpcode() == ISD::INSERT_SUBVECTOR) {
24798 SDValue InsSub = V.getOperand(i: 1);
24799 EVT InsSubVT = InsSub.getValueType();
24800 unsigned NumInsElts = InsSubVT.getVectorMinNumElements();
24801 unsigned InsIdx = V.getConstantOperandVal(i: 2);
24802 unsigned NumSubElts = NVT.getVectorMinNumElements();
24803 if (InsIdx <= ExtIdx && (ExtIdx + NumSubElts) <= (InsIdx + NumInsElts) &&
24804 TLI.isExtractSubvectorCheap(ResVT: NVT, SrcVT: InsSubVT, Index: ExtIdx - InsIdx) &&
24805 InsSubVT.isFixedLengthVector() && NVT.isFixedLengthVector() &&
24806 V.getValueType().isFixedLengthVector())
24807 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NVT, N1: InsSub,
24808 N2: DAG.getVectorIdxConstant(Val: ExtIdx - InsIdx, DL));
24809 }
24810
24811 // Try to move vector bitcast after extract_subv by scaling extraction index:
24812 // extract_subv (bitcast X), Index --> bitcast (extract_subv X, Index')
24813 if (V.getOpcode() == ISD::BITCAST &&
24814 V.getOperand(i: 0).getValueType().isVector() &&
24815 (!LegalOperations || TLI.isOperationLegal(Op: ISD::BITCAST, VT: NVT))) {
24816 SDValue SrcOp = V.getOperand(i: 0);
24817 EVT SrcVT = SrcOp.getValueType();
24818 unsigned SrcNumElts = SrcVT.getVectorMinNumElements();
24819 unsigned DestNumElts = V.getValueType().getVectorMinNumElements();
24820 if ((SrcNumElts % DestNumElts) == 0) {
24821 unsigned SrcDestRatio = SrcNumElts / DestNumElts;
24822 ElementCount NewExtEC = NVT.getVectorElementCount() * SrcDestRatio;
24823 EVT NewExtVT =
24824 EVT::getVectorVT(Context&: *DAG.getContext(), VT: SrcVT.getScalarType(), EC: NewExtEC);
24825 if (TLI.isOperationLegalOrCustom(Op: ISD::EXTRACT_SUBVECTOR, VT: NewExtVT)) {
24826 SDValue NewIndex = DAG.getVectorIdxConstant(Val: ExtIdx * SrcDestRatio, DL);
24827 SDValue NewExtract = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NewExtVT,
24828 N1: V.getOperand(i: 0), N2: NewIndex);
24829 return DAG.getBitcast(VT: NVT, V: NewExtract);
24830 }
24831 }
24832 if ((DestNumElts % SrcNumElts) == 0) {
24833 unsigned DestSrcRatio = DestNumElts / SrcNumElts;
24834 if (NVT.getVectorElementCount().isKnownMultipleOf(RHS: DestSrcRatio)) {
24835 ElementCount NewExtEC =
24836 NVT.getVectorElementCount().divideCoefficientBy(RHS: DestSrcRatio);
24837 EVT ScalarVT = SrcVT.getScalarType();
24838 if ((ExtIdx % DestSrcRatio) == 0) {
24839 unsigned IndexValScaled = ExtIdx / DestSrcRatio;
24840 EVT NewExtVT =
24841 EVT::getVectorVT(Context&: *DAG.getContext(), VT: ScalarVT, EC: NewExtEC);
24842 if (TLI.isOperationLegalOrCustom(Op: ISD::EXTRACT_SUBVECTOR, VT: NewExtVT)) {
24843 SDValue NewIndex = DAG.getVectorIdxConstant(Val: IndexValScaled, DL);
24844 SDValue NewExtract =
24845 DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NewExtVT,
24846 N1: V.getOperand(i: 0), N2: NewIndex);
24847 return DAG.getBitcast(VT: NVT, V: NewExtract);
24848 }
24849 if (NewExtEC.isScalar() &&
24850 TLI.isOperationLegalOrCustom(Op: ISD::EXTRACT_VECTOR_ELT, VT: ScalarVT)) {
24851 SDValue NewIndex = DAG.getVectorIdxConstant(Val: IndexValScaled, DL);
24852 SDValue NewExtract =
24853 DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: ScalarVT,
24854 N1: V.getOperand(i: 0), N2: NewIndex);
24855 return DAG.getBitcast(VT: NVT, V: NewExtract);
24856 }
24857 }
24858 }
24859 }
24860 }
24861
24862 if (V.getOpcode() == ISD::CONCAT_VECTORS) {
24863 unsigned ExtNumElts = NVT.getVectorMinNumElements();
24864 EVT ConcatSrcVT = V.getOperand(i: 0).getValueType();
24865 assert(ConcatSrcVT.getVectorElementType() == NVT.getVectorElementType() &&
24866 "Concat and extract subvector do not change element type");
24867 assert((ExtIdx % ExtNumElts) == 0 &&
24868 "Extract index is not a multiple of the input vector length.");
24869
24870 unsigned ConcatSrcNumElts = ConcatSrcVT.getVectorMinNumElements();
24871 unsigned ConcatOpIdx = ExtIdx / ConcatSrcNumElts;
24872
24873 // If the concatenated source types match this extract, it's a direct
24874 // simplification:
24875 // extract_subvec (concat V1, V2, ...), i --> Vi
24876 if (NVT.getVectorElementCount() == ConcatSrcVT.getVectorElementCount())
24877 return V.getOperand(i: ConcatOpIdx);
24878
24879 // If the concatenated source vectors are a multiple length of this extract,
24880 // then extract a fraction of one of those source vectors directly from a
24881 // concat operand. Example:
24882 // v2i8 extract_subvec (v16i8 concat (v8i8 X), (v8i8 Y), 14 -->
24883 // v2i8 extract_subvec v8i8 Y, 6
24884 if (NVT.isFixedLengthVector() && ConcatSrcVT.isFixedLengthVector() &&
24885 ConcatSrcNumElts % ExtNumElts == 0) {
24886 unsigned NewExtIdx = ExtIdx - ConcatOpIdx * ConcatSrcNumElts;
24887 assert(NewExtIdx + ExtNumElts <= ConcatSrcNumElts &&
24888 "Trying to extract from >1 concat operand?");
24889 assert(NewExtIdx % ExtNumElts == 0 &&
24890 "Extract index is not a multiple of the input vector length.");
24891 SDValue NewIndexC = DAG.getVectorIdxConstant(Val: NewExtIdx, DL);
24892 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NVT,
24893 N1: V.getOperand(i: ConcatOpIdx), N2: NewIndexC);
24894 }
24895 }
24896
24897 if (SDValue V =
24898 foldExtractSubvectorFromShuffleVector(N, DAG, TLI, LegalOperations))
24899 return V;
24900
24901 V = peekThroughBitcasts(V);
24902
24903 // If the input is a build vector. Try to make a smaller build vector.
24904 if (V.getOpcode() == ISD::BUILD_VECTOR) {
24905 EVT InVT = V.getValueType();
24906 unsigned ExtractSize = NVT.getSizeInBits();
24907 unsigned EltSize = InVT.getScalarSizeInBits();
24908 // Only do this if we won't split any elements.
24909 if (ExtractSize % EltSize == 0) {
24910 unsigned NumElems = ExtractSize / EltSize;
24911 EVT EltVT = InVT.getVectorElementType();
24912 EVT ExtractVT =
24913 NumElems == 1 ? EltVT
24914 : EVT::getVectorVT(Context&: *DAG.getContext(), VT: EltVT, NumElements: NumElems);
24915 if ((Level < AfterLegalizeDAG ||
24916 (NumElems == 1 ||
24917 TLI.isOperationLegal(Op: ISD::BUILD_VECTOR, VT: ExtractVT))) &&
24918 (!LegalTypes || TLI.isTypeLegal(VT: ExtractVT))) {
24919 unsigned IdxVal = (ExtIdx * NVT.getScalarSizeInBits()) / EltSize;
24920
24921 if (NumElems == 1) {
24922 SDValue Src = V->getOperand(Num: IdxVal);
24923 if (EltVT != Src.getValueType())
24924 Src = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: EltVT, Operand: Src);
24925 return DAG.getBitcast(VT: NVT, V: Src);
24926 }
24927
24928 // Extract the pieces from the original build_vector.
24929 SDValue BuildVec =
24930 DAG.getBuildVector(VT: ExtractVT, DL, Ops: V->ops().slice(N: IdxVal, M: NumElems));
24931 return DAG.getBitcast(VT: NVT, V: BuildVec);
24932 }
24933 }
24934 }
24935
24936 if (V.getOpcode() == ISD::INSERT_SUBVECTOR) {
24937 // Handle only simple case where vector being inserted and vector
24938 // being extracted are of same size.
24939 EVT SmallVT = V.getOperand(i: 1).getValueType();
24940 if (!NVT.bitsEq(VT: SmallVT))
24941 return SDValue();
24942
24943 // Combine:
24944 // (extract_subvec (insert_subvec V1, V2, InsIdx), ExtIdx)
24945 // Into:
24946 // indices are equal or bit offsets are equal => V1
24947 // otherwise => (extract_subvec V1, ExtIdx)
24948 uint64_t InsIdx = V.getConstantOperandVal(i: 2);
24949 if (InsIdx * SmallVT.getScalarSizeInBits() ==
24950 ExtIdx * NVT.getScalarSizeInBits()) {
24951 if (LegalOperations && !TLI.isOperationLegal(Op: ISD::BITCAST, VT: NVT))
24952 return SDValue();
24953
24954 return DAG.getBitcast(VT: NVT, V: V.getOperand(i: 1));
24955 }
24956 return DAG.getNode(
24957 Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NVT,
24958 N1: DAG.getBitcast(VT: N->getOperand(Num: 0).getValueType(), V: V.getOperand(i: 0)),
24959 N2: N->getOperand(Num: 1));
24960 }
24961
24962 if (SDValue NarrowBOp = narrowExtractedVectorBinOp(Extract: N, DAG, LegalOperations))
24963 return NarrowBOp;
24964
24965 if (SimplifyDemandedVectorElts(Op: SDValue(N, 0)))
24966 return SDValue(N, 0);
24967
24968 return SDValue();
24969}
24970
24971/// Try to convert a wide shuffle of concatenated vectors into 2 narrow shuffles
24972/// followed by concatenation. Narrow vector ops may have better performance
24973/// than wide ops, and this can unlock further narrowing of other vector ops.
24974/// Targets can invert this transform later if it is not profitable.
24975static SDValue foldShuffleOfConcatUndefs(ShuffleVectorSDNode *Shuf,
24976 SelectionDAG &DAG) {
24977 SDValue N0 = Shuf->getOperand(Num: 0), N1 = Shuf->getOperand(Num: 1);
24978 if (N0.getOpcode() != ISD::CONCAT_VECTORS || N0.getNumOperands() != 2 ||
24979 N1.getOpcode() != ISD::CONCAT_VECTORS || N1.getNumOperands() != 2 ||
24980 !N0.getOperand(i: 1).isUndef() || !N1.getOperand(i: 1).isUndef())
24981 return SDValue();
24982
24983 // Split the wide shuffle mask into halves. Any mask element that is accessing
24984 // operand 1 is offset down to account for narrowing of the vectors.
24985 ArrayRef<int> Mask = Shuf->getMask();
24986 EVT VT = Shuf->getValueType(ResNo: 0);
24987 unsigned NumElts = VT.getVectorNumElements();
24988 unsigned HalfNumElts = NumElts / 2;
24989 SmallVector<int, 16> Mask0(HalfNumElts, -1);
24990 SmallVector<int, 16> Mask1(HalfNumElts, -1);
24991 for (unsigned i = 0; i != NumElts; ++i) {
24992 if (Mask[i] == -1)
24993 continue;
24994 // If we reference the upper (undef) subvector then the element is undef.
24995 if ((Mask[i] % NumElts) >= HalfNumElts)
24996 continue;
24997 int M = Mask[i] < (int)NumElts ? Mask[i] : Mask[i] - (int)HalfNumElts;
24998 if (i < HalfNumElts)
24999 Mask0[i] = M;
25000 else
25001 Mask1[i - HalfNumElts] = M;
25002 }
25003
25004 // Ask the target if this is a valid transform.
25005 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
25006 EVT HalfVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: VT.getScalarType(),
25007 NumElements: HalfNumElts);
25008 if (!TLI.isShuffleMaskLegal(Mask0, HalfVT) ||
25009 !TLI.isShuffleMaskLegal(Mask1, HalfVT))
25010 return SDValue();
25011
25012 // shuffle (concat X, undef), (concat Y, undef), Mask -->
25013 // concat (shuffle X, Y, Mask0), (shuffle X, Y, Mask1)
25014 SDValue X = N0.getOperand(i: 0), Y = N1.getOperand(i: 0);
25015 SDLoc DL(Shuf);
25016 SDValue Shuf0 = DAG.getVectorShuffle(VT: HalfVT, dl: DL, N1: X, N2: Y, Mask: Mask0);
25017 SDValue Shuf1 = DAG.getVectorShuffle(VT: HalfVT, dl: DL, N1: X, N2: Y, Mask: Mask1);
25018 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT, N1: Shuf0, N2: Shuf1);
25019}
25020
25021// Tries to turn a shuffle of two CONCAT_VECTORS into a single concat,
25022// or turn a shuffle of a single concat into simpler shuffle then concat.
25023static SDValue partitionShuffleOfConcats(SDNode *N, SelectionDAG &DAG) {
25024 EVT VT = N->getValueType(ResNo: 0);
25025 unsigned NumElts = VT.getVectorNumElements();
25026
25027 SDValue N0 = N->getOperand(Num: 0);
25028 SDValue N1 = N->getOperand(Num: 1);
25029 ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Val: N);
25030 ArrayRef<int> Mask = SVN->getMask();
25031
25032 SmallVector<SDValue, 4> Ops;
25033 EVT ConcatVT = N0.getOperand(i: 0).getValueType();
25034 unsigned NumElemsPerConcat = ConcatVT.getVectorNumElements();
25035 unsigned NumConcats = NumElts / NumElemsPerConcat;
25036
25037 auto IsUndefMaskElt = [](int i) { return i == -1; };
25038
25039 // Special case: shuffle(concat(A,B)) can be more efficiently represented
25040 // as concat(shuffle(A,B),UNDEF) if the shuffle doesn't set any of the high
25041 // half vector elements.
25042 if (NumElemsPerConcat * 2 == NumElts && N1.isUndef() &&
25043 llvm::all_of(Range: Mask.slice(N: NumElemsPerConcat, M: NumElemsPerConcat),
25044 P: IsUndefMaskElt)) {
25045 N0 = DAG.getVectorShuffle(VT: ConcatVT, dl: SDLoc(N), N1: N0.getOperand(i: 0),
25046 N2: N0.getOperand(i: 1),
25047 Mask: Mask.slice(N: 0, M: NumElemsPerConcat));
25048 N1 = DAG.getUNDEF(VT: ConcatVT);
25049 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT, N1: N0, N2: N1);
25050 }
25051
25052 // Look at every vector that's inserted. We're looking for exact
25053 // subvector-sized copies from a concatenated vector
25054 for (unsigned I = 0; I != NumConcats; ++I) {
25055 unsigned Begin = I * NumElemsPerConcat;
25056 ArrayRef<int> SubMask = Mask.slice(N: Begin, M: NumElemsPerConcat);
25057
25058 // Make sure we're dealing with a copy.
25059 if (llvm::all_of(Range&: SubMask, P: IsUndefMaskElt)) {
25060 Ops.push_back(Elt: DAG.getUNDEF(VT: ConcatVT));
25061 continue;
25062 }
25063
25064 int OpIdx = -1;
25065 for (int i = 0; i != (int)NumElemsPerConcat; ++i) {
25066 if (IsUndefMaskElt(SubMask[i]))
25067 continue;
25068 if ((SubMask[i] % (int)NumElemsPerConcat) != i)
25069 return SDValue();
25070 int EltOpIdx = SubMask[i] / NumElemsPerConcat;
25071 if (0 <= OpIdx && EltOpIdx != OpIdx)
25072 return SDValue();
25073 OpIdx = EltOpIdx;
25074 }
25075 assert(0 <= OpIdx && "Unknown concat_vectors op");
25076
25077 if (OpIdx < (int)N0.getNumOperands())
25078 Ops.push_back(Elt: N0.getOperand(i: OpIdx));
25079 else
25080 Ops.push_back(Elt: N1.getOperand(i: OpIdx - N0.getNumOperands()));
25081 }
25082
25083 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT, Ops);
25084}
25085
25086// Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
25087// BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
25088//
25089// SHUFFLE(BUILD_VECTOR(), BUILD_VECTOR()) -> BUILD_VECTOR() is always
25090// a simplification in some sense, but it isn't appropriate in general: some
25091// BUILD_VECTORs are substantially cheaper than others. The general case
25092// of a BUILD_VECTOR requires inserting each element individually (or
25093// performing the equivalent in a temporary stack variable). A BUILD_VECTOR of
25094// all constants is a single constant pool load. A BUILD_VECTOR where each
25095// element is identical is a splat. A BUILD_VECTOR where most of the operands
25096// are undef lowers to a small number of element insertions.
25097//
25098// To deal with this, we currently use a bunch of mostly arbitrary heuristics.
25099// We don't fold shuffles where one side is a non-zero constant, and we don't
25100// fold shuffles if the resulting (non-splat) BUILD_VECTOR would have duplicate
25101// non-constant operands. This seems to work out reasonably well in practice.
25102static SDValue combineShuffleOfScalars(ShuffleVectorSDNode *SVN,
25103 SelectionDAG &DAG,
25104 const TargetLowering &TLI) {
25105 EVT VT = SVN->getValueType(ResNo: 0);
25106 unsigned NumElts = VT.getVectorNumElements();
25107 SDValue N0 = SVN->getOperand(Num: 0);
25108 SDValue N1 = SVN->getOperand(Num: 1);
25109
25110 if (!N0->hasOneUse())
25111 return SDValue();
25112
25113 // If only one of N1,N2 is constant, bail out if it is not ALL_ZEROS as
25114 // discussed above.
25115 if (!N1.isUndef()) {
25116 if (!N1->hasOneUse())
25117 return SDValue();
25118
25119 bool N0AnyConst = isAnyConstantBuildVector(V: N0);
25120 bool N1AnyConst = isAnyConstantBuildVector(V: N1);
25121 if (N0AnyConst && !N1AnyConst && !ISD::isBuildVectorAllZeros(N: N0.getNode()))
25122 return SDValue();
25123 if (!N0AnyConst && N1AnyConst && !ISD::isBuildVectorAllZeros(N: N1.getNode()))
25124 return SDValue();
25125 }
25126
25127 // If both inputs are splats of the same value then we can safely merge this
25128 // to a single BUILD_VECTOR with undef elements based on the shuffle mask.
25129 bool IsSplat = false;
25130 auto *BV0 = dyn_cast<BuildVectorSDNode>(Val&: N0);
25131 auto *BV1 = dyn_cast<BuildVectorSDNode>(Val&: N1);
25132 if (BV0 && BV1)
25133 if (SDValue Splat0 = BV0->getSplatValue())
25134 IsSplat = (Splat0 == BV1->getSplatValue());
25135
25136 SmallVector<SDValue, 8> Ops;
25137 SmallSet<SDValue, 16> DuplicateOps;
25138 for (int M : SVN->getMask()) {
25139 SDValue Op = DAG.getUNDEF(VT: VT.getScalarType());
25140 if (M >= 0) {
25141 int Idx = M < (int)NumElts ? M : M - NumElts;
25142 SDValue &S = (M < (int)NumElts ? N0 : N1);
25143 if (S.getOpcode() == ISD::BUILD_VECTOR) {
25144 Op = S.getOperand(i: Idx);
25145 } else if (S.getOpcode() == ISD::SCALAR_TO_VECTOR) {
25146 SDValue Op0 = S.getOperand(i: 0);
25147 Op = Idx == 0 ? Op0 : DAG.getUNDEF(VT: Op0.getValueType());
25148 } else {
25149 // Operand can't be combined - bail out.
25150 return SDValue();
25151 }
25152 }
25153
25154 // Don't duplicate a non-constant BUILD_VECTOR operand unless we're
25155 // generating a splat; semantically, this is fine, but it's likely to
25156 // generate low-quality code if the target can't reconstruct an appropriate
25157 // shuffle.
25158 if (!Op.isUndef() && !isIntOrFPConstant(V: Op))
25159 if (!IsSplat && !DuplicateOps.insert(V: Op).second)
25160 return SDValue();
25161
25162 Ops.push_back(Elt: Op);
25163 }
25164
25165 // BUILD_VECTOR requires all inputs to be of the same type, find the
25166 // maximum type and extend them all.
25167 EVT SVT = VT.getScalarType();
25168 if (SVT.isInteger())
25169 for (SDValue &Op : Ops)
25170 SVT = (SVT.bitsLT(VT: Op.getValueType()) ? Op.getValueType() : SVT);
25171 if (SVT != VT.getScalarType())
25172 for (SDValue &Op : Ops)
25173 Op = Op.isUndef() ? DAG.getUNDEF(VT: SVT)
25174 : (TLI.isZExtFree(FromTy: Op.getValueType(), ToTy: SVT)
25175 ? DAG.getZExtOrTrunc(Op, DL: SDLoc(SVN), VT: SVT)
25176 : DAG.getSExtOrTrunc(Op, DL: SDLoc(SVN), VT: SVT));
25177 return DAG.getBuildVector(VT, DL: SDLoc(SVN), Ops);
25178}
25179
25180// Match shuffles that can be converted to *_vector_extend_in_reg.
25181// This is often generated during legalization.
25182// e.g. v4i32 <0,u,1,u> -> (v2i64 any_vector_extend_in_reg(v4i32 src)),
25183// and returns the EVT to which the extension should be performed.
25184// NOTE: this assumes that the src is the first operand of the shuffle.
25185static std::optional<EVT> canCombineShuffleToExtendVectorInreg(
25186 unsigned Opcode, EVT VT, std::function<bool(unsigned)> Match,
25187 SelectionDAG &DAG, const TargetLowering &TLI, bool LegalTypes,
25188 bool LegalOperations) {
25189 bool IsBigEndian = DAG.getDataLayout().isBigEndian();
25190
25191 // TODO Add support for big-endian when we have a test case.
25192 if (!VT.isInteger() || IsBigEndian)
25193 return std::nullopt;
25194
25195 unsigned NumElts = VT.getVectorNumElements();
25196 unsigned EltSizeInBits = VT.getScalarSizeInBits();
25197
25198 // Attempt to match a '*_extend_vector_inreg' shuffle, we just search for
25199 // power-of-2 extensions as they are the most likely.
25200 // FIXME: should try Scale == NumElts case too,
25201 for (unsigned Scale = 2; Scale < NumElts; Scale *= 2) {
25202 // The vector width must be a multiple of Scale.
25203 if (NumElts % Scale != 0)
25204 continue;
25205
25206 EVT OutSVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: EltSizeInBits * Scale);
25207 EVT OutVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: OutSVT, NumElements: NumElts / Scale);
25208
25209 if ((LegalTypes && !TLI.isTypeLegal(VT: OutVT)) ||
25210 (LegalOperations && !TLI.isOperationLegalOrCustom(Op: Opcode, VT: OutVT)))
25211 continue;
25212
25213 if (Match(Scale))
25214 return OutVT;
25215 }
25216
25217 return std::nullopt;
25218}
25219
25220// Match shuffles that can be converted to any_vector_extend_in_reg.
25221// This is often generated during legalization.
25222// e.g. v4i32 <0,u,1,u> -> (v2i64 any_vector_extend_in_reg(v4i32 src))
25223static SDValue combineShuffleToAnyExtendVectorInreg(ShuffleVectorSDNode *SVN,
25224 SelectionDAG &DAG,
25225 const TargetLowering &TLI,
25226 bool LegalOperations) {
25227 EVT VT = SVN->getValueType(ResNo: 0);
25228 bool IsBigEndian = DAG.getDataLayout().isBigEndian();
25229
25230 // TODO Add support for big-endian when we have a test case.
25231 if (!VT.isInteger() || IsBigEndian)
25232 return SDValue();
25233
25234 // shuffle<0,-1,1,-1> == (v2i64 anyextend_vector_inreg(v4i32))
25235 auto isAnyExtend = [NumElts = VT.getVectorNumElements(),
25236 Mask = SVN->getMask()](unsigned Scale) {
25237 for (unsigned i = 0; i != NumElts; ++i) {
25238 if (Mask[i] < 0)
25239 continue;
25240 if ((i % Scale) == 0 && Mask[i] == (int)(i / Scale))
25241 continue;
25242 return false;
25243 }
25244 return true;
25245 };
25246
25247 unsigned Opcode = ISD::ANY_EXTEND_VECTOR_INREG;
25248 SDValue N0 = SVN->getOperand(Num: 0);
25249 // Never create an illegal type. Only create unsupported operations if we
25250 // are pre-legalization.
25251 std::optional<EVT> OutVT = canCombineShuffleToExtendVectorInreg(
25252 Opcode, VT, Match: isAnyExtend, DAG, TLI, /*LegalTypes=*/true, LegalOperations);
25253 if (!OutVT)
25254 return SDValue();
25255 return DAG.getBitcast(VT, V: DAG.getNode(Opcode, DL: SDLoc(SVN), VT: *OutVT, Operand: N0));
25256}
25257
25258// Match shuffles that can be converted to zero_extend_vector_inreg.
25259// This is often generated during legalization.
25260// e.g. v4i32 <0,z,1,u> -> (v2i64 zero_extend_vector_inreg(v4i32 src))
25261static SDValue combineShuffleToZeroExtendVectorInReg(ShuffleVectorSDNode *SVN,
25262 SelectionDAG &DAG,
25263 const TargetLowering &TLI,
25264 bool LegalOperations) {
25265 bool LegalTypes = true;
25266 EVT VT = SVN->getValueType(ResNo: 0);
25267 assert(!VT.isScalableVector() && "Encountered scalable shuffle?");
25268 unsigned NumElts = VT.getVectorNumElements();
25269 unsigned EltSizeInBits = VT.getScalarSizeInBits();
25270
25271 // TODO: add support for big-endian when we have a test case.
25272 bool IsBigEndian = DAG.getDataLayout().isBigEndian();
25273 if (!VT.isInteger() || IsBigEndian)
25274 return SDValue();
25275
25276 SmallVector<int, 16> Mask(SVN->getMask().begin(), SVN->getMask().end());
25277 auto ForEachDecomposedIndice = [NumElts, &Mask](auto Fn) {
25278 for (int &Indice : Mask) {
25279 if (Indice < 0)
25280 continue;
25281 int OpIdx = (unsigned)Indice < NumElts ? 0 : 1;
25282 int OpEltIdx = (unsigned)Indice < NumElts ? Indice : Indice - NumElts;
25283 Fn(Indice, OpIdx, OpEltIdx);
25284 }
25285 };
25286
25287 // Which elements of which operand does this shuffle demand?
25288 std::array<APInt, 2> OpsDemandedElts;
25289 for (APInt &OpDemandedElts : OpsDemandedElts)
25290 OpDemandedElts = APInt::getZero(numBits: NumElts);
25291 ForEachDecomposedIndice(
25292 [&OpsDemandedElts](int &Indice, int OpIdx, int OpEltIdx) {
25293 OpsDemandedElts[OpIdx].setBit(OpEltIdx);
25294 });
25295
25296 // Element-wise(!), which of these demanded elements are know to be zero?
25297 std::array<APInt, 2> OpsKnownZeroElts;
25298 for (auto I : zip(t: SVN->ops(), u&: OpsDemandedElts, args&: OpsKnownZeroElts))
25299 std::get<2>(t&: I) =
25300 DAG.computeVectorKnownZeroElements(Op: std::get<0>(t&: I), DemandedElts: std::get<1>(t&: I));
25301
25302 // Manifest zeroable element knowledge in the shuffle mask.
25303 // NOTE: we don't have 'zeroable' sentinel value in generic DAG,
25304 // this is a local invention, but it won't leak into DAG.
25305 // FIXME: should we not manifest them, but just check when matching?
25306 bool HadZeroableElts = false;
25307 ForEachDecomposedIndice([&OpsKnownZeroElts, &HadZeroableElts](
25308 int &Indice, int OpIdx, int OpEltIdx) {
25309 if (OpsKnownZeroElts[OpIdx][OpEltIdx]) {
25310 Indice = -2; // Zeroable element.
25311 HadZeroableElts = true;
25312 }
25313 });
25314
25315 // Don't proceed unless we've refined at least one zeroable mask indice.
25316 // If we didn't, then we are still trying to match the same shuffle mask
25317 // we previously tried to match as ISD::ANY_EXTEND_VECTOR_INREG,
25318 // and evidently failed. Proceeding will lead to endless combine loops.
25319 if (!HadZeroableElts)
25320 return SDValue();
25321
25322 // The shuffle may be more fine-grained than we want. Widen elements first.
25323 // FIXME: should we do this before manifesting zeroable shuffle mask indices?
25324 SmallVector<int, 16> ScaledMask;
25325 getShuffleMaskWithWidestElts(Mask, ScaledMask);
25326 assert(Mask.size() >= ScaledMask.size() &&
25327 Mask.size() % ScaledMask.size() == 0 && "Unexpected mask widening.");
25328 int Prescale = Mask.size() / ScaledMask.size();
25329
25330 NumElts = ScaledMask.size();
25331 EltSizeInBits *= Prescale;
25332
25333 EVT PrescaledVT = EVT::getVectorVT(
25334 Context&: *DAG.getContext(), VT: EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: EltSizeInBits),
25335 NumElements: NumElts);
25336
25337 if (LegalTypes && !TLI.isTypeLegal(VT: PrescaledVT) && TLI.isTypeLegal(VT))
25338 return SDValue();
25339
25340 // For example,
25341 // shuffle<0,z,1,-1> == (v2i64 zero_extend_vector_inreg(v4i32))
25342 // But not shuffle<z,z,1,-1> and not shuffle<0,z,z,-1> ! (for same types)
25343 auto isZeroExtend = [NumElts, &ScaledMask](unsigned Scale) {
25344 assert(Scale >= 2 && Scale <= NumElts && NumElts % Scale == 0 &&
25345 "Unexpected mask scaling factor.");
25346 ArrayRef<int> Mask = ScaledMask;
25347 for (unsigned SrcElt = 0, NumSrcElts = NumElts / Scale;
25348 SrcElt != NumSrcElts; ++SrcElt) {
25349 // Analyze the shuffle mask in Scale-sized chunks.
25350 ArrayRef<int> MaskChunk = Mask.take_front(N: Scale);
25351 assert(MaskChunk.size() == Scale && "Unexpected mask size.");
25352 Mask = Mask.drop_front(N: MaskChunk.size());
25353 // The first indice in this chunk must be SrcElt, but not zero!
25354 // FIXME: undef should be fine, but that results in more-defined result.
25355 if (int FirstIndice = MaskChunk[0]; (unsigned)FirstIndice != SrcElt)
25356 return false;
25357 // The rest of the indices in this chunk must be zeros.
25358 // FIXME: undef should be fine, but that results in more-defined result.
25359 if (!all_of(Range: MaskChunk.drop_front(N: 1),
25360 P: [](int Indice) { return Indice == -2; }))
25361 return false;
25362 }
25363 assert(Mask.empty() && "Did not process the whole mask?");
25364 return true;
25365 };
25366
25367 unsigned Opcode = ISD::ZERO_EXTEND_VECTOR_INREG;
25368 for (bool Commuted : {false, true}) {
25369 SDValue Op = SVN->getOperand(Num: !Commuted ? 0 : 1);
25370 if (Commuted)
25371 ShuffleVectorSDNode::commuteMask(Mask: ScaledMask);
25372 std::optional<EVT> OutVT = canCombineShuffleToExtendVectorInreg(
25373 Opcode, VT: PrescaledVT, Match: isZeroExtend, DAG, TLI, LegalTypes,
25374 LegalOperations);
25375 if (OutVT)
25376 return DAG.getBitcast(VT, V: DAG.getNode(Opcode, DL: SDLoc(SVN), VT: *OutVT,
25377 Operand: DAG.getBitcast(VT: PrescaledVT, V: Op)));
25378 }
25379 return SDValue();
25380}
25381
25382// Detect 'truncate_vector_inreg' style shuffles that pack the lower parts of
25383// each source element of a large type into the lowest elements of a smaller
25384// destination type. This is often generated during legalization.
25385// If the source node itself was a '*_extend_vector_inreg' node then we should
25386// then be able to remove it.
25387static SDValue combineTruncationShuffle(ShuffleVectorSDNode *SVN,
25388 SelectionDAG &DAG) {
25389 EVT VT = SVN->getValueType(ResNo: 0);
25390 bool IsBigEndian = DAG.getDataLayout().isBigEndian();
25391
25392 // TODO Add support for big-endian when we have a test case.
25393 if (!VT.isInteger() || IsBigEndian)
25394 return SDValue();
25395
25396 SDValue N0 = peekThroughBitcasts(V: SVN->getOperand(Num: 0));
25397
25398 unsigned Opcode = N0.getOpcode();
25399 if (!ISD::isExtVecInRegOpcode(Opcode))
25400 return SDValue();
25401
25402 SDValue N00 = N0.getOperand(i: 0);
25403 ArrayRef<int> Mask = SVN->getMask();
25404 unsigned NumElts = VT.getVectorNumElements();
25405 unsigned EltSizeInBits = VT.getScalarSizeInBits();
25406 unsigned ExtSrcSizeInBits = N00.getScalarValueSizeInBits();
25407 unsigned ExtDstSizeInBits = N0.getScalarValueSizeInBits();
25408
25409 if (ExtDstSizeInBits % ExtSrcSizeInBits != 0)
25410 return SDValue();
25411 unsigned ExtScale = ExtDstSizeInBits / ExtSrcSizeInBits;
25412
25413 // (v4i32 truncate_vector_inreg(v2i64)) == shuffle<0,2-1,-1>
25414 // (v8i16 truncate_vector_inreg(v4i32)) == shuffle<0,2,4,6,-1,-1,-1,-1>
25415 // (v8i16 truncate_vector_inreg(v2i64)) == shuffle<0,4,-1,-1,-1,-1,-1,-1>
25416 auto isTruncate = [&Mask, &NumElts](unsigned Scale) {
25417 for (unsigned i = 0; i != NumElts; ++i) {
25418 if (Mask[i] < 0)
25419 continue;
25420 if ((i * Scale) < NumElts && Mask[i] == (int)(i * Scale))
25421 continue;
25422 return false;
25423 }
25424 return true;
25425 };
25426
25427 // At the moment we just handle the case where we've truncated back to the
25428 // same size as before the extension.
25429 // TODO: handle more extension/truncation cases as cases arise.
25430 if (EltSizeInBits != ExtSrcSizeInBits)
25431 return SDValue();
25432
25433 // We can remove *extend_vector_inreg only if the truncation happens at
25434 // the same scale as the extension.
25435 if (isTruncate(ExtScale))
25436 return DAG.getBitcast(VT, V: N00);
25437
25438 return SDValue();
25439}
25440
25441// Combine shuffles of splat-shuffles of the form:
25442// shuffle (shuffle V, undef, splat-mask), undef, M
25443// If splat-mask contains undef elements, we need to be careful about
25444// introducing undef's in the folded mask which are not the result of composing
25445// the masks of the shuffles.
25446static SDValue combineShuffleOfSplatVal(ShuffleVectorSDNode *Shuf,
25447 SelectionDAG &DAG) {
25448 EVT VT = Shuf->getValueType(ResNo: 0);
25449 unsigned NumElts = VT.getVectorNumElements();
25450
25451 if (!Shuf->getOperand(Num: 1).isUndef())
25452 return SDValue();
25453
25454 // See if this unary non-splat shuffle actually *is* a splat shuffle,
25455 // in disguise, with all demanded elements being identical.
25456 // FIXME: this can be done per-operand.
25457 if (!Shuf->isSplat()) {
25458 APInt DemandedElts(NumElts, 0);
25459 for (int Idx : Shuf->getMask()) {
25460 if (Idx < 0)
25461 continue; // Ignore sentinel indices.
25462 assert((unsigned)Idx < NumElts && "Out-of-bounds shuffle indice?");
25463 DemandedElts.setBit(Idx);
25464 }
25465 assert(DemandedElts.popcount() > 1 && "Is a splat shuffle already?");
25466 APInt UndefElts;
25467 if (DAG.isSplatValue(V: Shuf->getOperand(Num: 0), DemandedElts, UndefElts)) {
25468 // Even if all demanded elements are splat, some of them could be undef.
25469 // Which lowest demanded element is *not* known-undef?
25470 std::optional<unsigned> MinNonUndefIdx;
25471 for (int Idx : Shuf->getMask()) {
25472 if (Idx < 0 || UndefElts[Idx])
25473 continue; // Ignore sentinel indices, and undef elements.
25474 MinNonUndefIdx = std::min<unsigned>(a: Idx, b: MinNonUndefIdx.value_or(u: ~0U));
25475 }
25476 if (!MinNonUndefIdx)
25477 return DAG.getUNDEF(VT); // All undef - result is undef.
25478 assert(*MinNonUndefIdx < NumElts && "Expected valid element index.");
25479 SmallVector<int, 8> SplatMask(Shuf->getMask().begin(),
25480 Shuf->getMask().end());
25481 for (int &Idx : SplatMask) {
25482 if (Idx < 0)
25483 continue; // Passthrough sentinel indices.
25484 // Otherwise, just pick the lowest demanded non-undef element.
25485 // Or sentinel undef, if we know we'd pick a known-undef element.
25486 Idx = UndefElts[Idx] ? -1 : *MinNonUndefIdx;
25487 }
25488 assert(SplatMask != Shuf->getMask() && "Expected mask to change!");
25489 return DAG.getVectorShuffle(VT, dl: SDLoc(Shuf), N1: Shuf->getOperand(Num: 0),
25490 N2: Shuf->getOperand(Num: 1), Mask: SplatMask);
25491 }
25492 }
25493
25494 // If the inner operand is a known splat with no undefs, just return that directly.
25495 // TODO: Create DemandedElts mask from Shuf's mask.
25496 // TODO: Allow undef elements and merge with the shuffle code below.
25497 if (DAG.isSplatValue(V: Shuf->getOperand(Num: 0), /*AllowUndefs*/ false))
25498 return Shuf->getOperand(Num: 0);
25499
25500 auto *Splat = dyn_cast<ShuffleVectorSDNode>(Val: Shuf->getOperand(Num: 0));
25501 if (!Splat || !Splat->isSplat())
25502 return SDValue();
25503
25504 ArrayRef<int> ShufMask = Shuf->getMask();
25505 ArrayRef<int> SplatMask = Splat->getMask();
25506 assert(ShufMask.size() == SplatMask.size() && "Mask length mismatch");
25507
25508 // Prefer simplifying to the splat-shuffle, if possible. This is legal if
25509 // every undef mask element in the splat-shuffle has a corresponding undef
25510 // element in the user-shuffle's mask or if the composition of mask elements
25511 // would result in undef.
25512 // Examples for (shuffle (shuffle v, undef, SplatMask), undef, UserMask):
25513 // * UserMask=[0,2,u,u], SplatMask=[2,u,2,u] -> [2,2,u,u]
25514 // In this case it is not legal to simplify to the splat-shuffle because we
25515 // may be exposing the users of the shuffle an undef element at index 1
25516 // which was not there before the combine.
25517 // * UserMask=[0,u,2,u], SplatMask=[2,u,2,u] -> [2,u,2,u]
25518 // In this case the composition of masks yields SplatMask, so it's ok to
25519 // simplify to the splat-shuffle.
25520 // * UserMask=[3,u,2,u], SplatMask=[2,u,2,u] -> [u,u,2,u]
25521 // In this case the composed mask includes all undef elements of SplatMask
25522 // and in addition sets element zero to undef. It is safe to simplify to
25523 // the splat-shuffle.
25524 auto CanSimplifyToExistingSplat = [](ArrayRef<int> UserMask,
25525 ArrayRef<int> SplatMask) {
25526 for (unsigned i = 0, e = UserMask.size(); i != e; ++i)
25527 if (UserMask[i] != -1 && SplatMask[i] == -1 &&
25528 SplatMask[UserMask[i]] != -1)
25529 return false;
25530 return true;
25531 };
25532 if (CanSimplifyToExistingSplat(ShufMask, SplatMask))
25533 return Shuf->getOperand(Num: 0);
25534
25535 // Create a new shuffle with a mask that is composed of the two shuffles'
25536 // masks.
25537 SmallVector<int, 32> NewMask;
25538 for (int Idx : ShufMask)
25539 NewMask.push_back(Elt: Idx == -1 ? -1 : SplatMask[Idx]);
25540
25541 return DAG.getVectorShuffle(VT: Splat->getValueType(ResNo: 0), dl: SDLoc(Splat),
25542 N1: Splat->getOperand(Num: 0), N2: Splat->getOperand(Num: 1),
25543 Mask: NewMask);
25544}
25545
25546// Combine shuffles of bitcasts into a shuffle of the bitcast type, providing
25547// the mask can be treated as a larger type.
25548static SDValue combineShuffleOfBitcast(ShuffleVectorSDNode *SVN,
25549 SelectionDAG &DAG,
25550 const TargetLowering &TLI,
25551 bool LegalOperations) {
25552 SDValue Op0 = SVN->getOperand(Num: 0);
25553 SDValue Op1 = SVN->getOperand(Num: 1);
25554 EVT VT = SVN->getValueType(ResNo: 0);
25555 if (Op0.getOpcode() != ISD::BITCAST)
25556 return SDValue();
25557 EVT InVT = Op0.getOperand(i: 0).getValueType();
25558 if (!InVT.isVector() ||
25559 (!Op1.isUndef() && (Op1.getOpcode() != ISD::BITCAST ||
25560 Op1.getOperand(i: 0).getValueType() != InVT)))
25561 return SDValue();
25562 if (isAnyConstantBuildVector(V: Op0.getOperand(i: 0)) &&
25563 (Op1.isUndef() || isAnyConstantBuildVector(V: Op1.getOperand(i: 0))))
25564 return SDValue();
25565
25566 int VTLanes = VT.getVectorNumElements();
25567 int InLanes = InVT.getVectorNumElements();
25568 if (VTLanes <= InLanes || VTLanes % InLanes != 0 ||
25569 (LegalOperations &&
25570 !TLI.isOperationLegalOrCustom(Op: ISD::VECTOR_SHUFFLE, VT: InVT)))
25571 return SDValue();
25572 int Factor = VTLanes / InLanes;
25573
25574 // Check that each group of lanes in the mask are either undef or make a valid
25575 // mask for the wider lane type.
25576 ArrayRef<int> Mask = SVN->getMask();
25577 SmallVector<int> NewMask;
25578 if (!widenShuffleMaskElts(Scale: Factor, Mask, ScaledMask&: NewMask))
25579 return SDValue();
25580
25581 if (!TLI.isShuffleMaskLegal(NewMask, InVT))
25582 return SDValue();
25583
25584 // Create the new shuffle with the new mask and bitcast it back to the
25585 // original type.
25586 SDLoc DL(SVN);
25587 Op0 = Op0.getOperand(i: 0);
25588 Op1 = Op1.isUndef() ? DAG.getUNDEF(VT: InVT) : Op1.getOperand(i: 0);
25589 SDValue NewShuf = DAG.getVectorShuffle(VT: InVT, dl: DL, N1: Op0, N2: Op1, Mask: NewMask);
25590 return DAG.getBitcast(VT, V: NewShuf);
25591}
25592
25593/// Combine shuffle of shuffle of the form:
25594/// shuf (shuf X, undef, InnerMask), undef, OuterMask --> splat X
25595static SDValue formSplatFromShuffles(ShuffleVectorSDNode *OuterShuf,
25596 SelectionDAG &DAG) {
25597 if (!OuterShuf->getOperand(Num: 1).isUndef())
25598 return SDValue();
25599 auto *InnerShuf = dyn_cast<ShuffleVectorSDNode>(Val: OuterShuf->getOperand(Num: 0));
25600 if (!InnerShuf || !InnerShuf->getOperand(Num: 1).isUndef())
25601 return SDValue();
25602
25603 ArrayRef<int> OuterMask = OuterShuf->getMask();
25604 ArrayRef<int> InnerMask = InnerShuf->getMask();
25605 unsigned NumElts = OuterMask.size();
25606 assert(NumElts == InnerMask.size() && "Mask length mismatch");
25607 SmallVector<int, 32> CombinedMask(NumElts, -1);
25608 int SplatIndex = -1;
25609 for (unsigned i = 0; i != NumElts; ++i) {
25610 // Undef lanes remain undef.
25611 int OuterMaskElt = OuterMask[i];
25612 if (OuterMaskElt == -1)
25613 continue;
25614
25615 // Peek through the shuffle masks to get the underlying source element.
25616 int InnerMaskElt = InnerMask[OuterMaskElt];
25617 if (InnerMaskElt == -1)
25618 continue;
25619
25620 // Initialize the splatted element.
25621 if (SplatIndex == -1)
25622 SplatIndex = InnerMaskElt;
25623
25624 // Non-matching index - this is not a splat.
25625 if (SplatIndex != InnerMaskElt)
25626 return SDValue();
25627
25628 CombinedMask[i] = InnerMaskElt;
25629 }
25630 assert((all_of(CombinedMask, [](int M) { return M == -1; }) ||
25631 getSplatIndex(CombinedMask) != -1) &&
25632 "Expected a splat mask");
25633
25634 // TODO: The transform may be a win even if the mask is not legal.
25635 EVT VT = OuterShuf->getValueType(ResNo: 0);
25636 assert(VT == InnerShuf->getValueType(0) && "Expected matching shuffle types");
25637 if (!DAG.getTargetLoweringInfo().isShuffleMaskLegal(CombinedMask, VT))
25638 return SDValue();
25639
25640 return DAG.getVectorShuffle(VT, dl: SDLoc(OuterShuf), N1: InnerShuf->getOperand(Num: 0),
25641 N2: InnerShuf->getOperand(Num: 1), Mask: CombinedMask);
25642}
25643
25644/// If the shuffle mask is taking exactly one element from the first vector
25645/// operand and passing through all other elements from the second vector
25646/// operand, return the index of the mask element that is choosing an element
25647/// from the first operand. Otherwise, return -1.
25648static int getShuffleMaskIndexOfOneElementFromOp0IntoOp1(ArrayRef<int> Mask) {
25649 int MaskSize = Mask.size();
25650 int EltFromOp0 = -1;
25651 // TODO: This does not match if there are undef elements in the shuffle mask.
25652 // Should we ignore undefs in the shuffle mask instead? The trade-off is
25653 // removing an instruction (a shuffle), but losing the knowledge that some
25654 // vector lanes are not needed.
25655 for (int i = 0; i != MaskSize; ++i) {
25656 if (Mask[i] >= 0 && Mask[i] < MaskSize) {
25657 // We're looking for a shuffle of exactly one element from operand 0.
25658 if (EltFromOp0 != -1)
25659 return -1;
25660 EltFromOp0 = i;
25661 } else if (Mask[i] != i + MaskSize) {
25662 // Nothing from operand 1 can change lanes.
25663 return -1;
25664 }
25665 }
25666 return EltFromOp0;
25667}
25668
25669/// If a shuffle inserts exactly one element from a source vector operand into
25670/// another vector operand and we can access the specified element as a scalar,
25671/// then we can eliminate the shuffle.
25672static SDValue replaceShuffleOfInsert(ShuffleVectorSDNode *Shuf,
25673 SelectionDAG &DAG) {
25674 // First, check if we are taking one element of a vector and shuffling that
25675 // element into another vector.
25676 ArrayRef<int> Mask = Shuf->getMask();
25677 SmallVector<int, 16> CommutedMask(Mask);
25678 SDValue Op0 = Shuf->getOperand(Num: 0);
25679 SDValue Op1 = Shuf->getOperand(Num: 1);
25680 int ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(Mask);
25681 if (ShufOp0Index == -1) {
25682 // Commute mask and check again.
25683 ShuffleVectorSDNode::commuteMask(Mask: CommutedMask);
25684 ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(Mask: CommutedMask);
25685 if (ShufOp0Index == -1)
25686 return SDValue();
25687 // Commute operands to match the commuted shuffle mask.
25688 std::swap(a&: Op0, b&: Op1);
25689 Mask = CommutedMask;
25690 }
25691
25692 // The shuffle inserts exactly one element from operand 0 into operand 1.
25693 // Now see if we can access that element as a scalar via a real insert element
25694 // instruction.
25695 // TODO: We can try harder to locate the element as a scalar. Examples: it
25696 // could be an operand of SCALAR_TO_VECTOR, BUILD_VECTOR, or a constant.
25697 assert(Mask[ShufOp0Index] >= 0 && Mask[ShufOp0Index] < (int)Mask.size() &&
25698 "Shuffle mask value must be from operand 0");
25699 if (Op0.getOpcode() != ISD::INSERT_VECTOR_ELT)
25700 return SDValue();
25701
25702 auto *InsIndexC = dyn_cast<ConstantSDNode>(Val: Op0.getOperand(i: 2));
25703 if (!InsIndexC || InsIndexC->getSExtValue() != Mask[ShufOp0Index])
25704 return SDValue();
25705
25706 // There's an existing insertelement with constant insertion index, so we
25707 // don't need to check the legality/profitability of a replacement operation
25708 // that differs at most in the constant value. The target should be able to
25709 // lower any of those in a similar way. If not, legalization will expand this
25710 // to a scalar-to-vector plus shuffle.
25711 //
25712 // Note that the shuffle may move the scalar from the position that the insert
25713 // element used. Therefore, our new insert element occurs at the shuffle's
25714 // mask index value, not the insert's index value.
25715 // shuffle (insertelt v1, x, C), v2, mask --> insertelt v2, x, C'
25716 SDValue NewInsIndex = DAG.getVectorIdxConstant(Val: ShufOp0Index, DL: SDLoc(Shuf));
25717 return DAG.getNode(Opcode: ISD::INSERT_VECTOR_ELT, DL: SDLoc(Shuf), VT: Op0.getValueType(),
25718 N1: Op1, N2: Op0.getOperand(i: 1), N3: NewInsIndex);
25719}
25720
25721/// If we have a unary shuffle of a shuffle, see if it can be folded away
25722/// completely. This has the potential to lose undef knowledge because the first
25723/// shuffle may not have an undef mask element where the second one does. So
25724/// only call this after doing simplifications based on demanded elements.
25725static SDValue simplifyShuffleOfShuffle(ShuffleVectorSDNode *Shuf) {
25726 // shuf (shuf0 X, Y, Mask0), undef, Mask
25727 auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(Val: Shuf->getOperand(Num: 0));
25728 if (!Shuf0 || !Shuf->getOperand(Num: 1).isUndef())
25729 return SDValue();
25730
25731 ArrayRef<int> Mask = Shuf->getMask();
25732 ArrayRef<int> Mask0 = Shuf0->getMask();
25733 for (int i = 0, e = (int)Mask.size(); i != e; ++i) {
25734 // Ignore undef elements.
25735 if (Mask[i] == -1)
25736 continue;
25737 assert(Mask[i] >= 0 && Mask[i] < e && "Unexpected shuffle mask value");
25738
25739 // Is the element of the shuffle operand chosen by this shuffle the same as
25740 // the element chosen by the shuffle operand itself?
25741 if (Mask0[Mask[i]] != Mask0[i])
25742 return SDValue();
25743 }
25744 // Every element of this shuffle is identical to the result of the previous
25745 // shuffle, so we can replace this value.
25746 return Shuf->getOperand(Num: 0);
25747}
25748
25749SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
25750 EVT VT = N->getValueType(ResNo: 0);
25751 unsigned NumElts = VT.getVectorNumElements();
25752
25753 SDValue N0 = N->getOperand(Num: 0);
25754 SDValue N1 = N->getOperand(Num: 1);
25755
25756 assert(N0.getValueType() == VT && "Vector shuffle must be normalized in DAG");
25757
25758 // Canonicalize shuffle undef, undef -> undef
25759 if (N0.isUndef() && N1.isUndef())
25760 return DAG.getUNDEF(VT);
25761
25762 ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Val: N);
25763
25764 // Canonicalize shuffle v, v -> v, undef
25765 if (N0 == N1)
25766 return DAG.getVectorShuffle(VT, dl: SDLoc(N), N1: N0, N2: DAG.getUNDEF(VT),
25767 Mask: createUnaryMask(Mask: SVN->getMask(), NumElts));
25768
25769 // Canonicalize shuffle undef, v -> v, undef. Commute the shuffle mask.
25770 if (N0.isUndef())
25771 return DAG.getCommutedVectorShuffle(SV: *SVN);
25772
25773 // Remove references to rhs if it is undef
25774 if (N1.isUndef()) {
25775 bool Changed = false;
25776 SmallVector<int, 8> NewMask;
25777 for (unsigned i = 0; i != NumElts; ++i) {
25778 int Idx = SVN->getMaskElt(Idx: i);
25779 if (Idx >= (int)NumElts) {
25780 Idx = -1;
25781 Changed = true;
25782 }
25783 NewMask.push_back(Elt: Idx);
25784 }
25785 if (Changed)
25786 return DAG.getVectorShuffle(VT, dl: SDLoc(N), N1: N0, N2: N1, Mask: NewMask);
25787 }
25788
25789 if (SDValue InsElt = replaceShuffleOfInsert(Shuf: SVN, DAG))
25790 return InsElt;
25791
25792 // A shuffle of a single vector that is a splatted value can always be folded.
25793 if (SDValue V = combineShuffleOfSplatVal(Shuf: SVN, DAG))
25794 return V;
25795
25796 if (SDValue V = formSplatFromShuffles(OuterShuf: SVN, DAG))
25797 return V;
25798
25799 // If it is a splat, check if the argument vector is another splat or a
25800 // build_vector.
25801 if (SVN->isSplat() && SVN->getSplatIndex() < (int)NumElts) {
25802 int SplatIndex = SVN->getSplatIndex();
25803 if (N0.hasOneUse() && TLI.isExtractVecEltCheap(VT, Index: SplatIndex) &&
25804 TLI.isBinOp(Opcode: N0.getOpcode()) && N0->getNumValues() == 1) {
25805 // splat (vector_bo L, R), Index -->
25806 // splat (scalar_bo (extelt L, Index), (extelt R, Index))
25807 SDValue L = N0.getOperand(i: 0), R = N0.getOperand(i: 1);
25808 SDLoc DL(N);
25809 EVT EltVT = VT.getScalarType();
25810 SDValue Index = DAG.getVectorIdxConstant(Val: SplatIndex, DL);
25811 SDValue ExtL = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: L, N2: Index);
25812 SDValue ExtR = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: R, N2: Index);
25813 SDValue NewBO =
25814 DAG.getNode(Opcode: N0.getOpcode(), DL, VT: EltVT, N1: ExtL, N2: ExtR, Flags: N0->getFlags());
25815 SDValue Insert = DAG.getNode(Opcode: ISD::SCALAR_TO_VECTOR, DL, VT, Operand: NewBO);
25816 SmallVector<int, 16> ZeroMask(VT.getVectorNumElements(), 0);
25817 return DAG.getVectorShuffle(VT, dl: DL, N1: Insert, N2: DAG.getUNDEF(VT), Mask: ZeroMask);
25818 }
25819
25820 // splat(scalar_to_vector(x), 0) -> build_vector(x,...,x)
25821 // splat(insert_vector_elt(v, x, c), c) -> build_vector(x,...,x)
25822 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::BUILD_VECTOR, VT)) &&
25823 N0.hasOneUse()) {
25824 if (N0.getOpcode() == ISD::SCALAR_TO_VECTOR && SplatIndex == 0)
25825 return DAG.getSplatBuildVector(VT, DL: SDLoc(N), Op: N0.getOperand(i: 0));
25826
25827 if (N0.getOpcode() == ISD::INSERT_VECTOR_ELT)
25828 if (auto *Idx = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 2)))
25829 if (Idx->getAPIntValue() == SplatIndex)
25830 return DAG.getSplatBuildVector(VT, DL: SDLoc(N), Op: N0.getOperand(i: 1));
25831
25832 // Look through a bitcast if LE and splatting lane 0, through to a
25833 // scalar_to_vector or a build_vector.
25834 if (N0.getOpcode() == ISD::BITCAST && N0.getOperand(i: 0).hasOneUse() &&
25835 SplatIndex == 0 && DAG.getDataLayout().isLittleEndian() &&
25836 (N0.getOperand(i: 0).getOpcode() == ISD::SCALAR_TO_VECTOR ||
25837 N0.getOperand(i: 0).getOpcode() == ISD::BUILD_VECTOR)) {
25838 EVT N00VT = N0.getOperand(i: 0).getValueType();
25839 if (VT.getScalarSizeInBits() <= N00VT.getScalarSizeInBits() &&
25840 VT.isInteger() && N00VT.isInteger()) {
25841 EVT InVT =
25842 TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT: VT.getScalarType());
25843 SDValue Op = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 0).getOperand(i: 0),
25844 DL: SDLoc(N), VT: InVT);
25845 return DAG.getSplatBuildVector(VT, DL: SDLoc(N), Op);
25846 }
25847 }
25848 }
25849
25850 // If this is a bit convert that changes the element type of the vector but
25851 // not the number of vector elements, look through it. Be careful not to
25852 // look though conversions that change things like v4f32 to v2f64.
25853 SDNode *V = N0.getNode();
25854 if (V->getOpcode() == ISD::BITCAST) {
25855 SDValue ConvInput = V->getOperand(Num: 0);
25856 if (ConvInput.getValueType().isVector() &&
25857 ConvInput.getValueType().getVectorNumElements() == NumElts)
25858 V = ConvInput.getNode();
25859 }
25860
25861 if (V->getOpcode() == ISD::BUILD_VECTOR) {
25862 assert(V->getNumOperands() == NumElts &&
25863 "BUILD_VECTOR has wrong number of operands");
25864 SDValue Base;
25865 bool AllSame = true;
25866 for (unsigned i = 0; i != NumElts; ++i) {
25867 if (!V->getOperand(Num: i).isUndef()) {
25868 Base = V->getOperand(Num: i);
25869 break;
25870 }
25871 }
25872 // Splat of <u, u, u, u>, return <u, u, u, u>
25873 if (!Base.getNode())
25874 return N0;
25875 for (unsigned i = 0; i != NumElts; ++i) {
25876 if (V->getOperand(Num: i) != Base) {
25877 AllSame = false;
25878 break;
25879 }
25880 }
25881 // Splat of <x, x, x, x>, return <x, x, x, x>
25882 if (AllSame)
25883 return N0;
25884
25885 // Canonicalize any other splat as a build_vector.
25886 SDValue Splatted = V->getOperand(Num: SplatIndex);
25887 SmallVector<SDValue, 8> Ops(NumElts, Splatted);
25888 SDValue NewBV = DAG.getBuildVector(VT: V->getValueType(ResNo: 0), DL: SDLoc(N), Ops);
25889
25890 // We may have jumped through bitcasts, so the type of the
25891 // BUILD_VECTOR may not match the type of the shuffle.
25892 if (V->getValueType(ResNo: 0) != VT)
25893 NewBV = DAG.getBitcast(VT, V: NewBV);
25894 return NewBV;
25895 }
25896 }
25897
25898 // Simplify source operands based on shuffle mask.
25899 if (SimplifyDemandedVectorElts(Op: SDValue(N, 0)))
25900 return SDValue(N, 0);
25901
25902 // This is intentionally placed after demanded elements simplification because
25903 // it could eliminate knowledge of undef elements created by this shuffle.
25904 if (SDValue ShufOp = simplifyShuffleOfShuffle(Shuf: SVN))
25905 return ShufOp;
25906
25907 // Match shuffles that can be converted to any_vector_extend_in_reg.
25908 if (SDValue V =
25909 combineShuffleToAnyExtendVectorInreg(SVN, DAG, TLI, LegalOperations))
25910 return V;
25911
25912 // Combine "truncate_vector_in_reg" style shuffles.
25913 if (SDValue V = combineTruncationShuffle(SVN, DAG))
25914 return V;
25915
25916 if (N0.getOpcode() == ISD::CONCAT_VECTORS &&
25917 Level < AfterLegalizeVectorOps &&
25918 (N1.isUndef() ||
25919 (N1.getOpcode() == ISD::CONCAT_VECTORS &&
25920 N0.getOperand(i: 0).getValueType() == N1.getOperand(i: 0).getValueType()))) {
25921 if (SDValue V = partitionShuffleOfConcats(N, DAG))
25922 return V;
25923 }
25924
25925 // A shuffle of a concat of the same narrow vector can be reduced to use
25926 // only low-half elements of a concat with undef:
25927 // shuf (concat X, X), undef, Mask --> shuf (concat X, undef), undef, Mask'
25928 if (N0.getOpcode() == ISD::CONCAT_VECTORS && N1.isUndef() &&
25929 N0.getNumOperands() == 2 &&
25930 N0.getOperand(i: 0) == N0.getOperand(i: 1)) {
25931 int HalfNumElts = (int)NumElts / 2;
25932 SmallVector<int, 8> NewMask;
25933 for (unsigned i = 0; i != NumElts; ++i) {
25934 int Idx = SVN->getMaskElt(Idx: i);
25935 if (Idx >= HalfNumElts) {
25936 assert(Idx < (int)NumElts && "Shuffle mask chooses undef op");
25937 Idx -= HalfNumElts;
25938 }
25939 NewMask.push_back(Elt: Idx);
25940 }
25941 if (TLI.isShuffleMaskLegal(NewMask, VT)) {
25942 SDValue UndefVec = DAG.getUNDEF(VT: N0.getOperand(i: 0).getValueType());
25943 SDValue NewCat = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT,
25944 N1: N0.getOperand(i: 0), N2: UndefVec);
25945 return DAG.getVectorShuffle(VT, dl: SDLoc(N), N1: NewCat, N2: N1, Mask: NewMask);
25946 }
25947 }
25948
25949 // See if we can replace a shuffle with an insert_subvector.
25950 // e.g. v2i32 into v8i32:
25951 // shuffle(lhs,concat(rhs0,rhs1,rhs2,rhs3),0,1,2,3,10,11,6,7).
25952 // --> insert_subvector(lhs,rhs1,4).
25953 if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT) &&
25954 TLI.isOperationLegalOrCustom(Op: ISD::INSERT_SUBVECTOR, VT)) {
25955 auto ShuffleToInsert = [&](SDValue LHS, SDValue RHS, ArrayRef<int> Mask) {
25956 // Ensure RHS subvectors are legal.
25957 assert(RHS.getOpcode() == ISD::CONCAT_VECTORS && "Can't find subvectors");
25958 EVT SubVT = RHS.getOperand(i: 0).getValueType();
25959 int NumSubVecs = RHS.getNumOperands();
25960 int NumSubElts = SubVT.getVectorNumElements();
25961 assert((NumElts % NumSubElts) == 0 && "Subvector mismatch");
25962 if (!TLI.isTypeLegal(VT: SubVT))
25963 return SDValue();
25964
25965 // Don't bother if we have an unary shuffle (matches undef + LHS elts).
25966 if (all_of(Range&: Mask, P: [NumElts](int M) { return M < (int)NumElts; }))
25967 return SDValue();
25968
25969 // Search [NumSubElts] spans for RHS sequence.
25970 // TODO: Can we avoid nested loops to increase performance?
25971 SmallVector<int> InsertionMask(NumElts);
25972 for (int SubVec = 0; SubVec != NumSubVecs; ++SubVec) {
25973 for (int SubIdx = 0; SubIdx != (int)NumElts; SubIdx += NumSubElts) {
25974 // Reset mask to identity.
25975 std::iota(first: InsertionMask.begin(), last: InsertionMask.end(), value: 0);
25976
25977 // Add subvector insertion.
25978 std::iota(first: InsertionMask.begin() + SubIdx,
25979 last: InsertionMask.begin() + SubIdx + NumSubElts,
25980 value: NumElts + (SubVec * NumSubElts));
25981
25982 // See if the shuffle mask matches the reference insertion mask.
25983 bool MatchingShuffle = true;
25984 for (int i = 0; i != (int)NumElts; ++i) {
25985 int ExpectIdx = InsertionMask[i];
25986 int ActualIdx = Mask[i];
25987 if (0 <= ActualIdx && ExpectIdx != ActualIdx) {
25988 MatchingShuffle = false;
25989 break;
25990 }
25991 }
25992
25993 if (MatchingShuffle)
25994 return DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N), VT, N1: LHS,
25995 N2: RHS.getOperand(i: SubVec),
25996 N3: DAG.getVectorIdxConstant(Val: SubIdx, DL: SDLoc(N)));
25997 }
25998 }
25999 return SDValue();
26000 };
26001 ArrayRef<int> Mask = SVN->getMask();
26002 if (N1.getOpcode() == ISD::CONCAT_VECTORS)
26003 if (SDValue InsertN1 = ShuffleToInsert(N0, N1, Mask))
26004 return InsertN1;
26005 if (N0.getOpcode() == ISD::CONCAT_VECTORS) {
26006 SmallVector<int> CommuteMask(Mask);
26007 ShuffleVectorSDNode::commuteMask(Mask: CommuteMask);
26008 if (SDValue InsertN0 = ShuffleToInsert(N1, N0, CommuteMask))
26009 return InsertN0;
26010 }
26011 }
26012
26013 // If we're not performing a select/blend shuffle, see if we can convert the
26014 // shuffle into a AND node, with all the out-of-lane elements are known zero.
26015 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
26016 bool IsInLaneMask = true;
26017 ArrayRef<int> Mask = SVN->getMask();
26018 SmallVector<int, 16> ClearMask(NumElts, -1);
26019 APInt DemandedLHS = APInt::getZero(numBits: NumElts);
26020 APInt DemandedRHS = APInt::getZero(numBits: NumElts);
26021 for (int I = 0; I != (int)NumElts; ++I) {
26022 int M = Mask[I];
26023 if (M < 0)
26024 continue;
26025 ClearMask[I] = M == I ? I : (I + NumElts);
26026 IsInLaneMask &= (M == I) || (M == (int)(I + NumElts));
26027 if (M != I) {
26028 APInt &Demanded = M < (int)NumElts ? DemandedLHS : DemandedRHS;
26029 Demanded.setBit(M % NumElts);
26030 }
26031 }
26032 // TODO: Should we try to mask with N1 as well?
26033 if (!IsInLaneMask && (!DemandedLHS.isZero() || !DemandedRHS.isZero()) &&
26034 (DemandedLHS.isZero() || DAG.MaskedVectorIsZero(Op: N0, DemandedElts: DemandedLHS)) &&
26035 (DemandedRHS.isZero() || DAG.MaskedVectorIsZero(Op: N1, DemandedElts: DemandedRHS))) {
26036 SDLoc DL(N);
26037 EVT IntVT = VT.changeVectorElementTypeToInteger();
26038 EVT IntSVT = VT.getVectorElementType().changeTypeToInteger();
26039 // Transform the type to a legal type so that the buildvector constant
26040 // elements are not illegal. Make sure that the result is larger than the
26041 // original type, incase the value is split into two (eg i64->i32).
26042 if (!TLI.isTypeLegal(VT: IntSVT) && LegalTypes)
26043 IntSVT = TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT: IntSVT);
26044 if (IntSVT.getSizeInBits() >= IntVT.getScalarSizeInBits()) {
26045 SDValue ZeroElt = DAG.getConstant(Val: 0, DL, VT: IntSVT);
26046 SDValue AllOnesElt = DAG.getAllOnesConstant(DL, VT: IntSVT);
26047 SmallVector<SDValue, 16> AndMask(NumElts, DAG.getUNDEF(VT: IntSVT));
26048 for (int I = 0; I != (int)NumElts; ++I)
26049 if (0 <= Mask[I])
26050 AndMask[I] = Mask[I] == I ? AllOnesElt : ZeroElt;
26051
26052 // See if a clear mask is legal instead of going via
26053 // XformToShuffleWithZero which loses UNDEF mask elements.
26054 if (TLI.isVectorClearMaskLegal(ClearMask, IntVT))
26055 return DAG.getBitcast(
26056 VT, V: DAG.getVectorShuffle(VT: IntVT, dl: DL, N1: DAG.getBitcast(VT: IntVT, V: N0),
26057 N2: DAG.getConstant(Val: 0, DL, VT: IntVT), Mask: ClearMask));
26058
26059 if (TLI.isOperationLegalOrCustom(Op: ISD::AND, VT: IntVT))
26060 return DAG.getBitcast(
26061 VT, V: DAG.getNode(Opcode: ISD::AND, DL, VT: IntVT, N1: DAG.getBitcast(VT: IntVT, V: N0),
26062 N2: DAG.getBuildVector(VT: IntVT, DL, Ops: AndMask)));
26063 }
26064 }
26065 }
26066
26067 // Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
26068 // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
26069 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT))
26070 if (SDValue Res = combineShuffleOfScalars(SVN, DAG, TLI))
26071 return Res;
26072
26073 // If this shuffle only has a single input that is a bitcasted shuffle,
26074 // attempt to merge the 2 shuffles and suitably bitcast the inputs/output
26075 // back to their original types.
26076 if (N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
26077 N1.isUndef() && Level < AfterLegalizeVectorOps &&
26078 TLI.isTypeLegal(VT)) {
26079
26080 SDValue BC0 = peekThroughOneUseBitcasts(V: N0);
26081 if (BC0.getOpcode() == ISD::VECTOR_SHUFFLE && BC0.hasOneUse()) {
26082 EVT SVT = VT.getScalarType();
26083 EVT InnerVT = BC0->getValueType(ResNo: 0);
26084 EVT InnerSVT = InnerVT.getScalarType();
26085
26086 // Determine which shuffle works with the smaller scalar type.
26087 EVT ScaleVT = SVT.bitsLT(VT: InnerSVT) ? VT : InnerVT;
26088 EVT ScaleSVT = ScaleVT.getScalarType();
26089
26090 if (TLI.isTypeLegal(VT: ScaleVT) &&
26091 0 == (InnerSVT.getSizeInBits() % ScaleSVT.getSizeInBits()) &&
26092 0 == (SVT.getSizeInBits() % ScaleSVT.getSizeInBits())) {
26093 int InnerScale = InnerSVT.getSizeInBits() / ScaleSVT.getSizeInBits();
26094 int OuterScale = SVT.getSizeInBits() / ScaleSVT.getSizeInBits();
26095
26096 // Scale the shuffle masks to the smaller scalar type.
26097 ShuffleVectorSDNode *InnerSVN = cast<ShuffleVectorSDNode>(Val&: BC0);
26098 SmallVector<int, 8> InnerMask;
26099 SmallVector<int, 8> OuterMask;
26100 narrowShuffleMaskElts(Scale: InnerScale, Mask: InnerSVN->getMask(), ScaledMask&: InnerMask);
26101 narrowShuffleMaskElts(Scale: OuterScale, Mask: SVN->getMask(), ScaledMask&: OuterMask);
26102
26103 // Merge the shuffle masks.
26104 SmallVector<int, 8> NewMask;
26105 for (int M : OuterMask)
26106 NewMask.push_back(Elt: M < 0 ? -1 : InnerMask[M]);
26107
26108 // Test for shuffle mask legality over both commutations.
26109 SDValue SV0 = BC0->getOperand(Num: 0);
26110 SDValue SV1 = BC0->getOperand(Num: 1);
26111 bool LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
26112 if (!LegalMask) {
26113 std::swap(a&: SV0, b&: SV1);
26114 ShuffleVectorSDNode::commuteMask(Mask: NewMask);
26115 LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
26116 }
26117
26118 if (LegalMask) {
26119 SV0 = DAG.getBitcast(VT: ScaleVT, V: SV0);
26120 SV1 = DAG.getBitcast(VT: ScaleVT, V: SV1);
26121 return DAG.getBitcast(
26122 VT, V: DAG.getVectorShuffle(VT: ScaleVT, dl: SDLoc(N), N1: SV0, N2: SV1, Mask: NewMask));
26123 }
26124 }
26125 }
26126 }
26127
26128 // Match shuffles of bitcasts, so long as the mask can be treated as the
26129 // larger type.
26130 if (SDValue V = combineShuffleOfBitcast(SVN, DAG, TLI, LegalOperations))
26131 return V;
26132
26133 // Compute the combined shuffle mask for a shuffle with SV0 as the first
26134 // operand, and SV1 as the second operand.
26135 // i.e. Merge SVN(OtherSVN, N1) -> shuffle(SV0, SV1, Mask) iff Commute = false
26136 // Merge SVN(N1, OtherSVN) -> shuffle(SV0, SV1, Mask') iff Commute = true
26137 auto MergeInnerShuffle =
26138 [NumElts, &VT](bool Commute, ShuffleVectorSDNode *SVN,
26139 ShuffleVectorSDNode *OtherSVN, SDValue N1,
26140 const TargetLowering &TLI, SDValue &SV0, SDValue &SV1,
26141 SmallVectorImpl<int> &Mask) -> bool {
26142 // Don't try to fold splats; they're likely to simplify somehow, or they
26143 // might be free.
26144 if (OtherSVN->isSplat())
26145 return false;
26146
26147 SV0 = SV1 = SDValue();
26148 Mask.clear();
26149
26150 for (unsigned i = 0; i != NumElts; ++i) {
26151 int Idx = SVN->getMaskElt(Idx: i);
26152 if (Idx < 0) {
26153 // Propagate Undef.
26154 Mask.push_back(Elt: Idx);
26155 continue;
26156 }
26157
26158 if (Commute)
26159 Idx = (Idx < (int)NumElts) ? (Idx + NumElts) : (Idx - NumElts);
26160
26161 SDValue CurrentVec;
26162 if (Idx < (int)NumElts) {
26163 // This shuffle index refers to the inner shuffle N0. Lookup the inner
26164 // shuffle mask to identify which vector is actually referenced.
26165 Idx = OtherSVN->getMaskElt(Idx);
26166 if (Idx < 0) {
26167 // Propagate Undef.
26168 Mask.push_back(Elt: Idx);
26169 continue;
26170 }
26171 CurrentVec = (Idx < (int)NumElts) ? OtherSVN->getOperand(Num: 0)
26172 : OtherSVN->getOperand(Num: 1);
26173 } else {
26174 // This shuffle index references an element within N1.
26175 CurrentVec = N1;
26176 }
26177
26178 // Simple case where 'CurrentVec' is UNDEF.
26179 if (CurrentVec.isUndef()) {
26180 Mask.push_back(Elt: -1);
26181 continue;
26182 }
26183
26184 // Canonicalize the shuffle index. We don't know yet if CurrentVec
26185 // will be the first or second operand of the combined shuffle.
26186 Idx = Idx % NumElts;
26187 if (!SV0.getNode() || SV0 == CurrentVec) {
26188 // Ok. CurrentVec is the left hand side.
26189 // Update the mask accordingly.
26190 SV0 = CurrentVec;
26191 Mask.push_back(Elt: Idx);
26192 continue;
26193 }
26194 if (!SV1.getNode() || SV1 == CurrentVec) {
26195 // Ok. CurrentVec is the right hand side.
26196 // Update the mask accordingly.
26197 SV1 = CurrentVec;
26198 Mask.push_back(Elt: Idx + NumElts);
26199 continue;
26200 }
26201
26202 // Last chance - see if the vector is another shuffle and if it
26203 // uses one of the existing candidate shuffle ops.
26204 if (auto *CurrentSVN = dyn_cast<ShuffleVectorSDNode>(Val&: CurrentVec)) {
26205 int InnerIdx = CurrentSVN->getMaskElt(Idx);
26206 if (InnerIdx < 0) {
26207 Mask.push_back(Elt: -1);
26208 continue;
26209 }
26210 SDValue InnerVec = (InnerIdx < (int)NumElts)
26211 ? CurrentSVN->getOperand(Num: 0)
26212 : CurrentSVN->getOperand(Num: 1);
26213 if (InnerVec.isUndef()) {
26214 Mask.push_back(Elt: -1);
26215 continue;
26216 }
26217 InnerIdx %= NumElts;
26218 if (InnerVec == SV0) {
26219 Mask.push_back(Elt: InnerIdx);
26220 continue;
26221 }
26222 if (InnerVec == SV1) {
26223 Mask.push_back(Elt: InnerIdx + NumElts);
26224 continue;
26225 }
26226 }
26227
26228 // Bail out if we cannot convert the shuffle pair into a single shuffle.
26229 return false;
26230 }
26231
26232 if (llvm::all_of(Range&: Mask, P: [](int M) { return M < 0; }))
26233 return true;
26234
26235 // Avoid introducing shuffles with illegal mask.
26236 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
26237 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
26238 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
26239 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, A, M2)
26240 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, A, M2)
26241 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, B, M2)
26242 if (TLI.isShuffleMaskLegal(Mask, VT))
26243 return true;
26244
26245 std::swap(a&: SV0, b&: SV1);
26246 ShuffleVectorSDNode::commuteMask(Mask);
26247 return TLI.isShuffleMaskLegal(Mask, VT);
26248 };
26249
26250 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
26251 // Canonicalize shuffles according to rules:
26252 // shuffle(A, shuffle(A, B)) -> shuffle(shuffle(A,B), A)
26253 // shuffle(B, shuffle(A, B)) -> shuffle(shuffle(A,B), B)
26254 // shuffle(B, shuffle(A, Undef)) -> shuffle(shuffle(A, Undef), B)
26255 if (N1.getOpcode() == ISD::VECTOR_SHUFFLE &&
26256 N0.getOpcode() != ISD::VECTOR_SHUFFLE) {
26257 // The incoming shuffle must be of the same type as the result of the
26258 // current shuffle.
26259 assert(N1->getOperand(0).getValueType() == VT &&
26260 "Shuffle types don't match");
26261
26262 SDValue SV0 = N1->getOperand(Num: 0);
26263 SDValue SV1 = N1->getOperand(Num: 1);
26264 bool HasSameOp0 = N0 == SV0;
26265 bool IsSV1Undef = SV1.isUndef();
26266 if (HasSameOp0 || IsSV1Undef || N0 == SV1)
26267 // Commute the operands of this shuffle so merging below will trigger.
26268 return DAG.getCommutedVectorShuffle(SV: *SVN);
26269 }
26270
26271 // Canonicalize splat shuffles to the RHS to improve merging below.
26272 // shuffle(splat(A,u), shuffle(C,D)) -> shuffle'(shuffle(C,D), splat(A,u))
26273 if (N0.getOpcode() == ISD::VECTOR_SHUFFLE &&
26274 N1.getOpcode() == ISD::VECTOR_SHUFFLE &&
26275 cast<ShuffleVectorSDNode>(Val&: N0)->isSplat() &&
26276 !cast<ShuffleVectorSDNode>(Val&: N1)->isSplat()) {
26277 return DAG.getCommutedVectorShuffle(SV: *SVN);
26278 }
26279
26280 // Try to fold according to rules:
26281 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
26282 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
26283 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
26284 // Don't try to fold shuffles with illegal type.
26285 // Only fold if this shuffle is the only user of the other shuffle.
26286 // Try matching shuffle(C,shuffle(A,B)) commutted patterns as well.
26287 for (int i = 0; i != 2; ++i) {
26288 if (N->getOperand(Num: i).getOpcode() == ISD::VECTOR_SHUFFLE &&
26289 N->isOnlyUserOf(N: N->getOperand(Num: i).getNode())) {
26290 // The incoming shuffle must be of the same type as the result of the
26291 // current shuffle.
26292 auto *OtherSV = cast<ShuffleVectorSDNode>(Val: N->getOperand(Num: i));
26293 assert(OtherSV->getOperand(0).getValueType() == VT &&
26294 "Shuffle types don't match");
26295
26296 SDValue SV0, SV1;
26297 SmallVector<int, 4> Mask;
26298 if (MergeInnerShuffle(i != 0, SVN, OtherSV, N->getOperand(Num: 1 - i), TLI,
26299 SV0, SV1, Mask)) {
26300 // Check if all indices in Mask are Undef. In case, propagate Undef.
26301 if (llvm::all_of(Range&: Mask, P: [](int M) { return M < 0; }))
26302 return DAG.getUNDEF(VT);
26303
26304 return DAG.getVectorShuffle(VT, dl: SDLoc(N),
26305 N1: SV0 ? SV0 : DAG.getUNDEF(VT),
26306 N2: SV1 ? SV1 : DAG.getUNDEF(VT), Mask);
26307 }
26308 }
26309 }
26310
26311 // Merge shuffles through binops if we are able to merge it with at least
26312 // one other shuffles.
26313 // shuffle(bop(shuffle(x,y),shuffle(z,w)),undef)
26314 // shuffle(bop(shuffle(x,y),shuffle(z,w)),bop(shuffle(a,b),shuffle(c,d)))
26315 unsigned SrcOpcode = N0.getOpcode();
26316 if (TLI.isBinOp(Opcode: SrcOpcode) && N->isOnlyUserOf(N: N0.getNode()) &&
26317 (N1.isUndef() ||
26318 (SrcOpcode == N1.getOpcode() && N->isOnlyUserOf(N: N1.getNode())))) {
26319 // Get binop source ops, or just pass on the undef.
26320 SDValue Op00 = N0.getOperand(i: 0);
26321 SDValue Op01 = N0.getOperand(i: 1);
26322 SDValue Op10 = N1.isUndef() ? N1 : N1.getOperand(i: 0);
26323 SDValue Op11 = N1.isUndef() ? N1 : N1.getOperand(i: 1);
26324 // TODO: We might be able to relax the VT check but we don't currently
26325 // have any isBinOp() that has different result/ops VTs so play safe until
26326 // we have test coverage.
26327 if (Op00.getValueType() == VT && Op10.getValueType() == VT &&
26328 Op01.getValueType() == VT && Op11.getValueType() == VT &&
26329 (Op00.getOpcode() == ISD::VECTOR_SHUFFLE ||
26330 Op10.getOpcode() == ISD::VECTOR_SHUFFLE ||
26331 Op01.getOpcode() == ISD::VECTOR_SHUFFLE ||
26332 Op11.getOpcode() == ISD::VECTOR_SHUFFLE)) {
26333 auto CanMergeInnerShuffle = [&](SDValue &SV0, SDValue &SV1,
26334 SmallVectorImpl<int> &Mask, bool LeftOp,
26335 bool Commute) {
26336 SDValue InnerN = Commute ? N1 : N0;
26337 SDValue Op0 = LeftOp ? Op00 : Op01;
26338 SDValue Op1 = LeftOp ? Op10 : Op11;
26339 if (Commute)
26340 std::swap(a&: Op0, b&: Op1);
26341 // Only accept the merged shuffle if we don't introduce undef elements,
26342 // or the inner shuffle already contained undef elements.
26343 auto *SVN0 = dyn_cast<ShuffleVectorSDNode>(Val&: Op0);
26344 return SVN0 && InnerN->isOnlyUserOf(N: SVN0) &&
26345 MergeInnerShuffle(Commute, SVN, SVN0, Op1, TLI, SV0, SV1,
26346 Mask) &&
26347 (llvm::any_of(Range: SVN0->getMask(), P: [](int M) { return M < 0; }) ||
26348 llvm::none_of(Range&: Mask, P: [](int M) { return M < 0; }));
26349 };
26350
26351 // Ensure we don't increase the number of shuffles - we must merge a
26352 // shuffle from at least one of the LHS and RHS ops.
26353 bool MergedLeft = false;
26354 SDValue LeftSV0, LeftSV1;
26355 SmallVector<int, 4> LeftMask;
26356 if (CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, false) ||
26357 CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, true)) {
26358 MergedLeft = true;
26359 } else {
26360 LeftMask.assign(in_start: SVN->getMask().begin(), in_end: SVN->getMask().end());
26361 LeftSV0 = Op00, LeftSV1 = Op10;
26362 }
26363
26364 bool MergedRight = false;
26365 SDValue RightSV0, RightSV1;
26366 SmallVector<int, 4> RightMask;
26367 if (CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, false) ||
26368 CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, true)) {
26369 MergedRight = true;
26370 } else {
26371 RightMask.assign(in_start: SVN->getMask().begin(), in_end: SVN->getMask().end());
26372 RightSV0 = Op01, RightSV1 = Op11;
26373 }
26374
26375 if (MergedLeft || MergedRight) {
26376 SDLoc DL(N);
26377 SDValue LHS = DAG.getVectorShuffle(
26378 VT, dl: DL, N1: LeftSV0 ? LeftSV0 : DAG.getUNDEF(VT),
26379 N2: LeftSV1 ? LeftSV1 : DAG.getUNDEF(VT), Mask: LeftMask);
26380 SDValue RHS = DAG.getVectorShuffle(
26381 VT, dl: DL, N1: RightSV0 ? RightSV0 : DAG.getUNDEF(VT),
26382 N2: RightSV1 ? RightSV1 : DAG.getUNDEF(VT), Mask: RightMask);
26383 return DAG.getNode(Opcode: SrcOpcode, DL, VT, N1: LHS, N2: RHS);
26384 }
26385 }
26386 }
26387 }
26388
26389 if (SDValue V = foldShuffleOfConcatUndefs(Shuf: SVN, DAG))
26390 return V;
26391
26392 // Match shuffles that can be converted to ISD::ZERO_EXTEND_VECTOR_INREG.
26393 // Perform this really late, because it could eliminate knowledge
26394 // of undef elements created by this shuffle.
26395 if (Level < AfterLegalizeTypes)
26396 if (SDValue V = combineShuffleToZeroExtendVectorInReg(SVN, DAG, TLI,
26397 LegalOperations))
26398 return V;
26399
26400 return SDValue();
26401}
26402
26403SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) {
26404 EVT VT = N->getValueType(ResNo: 0);
26405 if (!VT.isFixedLengthVector())
26406 return SDValue();
26407
26408 // Try to convert a scalar binop with an extracted vector element to a vector
26409 // binop. This is intended to reduce potentially expensive register moves.
26410 // TODO: Check if both operands are extracted.
26411 // TODO: How to prefer scalar/vector ops with multiple uses of the extact?
26412 // TODO: Generalize this, so it can be called from visitINSERT_VECTOR_ELT().
26413 SDValue Scalar = N->getOperand(Num: 0);
26414 unsigned Opcode = Scalar.getOpcode();
26415 EVT VecEltVT = VT.getScalarType();
26416 if (Scalar.hasOneUse() && Scalar->getNumValues() == 1 &&
26417 TLI.isBinOp(Opcode) && Scalar.getValueType() == VecEltVT &&
26418 Scalar.getOperand(i: 0).getValueType() == VecEltVT &&
26419 Scalar.getOperand(i: 1).getValueType() == VecEltVT &&
26420 Scalar->isOnlyUserOf(N: Scalar.getOperand(i: 0).getNode()) &&
26421 Scalar->isOnlyUserOf(N: Scalar.getOperand(i: 1).getNode()) &&
26422 DAG.isSafeToSpeculativelyExecute(Opcode) && hasOperation(Opcode, VT)) {
26423 // Match an extract element and get a shuffle mask equivalent.
26424 SmallVector<int, 8> ShufMask(VT.getVectorNumElements(), -1);
26425
26426 for (int i : {0, 1}) {
26427 // s2v (bo (extelt V, Idx), C) --> shuffle (bo V, C'), {Idx, -1, -1...}
26428 // s2v (bo C, (extelt V, Idx)) --> shuffle (bo C', V), {Idx, -1, -1...}
26429 SDValue EE = Scalar.getOperand(i);
26430 auto *C = dyn_cast<ConstantSDNode>(Val: Scalar.getOperand(i: i ? 0 : 1));
26431 if (C && EE.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
26432 EE.getOperand(i: 0).getValueType() == VT &&
26433 isa<ConstantSDNode>(Val: EE.getOperand(i: 1))) {
26434 // Mask = {ExtractIndex, undef, undef....}
26435 ShufMask[0] = EE.getConstantOperandVal(i: 1);
26436 // Make sure the shuffle is legal if we are crossing lanes.
26437 if (TLI.isShuffleMaskLegal(ShufMask, VT)) {
26438 SDLoc DL(N);
26439 SDValue V[] = {EE.getOperand(i: 0),
26440 DAG.getConstant(Val: C->getAPIntValue(), DL, VT)};
26441 SDValue VecBO = DAG.getNode(Opcode, DL, VT, N1: V[i], N2: V[1 - i]);
26442 return DAG.getVectorShuffle(VT, dl: DL, N1: VecBO, N2: DAG.getUNDEF(VT),
26443 Mask: ShufMask);
26444 }
26445 }
26446 }
26447 }
26448
26449 // Replace a SCALAR_TO_VECTOR(EXTRACT_VECTOR_ELT(V,C0)) pattern
26450 // with a VECTOR_SHUFFLE and possible truncate.
26451 if (Opcode != ISD::EXTRACT_VECTOR_ELT ||
26452 !Scalar.getOperand(i: 0).getValueType().isFixedLengthVector())
26453 return SDValue();
26454
26455 // If we have an implicit truncate, truncate here if it is legal.
26456 if (VecEltVT != Scalar.getValueType() &&
26457 Scalar.getValueType().isScalarInteger() && isTypeLegal(VT: VecEltVT)) {
26458 SDValue Val = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(Scalar), VT: VecEltVT, Operand: Scalar);
26459 return DAG.getNode(Opcode: ISD::SCALAR_TO_VECTOR, DL: SDLoc(N), VT, Operand: Val);
26460 }
26461
26462 auto *ExtIndexC = dyn_cast<ConstantSDNode>(Val: Scalar.getOperand(i: 1));
26463 if (!ExtIndexC)
26464 return SDValue();
26465
26466 SDValue SrcVec = Scalar.getOperand(i: 0);
26467 EVT SrcVT = SrcVec.getValueType();
26468 unsigned SrcNumElts = SrcVT.getVectorNumElements();
26469 unsigned VTNumElts = VT.getVectorNumElements();
26470 if (VecEltVT == SrcVT.getScalarType() && VTNumElts <= SrcNumElts) {
26471 // Create a shuffle equivalent for scalar-to-vector: {ExtIndex, -1, -1, ...}
26472 SmallVector<int, 8> Mask(SrcNumElts, -1);
26473 Mask[0] = ExtIndexC->getZExtValue();
26474 SDValue LegalShuffle = TLI.buildLegalVectorShuffle(
26475 VT: SrcVT, DL: SDLoc(N), N0: SrcVec, N1: DAG.getUNDEF(VT: SrcVT), Mask, DAG);
26476 if (!LegalShuffle)
26477 return SDValue();
26478
26479 // If the initial vector is the same size, the shuffle is the result.
26480 if (VT == SrcVT)
26481 return LegalShuffle;
26482
26483 // If not, shorten the shuffled vector.
26484 if (VTNumElts != SrcNumElts) {
26485 SDValue ZeroIdx = DAG.getVectorIdxConstant(Val: 0, DL: SDLoc(N));
26486 EVT SubVT = EVT::getVectorVT(Context&: *DAG.getContext(),
26487 VT: SrcVT.getVectorElementType(), NumElements: VTNumElts);
26488 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL: SDLoc(N), VT: SubVT, N1: LegalShuffle,
26489 N2: ZeroIdx);
26490 }
26491 }
26492
26493 return SDValue();
26494}
26495
26496SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) {
26497 EVT VT = N->getValueType(ResNo: 0);
26498 SDValue N0 = N->getOperand(Num: 0);
26499 SDValue N1 = N->getOperand(Num: 1);
26500 SDValue N2 = N->getOperand(Num: 2);
26501 uint64_t InsIdx = N->getConstantOperandVal(Num: 2);
26502
26503 // If inserting an UNDEF, just return the original vector.
26504 if (N1.isUndef())
26505 return N0;
26506
26507 // If this is an insert of an extracted vector into an undef vector, we can
26508 // just use the input to the extract if the types match, and can simplify
26509 // in some cases even if they don't.
26510 if (N0.isUndef() && N1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
26511 N1.getOperand(i: 1) == N2) {
26512 EVT SrcVT = N1.getOperand(i: 0).getValueType();
26513 if (SrcVT == VT)
26514 return N1.getOperand(i: 0);
26515 // TODO: To remove the zero check, need to adjust the offset to
26516 // a multiple of the new src type.
26517 if (isNullConstant(V: N2)) {
26518 if (VT.knownBitsGE(VT: SrcVT) &&
26519 !(VT.isFixedLengthVector() && SrcVT.isScalableVector()))
26520 return DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N),
26521 VT, N1: N0, N2: N1.getOperand(i: 0), N3: N2);
26522 else if (VT.knownBitsLE(VT: SrcVT) &&
26523 !(VT.isScalableVector() && SrcVT.isFixedLengthVector()))
26524 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL: SDLoc(N),
26525 VT, N1: N1.getOperand(i: 0), N2);
26526 }
26527 }
26528
26529 // Handle case where we've ended up inserting back into the source vector
26530 // we extracted the subvector from.
26531 // insert_subvector(N0, extract_subvector(N0, N2), N2) --> N0
26532 if (N1.getOpcode() == ISD::EXTRACT_SUBVECTOR && N1.getOperand(i: 0) == N0 &&
26533 N1.getOperand(i: 1) == N2)
26534 return N0;
26535
26536 // Simplify scalar inserts into an undef vector:
26537 // insert_subvector undef, (splat X), N2 -> splat X
26538 if (N0.isUndef() && N1.getOpcode() == ISD::SPLAT_VECTOR)
26539 if (DAG.isConstantValueOfAnyType(N: N1.getOperand(i: 0)) || N1.hasOneUse())
26540 return DAG.getNode(Opcode: ISD::SPLAT_VECTOR, DL: SDLoc(N), VT, Operand: N1.getOperand(i: 0));
26541
26542 // If we are inserting a bitcast value into an undef, with the same
26543 // number of elements, just use the bitcast input of the extract.
26544 // i.e. INSERT_SUBVECTOR UNDEF (BITCAST N1) N2 ->
26545 // BITCAST (INSERT_SUBVECTOR UNDEF N1 N2)
26546 if (N0.isUndef() && N1.getOpcode() == ISD::BITCAST &&
26547 N1.getOperand(i: 0).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
26548 N1.getOperand(i: 0).getOperand(i: 1) == N2 &&
26549 N1.getOperand(i: 0).getOperand(i: 0).getValueType().getVectorElementCount() ==
26550 VT.getVectorElementCount() &&
26551 N1.getOperand(i: 0).getOperand(i: 0).getValueType().getSizeInBits() ==
26552 VT.getSizeInBits()) {
26553 return DAG.getBitcast(VT, V: N1.getOperand(i: 0).getOperand(i: 0));
26554 }
26555
26556 // If both N1 and N2 are bitcast values on which insert_subvector
26557 // would makes sense, pull the bitcast through.
26558 // i.e. INSERT_SUBVECTOR (BITCAST N0) (BITCAST N1) N2 ->
26559 // BITCAST (INSERT_SUBVECTOR N0 N1 N2)
26560 if (N0.getOpcode() == ISD::BITCAST && N1.getOpcode() == ISD::BITCAST) {
26561 SDValue CN0 = N0.getOperand(i: 0);
26562 SDValue CN1 = N1.getOperand(i: 0);
26563 EVT CN0VT = CN0.getValueType();
26564 EVT CN1VT = CN1.getValueType();
26565 if (CN0VT.isVector() && CN1VT.isVector() &&
26566 CN0VT.getVectorElementType() == CN1VT.getVectorElementType() &&
26567 CN0VT.getVectorElementCount() == VT.getVectorElementCount()) {
26568 SDValue NewINSERT = DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N),
26569 VT: CN0.getValueType(), N1: CN0, N2: CN1, N3: N2);
26570 return DAG.getBitcast(VT, V: NewINSERT);
26571 }
26572 }
26573
26574 // Combine INSERT_SUBVECTORs where we are inserting to the same index.
26575 // INSERT_SUBVECTOR( INSERT_SUBVECTOR( Vec, SubOld, Idx ), SubNew, Idx )
26576 // --> INSERT_SUBVECTOR( Vec, SubNew, Idx )
26577 if (N0.getOpcode() == ISD::INSERT_SUBVECTOR &&
26578 N0.getOperand(i: 1).getValueType() == N1.getValueType() &&
26579 N0.getOperand(i: 2) == N2)
26580 return DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N), VT, N1: N0.getOperand(i: 0),
26581 N2: N1, N3: N2);
26582
26583 // Eliminate an intermediate insert into an undef vector:
26584 // insert_subvector undef, (insert_subvector undef, X, 0), 0 -->
26585 // insert_subvector undef, X, 0
26586 if (N0.isUndef() && N1.getOpcode() == ISD::INSERT_SUBVECTOR &&
26587 N1.getOperand(i: 0).isUndef() && isNullConstant(V: N1.getOperand(i: 2)) &&
26588 isNullConstant(V: N2))
26589 return DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N), VT, N1: N0,
26590 N2: N1.getOperand(i: 1), N3: N2);
26591
26592 // Push subvector bitcasts to the output, adjusting the index as we go.
26593 // insert_subvector(bitcast(v), bitcast(s), c1)
26594 // -> bitcast(insert_subvector(v, s, c2))
26595 if ((N0.isUndef() || N0.getOpcode() == ISD::BITCAST) &&
26596 N1.getOpcode() == ISD::BITCAST) {
26597 SDValue N0Src = peekThroughBitcasts(V: N0);
26598 SDValue N1Src = peekThroughBitcasts(V: N1);
26599 EVT N0SrcSVT = N0Src.getValueType().getScalarType();
26600 EVT N1SrcSVT = N1Src.getValueType().getScalarType();
26601 if ((N0.isUndef() || N0SrcSVT == N1SrcSVT) &&
26602 N0Src.getValueType().isVector() && N1Src.getValueType().isVector()) {
26603 EVT NewVT;
26604 SDLoc DL(N);
26605 SDValue NewIdx;
26606 LLVMContext &Ctx = *DAG.getContext();
26607 ElementCount NumElts = VT.getVectorElementCount();
26608 unsigned EltSizeInBits = VT.getScalarSizeInBits();
26609 if ((EltSizeInBits % N1SrcSVT.getSizeInBits()) == 0) {
26610 unsigned Scale = EltSizeInBits / N1SrcSVT.getSizeInBits();
26611 NewVT = EVT::getVectorVT(Context&: Ctx, VT: N1SrcSVT, EC: NumElts * Scale);
26612 NewIdx = DAG.getVectorIdxConstant(Val: InsIdx * Scale, DL);
26613 } else if ((N1SrcSVT.getSizeInBits() % EltSizeInBits) == 0) {
26614 unsigned Scale = N1SrcSVT.getSizeInBits() / EltSizeInBits;
26615 if (NumElts.isKnownMultipleOf(RHS: Scale) && (InsIdx % Scale) == 0) {
26616 NewVT = EVT::getVectorVT(Context&: Ctx, VT: N1SrcSVT,
26617 EC: NumElts.divideCoefficientBy(RHS: Scale));
26618 NewIdx = DAG.getVectorIdxConstant(Val: InsIdx / Scale, DL);
26619 }
26620 }
26621 if (NewIdx && hasOperation(Opcode: ISD::INSERT_SUBVECTOR, VT: NewVT)) {
26622 SDValue Res = DAG.getBitcast(VT: NewVT, V: N0Src);
26623 Res = DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL, VT: NewVT, N1: Res, N2: N1Src, N3: NewIdx);
26624 return DAG.getBitcast(VT, V: Res);
26625 }
26626 }
26627 }
26628
26629 // Canonicalize insert_subvector dag nodes.
26630 // Example:
26631 // (insert_subvector (insert_subvector A, Idx0), Idx1)
26632 // -> (insert_subvector (insert_subvector A, Idx1), Idx0)
26633 if (N0.getOpcode() == ISD::INSERT_SUBVECTOR && N0.hasOneUse() &&
26634 N1.getValueType() == N0.getOperand(i: 1).getValueType()) {
26635 unsigned OtherIdx = N0.getConstantOperandVal(i: 2);
26636 if (InsIdx < OtherIdx) {
26637 // Swap nodes.
26638 SDValue NewOp = DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N), VT,
26639 N1: N0.getOperand(i: 0), N2: N1, N3: N2);
26640 AddToWorklist(N: NewOp.getNode());
26641 return DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N0.getNode()),
26642 VT, N1: NewOp, N2: N0.getOperand(i: 1), N3: N0.getOperand(i: 2));
26643 }
26644 }
26645
26646 // If the input vector is a concatenation, and the insert replaces
26647 // one of the pieces, we can optimize into a single concat_vectors.
26648 if (N0.getOpcode() == ISD::CONCAT_VECTORS && N0.hasOneUse() &&
26649 N0.getOperand(i: 0).getValueType() == N1.getValueType() &&
26650 N0.getOperand(i: 0).getValueType().isScalableVector() ==
26651 N1.getValueType().isScalableVector()) {
26652 unsigned Factor = N1.getValueType().getVectorMinNumElements();
26653 SmallVector<SDValue, 8> Ops(N0->op_begin(), N0->op_end());
26654 Ops[InsIdx / Factor] = N1;
26655 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT, Ops);
26656 }
26657
26658 // Simplify source operands based on insertion.
26659 if (SimplifyDemandedVectorElts(Op: SDValue(N, 0)))
26660 return SDValue(N, 0);
26661
26662 return SDValue();
26663}
26664
26665SDValue DAGCombiner::visitFP_TO_FP16(SDNode *N) {
26666 SDValue N0 = N->getOperand(Num: 0);
26667
26668 // fold (fp_to_fp16 (fp16_to_fp op)) -> op
26669 if (N0->getOpcode() == ISD::FP16_TO_FP)
26670 return N0->getOperand(Num: 0);
26671
26672 return SDValue();
26673}
26674
26675SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
26676 auto Op = N->getOpcode();
26677 assert((Op == ISD::FP16_TO_FP || Op == ISD::BF16_TO_FP) &&
26678 "opcode should be FP16_TO_FP or BF16_TO_FP.");
26679 SDValue N0 = N->getOperand(Num: 0);
26680
26681 // fold fp16_to_fp(op & 0xffff) -> fp16_to_fp(op) or
26682 // fold bf16_to_fp(op & 0xffff) -> bf16_to_fp(op)
26683 if (!TLI.shouldKeepZExtForFP16Conv() && N0->getOpcode() == ISD::AND) {
26684 ConstantSDNode *AndConst = getAsNonOpaqueConstant(N: N0.getOperand(i: 1));
26685 if (AndConst && AndConst->getAPIntValue() == 0xffff) {
26686 return DAG.getNode(Opcode: Op, DL: SDLoc(N), VT: N->getValueType(ResNo: 0), Operand: N0.getOperand(i: 0));
26687 }
26688 }
26689
26690 // Sometimes constants manage to survive very late in the pipeline, e.g.,
26691 // because they are wrapped inside the <1 x f16> type. Try one last time to
26692 // get rid of them.
26693 SDValue Folded = DAG.FoldConstantArithmetic(Opcode: N->getOpcode(), DL: SDLoc(N),
26694 VT: N->getValueType(ResNo: 0), Ops: {N0});
26695 return Folded;
26696}
26697
26698SDValue DAGCombiner::visitFP_TO_BF16(SDNode *N) {
26699 SDValue N0 = N->getOperand(Num: 0);
26700
26701 // fold (fp_to_bf16 (bf16_to_fp op)) -> op
26702 if (N0->getOpcode() == ISD::BF16_TO_FP)
26703 return N0->getOperand(Num: 0);
26704
26705 return SDValue();
26706}
26707
26708SDValue DAGCombiner::visitBF16_TO_FP(SDNode *N) {
26709 // fold bf16_to_fp(op & 0xffff) -> bf16_to_fp(op)
26710 return visitFP16_TO_FP(N);
26711}
26712
26713SDValue DAGCombiner::visitVECREDUCE(SDNode *N) {
26714 SDValue N0 = N->getOperand(Num: 0);
26715 EVT VT = N0.getValueType();
26716 unsigned Opcode = N->getOpcode();
26717
26718 // VECREDUCE over 1-element vector is just an extract.
26719 if (VT.getVectorElementCount().isScalar()) {
26720 SDLoc dl(N);
26721 SDValue Res =
26722 DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: dl, VT: VT.getVectorElementType(), N1: N0,
26723 N2: DAG.getVectorIdxConstant(Val: 0, DL: dl));
26724 if (Res.getValueType() != N->getValueType(ResNo: 0))
26725 Res = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL: dl, VT: N->getValueType(ResNo: 0), Operand: Res);
26726 return Res;
26727 }
26728
26729 // On an boolean vector an and/or reduction is the same as a umin/umax
26730 // reduction. Convert them if the latter is legal while the former isn't.
26731 if (Opcode == ISD::VECREDUCE_AND || Opcode == ISD::VECREDUCE_OR) {
26732 unsigned NewOpcode = Opcode == ISD::VECREDUCE_AND
26733 ? ISD::VECREDUCE_UMIN : ISD::VECREDUCE_UMAX;
26734 if (!TLI.isOperationLegalOrCustom(Op: Opcode, VT) &&
26735 TLI.isOperationLegalOrCustom(Op: NewOpcode, VT) &&
26736 DAG.ComputeNumSignBits(Op: N0) == VT.getScalarSizeInBits())
26737 return DAG.getNode(Opcode: NewOpcode, DL: SDLoc(N), VT: N->getValueType(ResNo: 0), Operand: N0);
26738 }
26739
26740 // vecreduce_or(insert_subvector(zero or undef, val)) -> vecreduce_or(val)
26741 // vecreduce_and(insert_subvector(ones or undef, val)) -> vecreduce_and(val)
26742 if (N0.getOpcode() == ISD::INSERT_SUBVECTOR &&
26743 TLI.isTypeLegal(VT: N0.getOperand(i: 1).getValueType())) {
26744 SDValue Vec = N0.getOperand(i: 0);
26745 SDValue Subvec = N0.getOperand(i: 1);
26746 if ((Opcode == ISD::VECREDUCE_OR &&
26747 (N0.getOperand(i: 0).isUndef() || isNullOrNullSplat(V: Vec))) ||
26748 (Opcode == ISD::VECREDUCE_AND &&
26749 (N0.getOperand(i: 0).isUndef() || isAllOnesOrAllOnesSplat(V: Vec))))
26750 return DAG.getNode(Opcode, DL: SDLoc(N), VT: N->getValueType(ResNo: 0), Operand: Subvec);
26751 }
26752
26753 return SDValue();
26754}
26755
26756SDValue DAGCombiner::visitVP_FSUB(SDNode *N) {
26757 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
26758
26759 // FSUB -> FMA combines:
26760 if (SDValue Fused = visitFSUBForFMACombine<VPMatchContext>(N)) {
26761 AddToWorklist(N: Fused.getNode());
26762 return Fused;
26763 }
26764 return SDValue();
26765}
26766
26767SDValue DAGCombiner::visitVPOp(SDNode *N) {
26768
26769 if (N->getOpcode() == ISD::VP_GATHER)
26770 if (SDValue SD = visitVPGATHER(N))
26771 return SD;
26772
26773 if (N->getOpcode() == ISD::VP_SCATTER)
26774 if (SDValue SD = visitVPSCATTER(N))
26775 return SD;
26776
26777 if (N->getOpcode() == ISD::EXPERIMENTAL_VP_STRIDED_LOAD)
26778 if (SDValue SD = visitVP_STRIDED_LOAD(N))
26779 return SD;
26780
26781 if (N->getOpcode() == ISD::EXPERIMENTAL_VP_STRIDED_STORE)
26782 if (SDValue SD = visitVP_STRIDED_STORE(N))
26783 return SD;
26784
26785 // VP operations in which all vector elements are disabled - either by
26786 // determining that the mask is all false or that the EVL is 0 - can be
26787 // eliminated.
26788 bool AreAllEltsDisabled = false;
26789 if (auto EVLIdx = ISD::getVPExplicitVectorLengthIdx(Opcode: N->getOpcode()))
26790 AreAllEltsDisabled |= isNullConstant(V: N->getOperand(Num: *EVLIdx));
26791 if (auto MaskIdx = ISD::getVPMaskIdx(Opcode: N->getOpcode()))
26792 AreAllEltsDisabled |=
26793 ISD::isConstantSplatVectorAllZeros(N: N->getOperand(Num: *MaskIdx).getNode());
26794
26795 // This is the only generic VP combine we support for now.
26796 if (!AreAllEltsDisabled) {
26797 switch (N->getOpcode()) {
26798 case ISD::VP_FADD:
26799 return visitVP_FADD(N);
26800 case ISD::VP_FSUB:
26801 return visitVP_FSUB(N);
26802 case ISD::VP_FMA:
26803 return visitFMA<VPMatchContext>(N);
26804 case ISD::VP_SELECT:
26805 return visitVP_SELECT(N);
26806 case ISD::VP_MUL:
26807 return visitMUL<VPMatchContext>(N);
26808 default:
26809 break;
26810 }
26811 return SDValue();
26812 }
26813
26814 // Binary operations can be replaced by UNDEF.
26815 if (ISD::isVPBinaryOp(Opcode: N->getOpcode()))
26816 return DAG.getUNDEF(VT: N->getValueType(ResNo: 0));
26817
26818 // VP Memory operations can be replaced by either the chain (stores) or the
26819 // chain + undef (loads).
26820 if (const auto *MemSD = dyn_cast<MemSDNode>(Val: N)) {
26821 if (MemSD->writeMem())
26822 return MemSD->getChain();
26823 return CombineTo(N, Res0: DAG.getUNDEF(VT: N->getValueType(ResNo: 0)), Res1: MemSD->getChain());
26824 }
26825
26826 // Reduction operations return the start operand when no elements are active.
26827 if (ISD::isVPReduction(Opcode: N->getOpcode()))
26828 return N->getOperand(Num: 0);
26829
26830 return SDValue();
26831}
26832
26833SDValue DAGCombiner::visitGET_FPENV_MEM(SDNode *N) {
26834 SDValue Chain = N->getOperand(Num: 0);
26835 SDValue Ptr = N->getOperand(Num: 1);
26836 EVT MemVT = cast<FPStateAccessSDNode>(Val: N)->getMemoryVT();
26837
26838 // Check if the memory, where FP state is written to, is used only in a single
26839 // load operation.
26840 LoadSDNode *LdNode = nullptr;
26841 for (auto *U : Ptr->uses()) {
26842 if (U == N)
26843 continue;
26844 if (auto *Ld = dyn_cast<LoadSDNode>(Val: U)) {
26845 if (LdNode && LdNode != Ld)
26846 return SDValue();
26847 LdNode = Ld;
26848 continue;
26849 }
26850 return SDValue();
26851 }
26852 if (!LdNode || !LdNode->isSimple() || LdNode->isIndexed() ||
26853 !LdNode->getOffset().isUndef() || LdNode->getMemoryVT() != MemVT ||
26854 !LdNode->getChain().reachesChainWithoutSideEffects(Dest: SDValue(N, 0)))
26855 return SDValue();
26856
26857 // Check if the loaded value is used only in a store operation.
26858 StoreSDNode *StNode = nullptr;
26859 for (auto I = LdNode->use_begin(), E = LdNode->use_end(); I != E; ++I) {
26860 SDUse &U = I.getUse();
26861 if (U.getResNo() == 0) {
26862 if (auto *St = dyn_cast<StoreSDNode>(Val: U.getUser())) {
26863 if (StNode)
26864 return SDValue();
26865 StNode = St;
26866 } else {
26867 return SDValue();
26868 }
26869 }
26870 }
26871 if (!StNode || !StNode->isSimple() || StNode->isIndexed() ||
26872 !StNode->getOffset().isUndef() || StNode->getMemoryVT() != MemVT ||
26873 !StNode->getChain().reachesChainWithoutSideEffects(Dest: SDValue(LdNode, 1)))
26874 return SDValue();
26875
26876 // Create new node GET_FPENV_MEM, which uses the store address to write FP
26877 // environment.
26878 SDValue Res = DAG.getGetFPEnv(Chain, dl: SDLoc(N), Ptr: StNode->getBasePtr(), MemVT,
26879 MMO: StNode->getMemOperand());
26880 CombineTo(N: StNode, Res, AddTo: false);
26881 return Res;
26882}
26883
26884SDValue DAGCombiner::visitSET_FPENV_MEM(SDNode *N) {
26885 SDValue Chain = N->getOperand(Num: 0);
26886 SDValue Ptr = N->getOperand(Num: 1);
26887 EVT MemVT = cast<FPStateAccessSDNode>(Val: N)->getMemoryVT();
26888
26889 // Check if the address of FP state is used also in a store operation only.
26890 StoreSDNode *StNode = nullptr;
26891 for (auto *U : Ptr->uses()) {
26892 if (U == N)
26893 continue;
26894 if (auto *St = dyn_cast<StoreSDNode>(Val: U)) {
26895 if (StNode && StNode != St)
26896 return SDValue();
26897 StNode = St;
26898 continue;
26899 }
26900 return SDValue();
26901 }
26902 if (!StNode || !StNode->isSimple() || StNode->isIndexed() ||
26903 !StNode->getOffset().isUndef() || StNode->getMemoryVT() != MemVT ||
26904 !Chain.reachesChainWithoutSideEffects(Dest: SDValue(StNode, 0)))
26905 return SDValue();
26906
26907 // Check if the stored value is loaded from some location and the loaded
26908 // value is used only in the store operation.
26909 SDValue StValue = StNode->getValue();
26910 auto *LdNode = dyn_cast<LoadSDNode>(Val&: StValue);
26911 if (!LdNode || !LdNode->isSimple() || LdNode->isIndexed() ||
26912 !LdNode->getOffset().isUndef() || LdNode->getMemoryVT() != MemVT ||
26913 !StNode->getChain().reachesChainWithoutSideEffects(Dest: SDValue(LdNode, 1)))
26914 return SDValue();
26915
26916 // Create new node SET_FPENV_MEM, which uses the load address to read FP
26917 // environment.
26918 SDValue Res =
26919 DAG.getSetFPEnv(Chain: LdNode->getChain(), dl: SDLoc(N), Ptr: LdNode->getBasePtr(), MemVT,
26920 MMO: LdNode->getMemOperand());
26921 return Res;
26922}
26923
26924/// Returns a vector_shuffle if it able to transform an AND to a vector_shuffle
26925/// with the destination vector and a zero vector.
26926/// e.g. AND V, <0xffffffff, 0, 0xffffffff, 0>. ==>
26927/// vector_shuffle V, Zero, <0, 4, 2, 4>
26928SDValue DAGCombiner::XformToShuffleWithZero(SDNode *N) {
26929 assert(N->getOpcode() == ISD::AND && "Unexpected opcode!");
26930
26931 EVT VT = N->getValueType(ResNo: 0);
26932 SDValue LHS = N->getOperand(Num: 0);
26933 SDValue RHS = peekThroughBitcasts(V: N->getOperand(Num: 1));
26934 SDLoc DL(N);
26935
26936 // Make sure we're not running after operation legalization where it
26937 // may have custom lowered the vector shuffles.
26938 if (LegalOperations)
26939 return SDValue();
26940
26941 if (RHS.getOpcode() != ISD::BUILD_VECTOR)
26942 return SDValue();
26943
26944 EVT RVT = RHS.getValueType();
26945 unsigned NumElts = RHS.getNumOperands();
26946
26947 // Attempt to create a valid clear mask, splitting the mask into
26948 // sub elements and checking to see if each is
26949 // all zeros or all ones - suitable for shuffle masking.
26950 auto BuildClearMask = [&](int Split) {
26951 int NumSubElts = NumElts * Split;
26952 int NumSubBits = RVT.getScalarSizeInBits() / Split;
26953
26954 SmallVector<int, 8> Indices;
26955 for (int i = 0; i != NumSubElts; ++i) {
26956 int EltIdx = i / Split;
26957 int SubIdx = i % Split;
26958 SDValue Elt = RHS.getOperand(i: EltIdx);
26959 // X & undef --> 0 (not undef). So this lane must be converted to choose
26960 // from the zero constant vector (same as if the element had all 0-bits).
26961 if (Elt.isUndef()) {
26962 Indices.push_back(Elt: i + NumSubElts);
26963 continue;
26964 }
26965
26966 APInt Bits;
26967 if (auto *Cst = dyn_cast<ConstantSDNode>(Val&: Elt))
26968 Bits = Cst->getAPIntValue();
26969 else if (auto *CstFP = dyn_cast<ConstantFPSDNode>(Val&: Elt))
26970 Bits = CstFP->getValueAPF().bitcastToAPInt();
26971 else
26972 return SDValue();
26973
26974 // Extract the sub element from the constant bit mask.
26975 if (DAG.getDataLayout().isBigEndian())
26976 Bits = Bits.extractBits(numBits: NumSubBits, bitPosition: (Split - SubIdx - 1) * NumSubBits);
26977 else
26978 Bits = Bits.extractBits(numBits: NumSubBits, bitPosition: SubIdx * NumSubBits);
26979
26980 if (Bits.isAllOnes())
26981 Indices.push_back(Elt: i);
26982 else if (Bits == 0)
26983 Indices.push_back(Elt: i + NumSubElts);
26984 else
26985 return SDValue();
26986 }
26987
26988 // Let's see if the target supports this vector_shuffle.
26989 EVT ClearSVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: NumSubBits);
26990 EVT ClearVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: ClearSVT, NumElements: NumSubElts);
26991 if (!TLI.isVectorClearMaskLegal(Indices, ClearVT))
26992 return SDValue();
26993
26994 SDValue Zero = DAG.getConstant(Val: 0, DL, VT: ClearVT);
26995 return DAG.getBitcast(VT, V: DAG.getVectorShuffle(VT: ClearVT, dl: DL,
26996 N1: DAG.getBitcast(VT: ClearVT, V: LHS),
26997 N2: Zero, Mask: Indices));
26998 };
26999
27000 // Determine maximum split level (byte level masking).
27001 int MaxSplit = 1;
27002 if (RVT.getScalarSizeInBits() % 8 == 0)
27003 MaxSplit = RVT.getScalarSizeInBits() / 8;
27004
27005 for (int Split = 1; Split <= MaxSplit; ++Split)
27006 if (RVT.getScalarSizeInBits() % Split == 0)
27007 if (SDValue S = BuildClearMask(Split))
27008 return S;
27009
27010 return SDValue();
27011}
27012
27013/// If a vector binop is performed on splat values, it may be profitable to
27014/// extract, scalarize, and insert/splat.
27015static SDValue scalarizeBinOpOfSplats(SDNode *N, SelectionDAG &DAG,
27016 const SDLoc &DL) {
27017 SDValue N0 = N->getOperand(Num: 0);
27018 SDValue N1 = N->getOperand(Num: 1);
27019 unsigned Opcode = N->getOpcode();
27020 EVT VT = N->getValueType(ResNo: 0);
27021 EVT EltVT = VT.getVectorElementType();
27022 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
27023
27024 // TODO: Remove/replace the extract cost check? If the elements are available
27025 // as scalars, then there may be no extract cost. Should we ask if
27026 // inserting a scalar back into a vector is cheap instead?
27027 int Index0, Index1;
27028 SDValue Src0 = DAG.getSplatSourceVector(V: N0, SplatIndex&: Index0);
27029 SDValue Src1 = DAG.getSplatSourceVector(V: N1, SplatIndex&: Index1);
27030 // Extract element from splat_vector should be free.
27031 // TODO: use DAG.isSplatValue instead?
27032 bool IsBothSplatVector = N0.getOpcode() == ISD::SPLAT_VECTOR &&
27033 N1.getOpcode() == ISD::SPLAT_VECTOR;
27034 if (!Src0 || !Src1 || Index0 != Index1 ||
27035 Src0.getValueType().getVectorElementType() != EltVT ||
27036 Src1.getValueType().getVectorElementType() != EltVT ||
27037 !(IsBothSplatVector || TLI.isExtractVecEltCheap(VT, Index: Index0)) ||
27038 !TLI.isOperationLegalOrCustom(Op: Opcode, VT: EltVT))
27039 return SDValue();
27040
27041 SDValue IndexC = DAG.getVectorIdxConstant(Val: Index0, DL);
27042 SDValue X = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: Src0, N2: IndexC);
27043 SDValue Y = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: Src1, N2: IndexC);
27044 SDValue ScalarBO = DAG.getNode(Opcode, DL, VT: EltVT, N1: X, N2: Y, Flags: N->getFlags());
27045
27046 // If all lanes but 1 are undefined, no need to splat the scalar result.
27047 // TODO: Keep track of undefs and use that info in the general case.
27048 if (N0.getOpcode() == ISD::BUILD_VECTOR && N0.getOpcode() == N1.getOpcode() &&
27049 count_if(Range: N0->ops(), P: [](SDValue V) { return !V.isUndef(); }) == 1 &&
27050 count_if(Range: N1->ops(), P: [](SDValue V) { return !V.isUndef(); }) == 1) {
27051 // bo (build_vec ..undef, X, undef...), (build_vec ..undef, Y, undef...) -->
27052 // build_vec ..undef, (bo X, Y), undef...
27053 SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), DAG.getUNDEF(VT: EltVT));
27054 Ops[Index0] = ScalarBO;
27055 return DAG.getBuildVector(VT, DL, Ops);
27056 }
27057
27058 // bo (splat X, Index), (splat Y, Index) --> splat (bo X, Y), Index
27059 return DAG.getSplat(VT, DL, Op: ScalarBO);
27060}
27061
27062/// Visit a vector cast operation, like FP_EXTEND.
27063SDValue DAGCombiner::SimplifyVCastOp(SDNode *N, const SDLoc &DL) {
27064 EVT VT = N->getValueType(ResNo: 0);
27065 assert(VT.isVector() && "SimplifyVCastOp only works on vectors!");
27066 EVT EltVT = VT.getVectorElementType();
27067 unsigned Opcode = N->getOpcode();
27068
27069 SDValue N0 = N->getOperand(Num: 0);
27070 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
27071
27072 // TODO: promote operation might be also good here?
27073 int Index0;
27074 SDValue Src0 = DAG.getSplatSourceVector(V: N0, SplatIndex&: Index0);
27075 if (Src0 &&
27076 (N0.getOpcode() == ISD::SPLAT_VECTOR ||
27077 TLI.isExtractVecEltCheap(VT, Index: Index0)) &&
27078 TLI.isOperationLegalOrCustom(Op: Opcode, VT: EltVT) &&
27079 TLI.preferScalarizeSplat(N)) {
27080 EVT SrcVT = N0.getValueType();
27081 EVT SrcEltVT = SrcVT.getVectorElementType();
27082 SDValue IndexC = DAG.getVectorIdxConstant(Val: Index0, DL);
27083 SDValue Elt =
27084 DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: SrcEltVT, N1: Src0, N2: IndexC);
27085 SDValue ScalarBO = DAG.getNode(Opcode, DL, VT: EltVT, Operand: Elt, Flags: N->getFlags());
27086 if (VT.isScalableVector())
27087 return DAG.getSplatVector(VT, DL, Op: ScalarBO);
27088 SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), ScalarBO);
27089 return DAG.getBuildVector(VT, DL, Ops);
27090 }
27091
27092 return SDValue();
27093}
27094
27095/// Visit a binary vector operation, like ADD.
27096SDValue DAGCombiner::SimplifyVBinOp(SDNode *N, const SDLoc &DL) {
27097 EVT VT = N->getValueType(ResNo: 0);
27098 assert(VT.isVector() && "SimplifyVBinOp only works on vectors!");
27099
27100 SDValue LHS = N->getOperand(Num: 0);
27101 SDValue RHS = N->getOperand(Num: 1);
27102 unsigned Opcode = N->getOpcode();
27103 SDNodeFlags Flags = N->getFlags();
27104
27105 // Move unary shuffles with identical masks after a vector binop:
27106 // VBinOp (shuffle A, Undef, Mask), (shuffle B, Undef, Mask))
27107 // --> shuffle (VBinOp A, B), Undef, Mask
27108 // This does not require type legality checks because we are creating the
27109 // same types of operations that are in the original sequence. We do have to
27110 // restrict ops like integer div that have immediate UB (eg, div-by-zero)
27111 // though. This code is adapted from the identical transform in instcombine.
27112 if (DAG.isSafeToSpeculativelyExecute(Opcode)) {
27113 auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(Val&: LHS);
27114 auto *Shuf1 = dyn_cast<ShuffleVectorSDNode>(Val&: RHS);
27115 if (Shuf0 && Shuf1 && Shuf0->getMask().equals(RHS: Shuf1->getMask()) &&
27116 LHS.getOperand(i: 1).isUndef() && RHS.getOperand(i: 1).isUndef() &&
27117 (LHS.hasOneUse() || RHS.hasOneUse() || LHS == RHS)) {
27118 SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, N1: LHS.getOperand(i: 0),
27119 N2: RHS.getOperand(i: 0), Flags);
27120 SDValue UndefV = LHS.getOperand(i: 1);
27121 return DAG.getVectorShuffle(VT, dl: DL, N1: NewBinOp, N2: UndefV, Mask: Shuf0->getMask());
27122 }
27123
27124 // Try to sink a splat shuffle after a binop with a uniform constant.
27125 // This is limited to cases where neither the shuffle nor the constant have
27126 // undefined elements because that could be poison-unsafe or inhibit
27127 // demanded elements analysis. It is further limited to not change a splat
27128 // of an inserted scalar because that may be optimized better by
27129 // load-folding or other target-specific behaviors.
27130 if (isConstOrConstSplat(N: RHS) && Shuf0 && all_equal(Range: Shuf0->getMask()) &&
27131 Shuf0->hasOneUse() && Shuf0->getOperand(Num: 1).isUndef() &&
27132 Shuf0->getOperand(Num: 0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
27133 // binop (splat X), (splat C) --> splat (binop X, C)
27134 SDValue X = Shuf0->getOperand(Num: 0);
27135 SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, N1: X, N2: RHS, Flags);
27136 return DAG.getVectorShuffle(VT, dl: DL, N1: NewBinOp, N2: DAG.getUNDEF(VT),
27137 Mask: Shuf0->getMask());
27138 }
27139 if (isConstOrConstSplat(N: LHS) && Shuf1 && all_equal(Range: Shuf1->getMask()) &&
27140 Shuf1->hasOneUse() && Shuf1->getOperand(Num: 1).isUndef() &&
27141 Shuf1->getOperand(Num: 0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
27142 // binop (splat C), (splat X) --> splat (binop C, X)
27143 SDValue X = Shuf1->getOperand(Num: 0);
27144 SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, N1: LHS, N2: X, Flags);
27145 return DAG.getVectorShuffle(VT, dl: DL, N1: NewBinOp, N2: DAG.getUNDEF(VT),
27146 Mask: Shuf1->getMask());
27147 }
27148 }
27149
27150 // The following pattern is likely to emerge with vector reduction ops. Moving
27151 // the binary operation ahead of insertion may allow using a narrower vector
27152 // instruction that has better performance than the wide version of the op:
27153 // VBinOp (ins undef, X, Z), (ins undef, Y, Z) --> ins VecC, (VBinOp X, Y), Z
27154 if (LHS.getOpcode() == ISD::INSERT_SUBVECTOR && LHS.getOperand(i: 0).isUndef() &&
27155 RHS.getOpcode() == ISD::INSERT_SUBVECTOR && RHS.getOperand(i: 0).isUndef() &&
27156 LHS.getOperand(i: 2) == RHS.getOperand(i: 2) &&
27157 (LHS.hasOneUse() || RHS.hasOneUse())) {
27158 SDValue X = LHS.getOperand(i: 1);
27159 SDValue Y = RHS.getOperand(i: 1);
27160 SDValue Z = LHS.getOperand(i: 2);
27161 EVT NarrowVT = X.getValueType();
27162 if (NarrowVT == Y.getValueType() &&
27163 TLI.isOperationLegalOrCustomOrPromote(Op: Opcode, VT: NarrowVT,
27164 LegalOnly: LegalOperations)) {
27165 // (binop undef, undef) may not return undef, so compute that result.
27166 SDValue VecC =
27167 DAG.getNode(Opcode, DL, VT, N1: DAG.getUNDEF(VT), N2: DAG.getUNDEF(VT));
27168 SDValue NarrowBO = DAG.getNode(Opcode, DL, VT: NarrowVT, N1: X, N2: Y);
27169 return DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL, VT, N1: VecC, N2: NarrowBO, N3: Z);
27170 }
27171 }
27172
27173 // Make sure all but the first op are undef or constant.
27174 auto ConcatWithConstantOrUndef = [](SDValue Concat) {
27175 return Concat.getOpcode() == ISD::CONCAT_VECTORS &&
27176 all_of(Range: drop_begin(RangeOrContainer: Concat->ops()), P: [](const SDValue &Op) {
27177 return Op.isUndef() ||
27178 ISD::isBuildVectorOfConstantSDNodes(N: Op.getNode());
27179 });
27180 };
27181
27182 // The following pattern is likely to emerge with vector reduction ops. Moving
27183 // the binary operation ahead of the concat may allow using a narrower vector
27184 // instruction that has better performance than the wide version of the op:
27185 // VBinOp (concat X, undef/constant), (concat Y, undef/constant) -->
27186 // concat (VBinOp X, Y), VecC
27187 if (ConcatWithConstantOrUndef(LHS) && ConcatWithConstantOrUndef(RHS) &&
27188 (LHS.hasOneUse() || RHS.hasOneUse())) {
27189 EVT NarrowVT = LHS.getOperand(i: 0).getValueType();
27190 if (NarrowVT == RHS.getOperand(i: 0).getValueType() &&
27191 TLI.isOperationLegalOrCustomOrPromote(Op: Opcode, VT: NarrowVT)) {
27192 unsigned NumOperands = LHS.getNumOperands();
27193 SmallVector<SDValue, 4> ConcatOps;
27194 for (unsigned i = 0; i != NumOperands; ++i) {
27195 // This constant fold for operands 1 and up.
27196 ConcatOps.push_back(Elt: DAG.getNode(Opcode, DL, VT: NarrowVT, N1: LHS.getOperand(i),
27197 N2: RHS.getOperand(i)));
27198 }
27199
27200 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT, Ops: ConcatOps);
27201 }
27202 }
27203
27204 if (SDValue V = scalarizeBinOpOfSplats(N, DAG, DL))
27205 return V;
27206
27207 return SDValue();
27208}
27209
27210SDValue DAGCombiner::SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1,
27211 SDValue N2) {
27212 assert(N0.getOpcode() == ISD::SETCC &&
27213 "First argument must be a SetCC node!");
27214
27215 SDValue SCC = SimplifySelectCC(DL, N0: N0.getOperand(i: 0), N1: N0.getOperand(i: 1), N2: N1, N3: N2,
27216 CC: cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get());
27217
27218 // If we got a simplified select_cc node back from SimplifySelectCC, then
27219 // break it down into a new SETCC node, and a new SELECT node, and then return
27220 // the SELECT node, since we were called with a SELECT node.
27221 if (SCC.getNode()) {
27222 // Check to see if we got a select_cc back (to turn into setcc/select).
27223 // Otherwise, just return whatever node we got back, like fabs.
27224 if (SCC.getOpcode() == ISD::SELECT_CC) {
27225 const SDNodeFlags Flags = N0->getFlags();
27226 SDValue SETCC = DAG.getNode(Opcode: ISD::SETCC, DL: SDLoc(N0),
27227 VT: N0.getValueType(),
27228 N1: SCC.getOperand(i: 0), N2: SCC.getOperand(i: 1),
27229 N3: SCC.getOperand(i: 4), Flags);
27230 AddToWorklist(N: SETCC.getNode());
27231 SDValue SelectNode = DAG.getSelect(DL: SDLoc(SCC), VT: SCC.getValueType(), Cond: SETCC,
27232 LHS: SCC.getOperand(i: 2), RHS: SCC.getOperand(i: 3));
27233 SelectNode->setFlags(Flags);
27234 return SelectNode;
27235 }
27236
27237 return SCC;
27238 }
27239 return SDValue();
27240}
27241
27242/// Given a SELECT or a SELECT_CC node, where LHS and RHS are the two values
27243/// being selected between, see if we can simplify the select. Callers of this
27244/// should assume that TheSelect is deleted if this returns true. As such, they
27245/// should return the appropriate thing (e.g. the node) back to the top-level of
27246/// the DAG combiner loop to avoid it being looked at.
27247bool DAGCombiner::SimplifySelectOps(SDNode *TheSelect, SDValue LHS,
27248 SDValue RHS) {
27249 // fold (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
27250 // The select + setcc is redundant, because fsqrt returns NaN for X < 0.
27251 if (const ConstantFPSDNode *NaN = isConstOrConstSplatFP(N: LHS)) {
27252 if (NaN->isNaN() && RHS.getOpcode() == ISD::FSQRT) {
27253 // We have: (select (setcc ?, ?, ?), NaN, (fsqrt ?))
27254 SDValue Sqrt = RHS;
27255 ISD::CondCode CC;
27256 SDValue CmpLHS;
27257 const ConstantFPSDNode *Zero = nullptr;
27258
27259 if (TheSelect->getOpcode() == ISD::SELECT_CC) {
27260 CC = cast<CondCodeSDNode>(Val: TheSelect->getOperand(Num: 4))->get();
27261 CmpLHS = TheSelect->getOperand(Num: 0);
27262 Zero = isConstOrConstSplatFP(N: TheSelect->getOperand(Num: 1));
27263 } else {
27264 // SELECT or VSELECT
27265 SDValue Cmp = TheSelect->getOperand(Num: 0);
27266 if (Cmp.getOpcode() == ISD::SETCC) {
27267 CC = cast<CondCodeSDNode>(Val: Cmp.getOperand(i: 2))->get();
27268 CmpLHS = Cmp.getOperand(i: 0);
27269 Zero = isConstOrConstSplatFP(N: Cmp.getOperand(i: 1));
27270 }
27271 }
27272 if (Zero && Zero->isZero() &&
27273 Sqrt.getOperand(i: 0) == CmpLHS && (CC == ISD::SETOLT ||
27274 CC == ISD::SETULT || CC == ISD::SETLT)) {
27275 // We have: (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
27276 CombineTo(N: TheSelect, Res: Sqrt);
27277 return true;
27278 }
27279 }
27280 }
27281 // Cannot simplify select with vector condition
27282 if (TheSelect->getOperand(Num: 0).getValueType().isVector()) return false;
27283
27284 // If this is a select from two identical things, try to pull the operation
27285 // through the select.
27286 if (LHS.getOpcode() != RHS.getOpcode() ||
27287 !LHS.hasOneUse() || !RHS.hasOneUse())
27288 return false;
27289
27290 // If this is a load and the token chain is identical, replace the select
27291 // of two loads with a load through a select of the address to load from.
27292 // This triggers in things like "select bool X, 10.0, 123.0" after the FP
27293 // constants have been dropped into the constant pool.
27294 if (LHS.getOpcode() == ISD::LOAD) {
27295 LoadSDNode *LLD = cast<LoadSDNode>(Val&: LHS);
27296 LoadSDNode *RLD = cast<LoadSDNode>(Val&: RHS);
27297
27298 // Token chains must be identical.
27299 if (LHS.getOperand(i: 0) != RHS.getOperand(i: 0) ||
27300 // Do not let this transformation reduce the number of volatile loads.
27301 // Be conservative for atomics for the moment
27302 // TODO: This does appear to be legal for unordered atomics (see D66309)
27303 !LLD->isSimple() || !RLD->isSimple() ||
27304 // FIXME: If either is a pre/post inc/dec load,
27305 // we'd need to split out the address adjustment.
27306 LLD->isIndexed() || RLD->isIndexed() ||
27307 // If this is an EXTLOAD, the VT's must match.
27308 LLD->getMemoryVT() != RLD->getMemoryVT() ||
27309 // If this is an EXTLOAD, the kind of extension must match.
27310 (LLD->getExtensionType() != RLD->getExtensionType() &&
27311 // The only exception is if one of the extensions is anyext.
27312 LLD->getExtensionType() != ISD::EXTLOAD &&
27313 RLD->getExtensionType() != ISD::EXTLOAD) ||
27314 // FIXME: this discards src value information. This is
27315 // over-conservative. It would be beneficial to be able to remember
27316 // both potential memory locations. Since we are discarding
27317 // src value info, don't do the transformation if the memory
27318 // locations are not in the default address space.
27319 LLD->getPointerInfo().getAddrSpace() != 0 ||
27320 RLD->getPointerInfo().getAddrSpace() != 0 ||
27321 // We can't produce a CMOV of a TargetFrameIndex since we won't
27322 // generate the address generation required.
27323 LLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
27324 RLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
27325 !TLI.isOperationLegalOrCustom(Op: TheSelect->getOpcode(),
27326 VT: LLD->getBasePtr().getValueType()))
27327 return false;
27328
27329 // The loads must not depend on one another.
27330 if (LLD->isPredecessorOf(N: RLD) || RLD->isPredecessorOf(N: LLD))
27331 return false;
27332
27333 // Check that the select condition doesn't reach either load. If so,
27334 // folding this will induce a cycle into the DAG. If not, this is safe to
27335 // xform, so create a select of the addresses.
27336
27337 SmallPtrSet<const SDNode *, 32> Visited;
27338 SmallVector<const SDNode *, 16> Worklist;
27339
27340 // Always fail if LLD and RLD are not independent. TheSelect is a
27341 // predecessor to all Nodes in question so we need not search past it.
27342
27343 Visited.insert(Ptr: TheSelect);
27344 Worklist.push_back(Elt: LLD);
27345 Worklist.push_back(Elt: RLD);
27346
27347 if (SDNode::hasPredecessorHelper(N: LLD, Visited, Worklist) ||
27348 SDNode::hasPredecessorHelper(N: RLD, Visited, Worklist))
27349 return false;
27350
27351 SDValue Addr;
27352 if (TheSelect->getOpcode() == ISD::SELECT) {
27353 // We cannot do this optimization if any pair of {RLD, LLD} is a
27354 // predecessor to {RLD, LLD, CondNode}. As we've already compared the
27355 // Loads, we only need to check if CondNode is a successor to one of the
27356 // loads. We can further avoid this if there's no use of their chain
27357 // value.
27358 SDNode *CondNode = TheSelect->getOperand(Num: 0).getNode();
27359 Worklist.push_back(Elt: CondNode);
27360
27361 if ((LLD->hasAnyUseOfValue(Value: 1) &&
27362 SDNode::hasPredecessorHelper(N: LLD, Visited, Worklist)) ||
27363 (RLD->hasAnyUseOfValue(Value: 1) &&
27364 SDNode::hasPredecessorHelper(N: RLD, Visited, Worklist)))
27365 return false;
27366
27367 Addr = DAG.getSelect(DL: SDLoc(TheSelect),
27368 VT: LLD->getBasePtr().getValueType(),
27369 Cond: TheSelect->getOperand(Num: 0), LHS: LLD->getBasePtr(),
27370 RHS: RLD->getBasePtr());
27371 } else { // Otherwise SELECT_CC
27372 // We cannot do this optimization if any pair of {RLD, LLD} is a
27373 // predecessor to {RLD, LLD, CondLHS, CondRHS}. As we've already compared
27374 // the Loads, we only need to check if CondLHS/CondRHS is a successor to
27375 // one of the loads. We can further avoid this if there's no use of their
27376 // chain value.
27377
27378 SDNode *CondLHS = TheSelect->getOperand(Num: 0).getNode();
27379 SDNode *CondRHS = TheSelect->getOperand(Num: 1).getNode();
27380 Worklist.push_back(Elt: CondLHS);
27381 Worklist.push_back(Elt: CondRHS);
27382
27383 if ((LLD->hasAnyUseOfValue(Value: 1) &&
27384 SDNode::hasPredecessorHelper(N: LLD, Visited, Worklist)) ||
27385 (RLD->hasAnyUseOfValue(Value: 1) &&
27386 SDNode::hasPredecessorHelper(N: RLD, Visited, Worklist)))
27387 return false;
27388
27389 Addr = DAG.getNode(Opcode: ISD::SELECT_CC, DL: SDLoc(TheSelect),
27390 VT: LLD->getBasePtr().getValueType(),
27391 N1: TheSelect->getOperand(Num: 0),
27392 N2: TheSelect->getOperand(Num: 1),
27393 N3: LLD->getBasePtr(), N4: RLD->getBasePtr(),
27394 N5: TheSelect->getOperand(Num: 4));
27395 }
27396
27397 SDValue Load;
27398 // It is safe to replace the two loads if they have different alignments,
27399 // but the new load must be the minimum (most restrictive) alignment of the
27400 // inputs.
27401 Align Alignment = std::min(a: LLD->getAlign(), b: RLD->getAlign());
27402 MachineMemOperand::Flags MMOFlags = LLD->getMemOperand()->getFlags();
27403 if (!RLD->isInvariant())
27404 MMOFlags &= ~MachineMemOperand::MOInvariant;
27405 if (!RLD->isDereferenceable())
27406 MMOFlags &= ~MachineMemOperand::MODereferenceable;
27407 if (LLD->getExtensionType() == ISD::NON_EXTLOAD) {
27408 // FIXME: Discards pointer and AA info.
27409 Load = DAG.getLoad(VT: TheSelect->getValueType(ResNo: 0), dl: SDLoc(TheSelect),
27410 Chain: LLD->getChain(), Ptr: Addr, PtrInfo: MachinePointerInfo(), Alignment,
27411 MMOFlags);
27412 } else {
27413 // FIXME: Discards pointer and AA info.
27414 Load = DAG.getExtLoad(
27415 ExtType: LLD->getExtensionType() == ISD::EXTLOAD ? RLD->getExtensionType()
27416 : LLD->getExtensionType(),
27417 dl: SDLoc(TheSelect), VT: TheSelect->getValueType(ResNo: 0), Chain: LLD->getChain(), Ptr: Addr,
27418 PtrInfo: MachinePointerInfo(), MemVT: LLD->getMemoryVT(), Alignment, MMOFlags);
27419 }
27420
27421 // Users of the select now use the result of the load.
27422 CombineTo(N: TheSelect, Res: Load);
27423
27424 // Users of the old loads now use the new load's chain. We know the
27425 // old-load value is dead now.
27426 CombineTo(N: LHS.getNode(), Res0: Load.getValue(R: 0), Res1: Load.getValue(R: 1));
27427 CombineTo(N: RHS.getNode(), Res0: Load.getValue(R: 0), Res1: Load.getValue(R: 1));
27428 return true;
27429 }
27430
27431 return false;
27432}
27433
27434/// Try to fold an expression of the form (N0 cond N1) ? N2 : N3 to a shift and
27435/// bitwise 'and'.
27436SDValue DAGCombiner::foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0,
27437 SDValue N1, SDValue N2, SDValue N3,
27438 ISD::CondCode CC) {
27439 // If this is a select where the false operand is zero and the compare is a
27440 // check of the sign bit, see if we can perform the "gzip trick":
27441 // select_cc setlt X, 0, A, 0 -> and (sra X, size(X)-1), A
27442 // select_cc setgt X, 0, A, 0 -> and (not (sra X, size(X)-1)), A
27443 EVT XType = N0.getValueType();
27444 EVT AType = N2.getValueType();
27445 if (!isNullConstant(V: N3) || !XType.bitsGE(VT: AType))
27446 return SDValue();
27447
27448 // If the comparison is testing for a positive value, we have to invert
27449 // the sign bit mask, so only do that transform if the target has a bitwise
27450 // 'and not' instruction (the invert is free).
27451 if (CC == ISD::SETGT && TLI.hasAndNot(X: N2)) {
27452 // (X > -1) ? A : 0
27453 // (X > 0) ? X : 0 <-- This is canonical signed max.
27454 if (!(isAllOnesConstant(V: N1) || (isNullConstant(V: N1) && N0 == N2)))
27455 return SDValue();
27456 } else if (CC == ISD::SETLT) {
27457 // (X < 0) ? A : 0
27458 // (X < 1) ? X : 0 <-- This is un-canonicalized signed min.
27459 if (!(isNullConstant(V: N1) || (isOneConstant(V: N1) && N0 == N2)))
27460 return SDValue();
27461 } else {
27462 return SDValue();
27463 }
27464
27465 // and (sra X, size(X)-1), A -> "and (srl X, C2), A" iff A is a single-bit
27466 // constant.
27467 auto *N2C = dyn_cast<ConstantSDNode>(Val: N2.getNode());
27468 if (N2C && ((N2C->getAPIntValue() & (N2C->getAPIntValue() - 1)) == 0)) {
27469 unsigned ShCt = XType.getSizeInBits() - N2C->getAPIntValue().logBase2() - 1;
27470 if (!TLI.shouldAvoidTransformToShift(VT: XType, Amount: ShCt)) {
27471 SDValue ShiftAmt = DAG.getShiftAmountConstant(Val: ShCt, VT: XType, DL);
27472 SDValue Shift = DAG.getNode(Opcode: ISD::SRL, DL, VT: XType, N1: N0, N2: ShiftAmt);
27473 AddToWorklist(N: Shift.getNode());
27474
27475 if (XType.bitsGT(VT: AType)) {
27476 Shift = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: AType, Operand: Shift);
27477 AddToWorklist(N: Shift.getNode());
27478 }
27479
27480 if (CC == ISD::SETGT)
27481 Shift = DAG.getNOT(DL, Val: Shift, VT: AType);
27482
27483 return DAG.getNode(Opcode: ISD::AND, DL, VT: AType, N1: Shift, N2);
27484 }
27485 }
27486
27487 unsigned ShCt = XType.getSizeInBits() - 1;
27488 if (TLI.shouldAvoidTransformToShift(VT: XType, Amount: ShCt))
27489 return SDValue();
27490
27491 SDValue ShiftAmt = DAG.getShiftAmountConstant(Val: ShCt, VT: XType, DL);
27492 SDValue Shift = DAG.getNode(Opcode: ISD::SRA, DL, VT: XType, N1: N0, N2: ShiftAmt);
27493 AddToWorklist(N: Shift.getNode());
27494
27495 if (XType.bitsGT(VT: AType)) {
27496 Shift = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: AType, Operand: Shift);
27497 AddToWorklist(N: Shift.getNode());
27498 }
27499
27500 if (CC == ISD::SETGT)
27501 Shift = DAG.getNOT(DL, Val: Shift, VT: AType);
27502
27503 return DAG.getNode(Opcode: ISD::AND, DL, VT: AType, N1: Shift, N2);
27504}
27505
27506// Fold select(cc, binop(), binop()) -> binop(select(), select()) etc.
27507SDValue DAGCombiner::foldSelectOfBinops(SDNode *N) {
27508 SDValue N0 = N->getOperand(Num: 0);
27509 SDValue N1 = N->getOperand(Num: 1);
27510 SDValue N2 = N->getOperand(Num: 2);
27511 SDLoc DL(N);
27512
27513 unsigned BinOpc = N1.getOpcode();
27514 if (!TLI.isBinOp(Opcode: BinOpc) || (N2.getOpcode() != BinOpc) ||
27515 (N1.getResNo() != N2.getResNo()))
27516 return SDValue();
27517
27518 // The use checks are intentionally on SDNode because we may be dealing
27519 // with opcodes that produce more than one SDValue.
27520 // TODO: Do we really need to check N0 (the condition operand of the select)?
27521 // But removing that clause could cause an infinite loop...
27522 if (!N0->hasOneUse() || !N1->hasOneUse() || !N2->hasOneUse())
27523 return SDValue();
27524
27525 // Binops may include opcodes that return multiple values, so all values
27526 // must be created/propagated from the newly created binops below.
27527 SDVTList OpVTs = N1->getVTList();
27528
27529 // Fold select(cond, binop(x, y), binop(z, y))
27530 // --> binop(select(cond, x, z), y)
27531 if (N1.getOperand(i: 1) == N2.getOperand(i: 1)) {
27532 SDValue N10 = N1.getOperand(i: 0);
27533 SDValue N20 = N2.getOperand(i: 0);
27534 SDValue NewSel = DAG.getSelect(DL, VT: N10.getValueType(), Cond: N0, LHS: N10, RHS: N20);
27535 SDValue NewBinOp = DAG.getNode(Opcode: BinOpc, DL, VTList: OpVTs, N1: NewSel, N2: N1.getOperand(i: 1));
27536 NewBinOp->setFlags(N1->getFlags());
27537 NewBinOp->intersectFlagsWith(Flags: N2->getFlags());
27538 return SDValue(NewBinOp.getNode(), N1.getResNo());
27539 }
27540
27541 // Fold select(cond, binop(x, y), binop(x, z))
27542 // --> binop(x, select(cond, y, z))
27543 if (N1.getOperand(i: 0) == N2.getOperand(i: 0)) {
27544 SDValue N11 = N1.getOperand(i: 1);
27545 SDValue N21 = N2.getOperand(i: 1);
27546 // Second op VT might be different (e.g. shift amount type)
27547 if (N11.getValueType() == N21.getValueType()) {
27548 SDValue NewSel = DAG.getSelect(DL, VT: N11.getValueType(), Cond: N0, LHS: N11, RHS: N21);
27549 SDValue NewBinOp =
27550 DAG.getNode(Opcode: BinOpc, DL, VTList: OpVTs, N1: N1.getOperand(i: 0), N2: NewSel);
27551 NewBinOp->setFlags(N1->getFlags());
27552 NewBinOp->intersectFlagsWith(Flags: N2->getFlags());
27553 return SDValue(NewBinOp.getNode(), N1.getResNo());
27554 }
27555 }
27556
27557 // TODO: Handle isCommutativeBinOp patterns as well?
27558 return SDValue();
27559}
27560
27561// Transform (fneg/fabs (bitconvert x)) to avoid loading constant pool values.
27562SDValue DAGCombiner::foldSignChangeInBitcast(SDNode *N) {
27563 SDValue N0 = N->getOperand(Num: 0);
27564 EVT VT = N->getValueType(ResNo: 0);
27565 bool IsFabs = N->getOpcode() == ISD::FABS;
27566 bool IsFree = IsFabs ? TLI.isFAbsFree(VT) : TLI.isFNegFree(VT);
27567
27568 if (IsFree || N0.getOpcode() != ISD::BITCAST || !N0.hasOneUse())
27569 return SDValue();
27570
27571 SDValue Int = N0.getOperand(i: 0);
27572 EVT IntVT = Int.getValueType();
27573
27574 // The operand to cast should be integer.
27575 if (!IntVT.isInteger() || IntVT.isVector())
27576 return SDValue();
27577
27578 // (fneg (bitconvert x)) -> (bitconvert (xor x sign))
27579 // (fabs (bitconvert x)) -> (bitconvert (and x ~sign))
27580 APInt SignMask;
27581 if (N0.getValueType().isVector()) {
27582 // For vector, create a sign mask (0x80...) or its inverse (for fabs,
27583 // 0x7f...) per element and splat it.
27584 SignMask = APInt::getSignMask(BitWidth: N0.getScalarValueSizeInBits());
27585 if (IsFabs)
27586 SignMask = ~SignMask;
27587 SignMask = APInt::getSplat(NewLen: IntVT.getSizeInBits(), V: SignMask);
27588 } else {
27589 // For scalar, just use the sign mask (0x80... or the inverse, 0x7f...)
27590 SignMask = APInt::getSignMask(BitWidth: IntVT.getSizeInBits());
27591 if (IsFabs)
27592 SignMask = ~SignMask;
27593 }
27594 SDLoc DL(N0);
27595 Int = DAG.getNode(Opcode: IsFabs ? ISD::AND : ISD::XOR, DL, VT: IntVT, N1: Int,
27596 N2: DAG.getConstant(Val: SignMask, DL, VT: IntVT));
27597 AddToWorklist(N: Int.getNode());
27598 return DAG.getBitcast(VT, V: Int);
27599}
27600
27601/// Turn "(a cond b) ? 1.0f : 2.0f" into "load (tmp + ((a cond b) ? 0 : 4)"
27602/// where "tmp" is a constant pool entry containing an array with 1.0 and 2.0
27603/// in it. This may be a win when the constant is not otherwise available
27604/// because it replaces two constant pool loads with one.
27605SDValue DAGCombiner::convertSelectOfFPConstantsToLoadOffset(
27606 const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
27607 ISD::CondCode CC) {
27608 if (!TLI.reduceSelectOfFPConstantLoads(CmpOpVT: N0.getValueType()))
27609 return SDValue();
27610
27611 // If we are before legalize types, we want the other legalization to happen
27612 // first (for example, to avoid messing with soft float).
27613 auto *TV = dyn_cast<ConstantFPSDNode>(Val&: N2);
27614 auto *FV = dyn_cast<ConstantFPSDNode>(Val&: N3);
27615 EVT VT = N2.getValueType();
27616 if (!TV || !FV || !TLI.isTypeLegal(VT))
27617 return SDValue();
27618
27619 // If a constant can be materialized without loads, this does not make sense.
27620 if (TLI.getOperationAction(Op: ISD::ConstantFP, VT) == TargetLowering::Legal ||
27621 TLI.isFPImmLegal(TV->getValueAPF(), TV->getValueType(ResNo: 0), ForCodeSize) ||
27622 TLI.isFPImmLegal(FV->getValueAPF(), FV->getValueType(ResNo: 0), ForCodeSize))
27623 return SDValue();
27624
27625 // If both constants have multiple uses, then we won't need to do an extra
27626 // load. The values are likely around in registers for other users.
27627 if (!TV->hasOneUse() && !FV->hasOneUse())
27628 return SDValue();
27629
27630 Constant *Elts[] = { const_cast<ConstantFP*>(FV->getConstantFPValue()),
27631 const_cast<ConstantFP*>(TV->getConstantFPValue()) };
27632 Type *FPTy = Elts[0]->getType();
27633 const DataLayout &TD = DAG.getDataLayout();
27634
27635 // Create a ConstantArray of the two constants.
27636 Constant *CA = ConstantArray::get(T: ArrayType::get(ElementType: FPTy, NumElements: 2), V: Elts);
27637 SDValue CPIdx = DAG.getConstantPool(C: CA, VT: TLI.getPointerTy(DL: DAG.getDataLayout()),
27638 Align: TD.getPrefTypeAlign(Ty: FPTy));
27639 Align Alignment = cast<ConstantPoolSDNode>(Val&: CPIdx)->getAlign();
27640
27641 // Get offsets to the 0 and 1 elements of the array, so we can select between
27642 // them.
27643 SDValue Zero = DAG.getIntPtrConstant(Val: 0, DL);
27644 unsigned EltSize = (unsigned)TD.getTypeAllocSize(Ty: Elts[0]->getType());
27645 SDValue One = DAG.getIntPtrConstant(Val: EltSize, DL: SDLoc(FV));
27646 SDValue Cond =
27647 DAG.getSetCC(DL, VT: getSetCCResultType(VT: N0.getValueType()), LHS: N0, RHS: N1, Cond: CC);
27648 AddToWorklist(N: Cond.getNode());
27649 SDValue CstOffset = DAG.getSelect(DL, VT: Zero.getValueType(), Cond, LHS: One, RHS: Zero);
27650 AddToWorklist(N: CstOffset.getNode());
27651 CPIdx = DAG.getNode(Opcode: ISD::ADD, DL, VT: CPIdx.getValueType(), N1: CPIdx, N2: CstOffset);
27652 AddToWorklist(N: CPIdx.getNode());
27653 return DAG.getLoad(VT: TV->getValueType(ResNo: 0), dl: DL, Chain: DAG.getEntryNode(), Ptr: CPIdx,
27654 PtrInfo: MachinePointerInfo::getConstantPool(
27655 MF&: DAG.getMachineFunction()), Alignment);
27656}
27657
27658/// Simplify an expression of the form (N0 cond N1) ? N2 : N3
27659/// where 'cond' is the comparison specified by CC.
27660SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
27661 SDValue N2, SDValue N3, ISD::CondCode CC,
27662 bool NotExtCompare) {
27663 // (x ? y : y) -> y.
27664 if (N2 == N3) return N2;
27665
27666 EVT CmpOpVT = N0.getValueType();
27667 EVT CmpResVT = getSetCCResultType(VT: CmpOpVT);
27668 EVT VT = N2.getValueType();
27669 auto *N1C = dyn_cast<ConstantSDNode>(Val: N1.getNode());
27670 auto *N2C = dyn_cast<ConstantSDNode>(Val: N2.getNode());
27671 auto *N3C = dyn_cast<ConstantSDNode>(Val: N3.getNode());
27672
27673 // Determine if the condition we're dealing with is constant.
27674 if (SDValue SCC = DAG.FoldSetCC(VT: CmpResVT, N1: N0, N2: N1, Cond: CC, dl: DL)) {
27675 AddToWorklist(N: SCC.getNode());
27676 if (auto *SCCC = dyn_cast<ConstantSDNode>(Val&: SCC)) {
27677 // fold select_cc true, x, y -> x
27678 // fold select_cc false, x, y -> y
27679 return !(SCCC->isZero()) ? N2 : N3;
27680 }
27681 }
27682
27683 if (SDValue V =
27684 convertSelectOfFPConstantsToLoadOffset(DL, N0, N1, N2, N3, CC))
27685 return V;
27686
27687 if (SDValue V = foldSelectCCToShiftAnd(DL, N0, N1, N2, N3, CC))
27688 return V;
27689
27690 // fold (select_cc seteq (and x, y), 0, 0, A) -> (and (sra (shl x)) A)
27691 // where y is has a single bit set.
27692 // A plaintext description would be, we can turn the SELECT_CC into an AND
27693 // when the condition can be materialized as an all-ones register. Any
27694 // single bit-test can be materialized as an all-ones register with
27695 // shift-left and shift-right-arith.
27696 if (CC == ISD::SETEQ && N0->getOpcode() == ISD::AND &&
27697 N0->getValueType(ResNo: 0) == VT && isNullConstant(V: N1) && isNullConstant(V: N2)) {
27698 SDValue AndLHS = N0->getOperand(Num: 0);
27699 auto *ConstAndRHS = dyn_cast<ConstantSDNode>(Val: N0->getOperand(Num: 1));
27700 if (ConstAndRHS && ConstAndRHS->getAPIntValue().popcount() == 1) {
27701 // Shift the tested bit over the sign bit.
27702 const APInt &AndMask = ConstAndRHS->getAPIntValue();
27703 if (TLI.shouldFoldSelectWithSingleBitTest(VT, AndMask)) {
27704 unsigned ShCt = AndMask.getBitWidth() - 1;
27705 SDValue ShlAmt = DAG.getShiftAmountConstant(Val: AndMask.countl_zero(), VT,
27706 DL: SDLoc(AndLHS));
27707 SDValue Shl = DAG.getNode(Opcode: ISD::SHL, DL: SDLoc(N0), VT, N1: AndLHS, N2: ShlAmt);
27708
27709 // Now arithmetic right shift it all the way over, so the result is
27710 // either all-ones, or zero.
27711 SDValue ShrAmt = DAG.getShiftAmountConstant(Val: ShCt, VT, DL: SDLoc(Shl));
27712 SDValue Shr = DAG.getNode(Opcode: ISD::SRA, DL: SDLoc(N0), VT, N1: Shl, N2: ShrAmt);
27713
27714 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Shr, N2: N3);
27715 }
27716 }
27717 }
27718
27719 // fold select C, 16, 0 -> shl C, 4
27720 bool Fold = N2C && isNullConstant(V: N3) && N2C->getAPIntValue().isPowerOf2();
27721 bool Swap = N3C && isNullConstant(V: N2) && N3C->getAPIntValue().isPowerOf2();
27722
27723 if ((Fold || Swap) &&
27724 TLI.getBooleanContents(Type: CmpOpVT) ==
27725 TargetLowering::ZeroOrOneBooleanContent &&
27726 (!LegalOperations || TLI.isOperationLegal(Op: ISD::SETCC, VT: CmpOpVT))) {
27727
27728 if (Swap) {
27729 CC = ISD::getSetCCInverse(Operation: CC, Type: CmpOpVT);
27730 std::swap(a&: N2C, b&: N3C);
27731 }
27732
27733 // If the caller doesn't want us to simplify this into a zext of a compare,
27734 // don't do it.
27735 if (NotExtCompare && N2C->isOne())
27736 return SDValue();
27737
27738 SDValue Temp, SCC;
27739 // zext (setcc n0, n1)
27740 if (LegalTypes) {
27741 SCC = DAG.getSetCC(DL, VT: CmpResVT, LHS: N0, RHS: N1, Cond: CC);
27742 Temp = DAG.getZExtOrTrunc(Op: SCC, DL: SDLoc(N2), VT);
27743 } else {
27744 SCC = DAG.getSetCC(DL: SDLoc(N0), VT: MVT::i1, LHS: N0, RHS: N1, Cond: CC);
27745 Temp = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL: SDLoc(N2), VT, Operand: SCC);
27746 }
27747
27748 AddToWorklist(N: SCC.getNode());
27749 AddToWorklist(N: Temp.getNode());
27750
27751 if (N2C->isOne())
27752 return Temp;
27753
27754 unsigned ShCt = N2C->getAPIntValue().logBase2();
27755 if (TLI.shouldAvoidTransformToShift(VT, Amount: ShCt))
27756 return SDValue();
27757
27758 // shl setcc result by log2 n2c
27759 return DAG.getNode(
27760 Opcode: ISD::SHL, DL, VT: N2.getValueType(), N1: Temp,
27761 N2: DAG.getShiftAmountConstant(Val: ShCt, VT: N2.getValueType(), DL: SDLoc(Temp)));
27762 }
27763
27764 // select_cc seteq X, 0, sizeof(X), ctlz(X) -> ctlz(X)
27765 // select_cc seteq X, 0, sizeof(X), ctlz_zero_undef(X) -> ctlz(X)
27766 // select_cc seteq X, 0, sizeof(X), cttz(X) -> cttz(X)
27767 // select_cc seteq X, 0, sizeof(X), cttz_zero_undef(X) -> cttz(X)
27768 // select_cc setne X, 0, ctlz(X), sizeof(X) -> ctlz(X)
27769 // select_cc setne X, 0, ctlz_zero_undef(X), sizeof(X) -> ctlz(X)
27770 // select_cc setne X, 0, cttz(X), sizeof(X) -> cttz(X)
27771 // select_cc setne X, 0, cttz_zero_undef(X), sizeof(X) -> cttz(X)
27772 if (N1C && N1C->isZero() && (CC == ISD::SETEQ || CC == ISD::SETNE)) {
27773 SDValue ValueOnZero = N2;
27774 SDValue Count = N3;
27775 // If the condition is NE instead of E, swap the operands.
27776 if (CC == ISD::SETNE)
27777 std::swap(a&: ValueOnZero, b&: Count);
27778 // Check if the value on zero is a constant equal to the bits in the type.
27779 if (auto *ValueOnZeroC = dyn_cast<ConstantSDNode>(Val&: ValueOnZero)) {
27780 if (ValueOnZeroC->getAPIntValue() == VT.getSizeInBits()) {
27781 // If the other operand is cttz/cttz_zero_undef of N0, and cttz is
27782 // legal, combine to just cttz.
27783 if ((Count.getOpcode() == ISD::CTTZ ||
27784 Count.getOpcode() == ISD::CTTZ_ZERO_UNDEF) &&
27785 N0 == Count.getOperand(i: 0) &&
27786 (!LegalOperations || TLI.isOperationLegal(Op: ISD::CTTZ, VT)))
27787 return DAG.getNode(Opcode: ISD::CTTZ, DL, VT, Operand: N0);
27788 // If the other operand is ctlz/ctlz_zero_undef of N0, and ctlz is
27789 // legal, combine to just ctlz.
27790 if ((Count.getOpcode() == ISD::CTLZ ||
27791 Count.getOpcode() == ISD::CTLZ_ZERO_UNDEF) &&
27792 N0 == Count.getOperand(i: 0) &&
27793 (!LegalOperations || TLI.isOperationLegal(Op: ISD::CTLZ, VT)))
27794 return DAG.getNode(Opcode: ISD::CTLZ, DL, VT, Operand: N0);
27795 }
27796 }
27797 }
27798
27799 // Fold select_cc setgt X, -1, C, ~C -> xor (ashr X, BW-1), C
27800 // Fold select_cc setlt X, 0, C, ~C -> xor (ashr X, BW-1), ~C
27801 if (!NotExtCompare && N1C && N2C && N3C &&
27802 N2C->getAPIntValue() == ~N3C->getAPIntValue() &&
27803 ((N1C->isAllOnes() && CC == ISD::SETGT) ||
27804 (N1C->isZero() && CC == ISD::SETLT)) &&
27805 !TLI.shouldAvoidTransformToShift(VT, Amount: CmpOpVT.getScalarSizeInBits() - 1)) {
27806 SDValue ASR = DAG.getNode(
27807 Opcode: ISD::SRA, DL, VT: CmpOpVT, N1: N0,
27808 N2: DAG.getConstant(Val: CmpOpVT.getScalarSizeInBits() - 1, DL, VT: CmpOpVT));
27809 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1: DAG.getSExtOrTrunc(Op: ASR, DL, VT),
27810 N2: DAG.getSExtOrTrunc(Op: CC == ISD::SETLT ? N3 : N2, DL, VT));
27811 }
27812
27813 if (SDValue S = PerformMinMaxFpToSatCombine(N0, N1, N2, N3, CC, DAG))
27814 return S;
27815 if (SDValue S = PerformUMinFpToSatCombine(N0, N1, N2, N3, CC, DAG))
27816 return S;
27817
27818 return SDValue();
27819}
27820
27821/// This is a stub for TargetLowering::SimplifySetCC.
27822SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
27823 ISD::CondCode Cond, const SDLoc &DL,
27824 bool foldBooleans) {
27825 TargetLowering::DAGCombinerInfo
27826 DagCombineInfo(DAG, Level, false, this);
27827 return TLI.SimplifySetCC(VT, N0, N1, Cond, foldBooleans, DCI&: DagCombineInfo, dl: DL);
27828}
27829
27830/// Given an ISD::SDIV node expressing a divide by constant, return
27831/// a DAG expression to select that will generate the same value by multiplying
27832/// by a magic number.
27833/// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
27834SDValue DAGCombiner::BuildSDIV(SDNode *N) {
27835 // when optimising for minimum size, we don't want to expand a div to a mul
27836 // and a shift.
27837 if (DAG.getMachineFunction().getFunction().hasMinSize())
27838 return SDValue();
27839
27840 SmallVector<SDNode *, 8> Built;
27841 if (SDValue S = TLI.BuildSDIV(N, DAG, IsAfterLegalization: LegalOperations, Created&: Built)) {
27842 for (SDNode *N : Built)
27843 AddToWorklist(N);
27844 return S;
27845 }
27846
27847 return SDValue();
27848}
27849
27850/// Given an ISD::SDIV node expressing a divide by constant power of 2, return a
27851/// DAG expression that will generate the same value by right shifting.
27852SDValue DAGCombiner::BuildSDIVPow2(SDNode *N) {
27853 ConstantSDNode *C = isConstOrConstSplat(N: N->getOperand(Num: 1));
27854 if (!C)
27855 return SDValue();
27856
27857 // Avoid division by zero.
27858 if (C->isZero())
27859 return SDValue();
27860
27861 SmallVector<SDNode *, 8> Built;
27862 if (SDValue S = TLI.BuildSDIVPow2(N, Divisor: C->getAPIntValue(), DAG, Created&: Built)) {
27863 for (SDNode *N : Built)
27864 AddToWorklist(N);
27865 return S;
27866 }
27867
27868 return SDValue();
27869}
27870
27871/// Given an ISD::UDIV node expressing a divide by constant, return a DAG
27872/// expression that will generate the same value by multiplying by a magic
27873/// number.
27874/// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
27875SDValue DAGCombiner::BuildUDIV(SDNode *N) {
27876 // when optimising for minimum size, we don't want to expand a div to a mul
27877 // and a shift.
27878 if (DAG.getMachineFunction().getFunction().hasMinSize())
27879 return SDValue();
27880
27881 SmallVector<SDNode *, 8> Built;
27882 if (SDValue S = TLI.BuildUDIV(N, DAG, IsAfterLegalization: LegalOperations, Created&: Built)) {
27883 for (SDNode *N : Built)
27884 AddToWorklist(N);
27885 return S;
27886 }
27887
27888 return SDValue();
27889}
27890
27891/// Given an ISD::SREM node expressing a remainder by constant power of 2,
27892/// return a DAG expression that will generate the same value.
27893SDValue DAGCombiner::BuildSREMPow2(SDNode *N) {
27894 ConstantSDNode *C = isConstOrConstSplat(N: N->getOperand(Num: 1));
27895 if (!C)
27896 return SDValue();
27897
27898 // Avoid division by zero.
27899 if (C->isZero())
27900 return SDValue();
27901
27902 SmallVector<SDNode *, 8> Built;
27903 if (SDValue S = TLI.BuildSREMPow2(N, Divisor: C->getAPIntValue(), DAG, Created&: Built)) {
27904 for (SDNode *N : Built)
27905 AddToWorklist(N);
27906 return S;
27907 }
27908
27909 return SDValue();
27910}
27911
27912// This is basically just a port of takeLog2 from InstCombineMulDivRem.cpp
27913//
27914// Returns the node that represents `Log2(Op)`. This may create a new node. If
27915// we are unable to compute `Log2(Op)` its return `SDValue()`.
27916//
27917// All nodes will be created at `DL` and the output will be of type `VT`.
27918//
27919// This will only return `Log2(Op)` if we can prove `Op` is non-zero. Set
27920// `AssumeNonZero` if this function should simply assume (not require proving
27921// `Op` is non-zero).
27922static SDValue takeInexpensiveLog2(SelectionDAG &DAG, const SDLoc &DL, EVT VT,
27923 SDValue Op, unsigned Depth,
27924 bool AssumeNonZero) {
27925 assert(VT.isInteger() && "Only integer types are supported!");
27926
27927 auto PeekThroughCastsAndTrunc = [](SDValue V) {
27928 while (true) {
27929 switch (V.getOpcode()) {
27930 case ISD::TRUNCATE:
27931 case ISD::ZERO_EXTEND:
27932 V = V.getOperand(i: 0);
27933 break;
27934 default:
27935 return V;
27936 }
27937 }
27938 };
27939
27940 if (VT.isScalableVector())
27941 return SDValue();
27942
27943 Op = PeekThroughCastsAndTrunc(Op);
27944
27945 // Helper for determining whether a value is a power-2 constant scalar or a
27946 // vector of such elements.
27947 SmallVector<APInt> Pow2Constants;
27948 auto IsPowerOfTwo = [&Pow2Constants](ConstantSDNode *C) {
27949 if (C->isZero() || C->isOpaque())
27950 return false;
27951 // TODO: We may also be able to support negative powers of 2 here.
27952 if (C->getAPIntValue().isPowerOf2()) {
27953 Pow2Constants.emplace_back(Args: C->getAPIntValue());
27954 return true;
27955 }
27956 return false;
27957 };
27958
27959 if (ISD::matchUnaryPredicate(Op, Match: IsPowerOfTwo)) {
27960 if (!VT.isVector())
27961 return DAG.getConstant(Val: Pow2Constants.back().logBase2(), DL, VT);
27962 // We need to create a build vector
27963 if (Op.getOpcode() == ISD::SPLAT_VECTOR)
27964 return DAG.getSplat(VT, DL,
27965 Op: DAG.getConstant(Val: Pow2Constants.back().logBase2(), DL,
27966 VT: VT.getScalarType()));
27967 SmallVector<SDValue> Log2Ops;
27968 for (const APInt &Pow2 : Pow2Constants)
27969 Log2Ops.emplace_back(
27970 Args: DAG.getConstant(Val: Pow2.logBase2(), DL, VT: VT.getScalarType()));
27971 return DAG.getBuildVector(VT, DL, Ops: Log2Ops);
27972 }
27973
27974 if (Depth >= DAG.MaxRecursionDepth)
27975 return SDValue();
27976
27977 auto CastToVT = [&](EVT NewVT, SDValue ToCast) {
27978 ToCast = PeekThroughCastsAndTrunc(ToCast);
27979 EVT CurVT = ToCast.getValueType();
27980 if (NewVT == CurVT)
27981 return ToCast;
27982
27983 if (NewVT.getSizeInBits() == CurVT.getSizeInBits())
27984 return DAG.getBitcast(VT: NewVT, V: ToCast);
27985
27986 return DAG.getZExtOrTrunc(Op: ToCast, DL, VT: NewVT);
27987 };
27988
27989 // log2(X << Y) -> log2(X) + Y
27990 if (Op.getOpcode() == ISD::SHL) {
27991 // 1 << Y and X nuw/nsw << Y are all non-zero.
27992 if (AssumeNonZero || Op->getFlags().hasNoUnsignedWrap() ||
27993 Op->getFlags().hasNoSignedWrap() || isOneConstant(V: Op.getOperand(i: 0)))
27994 if (SDValue LogX = takeInexpensiveLog2(DAG, DL, VT, Op: Op.getOperand(i: 0),
27995 Depth: Depth + 1, AssumeNonZero))
27996 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: LogX,
27997 N2: CastToVT(VT, Op.getOperand(i: 1)));
27998 }
27999
28000 // c ? X : Y -> c ? Log2(X) : Log2(Y)
28001 if ((Op.getOpcode() == ISD::SELECT || Op.getOpcode() == ISD::VSELECT) &&
28002 Op.hasOneUse()) {
28003 if (SDValue LogX = takeInexpensiveLog2(DAG, DL, VT, Op: Op.getOperand(i: 1),
28004 Depth: Depth + 1, AssumeNonZero))
28005 if (SDValue LogY = takeInexpensiveLog2(DAG, DL, VT, Op: Op.getOperand(i: 2),
28006 Depth: Depth + 1, AssumeNonZero))
28007 return DAG.getSelect(DL, VT, Cond: Op.getOperand(i: 0), LHS: LogX, RHS: LogY);
28008 }
28009
28010 // log2(umin(X, Y)) -> umin(log2(X), log2(Y))
28011 // log2(umax(X, Y)) -> umax(log2(X), log2(Y))
28012 if ((Op.getOpcode() == ISD::UMIN || Op.getOpcode() == ISD::UMAX) &&
28013 Op.hasOneUse()) {
28014 // Use AssumeNonZero as false here. Otherwise we can hit case where
28015 // log2(umax(X, Y)) != umax(log2(X), log2(Y)) (because overflow).
28016 if (SDValue LogX =
28017 takeInexpensiveLog2(DAG, DL, VT, Op: Op.getOperand(i: 0), Depth: Depth + 1,
28018 /*AssumeNonZero*/ false))
28019 if (SDValue LogY =
28020 takeInexpensiveLog2(DAG, DL, VT, Op: Op.getOperand(i: 1), Depth: Depth + 1,
28021 /*AssumeNonZero*/ false))
28022 return DAG.getNode(Opcode: Op.getOpcode(), DL, VT, N1: LogX, N2: LogY);
28023 }
28024
28025 return SDValue();
28026}
28027
28028/// Determines the LogBase2 value for a non-null input value using the
28029/// transform: LogBase2(V) = (EltBits - 1) - ctlz(V).
28030SDValue DAGCombiner::BuildLogBase2(SDValue V, const SDLoc &DL,
28031 bool KnownNonZero, bool InexpensiveOnly,
28032 std::optional<EVT> OutVT) {
28033 EVT VT = OutVT ? *OutVT : V.getValueType();
28034 SDValue InexpensiveLogBase2 =
28035 takeInexpensiveLog2(DAG, DL, VT, Op: V, /*Depth*/ 0, AssumeNonZero: KnownNonZero);
28036 if (InexpensiveLogBase2 || InexpensiveOnly || !DAG.isKnownToBeAPowerOfTwo(Val: V))
28037 return InexpensiveLogBase2;
28038
28039 SDValue Ctlz = DAG.getNode(Opcode: ISD::CTLZ, DL, VT, Operand: V);
28040 SDValue Base = DAG.getConstant(Val: VT.getScalarSizeInBits() - 1, DL, VT);
28041 SDValue LogBase2 = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Base, N2: Ctlz);
28042 return LogBase2;
28043}
28044
28045/// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
28046/// For the reciprocal, we need to find the zero of the function:
28047/// F(X) = 1/X - A [which has a zero at X = 1/A]
28048/// =>
28049/// X_{i+1} = X_i (2 - A X_i) = X_i + X_i (1 - A X_i) [this second form
28050/// does not require additional intermediate precision]
28051/// For the last iteration, put numerator N into it to gain more precision:
28052/// Result = N X_i + X_i (N - N A X_i)
28053SDValue DAGCombiner::BuildDivEstimate(SDValue N, SDValue Op,
28054 SDNodeFlags Flags) {
28055 if (LegalDAG)
28056 return SDValue();
28057
28058 // TODO: Handle extended types?
28059 EVT VT = Op.getValueType();
28060 if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
28061 VT.getScalarType() != MVT::f64)
28062 return SDValue();
28063
28064 // If estimates are explicitly disabled for this function, we're done.
28065 MachineFunction &MF = DAG.getMachineFunction();
28066 int Enabled = TLI.getRecipEstimateDivEnabled(VT, MF);
28067 if (Enabled == TLI.ReciprocalEstimate::Disabled)
28068 return SDValue();
28069
28070 // Estimates may be explicitly enabled for this type with a custom number of
28071 // refinement steps.
28072 int Iterations = TLI.getDivRefinementSteps(VT, MF);
28073 if (SDValue Est = TLI.getRecipEstimate(Operand: Op, DAG, Enabled, RefinementSteps&: Iterations)) {
28074 AddToWorklist(N: Est.getNode());
28075
28076 SDLoc DL(Op);
28077 if (Iterations) {
28078 SDValue FPOne = DAG.getConstantFP(Val: 1.0, DL, VT);
28079
28080 // Newton iterations: Est = Est + Est (N - Arg * Est)
28081 // If this is the last iteration, also multiply by the numerator.
28082 for (int i = 0; i < Iterations; ++i) {
28083 SDValue MulEst = Est;
28084
28085 if (i == Iterations - 1) {
28086 MulEst = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N, N2: Est, Flags);
28087 AddToWorklist(N: MulEst.getNode());
28088 }
28089
28090 SDValue NewEst = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Op, N2: MulEst, Flags);
28091 AddToWorklist(N: NewEst.getNode());
28092
28093 NewEst = DAG.getNode(Opcode: ISD::FSUB, DL, VT,
28094 N1: (i == Iterations - 1 ? N : FPOne), N2: NewEst, Flags);
28095 AddToWorklist(N: NewEst.getNode());
28096
28097 NewEst = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Est, N2: NewEst, Flags);
28098 AddToWorklist(N: NewEst.getNode());
28099
28100 Est = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: MulEst, N2: NewEst, Flags);
28101 AddToWorklist(N: Est.getNode());
28102 }
28103 } else {
28104 // If no iterations are available, multiply with N.
28105 Est = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Est, N2: N, Flags);
28106 AddToWorklist(N: Est.getNode());
28107 }
28108
28109 return Est;
28110 }
28111
28112 return SDValue();
28113}
28114
28115/// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
28116/// For the reciprocal sqrt, we need to find the zero of the function:
28117/// F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
28118/// =>
28119/// X_{i+1} = X_i (1.5 - A X_i^2 / 2)
28120/// As a result, we precompute A/2 prior to the iteration loop.
28121SDValue DAGCombiner::buildSqrtNROneConst(SDValue Arg, SDValue Est,
28122 unsigned Iterations,
28123 SDNodeFlags Flags, bool Reciprocal) {
28124 EVT VT = Arg.getValueType();
28125 SDLoc DL(Arg);
28126 SDValue ThreeHalves = DAG.getConstantFP(Val: 1.5, DL, VT);
28127
28128 // We now need 0.5 * Arg which we can write as (1.5 * Arg - Arg) so that
28129 // this entire sequence requires only one FP constant.
28130 SDValue HalfArg = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: ThreeHalves, N2: Arg, Flags);
28131 HalfArg = DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1: HalfArg, N2: Arg, Flags);
28132
28133 // Newton iterations: Est = Est * (1.5 - HalfArg * Est * Est)
28134 for (unsigned i = 0; i < Iterations; ++i) {
28135 SDValue NewEst = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Est, N2: Est, Flags);
28136 NewEst = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: HalfArg, N2: NewEst, Flags);
28137 NewEst = DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1: ThreeHalves, N2: NewEst, Flags);
28138 Est = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Est, N2: NewEst, Flags);
28139 }
28140
28141 // If non-reciprocal square root is requested, multiply the result by Arg.
28142 if (!Reciprocal)
28143 Est = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Est, N2: Arg, Flags);
28144
28145 return Est;
28146}
28147
28148/// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
28149/// For the reciprocal sqrt, we need to find the zero of the function:
28150/// F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
28151/// =>
28152/// X_{i+1} = (-0.5 * X_i) * (A * X_i * X_i + (-3.0))
28153SDValue DAGCombiner::buildSqrtNRTwoConst(SDValue Arg, SDValue Est,
28154 unsigned Iterations,
28155 SDNodeFlags Flags, bool Reciprocal) {
28156 EVT VT = Arg.getValueType();
28157 SDLoc DL(Arg);
28158 SDValue MinusThree = DAG.getConstantFP(Val: -3.0, DL, VT);
28159 SDValue MinusHalf = DAG.getConstantFP(Val: -0.5, DL, VT);
28160
28161 // This routine must enter the loop below to work correctly
28162 // when (Reciprocal == false).
28163 assert(Iterations > 0);
28164
28165 // Newton iterations for reciprocal square root:
28166 // E = (E * -0.5) * ((A * E) * E + -3.0)
28167 for (unsigned i = 0; i < Iterations; ++i) {
28168 SDValue AE = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Arg, N2: Est, Flags);
28169 SDValue AEE = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: AE, N2: Est, Flags);
28170 SDValue RHS = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: AEE, N2: MinusThree, Flags);
28171
28172 // When calculating a square root at the last iteration build:
28173 // S = ((A * E) * -0.5) * ((A * E) * E + -3.0)
28174 // (notice a common subexpression)
28175 SDValue LHS;
28176 if (Reciprocal || (i + 1) < Iterations) {
28177 // RSQRT: LHS = (E * -0.5)
28178 LHS = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Est, N2: MinusHalf, Flags);
28179 } else {
28180 // SQRT: LHS = (A * E) * -0.5
28181 LHS = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: AE, N2: MinusHalf, Flags);
28182 }
28183
28184 Est = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: LHS, N2: RHS, Flags);
28185 }
28186
28187 return Est;
28188}
28189
28190/// Build code to calculate either rsqrt(Op) or sqrt(Op). In the latter case
28191/// Op*rsqrt(Op) is actually computed, so additional postprocessing is needed if
28192/// Op can be zero.
28193SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags,
28194 bool Reciprocal) {
28195 if (LegalDAG)
28196 return SDValue();
28197
28198 // TODO: Handle extended types?
28199 EVT VT = Op.getValueType();
28200 if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
28201 VT.getScalarType() != MVT::f64)
28202 return SDValue();
28203
28204 // If estimates are explicitly disabled for this function, we're done.
28205 MachineFunction &MF = DAG.getMachineFunction();
28206 int Enabled = TLI.getRecipEstimateSqrtEnabled(VT, MF);
28207 if (Enabled == TLI.ReciprocalEstimate::Disabled)
28208 return SDValue();
28209
28210 // Estimates may be explicitly enabled for this type with a custom number of
28211 // refinement steps.
28212 int Iterations = TLI.getSqrtRefinementSteps(VT, MF);
28213
28214 bool UseOneConstNR = false;
28215 if (SDValue Est =
28216 TLI.getSqrtEstimate(Operand: Op, DAG, Enabled, RefinementSteps&: Iterations, UseOneConstNR,
28217 Reciprocal)) {
28218 AddToWorklist(N: Est.getNode());
28219
28220 if (Iterations > 0)
28221 Est = UseOneConstNR
28222 ? buildSqrtNROneConst(Arg: Op, Est, Iterations, Flags, Reciprocal)
28223 : buildSqrtNRTwoConst(Arg: Op, Est, Iterations, Flags, Reciprocal);
28224 if (!Reciprocal) {
28225 SDLoc DL(Op);
28226 // Try the target specific test first.
28227 SDValue Test = TLI.getSqrtInputTest(Operand: Op, DAG, Mode: DAG.getDenormalMode(VT));
28228
28229 // The estimate is now completely wrong if the input was exactly 0.0 or
28230 // possibly a denormal. Force the answer to 0.0 or value provided by
28231 // target for those cases.
28232 Est = DAG.getNode(
28233 Opcode: Test.getValueType().isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT,
28234 N1: Test, N2: TLI.getSqrtResultForDenormInput(Operand: Op, DAG), N3: Est);
28235 }
28236 return Est;
28237 }
28238
28239 return SDValue();
28240}
28241
28242SDValue DAGCombiner::buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags) {
28243 return buildSqrtEstimateImpl(Op, Flags, Reciprocal: true);
28244}
28245
28246SDValue DAGCombiner::buildSqrtEstimate(SDValue Op, SDNodeFlags Flags) {
28247 return buildSqrtEstimateImpl(Op, Flags, Reciprocal: false);
28248}
28249
28250/// Return true if there is any possibility that the two addresses overlap.
28251bool DAGCombiner::mayAlias(SDNode *Op0, SDNode *Op1) const {
28252
28253 struct MemUseCharacteristics {
28254 bool IsVolatile;
28255 bool IsAtomic;
28256 SDValue BasePtr;
28257 int64_t Offset;
28258 LocationSize NumBytes;
28259 MachineMemOperand *MMO;
28260 };
28261
28262 auto getCharacteristics = [](SDNode *N) -> MemUseCharacteristics {
28263 if (const auto *LSN = dyn_cast<LSBaseSDNode>(Val: N)) {
28264 int64_t Offset = 0;
28265 if (auto *C = dyn_cast<ConstantSDNode>(Val: LSN->getOffset()))
28266 Offset = (LSN->getAddressingMode() == ISD::PRE_INC) ? C->getSExtValue()
28267 : (LSN->getAddressingMode() == ISD::PRE_DEC)
28268 ? -1 * C->getSExtValue()
28269 : 0;
28270 TypeSize Size = LSN->getMemoryVT().getStoreSize();
28271 return {.IsVolatile: LSN->isVolatile(), .IsAtomic: LSN->isAtomic(),
28272 .BasePtr: LSN->getBasePtr(), .Offset: Offset /*base offset*/,
28273 .NumBytes: LocationSize::precise(Value: Size), .MMO: LSN->getMemOperand()};
28274 }
28275 if (const auto *LN = cast<LifetimeSDNode>(Val: N))
28276 return {.IsVolatile: false /*isVolatile*/,
28277 /*isAtomic*/ .IsAtomic: false,
28278 .BasePtr: LN->getOperand(Num: 1),
28279 .Offset: (LN->hasOffset()) ? LN->getOffset() : 0,
28280 .NumBytes: (LN->hasOffset()) ? LocationSize::precise(Value: LN->getSize())
28281 : LocationSize::beforeOrAfterPointer(),
28282 .MMO: (MachineMemOperand *)nullptr};
28283 // Default.
28284 return {.IsVolatile: false /*isvolatile*/,
28285 /*isAtomic*/ .IsAtomic: false,
28286 .BasePtr: SDValue(),
28287 .Offset: (int64_t)0 /*offset*/,
28288 .NumBytes: LocationSize::beforeOrAfterPointer() /*size*/,
28289 .MMO: (MachineMemOperand *)nullptr};
28290 };
28291
28292 MemUseCharacteristics MUC0 = getCharacteristics(Op0),
28293 MUC1 = getCharacteristics(Op1);
28294
28295 // If they are to the same address, then they must be aliases.
28296 if (MUC0.BasePtr.getNode() && MUC0.BasePtr == MUC1.BasePtr &&
28297 MUC0.Offset == MUC1.Offset)
28298 return true;
28299
28300 // If they are both volatile then they cannot be reordered.
28301 if (MUC0.IsVolatile && MUC1.IsVolatile)
28302 return true;
28303
28304 // Be conservative about atomics for the moment
28305 // TODO: This is way overconservative for unordered atomics (see D66309)
28306 if (MUC0.IsAtomic && MUC1.IsAtomic)
28307 return true;
28308
28309 if (MUC0.MMO && MUC1.MMO) {
28310 if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
28311 (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
28312 return false;
28313 }
28314
28315 // If NumBytes is scalable and offset is not 0, conservatively return may
28316 // alias
28317 if ((MUC0.NumBytes.hasValue() && MUC0.NumBytes.isScalable() &&
28318 MUC0.Offset != 0) ||
28319 (MUC1.NumBytes.hasValue() && MUC1.NumBytes.isScalable() &&
28320 MUC1.Offset != 0))
28321 return true;
28322 // Try to prove that there is aliasing, or that there is no aliasing. Either
28323 // way, we can return now. If nothing can be proved, proceed with more tests.
28324 bool IsAlias;
28325 if (BaseIndexOffset::computeAliasing(Op0, NumBytes0: MUC0.NumBytes, Op1, NumBytes1: MUC1.NumBytes,
28326 DAG, IsAlias))
28327 return IsAlias;
28328
28329 // The following all rely on MMO0 and MMO1 being valid. Fail conservatively if
28330 // either are not known.
28331 if (!MUC0.MMO || !MUC1.MMO)
28332 return true;
28333
28334 // If one operation reads from invariant memory, and the other may store, they
28335 // cannot alias. These should really be checking the equivalent of mayWrite,
28336 // but it only matters for memory nodes other than load /store.
28337 if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
28338 (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
28339 return false;
28340
28341 // If we know required SrcValue1 and SrcValue2 have relatively large
28342 // alignment compared to the size and offset of the access, we may be able
28343 // to prove they do not alias. This check is conservative for now to catch
28344 // cases created by splitting vector types, it only works when the offsets are
28345 // multiples of the size of the data.
28346 int64_t SrcValOffset0 = MUC0.MMO->getOffset();
28347 int64_t SrcValOffset1 = MUC1.MMO->getOffset();
28348 Align OrigAlignment0 = MUC0.MMO->getBaseAlign();
28349 Align OrigAlignment1 = MUC1.MMO->getBaseAlign();
28350 LocationSize Size0 = MUC0.NumBytes;
28351 LocationSize Size1 = MUC1.NumBytes;
28352
28353 if (OrigAlignment0 == OrigAlignment1 && SrcValOffset0 != SrcValOffset1 &&
28354 Size0.hasValue() && Size1.hasValue() && !Size0.isScalable() &&
28355 !Size1.isScalable() && Size0 == Size1 &&
28356 OrigAlignment0 > Size0.getValue().getKnownMinValue() &&
28357 SrcValOffset0 % Size0.getValue().getKnownMinValue() == 0 &&
28358 SrcValOffset1 % Size1.getValue().getKnownMinValue() == 0) {
28359 int64_t OffAlign0 = SrcValOffset0 % OrigAlignment0.value();
28360 int64_t OffAlign1 = SrcValOffset1 % OrigAlignment1.value();
28361
28362 // There is no overlap between these relatively aligned accesses of
28363 // similar size. Return no alias.
28364 if ((OffAlign0 + static_cast<int64_t>(
28365 Size0.getValue().getKnownMinValue())) <= OffAlign1 ||
28366 (OffAlign1 + static_cast<int64_t>(
28367 Size1.getValue().getKnownMinValue())) <= OffAlign0)
28368 return false;
28369 }
28370
28371 bool UseAA = CombinerGlobalAA.getNumOccurrences() > 0
28372 ? CombinerGlobalAA
28373 : DAG.getSubtarget().useAA();
28374#ifndef NDEBUG
28375 if (CombinerAAOnlyFunc.getNumOccurrences() &&
28376 CombinerAAOnlyFunc != DAG.getMachineFunction().getName())
28377 UseAA = false;
28378#endif
28379
28380 if (UseAA && AA && MUC0.MMO->getValue() && MUC1.MMO->getValue() &&
28381 Size0.hasValue() && Size1.hasValue() &&
28382 // Can't represent a scalable size + fixed offset in LocationSize
28383 (!Size0.isScalable() || SrcValOffset0 == 0) &&
28384 (!Size1.isScalable() || SrcValOffset1 == 0)) {
28385 // Use alias analysis information.
28386 int64_t MinOffset = std::min(a: SrcValOffset0, b: SrcValOffset1);
28387 int64_t Overlap0 =
28388 Size0.getValue().getKnownMinValue() + SrcValOffset0 - MinOffset;
28389 int64_t Overlap1 =
28390 Size1.getValue().getKnownMinValue() + SrcValOffset1 - MinOffset;
28391 LocationSize Loc0 =
28392 Size0.isScalable() ? Size0 : LocationSize::precise(Value: Overlap0);
28393 LocationSize Loc1 =
28394 Size1.isScalable() ? Size1 : LocationSize::precise(Value: Overlap1);
28395 if (AA->isNoAlias(
28396 LocA: MemoryLocation(MUC0.MMO->getValue(), Loc0,
28397 UseTBAA ? MUC0.MMO->getAAInfo() : AAMDNodes()),
28398 LocB: MemoryLocation(MUC1.MMO->getValue(), Loc1,
28399 UseTBAA ? MUC1.MMO->getAAInfo() : AAMDNodes())))
28400 return false;
28401 }
28402
28403 // Otherwise we have to assume they alias.
28404 return true;
28405}
28406
28407/// Walk up chain skipping non-aliasing memory nodes,
28408/// looking for aliasing nodes and adding them to the Aliases vector.
28409void DAGCombiner::GatherAllAliases(SDNode *N, SDValue OriginalChain,
28410 SmallVectorImpl<SDValue> &Aliases) {
28411 SmallVector<SDValue, 8> Chains; // List of chains to visit.
28412 SmallPtrSet<SDNode *, 16> Visited; // Visited node set.
28413
28414 // Get alias information for node.
28415 // TODO: relax aliasing for unordered atomics (see D66309)
28416 const bool IsLoad = isa<LoadSDNode>(Val: N) && cast<LoadSDNode>(Val: N)->isSimple();
28417
28418 // Starting off.
28419 Chains.push_back(Elt: OriginalChain);
28420 unsigned Depth = 0;
28421
28422 // Attempt to improve chain by a single step
28423 auto ImproveChain = [&](SDValue &C) -> bool {
28424 switch (C.getOpcode()) {
28425 case ISD::EntryToken:
28426 // No need to mark EntryToken.
28427 C = SDValue();
28428 return true;
28429 case ISD::LOAD:
28430 case ISD::STORE: {
28431 // Get alias information for C.
28432 // TODO: Relax aliasing for unordered atomics (see D66309)
28433 bool IsOpLoad = isa<LoadSDNode>(Val: C.getNode()) &&
28434 cast<LSBaseSDNode>(Val: C.getNode())->isSimple();
28435 if ((IsLoad && IsOpLoad) || !mayAlias(Op0: N, Op1: C.getNode())) {
28436 // Look further up the chain.
28437 C = C.getOperand(i: 0);
28438 return true;
28439 }
28440 // Alias, so stop here.
28441 return false;
28442 }
28443
28444 case ISD::CopyFromReg:
28445 // Always forward past CopyFromReg.
28446 C = C.getOperand(i: 0);
28447 return true;
28448
28449 case ISD::LIFETIME_START:
28450 case ISD::LIFETIME_END: {
28451 // We can forward past any lifetime start/end that can be proven not to
28452 // alias the memory access.
28453 if (!mayAlias(Op0: N, Op1: C.getNode())) {
28454 // Look further up the chain.
28455 C = C.getOperand(i: 0);
28456 return true;
28457 }
28458 return false;
28459 }
28460 default:
28461 return false;
28462 }
28463 };
28464
28465 // Look at each chain and determine if it is an alias. If so, add it to the
28466 // aliases list. If not, then continue up the chain looking for the next
28467 // candidate.
28468 while (!Chains.empty()) {
28469 SDValue Chain = Chains.pop_back_val();
28470
28471 // Don't bother if we've seen Chain before.
28472 if (!Visited.insert(Ptr: Chain.getNode()).second)
28473 continue;
28474
28475 // For TokenFactor nodes, look at each operand and only continue up the
28476 // chain until we reach the depth limit.
28477 //
28478 // FIXME: The depth check could be made to return the last non-aliasing
28479 // chain we found before we hit a tokenfactor rather than the original
28480 // chain.
28481 if (Depth > TLI.getGatherAllAliasesMaxDepth()) {
28482 Aliases.clear();
28483 Aliases.push_back(Elt: OriginalChain);
28484 return;
28485 }
28486
28487 if (Chain.getOpcode() == ISD::TokenFactor) {
28488 // We have to check each of the operands of the token factor for "small"
28489 // token factors, so we queue them up. Adding the operands to the queue
28490 // (stack) in reverse order maintains the original order and increases the
28491 // likelihood that getNode will find a matching token factor (CSE.)
28492 if (Chain.getNumOperands() > 16) {
28493 Aliases.push_back(Elt: Chain);
28494 continue;
28495 }
28496 for (unsigned n = Chain.getNumOperands(); n;)
28497 Chains.push_back(Elt: Chain.getOperand(i: --n));
28498 ++Depth;
28499 continue;
28500 }
28501 // Everything else
28502 if (ImproveChain(Chain)) {
28503 // Updated Chain Found, Consider new chain if one exists.
28504 if (Chain.getNode())
28505 Chains.push_back(Elt: Chain);
28506 ++Depth;
28507 continue;
28508 }
28509 // No Improved Chain Possible, treat as Alias.
28510 Aliases.push_back(Elt: Chain);
28511 }
28512}
28513
28514/// Walk up chain skipping non-aliasing memory nodes, looking for a better chain
28515/// (aliasing node.)
28516SDValue DAGCombiner::FindBetterChain(SDNode *N, SDValue OldChain) {
28517 if (OptLevel == CodeGenOptLevel::None)
28518 return OldChain;
28519
28520 // Ops for replacing token factor.
28521 SmallVector<SDValue, 8> Aliases;
28522
28523 // Accumulate all the aliases to this node.
28524 GatherAllAliases(N, OriginalChain: OldChain, Aliases);
28525
28526 // If no operands then chain to entry token.
28527 if (Aliases.empty())
28528 return DAG.getEntryNode();
28529
28530 // If a single operand then chain to it. We don't need to revisit it.
28531 if (Aliases.size() == 1)
28532 return Aliases[0];
28533
28534 // Construct a custom tailored token factor.
28535 return DAG.getTokenFactor(DL: SDLoc(N), Vals&: Aliases);
28536}
28537
28538// This function tries to collect a bunch of potentially interesting
28539// nodes to improve the chains of, all at once. This might seem
28540// redundant, as this function gets called when visiting every store
28541// node, so why not let the work be done on each store as it's visited?
28542//
28543// I believe this is mainly important because mergeConsecutiveStores
28544// is unable to deal with merging stores of different sizes, so unless
28545// we improve the chains of all the potential candidates up-front
28546// before running mergeConsecutiveStores, it might only see some of
28547// the nodes that will eventually be candidates, and then not be able
28548// to go from a partially-merged state to the desired final
28549// fully-merged state.
28550
28551bool DAGCombiner::parallelizeChainedStores(StoreSDNode *St) {
28552 SmallVector<StoreSDNode *, 8> ChainedStores;
28553 StoreSDNode *STChain = St;
28554 // Intervals records which offsets from BaseIndex have been covered. In
28555 // the common case, every store writes to the immediately previous address
28556 // space and thus merged with the previous interval at insertion time.
28557
28558 using IMap = llvm::IntervalMap<int64_t, std::monostate, 8,
28559 IntervalMapHalfOpenInfo<int64_t>>;
28560 IMap::Allocator A;
28561 IMap Intervals(A);
28562
28563 // This holds the base pointer, index, and the offset in bytes from the base
28564 // pointer.
28565 const BaseIndexOffset BasePtr = BaseIndexOffset::match(N: St, DAG);
28566
28567 // We must have a base and an offset.
28568 if (!BasePtr.getBase().getNode())
28569 return false;
28570
28571 // Do not handle stores to undef base pointers.
28572 if (BasePtr.getBase().isUndef())
28573 return false;
28574
28575 // Do not handle stores to opaque types
28576 if (St->getMemoryVT().isZeroSized())
28577 return false;
28578
28579 // BaseIndexOffset assumes that offsets are fixed-size, which
28580 // is not valid for scalable vectors where the offsets are
28581 // scaled by `vscale`, so bail out early.
28582 if (St->getMemoryVT().isScalableVT())
28583 return false;
28584
28585 // Add ST's interval.
28586 Intervals.insert(a: 0, b: (St->getMemoryVT().getSizeInBits() + 7) / 8,
28587 y: std::monostate{});
28588
28589 while (StoreSDNode *Chain = dyn_cast<StoreSDNode>(Val: STChain->getChain())) {
28590 if (Chain->getMemoryVT().isScalableVector())
28591 return false;
28592
28593 // If the chain has more than one use, then we can't reorder the mem ops.
28594 if (!SDValue(Chain, 0)->hasOneUse())
28595 break;
28596 // TODO: Relax for unordered atomics (see D66309)
28597 if (!Chain->isSimple() || Chain->isIndexed())
28598 break;
28599
28600 // Find the base pointer and offset for this memory node.
28601 const BaseIndexOffset Ptr = BaseIndexOffset::match(N: Chain, DAG);
28602 // Check that the base pointer is the same as the original one.
28603 int64_t Offset;
28604 if (!BasePtr.equalBaseIndex(Other: Ptr, DAG, Off&: Offset))
28605 break;
28606 int64_t Length = (Chain->getMemoryVT().getSizeInBits() + 7) / 8;
28607 // Make sure we don't overlap with other intervals by checking the ones to
28608 // the left or right before inserting.
28609 auto I = Intervals.find(x: Offset);
28610 // If there's a next interval, we should end before it.
28611 if (I != Intervals.end() && I.start() < (Offset + Length))
28612 break;
28613 // If there's a previous interval, we should start after it.
28614 if (I != Intervals.begin() && (--I).stop() <= Offset)
28615 break;
28616 Intervals.insert(a: Offset, b: Offset + Length, y: std::monostate{});
28617
28618 ChainedStores.push_back(Elt: Chain);
28619 STChain = Chain;
28620 }
28621
28622 // If we didn't find a chained store, exit.
28623 if (ChainedStores.empty())
28624 return false;
28625
28626 // Improve all chained stores (St and ChainedStores members) starting from
28627 // where the store chain ended and return single TokenFactor.
28628 SDValue NewChain = STChain->getChain();
28629 SmallVector<SDValue, 8> TFOps;
28630 for (unsigned I = ChainedStores.size(); I;) {
28631 StoreSDNode *S = ChainedStores[--I];
28632 SDValue BetterChain = FindBetterChain(N: S, OldChain: NewChain);
28633 S = cast<StoreSDNode>(Val: DAG.UpdateNodeOperands(
28634 N: S, Op1: BetterChain, Op2: S->getOperand(Num: 1), Op3: S->getOperand(Num: 2), Op4: S->getOperand(Num: 3)));
28635 TFOps.push_back(Elt: SDValue(S, 0));
28636 ChainedStores[I] = S;
28637 }
28638
28639 // Improve St's chain. Use a new node to avoid creating a loop from CombineTo.
28640 SDValue BetterChain = FindBetterChain(N: St, OldChain: NewChain);
28641 SDValue NewST;
28642 if (St->isTruncatingStore())
28643 NewST = DAG.getTruncStore(Chain: BetterChain, dl: SDLoc(St), Val: St->getValue(),
28644 Ptr: St->getBasePtr(), SVT: St->getMemoryVT(),
28645 MMO: St->getMemOperand());
28646 else
28647 NewST = DAG.getStore(Chain: BetterChain, dl: SDLoc(St), Val: St->getValue(),
28648 Ptr: St->getBasePtr(), MMO: St->getMemOperand());
28649
28650 TFOps.push_back(Elt: NewST);
28651
28652 // If we improved every element of TFOps, then we've lost the dependence on
28653 // NewChain to successors of St and we need to add it back to TFOps. Do so at
28654 // the beginning to keep relative order consistent with FindBetterChains.
28655 auto hasImprovedChain = [&](SDValue ST) -> bool {
28656 return ST->getOperand(Num: 0) != NewChain;
28657 };
28658 bool AddNewChain = llvm::all_of(Range&: TFOps, P: hasImprovedChain);
28659 if (AddNewChain)
28660 TFOps.insert(I: TFOps.begin(), Elt: NewChain);
28661
28662 SDValue TF = DAG.getTokenFactor(DL: SDLoc(STChain), Vals&: TFOps);
28663 CombineTo(N: St, Res: TF);
28664
28665 // Add TF and its operands to the worklist.
28666 AddToWorklist(N: TF.getNode());
28667 for (const SDValue &Op : TF->ops())
28668 AddToWorklist(N: Op.getNode());
28669 AddToWorklist(N: STChain);
28670 return true;
28671}
28672
28673bool DAGCombiner::findBetterNeighborChains(StoreSDNode *St) {
28674 if (OptLevel == CodeGenOptLevel::None)
28675 return false;
28676
28677 const BaseIndexOffset BasePtr = BaseIndexOffset::match(N: St, DAG);
28678
28679 // We must have a base and an offset.
28680 if (!BasePtr.getBase().getNode())
28681 return false;
28682
28683 // Do not handle stores to undef base pointers.
28684 if (BasePtr.getBase().isUndef())
28685 return false;
28686
28687 // Directly improve a chain of disjoint stores starting at St.
28688 if (parallelizeChainedStores(St))
28689 return true;
28690
28691 // Improve St's Chain..
28692 SDValue BetterChain = FindBetterChain(N: St, OldChain: St->getChain());
28693 if (St->getChain() != BetterChain) {
28694 replaceStoreChain(ST: St, BetterChain);
28695 return true;
28696 }
28697 return false;
28698}
28699
28700/// This is the entry point for the file.
28701void SelectionDAG::Combine(CombineLevel Level, AliasAnalysis *AA,
28702 CodeGenOptLevel OptLevel) {
28703 /// This is the main entry point to this class.
28704 DAGCombiner(*this, AA, OptLevel).Run(AtLevel: Level);
28705}
28706