1//===- DAGCombiner.cpp - Implement a DAG node combiner --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass combines dag nodes to form fewer, simpler DAG nodes. It can be run
10// both before and after the DAG is legalized.
11//
12// This pass is not a substitute for the LLVM IR instcombine pass. This pass is
13// primarily intended to handle simplification opportunities that are implicit
14// in the LLVM IR and exposed by the various codegen lowering phases.
15//
16//===----------------------------------------------------------------------===//
17
18#include "llvm/ADT/APFloat.h"
19#include "llvm/ADT/APInt.h"
20#include "llvm/ADT/APSInt.h"
21#include "llvm/ADT/ArrayRef.h"
22#include "llvm/ADT/DenseMap.h"
23#include "llvm/ADT/IntervalMap.h"
24#include "llvm/ADT/STLExtras.h"
25#include "llvm/ADT/SetVector.h"
26#include "llvm/ADT/SmallBitVector.h"
27#include "llvm/ADT/SmallPtrSet.h"
28#include "llvm/ADT/SmallSet.h"
29#include "llvm/ADT/SmallVector.h"
30#include "llvm/ADT/Statistic.h"
31#include "llvm/Analysis/AliasAnalysis.h"
32#include "llvm/Analysis/MemoryLocation.h"
33#include "llvm/Analysis/TargetLibraryInfo.h"
34#include "llvm/Analysis/ValueTracking.h"
35#include "llvm/Analysis/VectorUtils.h"
36#include "llvm/CodeGen/ByteProvider.h"
37#include "llvm/CodeGen/DAGCombine.h"
38#include "llvm/CodeGen/ISDOpcodes.h"
39#include "llvm/CodeGen/MachineFrameInfo.h"
40#include "llvm/CodeGen/MachineFunction.h"
41#include "llvm/CodeGen/MachineMemOperand.h"
42#include "llvm/CodeGen/SDPatternMatch.h"
43#include "llvm/CodeGen/SelectionDAG.h"
44#include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
45#include "llvm/CodeGen/SelectionDAGNodes.h"
46#include "llvm/CodeGen/SelectionDAGTargetInfo.h"
47#include "llvm/CodeGen/TargetLowering.h"
48#include "llvm/CodeGen/TargetRegisterInfo.h"
49#include "llvm/CodeGen/TargetSubtargetInfo.h"
50#include "llvm/CodeGen/ValueTypes.h"
51#include "llvm/CodeGenTypes/MachineValueType.h"
52#include "llvm/IR/Attributes.h"
53#include "llvm/IR/Constant.h"
54#include "llvm/IR/DataLayout.h"
55#include "llvm/IR/DerivedTypes.h"
56#include "llvm/IR/Function.h"
57#include "llvm/IR/Metadata.h"
58#include "llvm/Support/Casting.h"
59#include "llvm/Support/CodeGen.h"
60#include "llvm/Support/CommandLine.h"
61#include "llvm/Support/Compiler.h"
62#include "llvm/Support/Debug.h"
63#include "llvm/Support/DebugCounter.h"
64#include "llvm/Support/ErrorHandling.h"
65#include "llvm/Support/KnownBits.h"
66#include "llvm/Support/MathExtras.h"
67#include "llvm/Support/raw_ostream.h"
68#include "llvm/Target/TargetMachine.h"
69#include "llvm/Target/TargetOptions.h"
70#include <algorithm>
71#include <cassert>
72#include <cstdint>
73#include <functional>
74#include <iterator>
75#include <optional>
76#include <string>
77#include <tuple>
78#include <utility>
79#include <variant>
80
81#include "MatchContext.h"
82
83using namespace llvm;
84using namespace llvm::SDPatternMatch;
85
86#define DEBUG_TYPE "dagcombine"
87
88STATISTIC(NodesCombined , "Number of dag nodes combined");
89STATISTIC(PreIndexedNodes , "Number of pre-indexed nodes created");
90STATISTIC(PostIndexedNodes, "Number of post-indexed nodes created");
91STATISTIC(OpsNarrowed , "Number of load/op/store narrowed");
92STATISTIC(LdStFP2Int , "Number of fp load/store pairs transformed to int");
93STATISTIC(SlicedLoads, "Number of load sliced");
94STATISTIC(NumFPLogicOpsConv, "Number of logic ops converted to fp ops");
95
96DEBUG_COUNTER(DAGCombineCounter, "dagcombine",
97 "Controls whether a DAG combine is performed for a node");
98
99static cl::opt<bool>
100CombinerGlobalAA("combiner-global-alias-analysis", cl::Hidden,
101 cl::desc("Enable DAG combiner's use of IR alias analysis"));
102
103static cl::opt<bool>
104UseTBAA("combiner-use-tbaa", cl::Hidden, cl::init(Val: true),
105 cl::desc("Enable DAG combiner's use of TBAA"));
106
107#ifndef NDEBUG
108static cl::opt<std::string>
109CombinerAAOnlyFunc("combiner-aa-only-func", cl::Hidden,
110 cl::desc("Only use DAG-combiner alias analysis in this"
111 " function"));
112#endif
113
114/// Hidden option to stress test load slicing, i.e., when this option
115/// is enabled, load slicing bypasses most of its profitability guards.
116static cl::opt<bool>
117StressLoadSlicing("combiner-stress-load-slicing", cl::Hidden,
118 cl::desc("Bypass the profitability model of load slicing"),
119 cl::init(Val: false));
120
121static cl::opt<bool>
122 MaySplitLoadIndex("combiner-split-load-index", cl::Hidden, cl::init(Val: true),
123 cl::desc("DAG combiner may split indexing from loads"));
124
125static cl::opt<bool>
126 EnableStoreMerging("combiner-store-merging", cl::Hidden, cl::init(Val: true),
127 cl::desc("DAG combiner enable merging multiple stores "
128 "into a wider store"));
129
130static cl::opt<unsigned> TokenFactorInlineLimit(
131 "combiner-tokenfactor-inline-limit", cl::Hidden, cl::init(Val: 2048),
132 cl::desc("Limit the number of operands to inline for Token Factors"));
133
134static cl::opt<unsigned> StoreMergeDependenceLimit(
135 "combiner-store-merge-dependence-limit", cl::Hidden, cl::init(Val: 10),
136 cl::desc("Limit the number of times for the same StoreNode and RootNode "
137 "to bail out in store merging dependence check"));
138
139static cl::opt<bool> EnableReduceLoadOpStoreWidth(
140 "combiner-reduce-load-op-store-width", cl::Hidden, cl::init(Val: true),
141 cl::desc("DAG combiner enable reducing the width of load/op/store "
142 "sequence"));
143static cl::opt<bool> ReduceLoadOpStoreWidthForceNarrowingProfitable(
144 "combiner-reduce-load-op-store-width-force-narrowing-profitable",
145 cl::Hidden, cl::init(Val: false),
146 cl::desc("DAG combiner force override the narrowing profitable check when "
147 "reducing the width of load/op/store sequences"));
148
149static cl::opt<bool> EnableShrinkLoadReplaceStoreWithStore(
150 "combiner-shrink-load-replace-store-with-store", cl::Hidden, cl::init(Val: true),
151 cl::desc("DAG combiner enable load/<replace bytes>/store with "
152 "a narrower store"));
153
154static cl::opt<bool> DisableCombines("combiner-disabled", cl::Hidden,
155 cl::init(Val: false),
156 cl::desc("Disable the DAG combiner"));
157
158namespace {
159
160 class DAGCombiner {
161 SelectionDAG &DAG;
162 const TargetLowering &TLI;
163 const SelectionDAGTargetInfo *STI;
164 CombineLevel Level = BeforeLegalizeTypes;
165 CodeGenOptLevel OptLevel;
166 bool LegalDAG = false;
167 bool LegalOperations = false;
168 bool LegalTypes = false;
169 bool ForCodeSize;
170 bool DisableGenericCombines;
171
172 /// Worklist of all of the nodes that need to be simplified.
173 ///
174 /// This must behave as a stack -- new nodes to process are pushed onto the
175 /// back and when processing we pop off of the back.
176 ///
177 /// The worklist will not contain duplicates but may contain null entries
178 /// due to nodes being deleted from the underlying DAG. For fast lookup and
179 /// deduplication, the index of the node in this vector is stored in the
180 /// node in SDNode::CombinerWorklistIndex.
181 SmallVector<SDNode *, 64> Worklist;
182
183 /// This records all nodes attempted to be added to the worklist since we
184 /// considered a new worklist entry. As we keep do not add duplicate nodes
185 /// in the worklist, this is different from the tail of the worklist.
186 SmallSetVector<SDNode *, 32> PruningList;
187
188 /// Map from candidate StoreNode to the pair of RootNode and count.
189 /// The count is used to track how many times we have seen the StoreNode
190 /// with the same RootNode bail out in dependence check. If we have seen
191 /// the bail out for the same pair many times over a limit, we won't
192 /// consider the StoreNode with the same RootNode as store merging
193 /// candidate again.
194 DenseMap<SDNode *, std::pair<SDNode *, unsigned>> StoreRootCountMap;
195
196 // BatchAA - Used for DAG load/store alias analysis.
197 BatchAAResults *BatchAA;
198
199 /// This caches all chains that have already been processed in
200 /// DAGCombiner::getStoreMergeCandidates() and found to have no mergeable
201 /// stores candidates.
202 SmallPtrSet<SDNode *, 4> ChainsWithoutMergeableStores;
203
204 /// When an instruction is simplified, add all users of the instruction to
205 /// the work lists because they might get more simplified now.
206 void AddUsersToWorklist(SDNode *N) {
207 for (SDNode *Node : N->users())
208 AddToWorklist(N: Node);
209 }
210
211 /// Convenient shorthand to add a node and all of its user to the worklist.
212 void AddToWorklistWithUsers(SDNode *N) {
213 AddUsersToWorklist(N);
214 AddToWorklist(N);
215 }
216
217 // Prune potentially dangling nodes. This is called after
218 // any visit to a node, but should also be called during a visit after any
219 // failed combine which may have created a DAG node.
220 void clearAddedDanglingWorklistEntries() {
221 // Check any nodes added to the worklist to see if they are prunable.
222 while (!PruningList.empty()) {
223 auto *N = PruningList.pop_back_val();
224 if (N->use_empty())
225 recursivelyDeleteUnusedNodes(N);
226 }
227 }
228
229 SDNode *getNextWorklistEntry() {
230 // Before we do any work, remove nodes that are not in use.
231 clearAddedDanglingWorklistEntries();
232 SDNode *N = nullptr;
233 // The Worklist holds the SDNodes in order, but it may contain null
234 // entries.
235 while (!N && !Worklist.empty()) {
236 N = Worklist.pop_back_val();
237 }
238
239 if (N) {
240 assert(N->getCombinerWorklistIndex() >= 0 &&
241 "Found a worklist entry without a corresponding map entry!");
242 // Set to -2 to indicate that we combined the node.
243 N->setCombinerWorklistIndex(-2);
244 }
245 return N;
246 }
247
248 /// Call the node-specific routine that folds each particular type of node.
249 SDValue visit(SDNode *N);
250
251 public:
252 DAGCombiner(SelectionDAG &D, BatchAAResults *BatchAA, CodeGenOptLevel OL)
253 : DAG(D), TLI(D.getTargetLoweringInfo()),
254 STI(D.getSubtarget().getSelectionDAGInfo()), OptLevel(OL),
255 BatchAA(BatchAA) {
256 ForCodeSize = DAG.shouldOptForSize();
257 DisableGenericCombines =
258 DisableCombines || (STI && STI->disableGenericCombines(OptLevel));
259
260 MaximumLegalStoreInBits = 0;
261 // We use the minimum store size here, since that's all we can guarantee
262 // for the scalable vector types.
263 for (MVT VT : MVT::all_valuetypes())
264 if (EVT(VT).isSimple() && VT != MVT::Other &&
265 TLI.isTypeLegal(VT: EVT(VT)) &&
266 VT.getSizeInBits().getKnownMinValue() >= MaximumLegalStoreInBits)
267 MaximumLegalStoreInBits = VT.getSizeInBits().getKnownMinValue();
268 }
269
270 void ConsiderForPruning(SDNode *N) {
271 // Mark this for potential pruning.
272 PruningList.insert(X: N);
273 }
274
275 /// Add to the worklist making sure its instance is at the back (next to be
276 /// processed.)
277 void AddToWorklist(SDNode *N, bool IsCandidateForPruning = true,
278 bool SkipIfCombinedBefore = false) {
279 assert(N->getOpcode() != ISD::DELETED_NODE &&
280 "Deleted Node added to Worklist");
281
282 // Skip handle nodes as they can't usefully be combined and confuse the
283 // zero-use deletion strategy.
284 if (N->getOpcode() == ISD::HANDLENODE)
285 return;
286
287 if (SkipIfCombinedBefore && N->getCombinerWorklistIndex() == -2)
288 return;
289
290 if (IsCandidateForPruning)
291 ConsiderForPruning(N);
292
293 if (N->getCombinerWorklistIndex() < 0) {
294 N->setCombinerWorklistIndex(Worklist.size());
295 Worklist.push_back(Elt: N);
296 }
297 }
298
299 /// Remove all instances of N from the worklist.
300 void removeFromWorklist(SDNode *N) {
301 PruningList.remove(X: N);
302 StoreRootCountMap.erase(Val: N);
303
304 int WorklistIndex = N->getCombinerWorklistIndex();
305 // If not in the worklist, the index might be -1 or -2 (was combined
306 // before). As the node gets deleted anyway, there's no need to update
307 // the index.
308 if (WorklistIndex < 0)
309 return; // Not in the worklist.
310
311 // Null out the entry rather than erasing it to avoid a linear operation.
312 Worklist[WorklistIndex] = nullptr;
313 N->setCombinerWorklistIndex(-1);
314 }
315
316 void deleteAndRecombine(SDNode *N);
317 bool recursivelyDeleteUnusedNodes(SDNode *N);
318
319 /// Replaces all uses of the results of one DAG node with new values.
320 SDValue CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
321 bool AddTo = true);
322
323 /// Replaces all uses of the results of one DAG node with new values.
324 SDValue CombineTo(SDNode *N, SDValue Res, bool AddTo = true) {
325 return CombineTo(N, To: &Res, NumTo: 1, AddTo);
326 }
327
328 /// Replaces all uses of the results of one DAG node with new values.
329 SDValue CombineTo(SDNode *N, SDValue Res0, SDValue Res1,
330 bool AddTo = true) {
331 SDValue To[] = { Res0, Res1 };
332 return CombineTo(N, To, NumTo: 2, AddTo);
333 }
334
335 SDValue CombineTo(SDNode *N, SmallVectorImpl<SDValue> *To,
336 bool AddTo = true) {
337 return CombineTo(N, To: To->data(), NumTo: To->size(), AddTo);
338 }
339
340 void CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO);
341
342 private:
343 unsigned MaximumLegalStoreInBits;
344
345 /// Check the specified integer node value to see if it can be simplified or
346 /// if things it uses can be simplified by bit propagation.
347 /// If so, return true.
348 bool SimplifyDemandedBits(SDValue Op) {
349 unsigned BitWidth = Op.getScalarValueSizeInBits();
350 APInt DemandedBits = APInt::getAllOnes(numBits: BitWidth);
351 return SimplifyDemandedBits(Op, DemandedBits);
352 }
353
354 bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits) {
355 EVT VT = Op.getValueType();
356 APInt DemandedElts = VT.isFixedLengthVector()
357 ? APInt::getAllOnes(numBits: VT.getVectorNumElements())
358 : APInt(1, 1);
359 return SimplifyDemandedBits(Op, DemandedBits, DemandedElts, AssumeSingleUse: false);
360 }
361
362 /// Check the specified vector node value to see if it can be simplified or
363 /// if things it uses can be simplified as it only uses some of the
364 /// elements. If so, return true.
365 bool SimplifyDemandedVectorElts(SDValue Op) {
366 // TODO: For now just pretend it cannot be simplified.
367 if (Op.getValueType().isScalableVector())
368 return false;
369
370 unsigned NumElts = Op.getValueType().getVectorNumElements();
371 APInt DemandedElts = APInt::getAllOnes(numBits: NumElts);
372 return SimplifyDemandedVectorElts(Op, DemandedElts);
373 }
374
375 bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
376 const APInt &DemandedElts,
377 bool AssumeSingleUse = false);
378 bool SimplifyDemandedVectorElts(SDValue Op, const APInt &DemandedElts,
379 bool AssumeSingleUse = false);
380
381 bool CombineToPreIndexedLoadStore(SDNode *N);
382 bool CombineToPostIndexedLoadStore(SDNode *N);
383 SDValue SplitIndexingFromLoad(LoadSDNode *LD);
384 bool SliceUpLoad(SDNode *N);
385
386 // Looks up the chain to find a unique (unaliased) store feeding the passed
387 // load. If no such store is found, returns a nullptr.
388 // Note: This will look past a CALLSEQ_START if the load is chained to it so
389 // so that it can find stack stores for byval params.
390 StoreSDNode *getUniqueStoreFeeding(LoadSDNode *LD, int64_t &Offset);
391 // Scalars have size 0 to distinguish from singleton vectors.
392 SDValue ForwardStoreValueToDirectLoad(LoadSDNode *LD);
393 bool getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val);
394 bool extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val);
395
396 void ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad);
397 SDValue PromoteOperand(SDValue Op, EVT PVT, bool &Replace);
398 SDValue SExtPromoteOperand(SDValue Op, EVT PVT);
399 SDValue ZExtPromoteOperand(SDValue Op, EVT PVT);
400 SDValue PromoteIntBinOp(SDValue Op);
401 SDValue PromoteIntShiftOp(SDValue Op);
402 SDValue PromoteExtend(SDValue Op);
403 bool PromoteLoad(SDValue Op);
404
405 SDValue foldShiftToAvg(SDNode *N, const SDLoc &DL);
406 // Fold `a bitwiseop (~b +/- c)` -> `a bitwiseop ~(b -/+ c)`
407 SDValue foldBitwiseOpWithNeg(SDNode *N, const SDLoc &DL, EVT VT);
408
409 SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
410 SDValue RHS, SDValue True, SDValue False,
411 ISD::CondCode CC);
412
413 /// Call the node-specific routine that knows how to fold each
414 /// particular type of node. If that doesn't do anything, try the
415 /// target-specific DAG combines.
416 SDValue combine(SDNode *N);
417
418 // Visitation implementation - Implement dag node combining for different
419 // node types. The semantics are as follows:
420 // Return Value:
421 // SDValue.getNode() == 0 - No change was made
422 // SDValue.getNode() == N - N was replaced, is dead and has been handled.
423 // otherwise - N should be replaced by the returned Operand.
424 //
425 SDValue visitTokenFactor(SDNode *N);
426 SDValue visitMERGE_VALUES(SDNode *N);
427 SDValue visitADD(SDNode *N);
428 SDValue visitADDLike(SDNode *N);
429 SDValue visitADDLikeCommutative(SDValue N0, SDValue N1,
430 SDNode *LocReference);
431 SDValue visitPTRADD(SDNode *N);
432 SDValue visitSUB(SDNode *N);
433 SDValue visitADDSAT(SDNode *N);
434 SDValue visitSUBSAT(SDNode *N);
435 SDValue visitADDC(SDNode *N);
436 SDValue visitADDO(SDNode *N);
437 SDValue visitUADDOLike(SDValue N0, SDValue N1, SDNode *N);
438 SDValue visitSUBC(SDNode *N);
439 SDValue visitSUBO(SDNode *N);
440 SDValue visitADDE(SDNode *N);
441 SDValue visitUADDO_CARRY(SDNode *N);
442 SDValue visitSADDO_CARRY(SDNode *N);
443 SDValue visitUADDO_CARRYLike(SDValue N0, SDValue N1, SDValue CarryIn,
444 SDNode *N);
445 SDValue visitSADDO_CARRYLike(SDValue N0, SDValue N1, SDValue CarryIn,
446 SDNode *N);
447 SDValue visitSUBE(SDNode *N);
448 SDValue visitUSUBO_CARRY(SDNode *N);
449 SDValue visitSSUBO_CARRY(SDNode *N);
450 template <class MatchContextClass> SDValue visitMUL(SDNode *N);
451 SDValue visitMULFIX(SDNode *N);
452 SDValue useDivRem(SDNode *N);
453 SDValue visitSDIV(SDNode *N);
454 SDValue visitSDIVLike(SDValue N0, SDValue N1, SDNode *N);
455 SDValue visitUDIV(SDNode *N);
456 SDValue visitUDIVLike(SDValue N0, SDValue N1, SDNode *N);
457 SDValue visitREM(SDNode *N);
458 SDValue visitMULHU(SDNode *N);
459 SDValue visitMULHS(SDNode *N);
460 SDValue visitAVG(SDNode *N);
461 SDValue visitABD(SDNode *N);
462 SDValue visitSMUL_LOHI(SDNode *N);
463 SDValue visitUMUL_LOHI(SDNode *N);
464 SDValue visitMULO(SDNode *N);
465 SDValue visitIMINMAX(SDNode *N);
466 SDValue visitAND(SDNode *N);
467 SDValue visitANDLike(SDValue N0, SDValue N1, SDNode *N);
468 SDValue visitOR(SDNode *N);
469 SDValue visitORLike(SDValue N0, SDValue N1, const SDLoc &DL);
470 SDValue visitXOR(SDNode *N);
471 SDValue SimplifyVCastOp(SDNode *N, const SDLoc &DL);
472 SDValue SimplifyVBinOp(SDNode *N, const SDLoc &DL);
473 SDValue visitSHL(SDNode *N);
474 SDValue visitSRA(SDNode *N);
475 SDValue visitSRL(SDNode *N);
476 SDValue visitFunnelShift(SDNode *N);
477 SDValue visitSHLSAT(SDNode *N);
478 SDValue visitRotate(SDNode *N);
479 SDValue visitABS(SDNode *N);
480 SDValue visitCLMUL(SDNode *N);
481 SDValue visitBSWAP(SDNode *N);
482 SDValue visitBITREVERSE(SDNode *N);
483 SDValue visitCTLZ(SDNode *N);
484 SDValue visitCTLZ_ZERO_UNDEF(SDNode *N);
485 SDValue visitCTTZ(SDNode *N);
486 SDValue visitCTTZ_ZERO_UNDEF(SDNode *N);
487 SDValue visitCTPOP(SDNode *N);
488 SDValue visitSELECT(SDNode *N);
489 SDValue visitVSELECT(SDNode *N);
490 SDValue visitVP_SELECT(SDNode *N);
491 SDValue visitSELECT_CC(SDNode *N);
492 SDValue visitSETCC(SDNode *N);
493 SDValue visitSETCCCARRY(SDNode *N);
494 SDValue visitSIGN_EXTEND(SDNode *N);
495 SDValue visitZERO_EXTEND(SDNode *N);
496 SDValue visitANY_EXTEND(SDNode *N);
497 SDValue visitAssertExt(SDNode *N);
498 SDValue visitAssertAlign(SDNode *N);
499 SDValue visitSIGN_EXTEND_INREG(SDNode *N);
500 SDValue visitEXTEND_VECTOR_INREG(SDNode *N);
501 SDValue visitTRUNCATE(SDNode *N);
502 SDValue visitTRUNCATE_USAT_U(SDNode *N);
503 SDValue visitBITCAST(SDNode *N);
504 SDValue visitFREEZE(SDNode *N);
505 SDValue visitBUILD_PAIR(SDNode *N);
506 SDValue visitFADD(SDNode *N);
507 SDValue visitVP_FADD(SDNode *N);
508 SDValue visitVP_FSUB(SDNode *N);
509 SDValue visitSTRICT_FADD(SDNode *N);
510 SDValue visitFSUB(SDNode *N);
511 SDValue visitFMUL(SDNode *N);
512 template <class MatchContextClass> SDValue visitFMA(SDNode *N);
513 SDValue visitFMAD(SDNode *N);
514 SDValue visitFMULADD(SDNode *N);
515 SDValue visitFDIV(SDNode *N);
516 SDValue visitFREM(SDNode *N);
517 SDValue visitFSQRT(SDNode *N);
518 SDValue visitFCOPYSIGN(SDNode *N);
519 SDValue visitFPOW(SDNode *N);
520 SDValue visitFCANONICALIZE(SDNode *N);
521 SDValue visitSINT_TO_FP(SDNode *N);
522 SDValue visitUINT_TO_FP(SDNode *N);
523 SDValue visitFP_TO_SINT(SDNode *N);
524 SDValue visitFP_TO_UINT(SDNode *N);
525 SDValue visitXROUND(SDNode *N);
526 SDValue visitFP_ROUND(SDNode *N);
527 SDValue visitFP_EXTEND(SDNode *N);
528 SDValue visitFNEG(SDNode *N);
529 SDValue visitFABS(SDNode *N);
530 SDValue visitFCEIL(SDNode *N);
531 SDValue visitFTRUNC(SDNode *N);
532 SDValue visitFFREXP(SDNode *N);
533 SDValue visitFFLOOR(SDNode *N);
534 SDValue visitFMinMax(SDNode *N);
535 SDValue visitBRCOND(SDNode *N);
536 SDValue visitBR_CC(SDNode *N);
537 SDValue visitLOAD(SDNode *N);
538
539 SDValue replaceStoreChain(StoreSDNode *ST, SDValue BetterChain);
540 SDValue replaceStoreOfFPConstant(StoreSDNode *ST);
541 SDValue replaceStoreOfInsertLoad(StoreSDNode *ST);
542
543 bool refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(SDNode *N);
544
545 SDValue visitSTORE(SDNode *N);
546 SDValue visitATOMIC_STORE(SDNode *N);
547 SDValue visitLIFETIME_END(SDNode *N);
548 SDValue visitINSERT_VECTOR_ELT(SDNode *N);
549 SDValue visitEXTRACT_VECTOR_ELT(SDNode *N);
550 SDValue visitBUILD_VECTOR(SDNode *N);
551 SDValue visitCONCAT_VECTORS(SDNode *N);
552 SDValue visitVECTOR_INTERLEAVE(SDNode *N);
553 SDValue visitEXTRACT_SUBVECTOR(SDNode *N);
554 SDValue visitVECTOR_SHUFFLE(SDNode *N);
555 SDValue visitSCALAR_TO_VECTOR(SDNode *N);
556 SDValue visitINSERT_SUBVECTOR(SDNode *N);
557 SDValue visitVECTOR_COMPRESS(SDNode *N);
558 SDValue visitMLOAD(SDNode *N);
559 SDValue visitMSTORE(SDNode *N);
560 SDValue visitMGATHER(SDNode *N);
561 SDValue visitMSCATTER(SDNode *N);
562 SDValue visitMHISTOGRAM(SDNode *N);
563 SDValue visitPARTIAL_REDUCE_MLA(SDNode *N);
564 SDValue visitVPGATHER(SDNode *N);
565 SDValue visitVPSCATTER(SDNode *N);
566 SDValue visitVP_STRIDED_LOAD(SDNode *N);
567 SDValue visitVP_STRIDED_STORE(SDNode *N);
568 SDValue visitFP_TO_FP16(SDNode *N);
569 SDValue visitFP16_TO_FP(SDNode *N);
570 SDValue visitFP_TO_BF16(SDNode *N);
571 SDValue visitBF16_TO_FP(SDNode *N);
572 SDValue visitVECREDUCE(SDNode *N);
573 SDValue visitVPOp(SDNode *N);
574 SDValue visitGET_FPENV_MEM(SDNode *N);
575 SDValue visitSET_FPENV_MEM(SDNode *N);
576
577 template <class MatchContextClass>
578 SDValue visitFADDForFMACombine(SDNode *N);
579 template <class MatchContextClass>
580 SDValue visitFSUBForFMACombine(SDNode *N);
581 SDValue visitFMULForFMADistributiveCombine(SDNode *N);
582
583 SDValue XformToShuffleWithZero(SDNode *N);
584 bool reassociationCanBreakAddressingModePattern(unsigned Opc,
585 const SDLoc &DL,
586 SDNode *N,
587 SDValue N0,
588 SDValue N1);
589 SDValue reassociateOpsCommutative(unsigned Opc, const SDLoc &DL, SDValue N0,
590 SDValue N1, SDNodeFlags Flags);
591 SDValue reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
592 SDValue N1, SDNodeFlags Flags);
593 SDValue reassociateReduction(unsigned RedOpc, unsigned Opc, const SDLoc &DL,
594 EVT VT, SDValue N0, SDValue N1,
595 SDNodeFlags Flags = SDNodeFlags());
596
597 SDValue visitShiftByConstant(SDNode *N);
598
599 SDValue foldSelectOfConstants(SDNode *N);
600 SDValue foldVSelectOfConstants(SDNode *N);
601 SDValue foldBinOpIntoSelect(SDNode *BO);
602 bool SimplifySelectOps(SDNode *SELECT, SDValue LHS, SDValue RHS);
603 SDValue hoistLogicOpWithSameOpcodeHands(SDNode *N);
604 SDValue SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2);
605 SDValue SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
606 SDValue N2, SDValue N3, ISD::CondCode CC,
607 bool NotExtCompare = false);
608 SDValue convertSelectOfFPConstantsToLoadOffset(
609 const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
610 ISD::CondCode CC);
611 SDValue foldSignChangeInBitcast(SDNode *N);
612 SDValue foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0, SDValue N1,
613 SDValue N2, SDValue N3, ISD::CondCode CC);
614 SDValue foldSelectOfBinops(SDNode *N);
615 SDValue foldSextSetcc(SDNode *N);
616 SDValue foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
617 const SDLoc &DL);
618 SDValue foldSubToUSubSat(EVT DstVT, SDNode *N, const SDLoc &DL);
619 SDValue foldABSToABD(SDNode *N, const SDLoc &DL);
620 SDValue foldSelectToABD(SDValue LHS, SDValue RHS, SDValue True,
621 SDValue False, ISD::CondCode CC, const SDLoc &DL);
622 SDValue foldSelectToUMin(SDValue LHS, SDValue RHS, SDValue True,
623 SDValue False, ISD::CondCode CC, const SDLoc &DL);
624 SDValue unfoldMaskedMerge(SDNode *N);
625 SDValue unfoldExtremeBitClearingToShifts(SDNode *N);
626 SDValue SimplifySetCC(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
627 const SDLoc &DL, bool foldBooleans);
628 SDValue rebuildSetCC(SDValue N);
629
630 bool isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
631 SDValue &CC, bool MatchStrict = false) const;
632 bool isOneUseSetCC(SDValue N) const;
633
634 SDValue foldAddToAvg(SDNode *N, const SDLoc &DL);
635 SDValue foldSubToAvg(SDNode *N, const SDLoc &DL);
636
637 SDValue foldCTLZToCTLS(SDValue Src, const SDLoc &DL);
638
639 SDValue SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
640 unsigned HiOp);
641 SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
642 SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
643 const TargetLowering &TLI);
644 SDValue foldPartialReduceMLAMulOp(SDNode *N);
645 SDValue foldPartialReduceAdd(SDNode *N);
646
647 SDValue CombineExtLoad(SDNode *N);
648 SDValue CombineZExtLogicopShiftLoad(SDNode *N);
649 SDValue combineRepeatedFPDivisors(SDNode *N);
650 SDValue combineFMulOrFDivWithIntPow2(SDNode *N);
651 SDValue replaceShuffleOfInsert(ShuffleVectorSDNode *Shuf);
652 SDValue mergeInsertEltWithShuffle(SDNode *N, unsigned InsIndex);
653 SDValue combineInsertEltToShuffle(SDNode *N, unsigned InsIndex);
654 SDValue combineInsertEltToLoad(SDNode *N, unsigned InsIndex);
655 SDValue BuildSDIV(SDNode *N);
656 SDValue BuildSDIVPow2(SDNode *N);
657 SDValue BuildUDIV(SDNode *N);
658 SDValue BuildSREMPow2(SDNode *N);
659 SDValue buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N);
660 SDValue BuildLogBase2(SDValue V, const SDLoc &DL,
661 bool KnownNeverZero = false,
662 bool InexpensiveOnly = false,
663 std::optional<EVT> OutVT = std::nullopt);
664 SDValue BuildDivEstimate(SDValue N, SDValue Op, SDNodeFlags Flags);
665 SDValue buildRsqrtEstimate(SDValue Op);
666 SDValue buildSqrtEstimate(SDValue Op);
667 SDValue buildSqrtEstimateImpl(SDValue Op, bool Recip);
668 SDValue buildSqrtNROneConst(SDValue Arg, SDValue Est, unsigned Iterations,
669 bool Reciprocal);
670 SDValue buildSqrtNRTwoConst(SDValue Arg, SDValue Est, unsigned Iterations,
671 bool Reciprocal);
672 SDValue MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
673 bool DemandHighBits = true);
674 SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1);
675 SDValue MatchRotatePosNeg(SDValue Shifted, SDValue Pos, SDValue Neg,
676 SDValue InnerPos, SDValue InnerNeg, bool FromAdd,
677 bool HasPos, unsigned PosOpcode,
678 unsigned NegOpcode, const SDLoc &DL);
679 SDValue MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos, SDValue Neg,
680 SDValue InnerPos, SDValue InnerNeg, bool FromAdd,
681 bool HasPos, unsigned PosOpcode,
682 unsigned NegOpcode, const SDLoc &DL);
683 SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL,
684 bool FromAdd);
685 SDValue MatchLoadCombine(SDNode *N);
686 SDValue mergeTruncStores(StoreSDNode *N);
687 SDValue reduceLoadWidth(SDNode *N);
688 SDValue ReduceLoadOpStoreWidth(SDNode *N);
689 SDValue splitMergedValStore(StoreSDNode *ST);
690 SDValue TransformFPLoadStorePair(SDNode *N);
691 SDValue convertBuildVecZextToZext(SDNode *N);
692 SDValue convertBuildVecZextToBuildVecWithZeros(SDNode *N);
693 SDValue reduceBuildVecExtToExtBuildVec(SDNode *N);
694 SDValue reduceBuildVecTruncToBitCast(SDNode *N);
695 SDValue reduceBuildVecToShuffle(SDNode *N);
696 SDValue createBuildVecShuffle(const SDLoc &DL, SDNode *N,
697 ArrayRef<int> VectorMask, SDValue VecIn1,
698 SDValue VecIn2, unsigned LeftIdx,
699 bool DidSplitVec);
700 SDValue matchVSelectOpSizesWithSetCC(SDNode *Cast);
701
702 /// Walk up chain skipping non-aliasing memory nodes,
703 /// looking for aliasing nodes and adding them to the Aliases vector.
704 void GatherAllAliases(SDNode *N, SDValue OriginalChain,
705 SmallVectorImpl<SDValue> &Aliases);
706
707 /// Return true if there is any possibility that the two addresses overlap.
708 bool mayAlias(SDNode *Op0, SDNode *Op1) const;
709
710 /// Walk up chain skipping non-aliasing memory nodes, looking for a better
711 /// chain (aliasing node.)
712 SDValue FindBetterChain(SDNode *N, SDValue Chain);
713
714 /// Try to replace a store and any possibly adjacent stores on
715 /// consecutive chains with better chains. Return true only if St is
716 /// replaced.
717 ///
718 /// Notice that other chains may still be replaced even if the function
719 /// returns false.
720 bool findBetterNeighborChains(StoreSDNode *St);
721
722 // Helper for findBetterNeighborChains. Walk up store chain add additional
723 // chained stores that do not overlap and can be parallelized.
724 bool parallelizeChainedStores(StoreSDNode *St);
725
726 /// Holds a pointer to an LSBaseSDNode as well as information on where it
727 /// is located in a sequence of memory operations connected by a chain.
728 struct MemOpLink {
729 // Ptr to the mem node.
730 LSBaseSDNode *MemNode;
731
732 // Offset from the base ptr.
733 int64_t OffsetFromBase;
734
735 MemOpLink(LSBaseSDNode *N, int64_t Offset)
736 : MemNode(N), OffsetFromBase(Offset) {}
737 };
738
739 // Classify the origin of a stored value.
740 enum class StoreSource { Unknown, Constant, Extract, Load };
741 StoreSource getStoreSource(SDValue StoreVal) {
742 switch (StoreVal.getOpcode()) {
743 case ISD::Constant:
744 case ISD::ConstantFP:
745 return StoreSource::Constant;
746 case ISD::BUILD_VECTOR:
747 if (ISD::isBuildVectorOfConstantSDNodes(N: StoreVal.getNode()) ||
748 ISD::isBuildVectorOfConstantFPSDNodes(N: StoreVal.getNode()))
749 return StoreSource::Constant;
750 return StoreSource::Unknown;
751 case ISD::EXTRACT_VECTOR_ELT:
752 case ISD::EXTRACT_SUBVECTOR:
753 return StoreSource::Extract;
754 case ISD::LOAD:
755 return StoreSource::Load;
756 default:
757 return StoreSource::Unknown;
758 }
759 }
760
761 /// This is a helper function for visitMUL to check the profitability
762 /// of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
763 /// MulNode is the original multiply, AddNode is (add x, c1),
764 /// and ConstNode is c2.
765 bool isMulAddWithConstProfitable(SDNode *MulNode, SDValue AddNode,
766 SDValue ConstNode);
767
768 /// This is a helper function for visitAND and visitZERO_EXTEND. Returns
769 /// true if the (and (load x) c) pattern matches an extload. ExtVT returns
770 /// the type of the loaded value to be extended.
771 bool isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
772 EVT LoadResultTy, EVT &ExtVT);
773
774 /// Helper function to calculate whether the given Load/Store can have its
775 /// width reduced to ExtVT.
776 bool isLegalNarrowLdSt(LSBaseSDNode *LDSTN, ISD::LoadExtType ExtType,
777 EVT &MemVT, unsigned ShAmt = 0);
778
779 /// Used by BackwardsPropagateMask to find suitable loads.
780 bool SearchForAndLoads(SDNode *N, SmallVectorImpl<LoadSDNode*> &Loads,
781 SmallPtrSetImpl<SDNode*> &NodesWithConsts,
782 ConstantSDNode *Mask, SDNode *&NodeToMask);
783 /// Attempt to propagate a given AND node back to load leaves so that they
784 /// can be combined into narrow loads.
785 bool BackwardsPropagateMask(SDNode *N);
786
787 /// Helper function for mergeConsecutiveStores which merges the component
788 /// store chains.
789 SDValue getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
790 unsigned NumStores);
791
792 /// Helper function for mergeConsecutiveStores which checks if all the store
793 /// nodes have the same underlying object. We can still reuse the first
794 /// store's pointer info if all the stores are from the same object.
795 bool hasSameUnderlyingObj(ArrayRef<MemOpLink> StoreNodes);
796
797 /// This is a helper function for mergeConsecutiveStores. When the source
798 /// elements of the consecutive stores are all constants or all extracted
799 /// vector elements, try to merge them into one larger store introducing
800 /// bitcasts if necessary. \return True if a merged store was created.
801 bool mergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> &StoreNodes,
802 EVT MemVT, unsigned NumStores,
803 bool IsConstantSrc, bool UseVector,
804 bool UseTrunc);
805
806 /// This is a helper function for mergeConsecutiveStores. Stores that
807 /// potentially may be merged with St are placed in StoreNodes. On success,
808 /// returns a chain predecessor to all store candidates.
809 SDNode *getStoreMergeCandidates(StoreSDNode *St,
810 SmallVectorImpl<MemOpLink> &StoreNodes);
811
812 /// Helper function for mergeConsecutiveStores. Checks if candidate stores
813 /// have indirect dependency through their operands. RootNode is the
814 /// predecessor to all stores calculated by getStoreMergeCandidates and is
815 /// used to prune the dependency check. \return True if safe to merge.
816 bool checkMergeStoreCandidatesForDependencies(
817 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
818 SDNode *RootNode);
819
820 /// Helper function for tryStoreMergeOfLoads. Checks if the load/store
821 /// chain has a call in it. \return True if a call is found.
822 bool hasCallInLdStChain(StoreSDNode *St, LoadSDNode *Ld);
823
824 /// This is a helper function for mergeConsecutiveStores. Given a list of
825 /// store candidates, find the first N that are consecutive in memory.
826 /// Returns 0 if there are not at least 2 consecutive stores to try merging.
827 unsigned getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
828 int64_t ElementSizeBytes) const;
829
830 /// This is a helper function for mergeConsecutiveStores. It is used for
831 /// store chains that are composed entirely of constant values.
832 bool tryStoreMergeOfConstants(SmallVectorImpl<MemOpLink> &StoreNodes,
833 unsigned NumConsecutiveStores,
834 EVT MemVT, SDNode *Root, bool AllowVectors);
835
836 /// This is a helper function for mergeConsecutiveStores. It is used for
837 /// store chains that are composed entirely of extracted vector elements.
838 /// When extracting multiple vector elements, try to store them in one
839 /// vector store rather than a sequence of scalar stores.
840 bool tryStoreMergeOfExtracts(SmallVectorImpl<MemOpLink> &StoreNodes,
841 unsigned NumConsecutiveStores, EVT MemVT,
842 SDNode *Root);
843
844 /// This is a helper function for mergeConsecutiveStores. It is used for
845 /// store chains that are composed entirely of loaded values.
846 bool tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
847 unsigned NumConsecutiveStores, EVT MemVT,
848 SDNode *Root, bool AllowVectors,
849 bool IsNonTemporalStore, bool IsNonTemporalLoad);
850
851 /// Merge consecutive store operations into a wide store.
852 /// This optimization uses wide integers or vectors when possible.
853 /// \return true if stores were merged.
854 bool mergeConsecutiveStores(StoreSDNode *St);
855
856 /// Try to transform a truncation where C is a constant:
857 /// (trunc (and X, C)) -> (and (trunc X), (trunc C))
858 ///
859 /// \p N needs to be a truncation and its first operand an AND. Other
860 /// requirements are checked by the function (e.g. that trunc is
861 /// single-use) and if missed an empty SDValue is returned.
862 SDValue distributeTruncateThroughAnd(SDNode *N);
863
864 /// Helper function to determine whether the target supports operation
865 /// given by \p Opcode for type \p VT, that is, whether the operation
866 /// is legal or custom before legalizing operations, and whether is
867 /// legal (but not custom) after legalization.
868 bool hasOperation(unsigned Opcode, EVT VT) {
869 return TLI.isOperationLegalOrCustom(Op: Opcode, VT, LegalOnly: LegalOperations);
870 }
871
872 bool hasUMin(EVT VT) const {
873 auto LK = TLI.getTypeConversion(Context&: *DAG.getContext(), VT);
874 return (LK.first == TargetLoweringBase::TypeLegal ||
875 LK.first == TargetLoweringBase::TypePromoteInteger) &&
876 TLI.isOperationLegalOrCustom(Op: ISD::UMIN, VT: LK.second);
877 }
878
879 public:
880 /// Runs the dag combiner on all nodes in the work list
881 void Run(CombineLevel AtLevel);
882
883 SelectionDAG &getDAG() const { return DAG; }
884
885 /// Convenience wrapper around TargetLowering::getShiftAmountTy.
886 EVT getShiftAmountTy(EVT LHSTy) {
887 return TLI.getShiftAmountTy(LHSTy, DL: DAG.getDataLayout());
888 }
889
890 /// This method returns true if we are running before type legalization or
891 /// if the specified VT is legal.
892 bool isTypeLegal(const EVT &VT) {
893 if (!LegalTypes) return true;
894 return TLI.isTypeLegal(VT);
895 }
896
897 /// Convenience wrapper around TargetLowering::getSetCCResultType
898 EVT getSetCCResultType(EVT VT) const {
899 return TLI.getSetCCResultType(DL: DAG.getDataLayout(), Context&: *DAG.getContext(), VT);
900 }
901
902 void ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
903 SDValue OrigLoad, SDValue ExtLoad,
904 ISD::NodeType ExtType);
905 };
906
907/// This class is a DAGUpdateListener that removes any deleted
908/// nodes from the worklist.
909class WorklistRemover : public SelectionDAG::DAGUpdateListener {
910 DAGCombiner &DC;
911
912public:
913 explicit WorklistRemover(DAGCombiner &dc)
914 : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
915
916 void NodeDeleted(SDNode *N, SDNode *E) override {
917 DC.removeFromWorklist(N);
918 }
919};
920
921class WorklistInserter : public SelectionDAG::DAGUpdateListener {
922 DAGCombiner &DC;
923
924public:
925 explicit WorklistInserter(DAGCombiner &dc)
926 : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
927
928 // FIXME: Ideally we could add N to the worklist, but this causes exponential
929 // compile time costs in large DAGs, e.g. Halide.
930 void NodeInserted(SDNode *N) override { DC.ConsiderForPruning(N); }
931};
932
933} // end anonymous namespace
934
935//===----------------------------------------------------------------------===//
936// TargetLowering::DAGCombinerInfo implementation
937//===----------------------------------------------------------------------===//
938
939void TargetLowering::DAGCombinerInfo::AddToWorklist(SDNode *N) {
940 ((DAGCombiner*)DC)->AddToWorklist(N);
941}
942
943SDValue TargetLowering::DAGCombinerInfo::
944CombineTo(SDNode *N, ArrayRef<SDValue> To, bool AddTo) {
945 return ((DAGCombiner*)DC)->CombineTo(N, To: &To[0], NumTo: To.size(), AddTo);
946}
947
948SDValue TargetLowering::DAGCombinerInfo::
949CombineTo(SDNode *N, SDValue Res, bool AddTo) {
950 return ((DAGCombiner*)DC)->CombineTo(N, Res, AddTo);
951}
952
953SDValue TargetLowering::DAGCombinerInfo::
954CombineTo(SDNode *N, SDValue Res0, SDValue Res1, bool AddTo) {
955 return ((DAGCombiner*)DC)->CombineTo(N, Res0, Res1, AddTo);
956}
957
958bool TargetLowering::DAGCombinerInfo::
959recursivelyDeleteUnusedNodes(SDNode *N) {
960 return ((DAGCombiner*)DC)->recursivelyDeleteUnusedNodes(N);
961}
962
963void TargetLowering::DAGCombinerInfo::
964CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
965 return ((DAGCombiner*)DC)->CommitTargetLoweringOpt(TLO);
966}
967
968//===----------------------------------------------------------------------===//
969// Helper Functions
970//===----------------------------------------------------------------------===//
971
972void DAGCombiner::deleteAndRecombine(SDNode *N) {
973 removeFromWorklist(N);
974
975 // If the operands of this node are only used by the node, they will now be
976 // dead. Make sure to re-visit them and recursively delete dead nodes.
977 for (const SDValue &Op : N->ops())
978 // For an operand generating multiple values, one of the values may
979 // become dead allowing further simplification (e.g. split index
980 // arithmetic from an indexed load).
981 if (Op->hasOneUse() || Op->getNumValues() > 1)
982 AddToWorklist(N: Op.getNode());
983
984 DAG.DeleteNode(N);
985}
986
987// APInts must be the same size for most operations, this helper
988// function zero extends the shorter of the pair so that they match.
989// We provide an Offset so that we can create bitwidths that won't overflow.
990static void zeroExtendToMatch(APInt &LHS, APInt &RHS, unsigned Offset = 0) {
991 unsigned Bits = Offset + std::max(a: LHS.getBitWidth(), b: RHS.getBitWidth());
992 LHS = LHS.zext(width: Bits);
993 RHS = RHS.zext(width: Bits);
994}
995
996// Return true if this node is a setcc, or is a select_cc
997// that selects between the target values used for true and false, making it
998// equivalent to a setcc. Also, set the incoming LHS, RHS, and CC references to
999// the appropriate nodes based on the type of node we are checking. This
1000// simplifies life a bit for the callers.
1001bool DAGCombiner::isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
1002 SDValue &CC, bool MatchStrict) const {
1003 if (N.getOpcode() == ISD::SETCC) {
1004 LHS = N.getOperand(i: 0);
1005 RHS = N.getOperand(i: 1);
1006 CC = N.getOperand(i: 2);
1007 return true;
1008 }
1009
1010 if (MatchStrict &&
1011 (N.getOpcode() == ISD::STRICT_FSETCC ||
1012 N.getOpcode() == ISD::STRICT_FSETCCS)) {
1013 LHS = N.getOperand(i: 1);
1014 RHS = N.getOperand(i: 2);
1015 CC = N.getOperand(i: 3);
1016 return true;
1017 }
1018
1019 if (N.getOpcode() != ISD::SELECT_CC || !TLI.isConstTrueVal(N: N.getOperand(i: 2)) ||
1020 !TLI.isConstFalseVal(N: N.getOperand(i: 3)))
1021 return false;
1022
1023 if (TLI.getBooleanContents(Type: N.getValueType()) ==
1024 TargetLowering::UndefinedBooleanContent)
1025 return false;
1026
1027 LHS = N.getOperand(i: 0);
1028 RHS = N.getOperand(i: 1);
1029 CC = N.getOperand(i: 4);
1030 return true;
1031}
1032
1033/// Return true if this is a SetCC-equivalent operation with only one use.
1034/// If this is true, it allows the users to invert the operation for free when
1035/// it is profitable to do so.
1036bool DAGCombiner::isOneUseSetCC(SDValue N) const {
1037 SDValue N0, N1, N2;
1038 if (isSetCCEquivalent(N, LHS&: N0, RHS&: N1, CC&: N2) && N->hasOneUse())
1039 return true;
1040 return false;
1041}
1042
1043static bool isConstantSplatVectorMaskForType(SDNode *N, EVT ScalarTy) {
1044 if (!ScalarTy.isSimple())
1045 return false;
1046
1047 uint64_t MaskForTy = 0ULL;
1048 switch (ScalarTy.getSimpleVT().SimpleTy) {
1049 case MVT::i8:
1050 MaskForTy = 0xFFULL;
1051 break;
1052 case MVT::i16:
1053 MaskForTy = 0xFFFFULL;
1054 break;
1055 case MVT::i32:
1056 MaskForTy = 0xFFFFFFFFULL;
1057 break;
1058 default:
1059 return false;
1060 break;
1061 }
1062
1063 APInt Val;
1064 if (ISD::isConstantSplatVector(N, SplatValue&: Val))
1065 return Val.getLimitedValue() == MaskForTy;
1066
1067 return false;
1068}
1069
1070// Determines if it is a constant integer or a splat/build vector of constant
1071// integers (and undefs).
1072// Do not permit build vector implicit truncation unless AllowTruncation is set.
1073static bool isConstantOrConstantVector(SDValue N, bool NoOpaques = false,
1074 bool AllowTruncation = false) {
1075 if (ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Val&: N))
1076 return !(Const->isOpaque() && NoOpaques);
1077 if (N.getOpcode() != ISD::BUILD_VECTOR && N.getOpcode() != ISD::SPLAT_VECTOR)
1078 return false;
1079 unsigned BitWidth = N.getScalarValueSizeInBits();
1080 for (const SDValue &Op : N->op_values()) {
1081 if (Op.isUndef())
1082 continue;
1083 ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Val: Op);
1084 if (!Const || (Const->isOpaque() && NoOpaques))
1085 return false;
1086 // When AllowTruncation is true, allow constants that have been promoted
1087 // during type legalization as long as the value fits in the target type.
1088 if ((AllowTruncation &&
1089 Const->getAPIntValue().getActiveBits() > BitWidth) ||
1090 (!AllowTruncation && Const->getAPIntValue().getBitWidth() != BitWidth))
1091 return false;
1092 }
1093 return true;
1094}
1095
1096// Determines if a BUILD_VECTOR is composed of all-constants possibly mixed with
1097// undef's.
1098static bool isAnyConstantBuildVector(SDValue V, bool NoOpaques = false) {
1099 if (V.getOpcode() != ISD::BUILD_VECTOR)
1100 return false;
1101 return isConstantOrConstantVector(N: V, NoOpaques) ||
1102 ISD::isBuildVectorOfConstantFPSDNodes(N: V.getNode());
1103}
1104
1105// Determine if this an indexed load with an opaque target constant index.
1106static bool canSplitIdx(LoadSDNode *LD) {
1107 return MaySplitLoadIndex &&
1108 (LD->getOperand(Num: 2).getOpcode() != ISD::TargetConstant ||
1109 !cast<ConstantSDNode>(Val: LD->getOperand(Num: 2))->isOpaque());
1110}
1111
1112bool DAGCombiner::reassociationCanBreakAddressingModePattern(unsigned Opc,
1113 const SDLoc &DL,
1114 SDNode *N,
1115 SDValue N0,
1116 SDValue N1) {
1117 // Currently this only tries to ensure we don't undo the GEP splits done by
1118 // CodeGenPrepare when shouldConsiderGEPOffsetSplit is true. To ensure this,
1119 // we check if the following transformation would be problematic:
1120 // (load/store (add, (add, x, offset1), offset2)) ->
1121 // (load/store (add, x, offset1+offset2)).
1122
1123 // (load/store (add, (add, x, y), offset2)) ->
1124 // (load/store (add, (add, x, offset2), y)).
1125
1126 if (!N0.isAnyAdd())
1127 return false;
1128
1129 // Check for vscale addressing modes.
1130 // (load/store (add/sub (add x, y), vscale))
1131 // (load/store (add/sub (add x, y), (lsl vscale, C)))
1132 // (load/store (add/sub (add x, y), (mul vscale, C)))
1133 if ((N1.getOpcode() == ISD::VSCALE ||
1134 ((N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::MUL) &&
1135 N1.getOperand(i: 0).getOpcode() == ISD::VSCALE &&
1136 isa<ConstantSDNode>(Val: N1.getOperand(i: 1)))) &&
1137 N1.getValueType().getFixedSizeInBits() <= 64) {
1138 int64_t ScalableOffset = N1.getOpcode() == ISD::VSCALE
1139 ? N1.getConstantOperandVal(i: 0)
1140 : (N1.getOperand(i: 0).getConstantOperandVal(i: 0) *
1141 (N1.getOpcode() == ISD::SHL
1142 ? (1LL << N1.getConstantOperandVal(i: 1))
1143 : N1.getConstantOperandVal(i: 1)));
1144 if (Opc == ISD::SUB)
1145 ScalableOffset = -ScalableOffset;
1146 if (all_of(Range: N->users(), P: [&](SDNode *Node) {
1147 if (auto *LoadStore = dyn_cast<MemSDNode>(Val: Node);
1148 LoadStore && LoadStore->getBasePtr().getNode() == N) {
1149 TargetLoweringBase::AddrMode AM;
1150 AM.HasBaseReg = true;
1151 AM.ScalableOffset = ScalableOffset;
1152 EVT VT = LoadStore->getMemoryVT();
1153 unsigned AS = LoadStore->getAddressSpace();
1154 Type *AccessTy = VT.getTypeForEVT(Context&: *DAG.getContext());
1155 return TLI.isLegalAddressingMode(DL: DAG.getDataLayout(), AM, Ty: AccessTy,
1156 AddrSpace: AS);
1157 }
1158 return false;
1159 }))
1160 return true;
1161 }
1162
1163 if (Opc != ISD::ADD && Opc != ISD::PTRADD)
1164 return false;
1165
1166 auto *C2 = dyn_cast<ConstantSDNode>(Val&: N1);
1167 if (!C2)
1168 return false;
1169
1170 const APInt &C2APIntVal = C2->getAPIntValue();
1171 if (C2APIntVal.getSignificantBits() > 64)
1172 return false;
1173
1174 if (auto *C1 = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1))) {
1175 if (N0.hasOneUse())
1176 return false;
1177
1178 const APInt &C1APIntVal = C1->getAPIntValue();
1179 const APInt CombinedValueIntVal = C1APIntVal + C2APIntVal;
1180 if (CombinedValueIntVal.getSignificantBits() > 64)
1181 return false;
1182 const int64_t CombinedValue = CombinedValueIntVal.getSExtValue();
1183
1184 for (SDNode *Node : N->users()) {
1185 if (auto *LoadStore = dyn_cast<MemSDNode>(Val: Node)) {
1186 // Is x[offset2] already not a legal addressing mode? If so then
1187 // reassociating the constants breaks nothing (we test offset2 because
1188 // that's the one we hope to fold into the load or store).
1189 TargetLoweringBase::AddrMode AM;
1190 AM.HasBaseReg = true;
1191 AM.BaseOffs = C2APIntVal.getSExtValue();
1192 EVT VT = LoadStore->getMemoryVT();
1193 unsigned AS = LoadStore->getAddressSpace();
1194 Type *AccessTy = VT.getTypeForEVT(Context&: *DAG.getContext());
1195 if (!TLI.isLegalAddressingMode(DL: DAG.getDataLayout(), AM, Ty: AccessTy, AddrSpace: AS))
1196 continue;
1197
1198 // Would x[offset1+offset2] still be a legal addressing mode?
1199 AM.BaseOffs = CombinedValue;
1200 if (!TLI.isLegalAddressingMode(DL: DAG.getDataLayout(), AM, Ty: AccessTy, AddrSpace: AS))
1201 return true;
1202 }
1203 }
1204 } else {
1205 if (auto *GA = dyn_cast<GlobalAddressSDNode>(Val: N0.getOperand(i: 1)))
1206 if (GA->getOpcode() == ISD::GlobalAddress && TLI.isOffsetFoldingLegal(GA))
1207 return false;
1208
1209 for (SDNode *Node : N->users()) {
1210 auto *LoadStore = dyn_cast<MemSDNode>(Val: Node);
1211 if (!LoadStore || !LoadStore->hasUniqueMemOperand())
1212 return false;
1213
1214 // Is x[offset2] a legal addressing mode? If so then
1215 // reassociating the constants breaks address pattern
1216 TargetLoweringBase::AddrMode AM;
1217 AM.HasBaseReg = true;
1218 AM.BaseOffs = C2APIntVal.getSExtValue();
1219 EVT VT = LoadStore->getMemoryVT();
1220 unsigned AS = LoadStore->getAddressSpace();
1221 Type *AccessTy = VT.getTypeForEVT(Context&: *DAG.getContext());
1222 if (!TLI.isLegalAddressingMode(DL: DAG.getDataLayout(), AM, Ty: AccessTy, AddrSpace: AS))
1223 return false;
1224 }
1225 return true;
1226 }
1227
1228 return false;
1229}
1230
1231/// Helper for DAGCombiner::reassociateOps. Try to reassociate (Opc N0, N1) if
1232/// \p N0 is the same kind of operation as \p Opc.
1233SDValue DAGCombiner::reassociateOpsCommutative(unsigned Opc, const SDLoc &DL,
1234 SDValue N0, SDValue N1,
1235 SDNodeFlags Flags) {
1236 EVT VT = N0.getValueType();
1237
1238 if (N0.getOpcode() != Opc)
1239 return SDValue();
1240
1241 SDValue N00 = N0.getOperand(i: 0);
1242 SDValue N01 = N0.getOperand(i: 1);
1243
1244 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N01)) {
1245 SDNodeFlags NewFlags;
1246 if (N0.getOpcode() == ISD::ADD && N0->getFlags().hasNoUnsignedWrap() &&
1247 Flags.hasNoUnsignedWrap())
1248 NewFlags |= SDNodeFlags::NoUnsignedWrap;
1249
1250 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N1)) {
1251 // Reassociate: (op (op x, c1), c2) -> (op x, (op c1, c2))
1252 if (SDValue OpNode = DAG.FoldConstantArithmetic(Opcode: Opc, DL, VT, Ops: {N01, N1})) {
1253 NewFlags.setDisjoint(Flags.hasDisjoint() &&
1254 N0->getFlags().hasDisjoint());
1255 return DAG.getNode(Opcode: Opc, DL, VT, N1: N00, N2: OpNode, Flags: NewFlags);
1256 }
1257 return SDValue();
1258 }
1259 if (TLI.isReassocProfitable(DAG, N0, N1)) {
1260 // Reassociate: (op (op x, c1), y) -> (op (op x, y), c1)
1261 // iff (op x, c1) has one use
1262 SDValue OpNode = DAG.getNode(Opcode: Opc, DL: SDLoc(N0), VT, N1: N00, N2: N1, Flags: NewFlags);
1263 return DAG.getNode(Opcode: Opc, DL, VT, N1: OpNode, N2: N01, Flags: NewFlags);
1264 }
1265 }
1266
1267 // Check for repeated operand logic simplifications.
1268 if (Opc == ISD::AND || Opc == ISD::OR) {
1269 // (N00 & N01) & N00 --> N00 & N01
1270 // (N00 & N01) & N01 --> N00 & N01
1271 // (N00 | N01) | N00 --> N00 | N01
1272 // (N00 | N01) | N01 --> N00 | N01
1273 if (N1 == N00 || N1 == N01)
1274 return N0;
1275 }
1276 if (Opc == ISD::XOR) {
1277 // (N00 ^ N01) ^ N00 --> N01
1278 if (N1 == N00)
1279 return N01;
1280 // (N00 ^ N01) ^ N01 --> N00
1281 if (N1 == N01)
1282 return N00;
1283 }
1284
1285 if (TLI.isReassocProfitable(DAG, N0, N1)) {
1286 if (N1 != N01) {
1287 // Reassociate if (op N00, N1) already exist
1288 if (SDNode *NE = DAG.getNodeIfExists(Opcode: Opc, VTList: DAG.getVTList(VT), Ops: {N00, N1})) {
1289 // if Op (Op N00, N1), N01 already exist
1290 // we need to stop reassciate to avoid dead loop
1291 if (!DAG.doesNodeExist(Opcode: Opc, VTList: DAG.getVTList(VT), Ops: {SDValue(NE, 0), N01}))
1292 return DAG.getNode(Opcode: Opc, DL, VT, N1: SDValue(NE, 0), N2: N01);
1293 }
1294 }
1295
1296 if (N1 != N00) {
1297 // Reassociate if (op N01, N1) already exist
1298 if (SDNode *NE = DAG.getNodeIfExists(Opcode: Opc, VTList: DAG.getVTList(VT), Ops: {N01, N1})) {
1299 // if Op (Op N01, N1), N00 already exist
1300 // we need to stop reassciate to avoid dead loop
1301 if (!DAG.doesNodeExist(Opcode: Opc, VTList: DAG.getVTList(VT), Ops: {SDValue(NE, 0), N00}))
1302 return DAG.getNode(Opcode: Opc, DL, VT, N1: SDValue(NE, 0), N2: N00);
1303 }
1304 }
1305
1306 // Reassociate the operands from (OR/AND (OR/AND(N00, N001)), N1) to (OR/AND
1307 // (OR/AND(N00, N1)), N01) when N00 and N1 are comparisons with the same
1308 // predicate or to (OR/AND (OR/AND(N1, N01)), N00) when N01 and N1 are
1309 // comparisons with the same predicate. This enables optimizations as the
1310 // following one:
1311 // CMP(A,C)||CMP(B,C) => CMP(MIN/MAX(A,B), C)
1312 // CMP(A,C)&&CMP(B,C) => CMP(MIN/MAX(A,B), C)
1313 if (Opc == ISD::AND || Opc == ISD::OR) {
1314 if (N1->getOpcode() == ISD::SETCC && N00->getOpcode() == ISD::SETCC &&
1315 N01->getOpcode() == ISD::SETCC) {
1316 ISD::CondCode CC1 = cast<CondCodeSDNode>(Val: N1.getOperand(i: 2))->get();
1317 ISD::CondCode CC00 = cast<CondCodeSDNode>(Val: N00.getOperand(i: 2))->get();
1318 ISD::CondCode CC01 = cast<CondCodeSDNode>(Val: N01.getOperand(i: 2))->get();
1319 if (CC1 == CC00 && CC1 != CC01) {
1320 SDValue OpNode = DAG.getNode(Opcode: Opc, DL: SDLoc(N0), VT, N1: N00, N2: N1, Flags);
1321 return DAG.getNode(Opcode: Opc, DL, VT, N1: OpNode, N2: N01, Flags);
1322 }
1323 if (CC1 == CC01 && CC1 != CC00) {
1324 SDValue OpNode = DAG.getNode(Opcode: Opc, DL: SDLoc(N0), VT, N1: N01, N2: N1, Flags);
1325 return DAG.getNode(Opcode: Opc, DL, VT, N1: OpNode, N2: N00, Flags);
1326 }
1327 }
1328 }
1329 }
1330
1331 return SDValue();
1332}
1333
1334/// Try to reassociate commutative (Opc N0, N1) if either \p N0 or \p N1 is the
1335/// same kind of operation as \p Opc.
1336SDValue DAGCombiner::reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
1337 SDValue N1, SDNodeFlags Flags) {
1338 assert(TLI.isCommutativeBinOp(Opc) && "Operation not commutative.");
1339
1340 // Floating-point reassociation is not allowed without loose FP math.
1341 if (N0.getValueType().isFloatingPoint() ||
1342 N1.getValueType().isFloatingPoint())
1343 if (!Flags.hasAllowReassociation() || !Flags.hasNoSignedZeros())
1344 return SDValue();
1345
1346 if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N0, N1, Flags))
1347 return Combined;
1348 if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N0: N1, N1: N0, Flags))
1349 return Combined;
1350 return SDValue();
1351}
1352
1353// Try to fold Opc(vecreduce(x), vecreduce(y)) -> vecreduce(Opc(x, y))
1354// Note that we only expect Flags to be passed from FP operations. For integer
1355// operations they need to be dropped.
1356SDValue DAGCombiner::reassociateReduction(unsigned RedOpc, unsigned Opc,
1357 const SDLoc &DL, EVT VT, SDValue N0,
1358 SDValue N1, SDNodeFlags Flags) {
1359 if (N0.getOpcode() == RedOpc && N1.getOpcode() == RedOpc &&
1360 N0.getOperand(i: 0).getValueType() == N1.getOperand(i: 0).getValueType() &&
1361 N0->hasOneUse() && N1->hasOneUse() &&
1362 TLI.isOperationLegalOrCustom(Op: Opc, VT: N0.getOperand(i: 0).getValueType()) &&
1363 TLI.shouldReassociateReduction(RedOpc, VT: N0.getOperand(i: 0).getValueType())) {
1364 SelectionDAG::FlagInserter FlagsInserter(DAG, Flags);
1365 return DAG.getNode(Opcode: RedOpc, DL, VT,
1366 Operand: DAG.getNode(Opcode: Opc, DL, VT: N0.getOperand(i: 0).getValueType(),
1367 N1: N0.getOperand(i: 0), N2: N1.getOperand(i: 0)));
1368 }
1369
1370 // Reassociate op(op(vecreduce(a), b), op(vecreduce(c), d)) into
1371 // op(vecreduce(op(a, c)), op(b, d)), to combine the reductions into a
1372 // single node.
1373 SDValue A, B, C, D, RedA, RedB;
1374 if (sd_match(N: N0, P: m_OneUse(P: m_c_BinOp(
1375 Opc,
1376 L: m_AllOf(preds: m_OneUse(P: m_UnaryOp(Opc: RedOpc, Op: m_Value(N&: A))),
1377 preds: m_Value(N&: RedA)),
1378 R: m_Value(N&: B)))) &&
1379 sd_match(N: N1, P: m_OneUse(P: m_c_BinOp(
1380 Opc,
1381 L: m_AllOf(preds: m_OneUse(P: m_UnaryOp(Opc: RedOpc, Op: m_Value(N&: C))),
1382 preds: m_Value(N&: RedB)),
1383 R: m_Value(N&: D)))) &&
1384 !sd_match(N: B, P: m_UnaryOp(Opc: RedOpc, Op: m_Value())) &&
1385 !sd_match(N: D, P: m_UnaryOp(Opc: RedOpc, Op: m_Value())) &&
1386 A.getValueType() == C.getValueType() &&
1387 hasOperation(Opcode: Opc, VT: A.getValueType()) &&
1388 TLI.shouldReassociateReduction(RedOpc, VT)) {
1389 if ((Opc == ISD::FADD || Opc == ISD::FMUL) &&
1390 (!N0->getFlags().hasAllowReassociation() ||
1391 !N1->getFlags().hasAllowReassociation() ||
1392 !RedA->getFlags().hasAllowReassociation() ||
1393 !RedB->getFlags().hasAllowReassociation()))
1394 return SDValue();
1395 SelectionDAG::FlagInserter FlagsInserter(
1396 DAG, Flags & N0->getFlags() & N1->getFlags() & RedA->getFlags() &
1397 RedB->getFlags());
1398 SDValue Op = DAG.getNode(Opcode: Opc, DL, VT: A.getValueType(), N1: A, N2: C);
1399 SDValue Red = DAG.getNode(Opcode: RedOpc, DL, VT, Operand: Op);
1400 SDValue Op2 = DAG.getNode(Opcode: Opc, DL, VT, N1: B, N2: D);
1401 return DAG.getNode(Opcode: Opc, DL, VT, N1: Red, N2: Op2);
1402 }
1403 return SDValue();
1404}
1405
1406SDValue DAGCombiner::CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
1407 bool AddTo) {
1408 assert(N->getNumValues() == NumTo && "Broken CombineTo call!");
1409 ++NodesCombined;
1410 LLVM_DEBUG(dbgs() << "\nReplacing.1 "; N->dump(&DAG); dbgs() << "\nWith: ";
1411 To[0].dump(&DAG);
1412 dbgs() << " and " << NumTo - 1 << " other values\n");
1413 for (unsigned i = 0, e = NumTo; i != e; ++i)
1414 assert((!To[i].getNode() ||
1415 N->getValueType(i) == To[i].getValueType()) &&
1416 "Cannot combine value to value of different type!");
1417
1418 WorklistRemover DeadNodes(*this);
1419 DAG.ReplaceAllUsesWith(From: N, To);
1420 if (AddTo) {
1421 // Push the new nodes and any users onto the worklist
1422 for (unsigned i = 0, e = NumTo; i != e; ++i) {
1423 if (To[i].getNode())
1424 AddToWorklistWithUsers(N: To[i].getNode());
1425 }
1426 }
1427
1428 // Finally, if the node is now dead, remove it from the graph. The node
1429 // may not be dead if the replacement process recursively simplified to
1430 // something else needing this node.
1431 if (N->use_empty())
1432 deleteAndRecombine(N);
1433 return SDValue(N, 0);
1434}
1435
1436void DAGCombiner::
1437CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
1438 // Replace the old value with the new one.
1439 ++NodesCombined;
1440 LLVM_DEBUG(dbgs() << "\nReplacing.2 "; TLO.Old.dump(&DAG);
1441 dbgs() << "\nWith: "; TLO.New.dump(&DAG); dbgs() << '\n');
1442
1443 // Replace all uses.
1444 DAG.ReplaceAllUsesOfValueWith(From: TLO.Old, To: TLO.New);
1445
1446 // Push the new node and any (possibly new) users onto the worklist.
1447 AddToWorklistWithUsers(N: TLO.New.getNode());
1448
1449 // Finally, if the node is now dead, remove it from the graph.
1450 recursivelyDeleteUnusedNodes(N: TLO.Old.getNode());
1451}
1452
1453/// Check the specified integer node value to see if it can be simplified or if
1454/// things it uses can be simplified by bit propagation. If so, return true.
1455bool DAGCombiner::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
1456 const APInt &DemandedElts,
1457 bool AssumeSingleUse) {
1458 TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1459 KnownBits Known;
1460 if (!TLI.SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO, Depth: 0,
1461 AssumeSingleUse))
1462 return false;
1463
1464 // Revisit the node.
1465 AddToWorklist(N: Op.getNode());
1466
1467 CommitTargetLoweringOpt(TLO);
1468 return true;
1469}
1470
1471/// Check the specified vector node value to see if it can be simplified or
1472/// if things it uses can be simplified as it only uses some of the elements.
1473/// If so, return true.
1474bool DAGCombiner::SimplifyDemandedVectorElts(SDValue Op,
1475 const APInt &DemandedElts,
1476 bool AssumeSingleUse) {
1477 TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1478 APInt KnownUndef, KnownZero;
1479 if (!TLI.SimplifyDemandedVectorElts(Op, DemandedEltMask: DemandedElts, KnownUndef, KnownZero,
1480 TLO, Depth: 0, AssumeSingleUse))
1481 return false;
1482
1483 // Revisit the node.
1484 AddToWorklist(N: Op.getNode());
1485
1486 CommitTargetLoweringOpt(TLO);
1487 return true;
1488}
1489
1490void DAGCombiner::ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad) {
1491 SDLoc DL(Load);
1492 EVT VT = Load->getValueType(ResNo: 0);
1493 SDValue Trunc = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: SDValue(ExtLoad, 0));
1494
1495 LLVM_DEBUG(dbgs() << "\nReplacing.9 "; Load->dump(&DAG); dbgs() << "\nWith: ";
1496 Trunc.dump(&DAG); dbgs() << '\n');
1497
1498 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Load, 0), To: Trunc);
1499 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Load, 1), To: SDValue(ExtLoad, 1));
1500
1501 AddToWorklist(N: Trunc.getNode());
1502 recursivelyDeleteUnusedNodes(N: Load);
1503}
1504
1505SDValue DAGCombiner::PromoteOperand(SDValue Op, EVT PVT, bool &Replace) {
1506 Replace = false;
1507 SDLoc DL(Op);
1508 if (ISD::isUNINDEXEDLoad(N: Op.getNode())) {
1509 LoadSDNode *LD = cast<LoadSDNode>(Val&: Op);
1510 EVT MemVT = LD->getMemoryVT();
1511 ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(N: LD) ? ISD::EXTLOAD
1512 : LD->getExtensionType();
1513 Replace = true;
1514 return DAG.getExtLoad(ExtType, dl: DL, VT: PVT,
1515 Chain: LD->getChain(), Ptr: LD->getBasePtr(),
1516 MemVT, MMO: LD->getMemOperand());
1517 }
1518
1519 unsigned Opc = Op.getOpcode();
1520 switch (Opc) {
1521 default: break;
1522 case ISD::AssertSext:
1523 if (SDValue Op0 = SExtPromoteOperand(Op: Op.getOperand(i: 0), PVT))
1524 return DAG.getNode(Opcode: ISD::AssertSext, DL, VT: PVT, N1: Op0, N2: Op.getOperand(i: 1));
1525 break;
1526 case ISD::AssertZext:
1527 if (SDValue Op0 = ZExtPromoteOperand(Op: Op.getOperand(i: 0), PVT))
1528 return DAG.getNode(Opcode: ISD::AssertZext, DL, VT: PVT, N1: Op0, N2: Op.getOperand(i: 1));
1529 break;
1530 case ISD::Constant: {
1531 unsigned ExtOpc =
1532 Op.getValueType().isByteSized() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
1533 return DAG.getNode(Opcode: ExtOpc, DL, VT: PVT, Operand: Op);
1534 }
1535 }
1536
1537 if (!TLI.isOperationLegal(Op: ISD::ANY_EXTEND, VT: PVT))
1538 return SDValue();
1539 return DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT: PVT, Operand: Op);
1540}
1541
1542SDValue DAGCombiner::SExtPromoteOperand(SDValue Op, EVT PVT) {
1543 if (!TLI.isOperationLegal(Op: ISD::SIGN_EXTEND_INREG, VT: PVT))
1544 return SDValue();
1545 EVT OldVT = Op.getValueType();
1546 SDLoc DL(Op);
1547 bool Replace = false;
1548 SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1549 if (!NewOp.getNode())
1550 return SDValue();
1551 AddToWorklist(N: NewOp.getNode());
1552
1553 if (Replace)
1554 ReplaceLoadWithPromotedLoad(Load: Op.getNode(), ExtLoad: NewOp.getNode());
1555 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL, VT: NewOp.getValueType(), N1: NewOp,
1556 N2: DAG.getValueType(OldVT));
1557}
1558
1559SDValue DAGCombiner::ZExtPromoteOperand(SDValue Op, EVT PVT) {
1560 EVT OldVT = Op.getValueType();
1561 SDLoc DL(Op);
1562 bool Replace = false;
1563 SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1564 if (!NewOp.getNode())
1565 return SDValue();
1566 AddToWorklist(N: NewOp.getNode());
1567
1568 if (Replace)
1569 ReplaceLoadWithPromotedLoad(Load: Op.getNode(), ExtLoad: NewOp.getNode());
1570 return DAG.getZeroExtendInReg(Op: NewOp, DL, VT: OldVT);
1571}
1572
1573/// Promote the specified integer binary operation if the target indicates it is
1574/// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1575/// i32 since i16 instructions are longer.
1576SDValue DAGCombiner::PromoteIntBinOp(SDValue Op) {
1577 if (!LegalOperations)
1578 return SDValue();
1579
1580 EVT VT = Op.getValueType();
1581 if (VT.isVector() || !VT.isInteger())
1582 return SDValue();
1583
1584 // If operation type is 'undesirable', e.g. i16 on x86, consider
1585 // promoting it.
1586 unsigned Opc = Op.getOpcode();
1587 if (TLI.isTypeDesirableForOp(Opc, VT))
1588 return SDValue();
1589
1590 EVT PVT = VT;
1591 // Consult target whether it is a good idea to promote this operation and
1592 // what's the right type to promote it to.
1593 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1594 assert(PVT != VT && "Don't know what type to promote to!");
1595
1596 LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1597
1598 bool Replace0 = false;
1599 SDValue N0 = Op.getOperand(i: 0);
1600 SDValue NN0 = PromoteOperand(Op: N0, PVT, Replace&: Replace0);
1601
1602 bool Replace1 = false;
1603 SDValue N1 = Op.getOperand(i: 1);
1604 SDValue NN1 = PromoteOperand(Op: N1, PVT, Replace&: Replace1);
1605 SDLoc DL(Op);
1606
1607 SDValue RV =
1608 DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: DAG.getNode(Opcode: Opc, DL, VT: PVT, N1: NN0, N2: NN1));
1609
1610 // We are always replacing N0/N1's use in N and only need additional
1611 // replacements if there are additional uses.
1612 // Note: We are checking uses of the *nodes* (SDNode) rather than values
1613 // (SDValue) here because the node may reference multiple values
1614 // (for example, the chain value of a load node).
1615 Replace0 &= !N0->hasOneUse();
1616 Replace1 &= (N0 != N1) && !N1->hasOneUse();
1617
1618 // Combine Op here so it is preserved past replacements.
1619 CombineTo(N: Op.getNode(), Res: RV);
1620
1621 // If operands have a use ordering, make sure we deal with
1622 // predecessor first.
1623 if (Replace0 && Replace1 && N0->isPredecessorOf(N: N1.getNode())) {
1624 std::swap(a&: N0, b&: N1);
1625 std::swap(a&: NN0, b&: NN1);
1626 }
1627
1628 if (Replace0) {
1629 AddToWorklist(N: NN0.getNode());
1630 ReplaceLoadWithPromotedLoad(Load: N0.getNode(), ExtLoad: NN0.getNode());
1631 }
1632 if (Replace1) {
1633 AddToWorklist(N: NN1.getNode());
1634 ReplaceLoadWithPromotedLoad(Load: N1.getNode(), ExtLoad: NN1.getNode());
1635 }
1636 return Op;
1637 }
1638 return SDValue();
1639}
1640
1641/// Promote the specified integer shift operation if the target indicates it is
1642/// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1643/// i32 since i16 instructions are longer.
1644SDValue DAGCombiner::PromoteIntShiftOp(SDValue Op) {
1645 if (!LegalOperations)
1646 return SDValue();
1647
1648 EVT VT = Op.getValueType();
1649 if (VT.isVector() || !VT.isInteger())
1650 return SDValue();
1651
1652 // If operation type is 'undesirable', e.g. i16 on x86, consider
1653 // promoting it.
1654 unsigned Opc = Op.getOpcode();
1655 if (TLI.isTypeDesirableForOp(Opc, VT))
1656 return SDValue();
1657
1658 EVT PVT = VT;
1659 // Consult target whether it is a good idea to promote this operation and
1660 // what's the right type to promote it to.
1661 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1662 assert(PVT != VT && "Don't know what type to promote to!");
1663
1664 LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1665
1666 SDNodeFlags TruncFlags;
1667 bool Replace = false;
1668 SDValue N0 = Op.getOperand(i: 0);
1669 if (Opc == ISD::SRA) {
1670 N0 = SExtPromoteOperand(Op: N0, PVT);
1671 } else if (Opc == ISD::SRL) {
1672 N0 = ZExtPromoteOperand(Op: N0, PVT);
1673 } else {
1674 if (Op->getFlags().hasNoUnsignedWrap()) {
1675 N0 = ZExtPromoteOperand(Op: N0, PVT);
1676 TruncFlags = SDNodeFlags::NoUnsignedWrap;
1677 } else if (Op->getFlags().hasNoSignedWrap()) {
1678 N0 = SExtPromoteOperand(Op: N0, PVT);
1679 TruncFlags = SDNodeFlags::NoSignedWrap;
1680 } else {
1681 N0 = PromoteOperand(Op: N0, PVT, Replace);
1682 }
1683 }
1684
1685 if (!N0.getNode())
1686 return SDValue();
1687
1688 SDLoc DL(Op);
1689 SDValue N1 = Op.getOperand(i: 1);
1690 SDValue RV = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT,
1691 Operand: DAG.getNode(Opcode: Opc, DL, VT: PVT, N1: N0, N2: N1), Flags: TruncFlags);
1692
1693 if (Replace)
1694 ReplaceLoadWithPromotedLoad(Load: Op.getOperand(i: 0).getNode(), ExtLoad: N0.getNode());
1695
1696 // Deal with Op being deleted.
1697 if (Op && Op.getOpcode() != ISD::DELETED_NODE)
1698 return RV;
1699 }
1700 return SDValue();
1701}
1702
1703SDValue DAGCombiner::PromoteExtend(SDValue Op) {
1704 if (!LegalOperations)
1705 return SDValue();
1706
1707 EVT VT = Op.getValueType();
1708 if (VT.isVector() || !VT.isInteger())
1709 return SDValue();
1710
1711 // If operation type is 'undesirable', e.g. i16 on x86, consider
1712 // promoting it.
1713 unsigned Opc = Op.getOpcode();
1714 if (TLI.isTypeDesirableForOp(Opc, VT))
1715 return SDValue();
1716
1717 EVT PVT = VT;
1718 // Consult target whether it is a good idea to promote this operation and
1719 // what's the right type to promote it to.
1720 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1721 assert(PVT != VT && "Don't know what type to promote to!");
1722 // fold (aext (aext x)) -> (aext x)
1723 // fold (aext (zext x)) -> (zext x)
1724 // fold (aext (sext x)) -> (sext x)
1725 LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1726 return DAG.getNode(Opcode: Op.getOpcode(), DL: SDLoc(Op), VT, Operand: Op.getOperand(i: 0));
1727 }
1728 return SDValue();
1729}
1730
1731bool DAGCombiner::PromoteLoad(SDValue Op) {
1732 if (!LegalOperations)
1733 return false;
1734
1735 if (!ISD::isUNINDEXEDLoad(N: Op.getNode()))
1736 return false;
1737
1738 EVT VT = Op.getValueType();
1739 if (VT.isVector() || !VT.isInteger())
1740 return false;
1741
1742 // If operation type is 'undesirable', e.g. i16 on x86, consider
1743 // promoting it.
1744 unsigned Opc = Op.getOpcode();
1745 if (TLI.isTypeDesirableForOp(Opc, VT))
1746 return false;
1747
1748 EVT PVT = VT;
1749 // Consult target whether it is a good idea to promote this operation and
1750 // what's the right type to promote it to.
1751 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1752 assert(PVT != VT && "Don't know what type to promote to!");
1753
1754 SDLoc DL(Op);
1755 SDNode *N = Op.getNode();
1756 LoadSDNode *LD = cast<LoadSDNode>(Val: N);
1757 EVT MemVT = LD->getMemoryVT();
1758 ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(N: LD) ? ISD::EXTLOAD
1759 : LD->getExtensionType();
1760 SDValue NewLD = DAG.getExtLoad(ExtType, dl: DL, VT: PVT,
1761 Chain: LD->getChain(), Ptr: LD->getBasePtr(),
1762 MemVT, MMO: LD->getMemOperand());
1763 SDValue Result = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: NewLD);
1764
1765 LLVM_DEBUG(dbgs() << "\nPromoting "; N->dump(&DAG); dbgs() << "\nTo: ";
1766 Result.dump(&DAG); dbgs() << '\n');
1767
1768 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Result);
1769 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 1), To: NewLD.getValue(R: 1));
1770
1771 AddToWorklist(N: Result.getNode());
1772 recursivelyDeleteUnusedNodes(N);
1773 return true;
1774 }
1775
1776 return false;
1777}
1778
1779/// Recursively delete a node which has no uses and any operands for
1780/// which it is the only use.
1781///
1782/// Note that this both deletes the nodes and removes them from the worklist.
1783/// It also adds any nodes who have had a user deleted to the worklist as they
1784/// may now have only one use and subject to other combines.
1785bool DAGCombiner::recursivelyDeleteUnusedNodes(SDNode *N) {
1786 if (!N->use_empty())
1787 return false;
1788
1789 SmallSetVector<SDNode *, 16> Nodes;
1790 Nodes.insert(X: N);
1791 do {
1792 N = Nodes.pop_back_val();
1793 if (!N)
1794 continue;
1795
1796 if (N->use_empty()) {
1797 for (const SDValue &ChildN : N->op_values())
1798 Nodes.insert(X: ChildN.getNode());
1799
1800 removeFromWorklist(N);
1801 DAG.DeleteNode(N);
1802 } else {
1803 AddToWorklist(N);
1804 }
1805 } while (!Nodes.empty());
1806 return true;
1807}
1808
1809//===----------------------------------------------------------------------===//
1810// Main DAG Combiner implementation
1811//===----------------------------------------------------------------------===//
1812
1813void DAGCombiner::Run(CombineLevel AtLevel) {
1814 // set the instance variables, so that the various visit routines may use it.
1815 Level = AtLevel;
1816 LegalDAG = Level >= AfterLegalizeDAG;
1817 LegalOperations = Level >= AfterLegalizeVectorOps;
1818 LegalTypes = Level >= AfterLegalizeTypes;
1819
1820 WorklistInserter AddNodes(*this);
1821
1822 // Add all the dag nodes to the worklist.
1823 //
1824 // Note: All nodes are not added to PruningList here, this is because the only
1825 // nodes which can be deleted are those which have no uses and all other nodes
1826 // which would otherwise be added to the worklist by the first call to
1827 // getNextWorklistEntry are already present in it.
1828 for (SDNode &Node : DAG.allnodes())
1829 AddToWorklist(N: &Node, /* IsCandidateForPruning */ Node.use_empty());
1830
1831 // Create a dummy node (which is not added to allnodes), that adds a reference
1832 // to the root node, preventing it from being deleted, and tracking any
1833 // changes of the root.
1834 HandleSDNode Dummy(DAG.getRoot());
1835
1836 // While we have a valid worklist entry node, try to combine it.
1837 while (SDNode *N = getNextWorklistEntry()) {
1838 // If N has no uses, it is dead. Make sure to revisit all N's operands once
1839 // N is deleted from the DAG, since they too may now be dead or may have a
1840 // reduced number of uses, allowing other xforms.
1841 if (recursivelyDeleteUnusedNodes(N))
1842 continue;
1843
1844 WorklistRemover DeadNodes(*this);
1845
1846 // If this combine is running after legalizing the DAG, re-legalize any
1847 // nodes pulled off the worklist.
1848 if (LegalDAG) {
1849 SmallSetVector<SDNode *, 16> UpdatedNodes;
1850 bool NIsValid = DAG.LegalizeOp(N, UpdatedNodes);
1851
1852 for (SDNode *LN : UpdatedNodes)
1853 AddToWorklistWithUsers(N: LN);
1854
1855 if (!NIsValid)
1856 continue;
1857 }
1858
1859 LLVM_DEBUG(dbgs() << "\nCombining: "; N->dump(&DAG));
1860
1861 // Add any operands of the new node which have not yet been combined to the
1862 // worklist as well. getNextWorklistEntry flags nodes that have been
1863 // combined before. Because the worklist uniques things already, this won't
1864 // repeatedly process the same operand.
1865 for (const SDValue &ChildN : N->op_values())
1866 AddToWorklist(N: ChildN.getNode(), /*IsCandidateForPruning=*/true,
1867 /*SkipIfCombinedBefore=*/true);
1868
1869 SDValue RV = combine(N);
1870
1871 if (!RV.getNode())
1872 continue;
1873
1874 ++NodesCombined;
1875
1876 // Invalidate cached info.
1877 ChainsWithoutMergeableStores.clear();
1878
1879 // If we get back the same node we passed in, rather than a new node or
1880 // zero, we know that the node must have defined multiple values and
1881 // CombineTo was used. Since CombineTo takes care of the worklist
1882 // mechanics for us, we have no work to do in this case.
1883 if (RV.getNode() == N)
1884 continue;
1885
1886 assert(N->getOpcode() != ISD::DELETED_NODE &&
1887 RV.getOpcode() != ISD::DELETED_NODE &&
1888 "Node was deleted but visit returned new node!");
1889
1890 LLVM_DEBUG(dbgs() << " ... into: "; RV.dump(&DAG));
1891
1892 if (N->getNumValues() == RV->getNumValues())
1893 DAG.ReplaceAllUsesWith(From: N, To: RV.getNode());
1894 else {
1895 assert(N->getValueType(0) == RV.getValueType() &&
1896 N->getNumValues() == 1 && "Type mismatch");
1897 DAG.ReplaceAllUsesWith(From: N, To: &RV);
1898 }
1899
1900 // Push the new node and any users onto the worklist. Omit this if the
1901 // new node is the EntryToken (e.g. if a store managed to get optimized
1902 // out), because re-visiting the EntryToken and its users will not uncover
1903 // any additional opportunities, but there may be a large number of such
1904 // users, potentially causing compile time explosion.
1905 if (RV.getOpcode() != ISD::EntryToken)
1906 AddToWorklistWithUsers(N: RV.getNode());
1907
1908 // Finally, if the node is now dead, remove it from the graph. The node
1909 // may not be dead if the replacement process recursively simplified to
1910 // something else needing this node. This will also take care of adding any
1911 // operands which have lost a user to the worklist.
1912 recursivelyDeleteUnusedNodes(N);
1913 }
1914
1915 // If the root changed (e.g. it was a dead load, update the root).
1916 DAG.setRoot(Dummy.getValue());
1917 DAG.RemoveDeadNodes();
1918}
1919
1920SDValue DAGCombiner::visit(SDNode *N) {
1921 // clang-format off
1922 switch (N->getOpcode()) {
1923 default: break;
1924 case ISD::TokenFactor: return visitTokenFactor(N);
1925 case ISD::MERGE_VALUES: return visitMERGE_VALUES(N);
1926 case ISD::ADD: return visitADD(N);
1927 case ISD::PTRADD: return visitPTRADD(N);
1928 case ISD::SUB: return visitSUB(N);
1929 case ISD::SADDSAT:
1930 case ISD::UADDSAT: return visitADDSAT(N);
1931 case ISD::SSUBSAT:
1932 case ISD::USUBSAT: return visitSUBSAT(N);
1933 case ISD::ADDC: return visitADDC(N);
1934 case ISD::SADDO:
1935 case ISD::UADDO: return visitADDO(N);
1936 case ISD::SUBC: return visitSUBC(N);
1937 case ISD::SSUBO:
1938 case ISD::USUBO: return visitSUBO(N);
1939 case ISD::ADDE: return visitADDE(N);
1940 case ISD::UADDO_CARRY: return visitUADDO_CARRY(N);
1941 case ISD::SADDO_CARRY: return visitSADDO_CARRY(N);
1942 case ISD::SUBE: return visitSUBE(N);
1943 case ISD::USUBO_CARRY: return visitUSUBO_CARRY(N);
1944 case ISD::SSUBO_CARRY: return visitSSUBO_CARRY(N);
1945 case ISD::SMULFIX:
1946 case ISD::SMULFIXSAT:
1947 case ISD::UMULFIX:
1948 case ISD::UMULFIXSAT: return visitMULFIX(N);
1949 case ISD::MUL: return visitMUL<EmptyMatchContext>(N);
1950 case ISD::SDIV: return visitSDIV(N);
1951 case ISD::UDIV: return visitUDIV(N);
1952 case ISD::SREM:
1953 case ISD::UREM: return visitREM(N);
1954 case ISD::MULHU: return visitMULHU(N);
1955 case ISD::MULHS: return visitMULHS(N);
1956 case ISD::AVGFLOORS:
1957 case ISD::AVGFLOORU:
1958 case ISD::AVGCEILS:
1959 case ISD::AVGCEILU: return visitAVG(N);
1960 case ISD::ABDS:
1961 case ISD::ABDU: return visitABD(N);
1962 case ISD::SMUL_LOHI: return visitSMUL_LOHI(N);
1963 case ISD::UMUL_LOHI: return visitUMUL_LOHI(N);
1964 case ISD::SMULO:
1965 case ISD::UMULO: return visitMULO(N);
1966 case ISD::SMIN:
1967 case ISD::SMAX:
1968 case ISD::UMIN:
1969 case ISD::UMAX: return visitIMINMAX(N);
1970 case ISD::AND: return visitAND(N);
1971 case ISD::OR: return visitOR(N);
1972 case ISD::XOR: return visitXOR(N);
1973 case ISD::SHL: return visitSHL(N);
1974 case ISD::SRA: return visitSRA(N);
1975 case ISD::SRL: return visitSRL(N);
1976 case ISD::ROTR:
1977 case ISD::ROTL: return visitRotate(N);
1978 case ISD::FSHL:
1979 case ISD::FSHR: return visitFunnelShift(N);
1980 case ISD::SSHLSAT:
1981 case ISD::USHLSAT: return visitSHLSAT(N);
1982 case ISD::ABS: return visitABS(N);
1983 case ISD::CLMUL:
1984 case ISD::CLMULR:
1985 case ISD::CLMULH: return visitCLMUL(N);
1986 case ISD::BSWAP: return visitBSWAP(N);
1987 case ISD::BITREVERSE: return visitBITREVERSE(N);
1988 case ISD::CTLZ: return visitCTLZ(N);
1989 case ISD::CTLZ_ZERO_UNDEF: return visitCTLZ_ZERO_UNDEF(N);
1990 case ISD::CTTZ: return visitCTTZ(N);
1991 case ISD::CTTZ_ZERO_UNDEF: return visitCTTZ_ZERO_UNDEF(N);
1992 case ISD::CTPOP: return visitCTPOP(N);
1993 case ISD::SELECT: return visitSELECT(N);
1994 case ISD::VSELECT: return visitVSELECT(N);
1995 case ISD::SELECT_CC: return visitSELECT_CC(N);
1996 case ISD::SETCC: return visitSETCC(N);
1997 case ISD::SETCCCARRY: return visitSETCCCARRY(N);
1998 case ISD::SIGN_EXTEND: return visitSIGN_EXTEND(N);
1999 case ISD::ZERO_EXTEND: return visitZERO_EXTEND(N);
2000 case ISD::ANY_EXTEND: return visitANY_EXTEND(N);
2001 case ISD::AssertSext:
2002 case ISD::AssertZext: return visitAssertExt(N);
2003 case ISD::AssertAlign: return visitAssertAlign(N);
2004 case ISD::SIGN_EXTEND_INREG: return visitSIGN_EXTEND_INREG(N);
2005 case ISD::SIGN_EXTEND_VECTOR_INREG:
2006 case ISD::ZERO_EXTEND_VECTOR_INREG:
2007 case ISD::ANY_EXTEND_VECTOR_INREG: return visitEXTEND_VECTOR_INREG(N);
2008 case ISD::TRUNCATE: return visitTRUNCATE(N);
2009 case ISD::TRUNCATE_USAT_U: return visitTRUNCATE_USAT_U(N);
2010 case ISD::BITCAST: return visitBITCAST(N);
2011 case ISD::BUILD_PAIR: return visitBUILD_PAIR(N);
2012 case ISD::FADD: return visitFADD(N);
2013 case ISD::STRICT_FADD: return visitSTRICT_FADD(N);
2014 case ISD::FSUB: return visitFSUB(N);
2015 case ISD::FMUL: return visitFMUL(N);
2016 case ISD::FMA: return visitFMA<EmptyMatchContext>(N);
2017 case ISD::FMAD: return visitFMAD(N);
2018 case ISD::FMULADD: return visitFMULADD(N);
2019 case ISD::FDIV: return visitFDIV(N);
2020 case ISD::FREM: return visitFREM(N);
2021 case ISD::FSQRT: return visitFSQRT(N);
2022 case ISD::FCOPYSIGN: return visitFCOPYSIGN(N);
2023 case ISD::FPOW: return visitFPOW(N);
2024 case ISD::SINT_TO_FP: return visitSINT_TO_FP(N);
2025 case ISD::UINT_TO_FP: return visitUINT_TO_FP(N);
2026 case ISD::FP_TO_SINT: return visitFP_TO_SINT(N);
2027 case ISD::FP_TO_UINT: return visitFP_TO_UINT(N);
2028 case ISD::LROUND:
2029 case ISD::LLROUND:
2030 case ISD::LRINT:
2031 case ISD::LLRINT: return visitXROUND(N);
2032 case ISD::FP_ROUND: return visitFP_ROUND(N);
2033 case ISD::FP_EXTEND: return visitFP_EXTEND(N);
2034 case ISD::FNEG: return visitFNEG(N);
2035 case ISD::FABS: return visitFABS(N);
2036 case ISD::FFLOOR: return visitFFLOOR(N);
2037 case ISD::FMINNUM:
2038 case ISD::FMAXNUM:
2039 case ISD::FMINIMUM:
2040 case ISD::FMAXIMUM:
2041 case ISD::FMINIMUMNUM:
2042 case ISD::FMAXIMUMNUM: return visitFMinMax(N);
2043 case ISD::FCEIL: return visitFCEIL(N);
2044 case ISD::FTRUNC: return visitFTRUNC(N);
2045 case ISD::FFREXP: return visitFFREXP(N);
2046 case ISD::BRCOND: return visitBRCOND(N);
2047 case ISD::BR_CC: return visitBR_CC(N);
2048 case ISD::LOAD: return visitLOAD(N);
2049 case ISD::STORE: return visitSTORE(N);
2050 case ISD::ATOMIC_STORE: return visitATOMIC_STORE(N);
2051 case ISD::INSERT_VECTOR_ELT: return visitINSERT_VECTOR_ELT(N);
2052 case ISD::EXTRACT_VECTOR_ELT: return visitEXTRACT_VECTOR_ELT(N);
2053 case ISD::BUILD_VECTOR: return visitBUILD_VECTOR(N);
2054 case ISD::CONCAT_VECTORS: return visitCONCAT_VECTORS(N);
2055 case ISD::VECTOR_INTERLEAVE: return visitVECTOR_INTERLEAVE(N);
2056 case ISD::EXTRACT_SUBVECTOR: return visitEXTRACT_SUBVECTOR(N);
2057 case ISD::VECTOR_SHUFFLE: return visitVECTOR_SHUFFLE(N);
2058 case ISD::SCALAR_TO_VECTOR: return visitSCALAR_TO_VECTOR(N);
2059 case ISD::INSERT_SUBVECTOR: return visitINSERT_SUBVECTOR(N);
2060 case ISD::MGATHER: return visitMGATHER(N);
2061 case ISD::MLOAD: return visitMLOAD(N);
2062 case ISD::MSCATTER: return visitMSCATTER(N);
2063 case ISD::MSTORE: return visitMSTORE(N);
2064 case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM: return visitMHISTOGRAM(N);
2065 case ISD::PARTIAL_REDUCE_SMLA:
2066 case ISD::PARTIAL_REDUCE_UMLA:
2067 case ISD::PARTIAL_REDUCE_SUMLA:
2068 case ISD::PARTIAL_REDUCE_FMLA:
2069 return visitPARTIAL_REDUCE_MLA(N);
2070 case ISD::VECTOR_COMPRESS: return visitVECTOR_COMPRESS(N);
2071 case ISD::LIFETIME_END: return visitLIFETIME_END(N);
2072 case ISD::FP_TO_FP16: return visitFP_TO_FP16(N);
2073 case ISD::FP16_TO_FP: return visitFP16_TO_FP(N);
2074 case ISD::FP_TO_BF16: return visitFP_TO_BF16(N);
2075 case ISD::BF16_TO_FP: return visitBF16_TO_FP(N);
2076 case ISD::FREEZE: return visitFREEZE(N);
2077 case ISD::GET_FPENV_MEM: return visitGET_FPENV_MEM(N);
2078 case ISD::SET_FPENV_MEM: return visitSET_FPENV_MEM(N);
2079 case ISD::FCANONICALIZE: return visitFCANONICALIZE(N);
2080 case ISD::VECREDUCE_FADD:
2081 case ISD::VECREDUCE_FMUL:
2082 case ISD::VECREDUCE_ADD:
2083 case ISD::VECREDUCE_MUL:
2084 case ISD::VECREDUCE_AND:
2085 case ISD::VECREDUCE_OR:
2086 case ISD::VECREDUCE_XOR:
2087 case ISD::VECREDUCE_SMAX:
2088 case ISD::VECREDUCE_SMIN:
2089 case ISD::VECREDUCE_UMAX:
2090 case ISD::VECREDUCE_UMIN:
2091 case ISD::VECREDUCE_FMAX:
2092 case ISD::VECREDUCE_FMIN:
2093 case ISD::VECREDUCE_FMAXIMUM:
2094 case ISD::VECREDUCE_FMINIMUM: return visitVECREDUCE(N);
2095#define BEGIN_REGISTER_VP_SDNODE(SDOPC, ...) case ISD::SDOPC:
2096#include "llvm/IR/VPIntrinsics.def"
2097 return visitVPOp(N);
2098 }
2099 // clang-format on
2100 return SDValue();
2101}
2102
2103SDValue DAGCombiner::combine(SDNode *N) {
2104 if (!DebugCounter::shouldExecute(Counter&: DAGCombineCounter))
2105 return SDValue();
2106
2107 SDValue RV;
2108 if (!DisableGenericCombines)
2109 RV = visit(N);
2110
2111 // If nothing happened, try a target-specific DAG combine.
2112 if (!RV.getNode()) {
2113 assert(N->getOpcode() != ISD::DELETED_NODE &&
2114 "Node was deleted but visit returned NULL!");
2115
2116 if (N->getOpcode() >= ISD::BUILTIN_OP_END ||
2117 TLI.hasTargetDAGCombine(NT: (ISD::NodeType)N->getOpcode())) {
2118
2119 // Expose the DAG combiner to the target combiner impls.
2120 TargetLowering::DAGCombinerInfo
2121 DagCombineInfo(DAG, Level, false, this);
2122
2123 RV = TLI.PerformDAGCombine(N, DCI&: DagCombineInfo);
2124 }
2125 }
2126
2127 // If nothing happened still, try promoting the operation.
2128 if (!RV.getNode()) {
2129 switch (N->getOpcode()) {
2130 default: break;
2131 case ISD::ADD:
2132 case ISD::SUB:
2133 case ISD::MUL:
2134 case ISD::AND:
2135 case ISD::OR:
2136 case ISD::XOR:
2137 RV = PromoteIntBinOp(Op: SDValue(N, 0));
2138 break;
2139 case ISD::SHL:
2140 case ISD::SRA:
2141 case ISD::SRL:
2142 RV = PromoteIntShiftOp(Op: SDValue(N, 0));
2143 break;
2144 case ISD::SIGN_EXTEND:
2145 case ISD::ZERO_EXTEND:
2146 case ISD::ANY_EXTEND:
2147 RV = PromoteExtend(Op: SDValue(N, 0));
2148 break;
2149 case ISD::LOAD:
2150 if (PromoteLoad(Op: SDValue(N, 0)))
2151 RV = SDValue(N, 0);
2152 break;
2153 }
2154 }
2155
2156 // If N is a commutative binary node, try to eliminate it if the commuted
2157 // version is already present in the DAG.
2158 if (!RV.getNode() && TLI.isCommutativeBinOp(Opcode: N->getOpcode())) {
2159 SDValue N0 = N->getOperand(Num: 0);
2160 SDValue N1 = N->getOperand(Num: 1);
2161
2162 // Constant operands are canonicalized to RHS.
2163 if (N0 != N1 && (isa<ConstantSDNode>(Val: N0) || !isa<ConstantSDNode>(Val: N1))) {
2164 SDValue Ops[] = {N1, N0};
2165 SDNode *CSENode = DAG.getNodeIfExists(Opcode: N->getOpcode(), VTList: N->getVTList(), Ops,
2166 Flags: N->getFlags());
2167 if (CSENode)
2168 return SDValue(CSENode, 0);
2169 }
2170 }
2171
2172 return RV;
2173}
2174
2175/// Given a node, return its input chain if it has one, otherwise return a null
2176/// sd operand.
2177static SDValue getInputChainForNode(SDNode *N) {
2178 if (unsigned NumOps = N->getNumOperands()) {
2179 if (N->getOperand(Num: 0).getValueType() == MVT::Other)
2180 return N->getOperand(Num: 0);
2181 if (N->getOperand(Num: NumOps-1).getValueType() == MVT::Other)
2182 return N->getOperand(Num: NumOps-1);
2183 for (unsigned i = 1; i < NumOps-1; ++i)
2184 if (N->getOperand(Num: i).getValueType() == MVT::Other)
2185 return N->getOperand(Num: i);
2186 }
2187 return SDValue();
2188}
2189
2190SDValue DAGCombiner::visitFCANONICALIZE(SDNode *N) {
2191 SDValue Operand = N->getOperand(Num: 0);
2192 EVT VT = Operand.getValueType();
2193 SDLoc dl(N);
2194
2195 // Canonicalize undef to quiet NaN.
2196 if (Operand.isUndef()) {
2197 APFloat CanonicalQNaN = APFloat::getQNaN(Sem: VT.getFltSemantics());
2198 return DAG.getConstantFP(Val: CanonicalQNaN, DL: dl, VT);
2199 }
2200 return SDValue();
2201}
2202
2203SDValue DAGCombiner::visitTokenFactor(SDNode *N) {
2204 // If N has two operands, where one has an input chain equal to the other,
2205 // the 'other' chain is redundant.
2206 if (N->getNumOperands() == 2) {
2207 if (getInputChainForNode(N: N->getOperand(Num: 0).getNode()) == N->getOperand(Num: 1))
2208 return N->getOperand(Num: 0);
2209 if (getInputChainForNode(N: N->getOperand(Num: 1).getNode()) == N->getOperand(Num: 0))
2210 return N->getOperand(Num: 1);
2211 }
2212
2213 // Don't simplify token factors if optnone.
2214 if (OptLevel == CodeGenOptLevel::None)
2215 return SDValue();
2216
2217 // Don't simplify the token factor if the node itself has too many operands.
2218 if (N->getNumOperands() > TokenFactorInlineLimit)
2219 return SDValue();
2220
2221 // If the sole user is a token factor, we should make sure we have a
2222 // chance to merge them together. This prevents TF chains from inhibiting
2223 // optimizations.
2224 if (N->hasOneUse() && N->user_begin()->getOpcode() == ISD::TokenFactor)
2225 AddToWorklist(N: *(N->user_begin()));
2226
2227 SmallVector<SDNode *, 8> TFs; // List of token factors to visit.
2228 SmallVector<SDValue, 8> Ops; // Ops for replacing token factor.
2229 SmallPtrSet<SDNode*, 16> SeenOps;
2230 bool Changed = false; // If we should replace this token factor.
2231
2232 // Start out with this token factor.
2233 TFs.push_back(Elt: N);
2234
2235 // Iterate through token factors. The TFs grows when new token factors are
2236 // encountered.
2237 for (unsigned i = 0; i < TFs.size(); ++i) {
2238 // Limit number of nodes to inline, to avoid quadratic compile times.
2239 // We have to add the outstanding Token Factors to Ops, otherwise we might
2240 // drop Ops from the resulting Token Factors.
2241 if (Ops.size() > TokenFactorInlineLimit) {
2242 for (unsigned j = i; j < TFs.size(); j++)
2243 Ops.emplace_back(Args&: TFs[j], Args: 0);
2244 // Drop unprocessed Token Factors from TFs, so we do not add them to the
2245 // combiner worklist later.
2246 TFs.resize(N: i);
2247 break;
2248 }
2249
2250 SDNode *TF = TFs[i];
2251 // Check each of the operands.
2252 for (const SDValue &Op : TF->op_values()) {
2253 switch (Op.getOpcode()) {
2254 case ISD::EntryToken:
2255 // Entry tokens don't need to be added to the list. They are
2256 // redundant.
2257 Changed = true;
2258 break;
2259
2260 case ISD::TokenFactor:
2261 if (Op.hasOneUse() && !is_contained(Range&: TFs, Element: Op.getNode())) {
2262 // Queue up for processing.
2263 TFs.push_back(Elt: Op.getNode());
2264 Changed = true;
2265 break;
2266 }
2267 [[fallthrough]];
2268
2269 default:
2270 // Only add if it isn't already in the list.
2271 if (SeenOps.insert(Ptr: Op.getNode()).second)
2272 Ops.push_back(Elt: Op);
2273 else
2274 Changed = true;
2275 break;
2276 }
2277 }
2278 }
2279
2280 // Re-visit inlined Token Factors, to clean them up in case they have been
2281 // removed. Skip the first Token Factor, as this is the current node.
2282 for (unsigned i = 1, e = TFs.size(); i < e; i++)
2283 AddToWorklist(N: TFs[i]);
2284
2285 // Remove Nodes that are chained to another node in the list. Do so
2286 // by walking up chains breath-first stopping when we've seen
2287 // another operand. In general we must climb to the EntryNode, but we can exit
2288 // early if we find all remaining work is associated with just one operand as
2289 // no further pruning is possible.
2290
2291 // List of nodes to search through and original Ops from which they originate.
2292 SmallVector<std::pair<SDNode *, unsigned>, 8> Worklist;
2293 SmallVector<unsigned, 8> OpWorkCount; // Count of work for each Op.
2294 SmallPtrSet<SDNode *, 16> SeenChains;
2295 bool DidPruneOps = false;
2296
2297 unsigned NumLeftToConsider = 0;
2298 for (const SDValue &Op : Ops) {
2299 Worklist.push_back(Elt: std::make_pair(x: Op.getNode(), y: NumLeftToConsider++));
2300 OpWorkCount.push_back(Elt: 1);
2301 }
2302
2303 auto AddToWorklist = [&](unsigned CurIdx, SDNode *Op, unsigned OpNumber) {
2304 // If this is an Op, we can remove the op from the list. Remark any
2305 // search associated with it as from the current OpNumber.
2306 if (SeenOps.contains(Ptr: Op)) {
2307 Changed = true;
2308 DidPruneOps = true;
2309 unsigned OrigOpNumber = 0;
2310 while (OrigOpNumber < Ops.size() && Ops[OrigOpNumber].getNode() != Op)
2311 OrigOpNumber++;
2312 assert((OrigOpNumber != Ops.size()) &&
2313 "expected to find TokenFactor Operand");
2314 // Re-mark worklist from OrigOpNumber to OpNumber
2315 for (unsigned i = CurIdx + 1; i < Worklist.size(); ++i) {
2316 if (Worklist[i].second == OrigOpNumber) {
2317 Worklist[i].second = OpNumber;
2318 }
2319 }
2320 OpWorkCount[OpNumber] += OpWorkCount[OrigOpNumber];
2321 OpWorkCount[OrigOpNumber] = 0;
2322 NumLeftToConsider--;
2323 }
2324 // Add if it's a new chain
2325 if (SeenChains.insert(Ptr: Op).second) {
2326 OpWorkCount[OpNumber]++;
2327 Worklist.push_back(Elt: std::make_pair(x&: Op, y&: OpNumber));
2328 }
2329 };
2330
2331 for (unsigned i = 0; i < Worklist.size() && i < 1024; ++i) {
2332 // We need at least be consider at least 2 Ops to prune.
2333 if (NumLeftToConsider <= 1)
2334 break;
2335 auto CurNode = Worklist[i].first;
2336 auto CurOpNumber = Worklist[i].second;
2337 assert((OpWorkCount[CurOpNumber] > 0) &&
2338 "Node should not appear in worklist");
2339 switch (CurNode->getOpcode()) {
2340 case ISD::EntryToken:
2341 // Hitting EntryToken is the only way for the search to terminate without
2342 // hitting
2343 // another operand's search. Prevent us from marking this operand
2344 // considered.
2345 NumLeftToConsider++;
2346 break;
2347 case ISD::TokenFactor:
2348 for (const SDValue &Op : CurNode->op_values())
2349 AddToWorklist(i, Op.getNode(), CurOpNumber);
2350 break;
2351 case ISD::LIFETIME_START:
2352 case ISD::LIFETIME_END:
2353 case ISD::CopyFromReg:
2354 case ISD::CopyToReg:
2355 AddToWorklist(i, CurNode->getOperand(Num: 0).getNode(), CurOpNumber);
2356 break;
2357 default:
2358 if (auto *MemNode = dyn_cast<MemSDNode>(Val: CurNode))
2359 AddToWorklist(i, MemNode->getChain().getNode(), CurOpNumber);
2360 break;
2361 }
2362 OpWorkCount[CurOpNumber]--;
2363 if (OpWorkCount[CurOpNumber] == 0)
2364 NumLeftToConsider--;
2365 }
2366
2367 // If we've changed things around then replace token factor.
2368 if (Changed) {
2369 SDValue Result;
2370 if (Ops.empty()) {
2371 // The entry token is the only possible outcome.
2372 Result = DAG.getEntryNode();
2373 } else {
2374 if (DidPruneOps) {
2375 SmallVector<SDValue, 8> PrunedOps;
2376 //
2377 for (const SDValue &Op : Ops) {
2378 if (SeenChains.count(Ptr: Op.getNode()) == 0)
2379 PrunedOps.push_back(Elt: Op);
2380 }
2381 Result = DAG.getTokenFactor(DL: SDLoc(N), Vals&: PrunedOps);
2382 } else {
2383 Result = DAG.getTokenFactor(DL: SDLoc(N), Vals&: Ops);
2384 }
2385 }
2386 return Result;
2387 }
2388 return SDValue();
2389}
2390
2391/// MERGE_VALUES can always be eliminated.
2392SDValue DAGCombiner::visitMERGE_VALUES(SDNode *N) {
2393 WorklistRemover DeadNodes(*this);
2394 // Replacing results may cause a different MERGE_VALUES to suddenly
2395 // be CSE'd with N, and carry its uses with it. Iterate until no
2396 // uses remain, to ensure that the node can be safely deleted.
2397 // First add the users of this node to the work list so that they
2398 // can be tried again once they have new operands.
2399 AddUsersToWorklist(N);
2400 do {
2401 // Do as a single replacement to avoid rewalking use lists.
2402 SmallVector<SDValue, 8> Ops(N->ops());
2403 DAG.ReplaceAllUsesWith(From: N, To: Ops.data());
2404 } while (!N->use_empty());
2405 deleteAndRecombine(N);
2406 return SDValue(N, 0); // Return N so it doesn't get rechecked!
2407}
2408
2409/// If \p N is a ConstantSDNode with isOpaque() == false return it casted to a
2410/// ConstantSDNode pointer else nullptr.
2411static ConstantSDNode *getAsNonOpaqueConstant(SDValue N) {
2412 ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Val&: N);
2413 return Const != nullptr && !Const->isOpaque() ? Const : nullptr;
2414}
2415
2416// isTruncateOf - If N is a truncate of some other value, return true, record
2417// the value being truncated in Op and which of Op's bits are zero/one in Known.
2418// This function computes KnownBits to avoid a duplicated call to
2419// computeKnownBits in the caller.
2420static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op,
2421 KnownBits &Known) {
2422 if (N->getOpcode() == ISD::TRUNCATE) {
2423 Op = N->getOperand(Num: 0);
2424 Known = DAG.computeKnownBits(Op);
2425 if (N->getFlags().hasNoUnsignedWrap())
2426 Known.Zero.setBitsFrom(N.getScalarValueSizeInBits());
2427 return true;
2428 }
2429
2430 if (N.getValueType().getScalarType() != MVT::i1 ||
2431 !sd_match(
2432 N, P: m_c_SetCC(LHS: m_Value(N&: Op), RHS: m_Zero(), CC: m_SpecificCondCode(CC: ISD::SETNE))))
2433 return false;
2434
2435 Known = DAG.computeKnownBits(Op);
2436 return (Known.Zero | 1).isAllOnes();
2437}
2438
2439/// Return true if 'Use' is a load or a store that uses N as its base pointer
2440/// and that N may be folded in the load / store addressing mode.
2441static bool canFoldInAddressingMode(SDNode *N, SDNode *Use, SelectionDAG &DAG,
2442 const TargetLowering &TLI) {
2443 EVT VT;
2444 unsigned AS;
2445
2446 if (LoadSDNode *LD = dyn_cast<LoadSDNode>(Val: Use)) {
2447 if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
2448 return false;
2449 VT = LD->getMemoryVT();
2450 AS = LD->getAddressSpace();
2451 } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(Val: Use)) {
2452 if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
2453 return false;
2454 VT = ST->getMemoryVT();
2455 AS = ST->getAddressSpace();
2456 } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(Val: Use)) {
2457 if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
2458 return false;
2459 VT = LD->getMemoryVT();
2460 AS = LD->getAddressSpace();
2461 } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(Val: Use)) {
2462 if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
2463 return false;
2464 VT = ST->getMemoryVT();
2465 AS = ST->getAddressSpace();
2466 } else {
2467 return false;
2468 }
2469
2470 TargetLowering::AddrMode AM;
2471 if (N->isAnyAdd()) {
2472 AM.HasBaseReg = true;
2473 ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
2474 if (Offset)
2475 // [reg +/- imm]
2476 AM.BaseOffs = Offset->getSExtValue();
2477 else
2478 // [reg +/- reg]
2479 AM.Scale = 1;
2480 } else if (N->getOpcode() == ISD::SUB) {
2481 AM.HasBaseReg = true;
2482 ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
2483 if (Offset)
2484 // [reg +/- imm]
2485 AM.BaseOffs = -Offset->getSExtValue();
2486 else
2487 // [reg +/- reg]
2488 AM.Scale = 1;
2489 } else {
2490 return false;
2491 }
2492
2493 return TLI.isLegalAddressingMode(DL: DAG.getDataLayout(), AM,
2494 Ty: VT.getTypeForEVT(Context&: *DAG.getContext()), AddrSpace: AS);
2495}
2496
2497/// This inverts a canonicalization in IR that replaces a variable select arm
2498/// with an identity constant. Codegen improves if we re-use the variable
2499/// operand rather than load a constant. This can also be converted into a
2500/// masked vector operation if the target supports it.
2501static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
2502 bool ShouldCommuteOperands) {
2503 SDValue N0 = N->getOperand(Num: 0);
2504 SDValue N1 = N->getOperand(Num: 1);
2505
2506 // Match a select as operand 1. The identity constant that we are looking for
2507 // is only valid as operand 1 of a non-commutative binop.
2508 if (ShouldCommuteOperands)
2509 std::swap(a&: N0, b&: N1);
2510
2511 SDValue Cond, TVal, FVal;
2512 if (!sd_match(N: N1, P: m_OneUse(P: m_SelectLike(Cond: m_Value(N&: Cond), T: m_Value(N&: TVal),
2513 F: m_Value(N&: FVal)))))
2514 return SDValue();
2515
2516 // We can't hoist all instructions because of immediate UB (not speculatable).
2517 // For example div/rem by zero.
2518 if (!DAG.isSafeToSpeculativelyExecuteNode(N))
2519 return SDValue();
2520
2521 unsigned SelOpcode = N1.getOpcode();
2522 unsigned Opcode = N->getOpcode();
2523 EVT VT = N->getValueType(ResNo: 0);
2524 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2525
2526 // This transform increases uses of N0, so freeze it to be safe.
2527 // binop N0, (vselect Cond, IDC, FVal) --> vselect Cond, N0, (binop N0, FVal)
2528 unsigned OpNo = ShouldCommuteOperands ? 0 : 1;
2529 if (isNeutralConstant(Opc: Opcode, Flags: N->getFlags(), V: TVal, OperandNo: OpNo) &&
2530 TLI.shouldFoldSelectWithIdentityConstant(BinOpcode: Opcode, VT, SelectOpcode: SelOpcode, X: N0,
2531 Y: FVal)) {
2532 SDValue F0 = DAG.getFreeze(V: N0);
2533 SDValue NewBO = DAG.getNode(Opcode, DL: SDLoc(N), VT, N1: F0, N2: FVal, Flags: N->getFlags());
2534 return DAG.getSelect(DL: SDLoc(N), VT, Cond, LHS: F0, RHS: NewBO);
2535 }
2536 // binop N0, (vselect Cond, TVal, IDC) --> vselect Cond, (binop N0, TVal), N0
2537 if (isNeutralConstant(Opc: Opcode, Flags: N->getFlags(), V: FVal, OperandNo: OpNo) &&
2538 TLI.shouldFoldSelectWithIdentityConstant(BinOpcode: Opcode, VT, SelectOpcode: SelOpcode, X: N0,
2539 Y: TVal)) {
2540 SDValue F0 = DAG.getFreeze(V: N0);
2541 SDValue NewBO = DAG.getNode(Opcode, DL: SDLoc(N), VT, N1: F0, N2: TVal, Flags: N->getFlags());
2542 return DAG.getSelect(DL: SDLoc(N), VT, Cond, LHS: NewBO, RHS: F0);
2543 }
2544
2545 return SDValue();
2546}
2547
2548SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) {
2549 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2550 assert(TLI.isBinOp(BO->getOpcode()) && BO->getNumValues() == 1 &&
2551 "Unexpected binary operator");
2552
2553 if (SDValue Sel = foldSelectWithIdentityConstant(N: BO, DAG, ShouldCommuteOperands: false))
2554 return Sel;
2555
2556 if (TLI.isCommutativeBinOp(Opcode: BO->getOpcode()))
2557 if (SDValue Sel = foldSelectWithIdentityConstant(N: BO, DAG, ShouldCommuteOperands: true))
2558 return Sel;
2559
2560 // Don't do this unless the old select is going away. We want to eliminate the
2561 // binary operator, not replace a binop with a select.
2562 // TODO: Handle ISD::SELECT_CC.
2563 unsigned SelOpNo = 0;
2564 SDValue Sel = BO->getOperand(Num: 0);
2565 auto BinOpcode = BO->getOpcode();
2566 if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse()) {
2567 SelOpNo = 1;
2568 Sel = BO->getOperand(Num: 1);
2569
2570 // Peek through trunc to shift amount type.
2571 if ((BinOpcode == ISD::SHL || BinOpcode == ISD::SRA ||
2572 BinOpcode == ISD::SRL) && Sel.hasOneUse()) {
2573 // This is valid when the truncated bits of x are already zero.
2574 SDValue Op;
2575 KnownBits Known;
2576 if (isTruncateOf(DAG, N: Sel, Op, Known) &&
2577 Known.countMaxActiveBits() < Sel.getScalarValueSizeInBits())
2578 Sel = Op;
2579 }
2580 }
2581
2582 if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse())
2583 return SDValue();
2584
2585 SDValue CT = Sel.getOperand(i: 1);
2586 if (!isConstantOrConstantVector(N: CT, NoOpaques: true) &&
2587 !DAG.isConstantFPBuildVectorOrConstantFP(N: CT))
2588 return SDValue();
2589
2590 SDValue CF = Sel.getOperand(i: 2);
2591 if (!isConstantOrConstantVector(N: CF, NoOpaques: true) &&
2592 !DAG.isConstantFPBuildVectorOrConstantFP(N: CF))
2593 return SDValue();
2594
2595 // Bail out if any constants are opaque because we can't constant fold those.
2596 // The exception is "and" and "or" with either 0 or -1 in which case we can
2597 // propagate non constant operands into select. I.e.:
2598 // and (select Cond, 0, -1), X --> select Cond, 0, X
2599 // or X, (select Cond, -1, 0) --> select Cond, -1, X
2600 bool CanFoldNonConst =
2601 (BinOpcode == ISD::AND || BinOpcode == ISD::OR) &&
2602 ((isNullOrNullSplat(V: CT) && isAllOnesOrAllOnesSplat(V: CF)) ||
2603 (isNullOrNullSplat(V: CF) && isAllOnesOrAllOnesSplat(V: CT)));
2604
2605 SDValue CBO = BO->getOperand(Num: SelOpNo ^ 1);
2606 if (!CanFoldNonConst &&
2607 !isConstantOrConstantVector(N: CBO, NoOpaques: true) &&
2608 !DAG.isConstantFPBuildVectorOrConstantFP(N: CBO))
2609 return SDValue();
2610
2611 SDLoc DL(Sel);
2612 SDValue NewCT, NewCF;
2613 EVT VT = BO->getValueType(ResNo: 0);
2614
2615 if (CanFoldNonConst) {
2616 // If CBO is an opaque constant, we can't rely on getNode to constant fold.
2617 if ((BinOpcode == ISD::AND && isNullOrNullSplat(V: CT)) ||
2618 (BinOpcode == ISD::OR && isAllOnesOrAllOnesSplat(V: CT)))
2619 NewCT = CT;
2620 else
2621 NewCT = CBO;
2622
2623 if ((BinOpcode == ISD::AND && isNullOrNullSplat(V: CF)) ||
2624 (BinOpcode == ISD::OR && isAllOnesOrAllOnesSplat(V: CF)))
2625 NewCF = CF;
2626 else
2627 NewCF = CBO;
2628 } else {
2629 // We have a select-of-constants followed by a binary operator with a
2630 // constant. Eliminate the binop by pulling the constant math into the
2631 // select. Example: add (select Cond, CT, CF), CBO --> select Cond, CT +
2632 // CBO, CF + CBO
2633 NewCT = SelOpNo ? DAG.FoldConstantArithmetic(Opcode: BinOpcode, DL, VT, Ops: {CBO, CT})
2634 : DAG.FoldConstantArithmetic(Opcode: BinOpcode, DL, VT, Ops: {CT, CBO});
2635 if (!NewCT)
2636 return SDValue();
2637
2638 NewCF = SelOpNo ? DAG.FoldConstantArithmetic(Opcode: BinOpcode, DL, VT, Ops: {CBO, CF})
2639 : DAG.FoldConstantArithmetic(Opcode: BinOpcode, DL, VT, Ops: {CF, CBO});
2640 if (!NewCF)
2641 return SDValue();
2642 }
2643
2644 return DAG.getSelect(DL, VT, Cond: Sel.getOperand(i: 0), LHS: NewCT, RHS: NewCF, Flags: BO->getFlags());
2645}
2646
2647static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, const SDLoc &DL,
2648 SelectionDAG &DAG) {
2649 assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2650 "Expecting add or sub");
2651
2652 // Match a constant operand and a zext operand for the math instruction:
2653 // add Z, C
2654 // sub C, Z
2655 bool IsAdd = N->getOpcode() == ISD::ADD;
2656 SDValue C = IsAdd ? N->getOperand(Num: 1) : N->getOperand(Num: 0);
2657 SDValue Z = IsAdd ? N->getOperand(Num: 0) : N->getOperand(Num: 1);
2658 auto *CN = dyn_cast<ConstantSDNode>(Val&: C);
2659 if (!CN || Z.getOpcode() != ISD::ZERO_EXTEND)
2660 return SDValue();
2661
2662 // Match the zext operand as a setcc of a boolean.
2663 if (Z.getOperand(i: 0).getValueType() != MVT::i1)
2664 return SDValue();
2665
2666 // Match the compare as: setcc (X & 1), 0, eq.
2667 if (!sd_match(N: Z.getOperand(i: 0), P: m_SetCC(LHS: m_And(L: m_Value(), R: m_One()), RHS: m_Zero(),
2668 CC: m_SpecificCondCode(CC: ISD::SETEQ))))
2669 return SDValue();
2670
2671 // We are adding/subtracting a constant and an inverted low bit. Turn that
2672 // into a subtract/add of the low bit with incremented/decremented constant:
2673 // add (zext i1 (seteq (X & 1), 0)), C --> sub C+1, (zext (X & 1))
2674 // sub C, (zext i1 (seteq (X & 1), 0)) --> add C-1, (zext (X & 1))
2675 EVT VT = C.getValueType();
2676 SDValue LowBit = DAG.getZExtOrTrunc(Op: Z.getOperand(i: 0).getOperand(i: 0), DL, VT);
2677 SDValue C1 = IsAdd ? DAG.getConstant(Val: CN->getAPIntValue() + 1, DL, VT)
2678 : DAG.getConstant(Val: CN->getAPIntValue() - 1, DL, VT);
2679 return DAG.getNode(Opcode: IsAdd ? ISD::SUB : ISD::ADD, DL, VT, N1: C1, N2: LowBit);
2680}
2681
2682// Attempt to form avgceil(A, B) from (A | B) - ((A ^ B) >> 1)
2683SDValue DAGCombiner::foldSubToAvg(SDNode *N, const SDLoc &DL) {
2684 SDValue N0 = N->getOperand(Num: 0);
2685 EVT VT = N0.getValueType();
2686 SDValue A, B;
2687
2688 if ((!LegalOperations || hasOperation(Opcode: ISD::AVGCEILU, VT)) &&
2689 sd_match(N, P: m_Sub(L: m_Or(L: m_Value(N&: A), R: m_Value(N&: B)),
2690 R: m_Srl(L: m_Xor(L: m_Deferred(V&: A), R: m_Deferred(V&: B)), R: m_One())))) {
2691 return DAG.getNode(Opcode: ISD::AVGCEILU, DL, VT, N1: A, N2: B);
2692 }
2693 if ((!LegalOperations || hasOperation(Opcode: ISD::AVGCEILS, VT)) &&
2694 sd_match(N, P: m_Sub(L: m_Or(L: m_Value(N&: A), R: m_Value(N&: B)),
2695 R: m_Sra(L: m_Xor(L: m_Deferred(V&: A), R: m_Deferred(V&: B)), R: m_One())))) {
2696 return DAG.getNode(Opcode: ISD::AVGCEILS, DL, VT, N1: A, N2: B);
2697 }
2698 return SDValue();
2699}
2700
2701/// Try to fold a pointer arithmetic node.
2702/// This needs to be done separately from normal addition, because pointer
2703/// addition is not commutative.
2704SDValue DAGCombiner::visitPTRADD(SDNode *N) {
2705 SDValue N0 = N->getOperand(Num: 0);
2706 SDValue N1 = N->getOperand(Num: 1);
2707 EVT PtrVT = N0.getValueType();
2708 EVT IntVT = N1.getValueType();
2709 SDLoc DL(N);
2710
2711 // This is already ensured by an assert in SelectionDAG::getNode(). Several
2712 // combines here depend on this assumption.
2713 assert(PtrVT == IntVT &&
2714 "PTRADD with different operand types is not supported");
2715
2716 // fold (ptradd x, 0) -> x
2717 if (isNullConstant(V: N1))
2718 return N0;
2719
2720 // fold (ptradd 0, x) -> x
2721 if (PtrVT == IntVT && isNullConstant(V: N0))
2722 return N1;
2723
2724 if (N0.getOpcode() == ISD::PTRADD &&
2725 !reassociationCanBreakAddressingModePattern(Opc: ISD::PTRADD, DL, N, N0, N1)) {
2726 SDValue X = N0.getOperand(i: 0);
2727 SDValue Y = N0.getOperand(i: 1);
2728 SDValue Z = N1;
2729 bool N0OneUse = N0.hasOneUse();
2730 bool YIsConstant = DAG.isConstantIntBuildVectorOrConstantInt(N: Y);
2731 bool ZIsConstant = DAG.isConstantIntBuildVectorOrConstantInt(N: Z);
2732
2733 // (ptradd (ptradd x, y), z) -> (ptradd x, (add y, z)) if:
2734 // * y is a constant and (ptradd x, y) has one use; or
2735 // * y and z are both constants.
2736 if ((YIsConstant && N0OneUse) || (YIsConstant && ZIsConstant)) {
2737 // If both additions in the original were NUW, the new ones are as well.
2738 SDNodeFlags Flags =
2739 (N->getFlags() & N0->getFlags()) & SDNodeFlags::NoUnsignedWrap;
2740 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT: IntVT, Ops: {Y, Z}, Flags);
2741 AddToWorklist(N: Add.getNode());
2742 // We can't set InBounds even if both original ptradds were InBounds and
2743 // NUW: SDAG usually represents pointers as integers, therefore, the
2744 // matched pattern behaves as if it had implicit casts:
2745 // (ptradd inbounds (inttoptr (ptrtoint (ptradd inbounds x, y))), z)
2746 // The outer inbounds ptradd might therefore rely on a provenance that x
2747 // does not have.
2748 return DAG.getMemBasePlusOffset(Base: X, Offset: Add, DL, Flags);
2749 }
2750 }
2751
2752 // The following combines can turn in-bounds pointer arithmetic out of bounds.
2753 // That is problematic for settings like AArch64's CPA, which checks that
2754 // intermediate results of pointer arithmetic remain in bounds. The target
2755 // therefore needs to opt-in to enable them.
2756 if (!TLI.canTransformPtrArithOutOfBounds(
2757 F: DAG.getMachineFunction().getFunction(), PtrVT))
2758 return SDValue();
2759
2760 if (N0.getOpcode() == ISD::PTRADD && isa<ConstantSDNode>(Val: N1)) {
2761 // Fold (ptradd (ptradd GA, v), c) -> (ptradd (ptradd GA, c) v) with
2762 // global address GA and constant c, such that c can be folded into GA.
2763 // TODO: Support constant vector splats.
2764 SDValue GAValue = N0.getOperand(i: 0);
2765 if (const GlobalAddressSDNode *GA =
2766 dyn_cast<GlobalAddressSDNode>(Val&: GAValue)) {
2767 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2768 if (!LegalOperations && TLI.isOffsetFoldingLegal(GA)) {
2769 // If both additions in the original were NUW, reassociation preserves
2770 // that.
2771 SDNodeFlags Flags =
2772 (N->getFlags() & N0->getFlags()) & SDNodeFlags::NoUnsignedWrap;
2773 // We can't set InBounds even if both original ptradds were InBounds and
2774 // NUW: SDAG usually represents pointers as integers, therefore, the
2775 // matched pattern behaves as if it had implicit casts:
2776 // (ptradd inbounds (inttoptr (ptrtoint (ptradd inbounds GA, v))), c)
2777 // The outer inbounds ptradd might therefore rely on a provenance that
2778 // GA does not have.
2779 SDValue Inner = DAG.getMemBasePlusOffset(Base: GAValue, Offset: N1, DL, Flags);
2780 AddToWorklist(N: Inner.getNode());
2781 return DAG.getMemBasePlusOffset(Base: Inner, Offset: N0.getOperand(i: 1), DL, Flags);
2782 }
2783 }
2784 }
2785
2786 if (N1.getOpcode() == ISD::ADD && N1.hasOneUse()) {
2787 // (ptradd x, (add y, z)) -> (ptradd (ptradd x, y), z) if z is a constant,
2788 // y is not, and (add y, z) is used only once.
2789 // (ptradd x, (add y, z)) -> (ptradd (ptradd x, z), y) if y is a constant,
2790 // z is not, and (add y, z) is used only once.
2791 // The goal is to move constant offsets to the outermost ptradd, to create
2792 // more opportunities to fold offsets into memory instructions.
2793 // Together with the another combine above, this also implements
2794 // (ptradd (ptradd x, y), z) -> (ptradd (ptradd x, z), y)).
2795 SDValue X = N0;
2796 SDValue Y = N1.getOperand(i: 0);
2797 SDValue Z = N1.getOperand(i: 1);
2798 bool YIsConstant = DAG.isConstantIntBuildVectorOrConstantInt(N: Y);
2799 bool ZIsConstant = DAG.isConstantIntBuildVectorOrConstantInt(N: Z);
2800
2801 // If both additions in the original were NUW, reassociation preserves that.
2802 SDNodeFlags CommonFlags = N->getFlags() & N1->getFlags();
2803 SDNodeFlags ReassocFlags = CommonFlags & SDNodeFlags::NoUnsignedWrap;
2804 if (CommonFlags.hasNoUnsignedWrap()) {
2805 // If both operations are NUW and the PTRADD is inbounds, the offests are
2806 // both non-negative, so the reassociated PTRADDs are also inbounds.
2807 ReassocFlags |= N->getFlags() & SDNodeFlags::InBounds;
2808 }
2809
2810 if (ZIsConstant != YIsConstant) {
2811 if (YIsConstant)
2812 std::swap(a&: Y, b&: Z);
2813 SDValue Inner = DAG.getMemBasePlusOffset(Base: X, Offset: Y, DL, Flags: ReassocFlags);
2814 AddToWorklist(N: Inner.getNode());
2815 return DAG.getMemBasePlusOffset(Base: Inner, Offset: Z, DL, Flags: ReassocFlags);
2816 }
2817 }
2818
2819 // Transform (ptradd a, b) -> (or disjoint a, b) if it is equivalent and if
2820 // that transformation can't block an offset folding at any use of the ptradd.
2821 // This should be done late, after legalization, so that it doesn't block
2822 // other ptradd combines that could enable more offset folding.
2823 if (LegalOperations && DAG.haveNoCommonBitsSet(A: N0, B: N1)) {
2824 bool TransformCannotBreakAddrMode = none_of(Range: N->users(), P: [&](SDNode *User) {
2825 return canFoldInAddressingMode(N, Use: User, DAG, TLI);
2826 });
2827
2828 if (TransformCannotBreakAddrMode)
2829 return DAG.getNode(Opcode: ISD::OR, DL, VT: PtrVT, N1: N0, N2: N1, Flags: SDNodeFlags::Disjoint);
2830 }
2831
2832 return SDValue();
2833}
2834
2835/// Try to fold a 'not' shifted sign-bit with add/sub with constant operand into
2836/// a shift and add with a different constant.
2837static SDValue foldAddSubOfSignBit(SDNode *N, const SDLoc &DL,
2838 SelectionDAG &DAG) {
2839 assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2840 "Expecting add or sub");
2841
2842 // We need a constant operand for the add/sub, and the other operand is a
2843 // logical shift right: add (srl), C or sub C, (srl).
2844 bool IsAdd = N->getOpcode() == ISD::ADD;
2845 SDValue ConstantOp = IsAdd ? N->getOperand(Num: 1) : N->getOperand(Num: 0);
2846 SDValue ShiftOp = IsAdd ? N->getOperand(Num: 0) : N->getOperand(Num: 1);
2847 if (!DAG.isConstantIntBuildVectorOrConstantInt(N: ConstantOp) ||
2848 ShiftOp.getOpcode() != ISD::SRL)
2849 return SDValue();
2850
2851 // The shift must be of a 'not' value.
2852 SDValue Not = ShiftOp.getOperand(i: 0);
2853 if (!Not.hasOneUse() || !isBitwiseNot(V: Not))
2854 return SDValue();
2855
2856 // The shift must be moving the sign bit to the least-significant-bit.
2857 EVT VT = ShiftOp.getValueType();
2858 SDValue ShAmt = ShiftOp.getOperand(i: 1);
2859 ConstantSDNode *ShAmtC = isConstOrConstSplat(N: ShAmt);
2860 if (!ShAmtC || ShAmtC->getAPIntValue() != (VT.getScalarSizeInBits() - 1))
2861 return SDValue();
2862
2863 // Eliminate the 'not' by adjusting the shift and add/sub constant:
2864 // add (srl (not X), 31), C --> add (sra X, 31), (C + 1)
2865 // sub C, (srl (not X), 31) --> add (srl X, 31), (C - 1)
2866 if (SDValue NewC = DAG.FoldConstantArithmetic(
2867 Opcode: IsAdd ? ISD::ADD : ISD::SUB, DL, VT,
2868 Ops: {ConstantOp, DAG.getConstant(Val: 1, DL, VT)})) {
2869 SDValue NewShift = DAG.getNode(Opcode: IsAdd ? ISD::SRA : ISD::SRL, DL, VT,
2870 N1: Not.getOperand(i: 0), N2: ShAmt);
2871 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: NewShift, N2: NewC);
2872 }
2873
2874 return SDValue();
2875}
2876
2877static bool
2878areBitwiseNotOfEachother(SDValue Op0, SDValue Op1) {
2879 return (isBitwiseNot(V: Op0) && Op0.getOperand(i: 0) == Op1) ||
2880 (isBitwiseNot(V: Op1) && Op1.getOperand(i: 0) == Op0);
2881}
2882
2883/// Try to fold a node that behaves like an ADD (note that N isn't necessarily
2884/// an ISD::ADD here, it could for example be an ISD::OR if we know that there
2885/// are no common bits set in the operands).
2886SDValue DAGCombiner::visitADDLike(SDNode *N) {
2887 SDValue N0 = N->getOperand(Num: 0);
2888 SDValue N1 = N->getOperand(Num: 1);
2889 EVT VT = N0.getValueType();
2890 SDLoc DL(N);
2891
2892 // fold (add x, undef) -> undef
2893 if (N0.isUndef())
2894 return N0;
2895 if (N1.isUndef())
2896 return N1;
2897
2898 // fold (add c1, c2) -> c1+c2
2899 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::ADD, DL, VT, Ops: {N0, N1}))
2900 return C;
2901
2902 // canonicalize constant to RHS
2903 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
2904 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
2905 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1, N2: N0);
2906
2907 if (areBitwiseNotOfEachother(Op0: N0, Op1: N1))
2908 return DAG.getConstant(Val: APInt::getAllOnes(numBits: VT.getScalarSizeInBits()), DL, VT);
2909
2910 // fold vector ops
2911 if (VT.isVector()) {
2912 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
2913 return FoldedVOp;
2914
2915 // fold (add x, 0) -> x, vector edition
2916 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
2917 return N0;
2918 }
2919
2920 // fold (add x, 0) -> x
2921 if (isNullConstant(V: N1))
2922 return N0;
2923
2924 if (N0.getOpcode() == ISD::SUB) {
2925 SDValue N00 = N0.getOperand(i: 0);
2926 SDValue N01 = N0.getOperand(i: 1);
2927
2928 // fold ((A-c1)+c2) -> (A+(c2-c1))
2929 if (SDValue Sub = DAG.FoldConstantArithmetic(Opcode: ISD::SUB, DL, VT, Ops: {N1, N01}))
2930 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 0), N2: Sub);
2931
2932 // fold ((c1-A)+c2) -> (c1+c2)-A
2933 if (SDValue Add = DAG.FoldConstantArithmetic(Opcode: ISD::ADD, DL, VT, Ops: {N1, N00}))
2934 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Add, N2: N0.getOperand(i: 1));
2935 }
2936
2937 // add (sext i1 X), 1 -> zext (not i1 X)
2938 // We don't transform this pattern:
2939 // add (zext i1 X), -1 -> sext (not i1 X)
2940 // because most (?) targets generate better code for the zext form.
2941 if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
2942 isOneOrOneSplat(V: N1)) {
2943 SDValue X = N0.getOperand(i: 0);
2944 if ((!LegalOperations ||
2945 (TLI.isOperationLegal(Op: ISD::XOR, VT: X.getValueType()) &&
2946 TLI.isOperationLegal(Op: ISD::ZERO_EXTEND, VT))) &&
2947 X.getScalarValueSizeInBits() == 1) {
2948 SDValue Not = DAG.getNOT(DL, Val: X, VT: X.getValueType());
2949 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: Not);
2950 }
2951 }
2952
2953 // Fold (add (or x, c0), c1) -> (add x, (c0 + c1))
2954 // iff (or x, c0) is equivalent to (add x, c0).
2955 // Fold (add (xor x, c0), c1) -> (add x, (c0 + c1))
2956 // iff (xor x, c0) is equivalent to (add x, c0).
2957 if (DAG.isADDLike(Op: N0)) {
2958 SDValue N01 = N0.getOperand(i: 1);
2959 if (SDValue Add = DAG.FoldConstantArithmetic(Opcode: ISD::ADD, DL, VT, Ops: {N1, N01}))
2960 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 0), N2: Add);
2961 }
2962
2963 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
2964 return NewSel;
2965
2966 // reassociate add
2967 if (!reassociationCanBreakAddressingModePattern(Opc: ISD::ADD, DL, N, N0, N1)) {
2968 if (SDValue RADD = reassociateOps(Opc: ISD::ADD, DL, N0, N1, Flags: N->getFlags()))
2969 return RADD;
2970
2971 // Reassociate (add (or x, c), y) -> (add add(x, y), c)) if (or x, c) is
2972 // equivalent to (add x, c).
2973 // Reassociate (add (xor x, c), y) -> (add add(x, y), c)) if (xor x, c) is
2974 // equivalent to (add x, c).
2975 // Do this optimization only when adding c does not introduce instructions
2976 // for adding carries.
2977 auto ReassociateAddOr = [&](SDValue N0, SDValue N1) {
2978 if (DAG.isADDLike(Op: N0) && N0.hasOneUse() &&
2979 isConstantOrConstantVector(N: N0.getOperand(i: 1), /* NoOpaque */ NoOpaques: true)) {
2980 // If N0's type does not split or is a sign mask, it does not introduce
2981 // add carry.
2982 auto TyActn = TLI.getTypeAction(Context&: *DAG.getContext(), VT: N0.getValueType());
2983 bool NoAddCarry = TyActn == TargetLoweringBase::TypeLegal ||
2984 TyActn == TargetLoweringBase::TypePromoteInteger ||
2985 isMinSignedConstant(V: N0.getOperand(i: 1));
2986 if (NoAddCarry)
2987 return DAG.getNode(
2988 Opcode: ISD::ADD, DL, VT,
2989 N1: DAG.getNode(Opcode: ISD::ADD, DL, VT, N1, N2: N0.getOperand(i: 0)),
2990 N2: N0.getOperand(i: 1));
2991 }
2992 return SDValue();
2993 };
2994 if (SDValue Add = ReassociateAddOr(N0, N1))
2995 return Add;
2996 if (SDValue Add = ReassociateAddOr(N1, N0))
2997 return Add;
2998
2999 // Fold add(vecreduce(x), vecreduce(y)) -> vecreduce(add(x, y))
3000 if (SDValue SD =
3001 reassociateReduction(RedOpc: ISD::VECREDUCE_ADD, Opc: ISD::ADD, DL, VT, N0, N1))
3002 return SD;
3003 }
3004
3005 SDValue A, B, C, D;
3006
3007 // fold ((0-A) + B) -> B-A
3008 if (sd_match(N: N0, P: m_Neg(V: m_Value(N&: A))))
3009 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1, N2: A);
3010
3011 // fold (A + (0-B)) -> A-B
3012 if (sd_match(N: N1, P: m_Neg(V: m_Value(N&: B))))
3013 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: B);
3014
3015 // fold (A+(B-A)) -> B
3016 if (sd_match(N: N1, P: m_Sub(L: m_Value(N&: B), R: m_Specific(N: N0))))
3017 return B;
3018
3019 // fold ((B-A)+A) -> B
3020 if (sd_match(N: N0, P: m_Sub(L: m_Value(N&: B), R: m_Specific(N: N1))))
3021 return B;
3022
3023 // fold ((A-B)+(C-A)) -> (C-B)
3024 if (sd_match(N: N0, P: m_Sub(L: m_Value(N&: A), R: m_Value(N&: B))) &&
3025 sd_match(N: N1, P: m_Sub(L: m_Value(N&: C), R: m_Specific(N: A))))
3026 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: C, N2: B);
3027
3028 // fold ((A-B)+(B-C)) -> (A-C)
3029 if (sd_match(N: N0, P: m_Sub(L: m_Value(N&: A), R: m_Value(N&: B))) &&
3030 sd_match(N: N1, P: m_Sub(L: m_Specific(N: B), R: m_Value(N&: C))))
3031 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: A, N2: C);
3032
3033 // fold (A+(B-(A+C))) to (B-C)
3034 // fold (A+(B-(C+A))) to (B-C)
3035 if (sd_match(N: N1, P: m_Sub(L: m_Value(N&: B), R: m_Add(L: m_Specific(N: N0), R: m_Value(N&: C)))))
3036 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: B, N2: C);
3037
3038 // fold (A+((B-A)+or-C)) to (B+or-C)
3039 if (sd_match(N: N1,
3040 P: m_AnyOf(preds: m_Add(L: m_Sub(L: m_Value(N&: B), R: m_Specific(N: N0)), R: m_Value(N&: C)),
3041 preds: m_Sub(L: m_Sub(L: m_Value(N&: B), R: m_Specific(N: N0)), R: m_Value(N&: C)))))
3042 return DAG.getNode(Opcode: N1.getOpcode(), DL, VT, N1: B, N2: C);
3043
3044 // fold (A-B)+(C-D) to (A+C)-(B+D) when A or C is constant
3045 if (sd_match(N: N0, P: m_OneUse(P: m_Sub(L: m_Value(N&: A), R: m_Value(N&: B)))) &&
3046 sd_match(N: N1, P: m_OneUse(P: m_Sub(L: m_Value(N&: C), R: m_Value(N&: D)))) &&
3047 (isConstantOrConstantVector(N: A) || isConstantOrConstantVector(N: C)))
3048 return DAG.getNode(Opcode: ISD::SUB, DL, VT,
3049 N1: DAG.getNode(Opcode: ISD::ADD, DL: SDLoc(N0), VT, N1: A, N2: C),
3050 N2: DAG.getNode(Opcode: ISD::ADD, DL: SDLoc(N1), VT, N1: B, N2: D));
3051
3052 // fold (add (umax X, C), -C) --> (usubsat X, C)
3053 if (N0.getOpcode() == ISD::UMAX && hasOperation(Opcode: ISD::USUBSAT, VT)) {
3054 auto MatchUSUBSAT = [](ConstantSDNode *Max, ConstantSDNode *Op) {
3055 return (!Max && !Op) ||
3056 (Max && Op && Max->getAPIntValue() == (-Op->getAPIntValue()));
3057 };
3058 if (ISD::matchBinaryPredicate(LHS: N0.getOperand(i: 1), RHS: N1, Match: MatchUSUBSAT,
3059 /*AllowUndefs*/ true))
3060 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT, N1: N0.getOperand(i: 0),
3061 N2: N0.getOperand(i: 1));
3062 }
3063
3064 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
3065 return SDValue(N, 0);
3066
3067 if (isOneOrOneSplat(V: N1)) {
3068 // fold (add (xor a, -1), 1) -> (sub 0, a)
3069 if (isBitwiseNot(V: N0))
3070 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: DAG.getConstant(Val: 0, DL, VT),
3071 N2: N0.getOperand(i: 0));
3072
3073 // fold (add (add (xor a, -1), b), 1) -> (sub b, a)
3074 if (N0.getOpcode() == ISD::ADD) {
3075 SDValue A, Xor;
3076
3077 if (isBitwiseNot(V: N0.getOperand(i: 0))) {
3078 A = N0.getOperand(i: 1);
3079 Xor = N0.getOperand(i: 0);
3080 } else if (isBitwiseNot(V: N0.getOperand(i: 1))) {
3081 A = N0.getOperand(i: 0);
3082 Xor = N0.getOperand(i: 1);
3083 }
3084
3085 if (Xor)
3086 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: A, N2: Xor.getOperand(i: 0));
3087 }
3088
3089 // Look for:
3090 // add (add x, y), 1
3091 // And if the target does not like this form then turn into:
3092 // sub y, (xor x, -1)
3093 if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.getOpcode() == ISD::ADD &&
3094 N0.hasOneUse() &&
3095 // Limit this to after legalization if the add has wrap flags
3096 (Level >= AfterLegalizeDAG || (!N->getFlags().hasNoUnsignedWrap() &&
3097 !N->getFlags().hasNoSignedWrap()))) {
3098 SDValue Not = DAG.getNOT(DL, Val: N0.getOperand(i: 0), VT);
3099 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0.getOperand(i: 1), N2: Not);
3100 }
3101 }
3102
3103 // (x - y) + -1 -> add (xor y, -1), x
3104 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
3105 isAllOnesOrAllOnesSplat(V: N1, /*AllowUndefs=*/true)) {
3106 SDValue Not = DAG.getNOT(DL, Val: N0.getOperand(i: 1), VT);
3107 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Not, N2: N0.getOperand(i: 0));
3108 }
3109
3110 // Fold add(mul(add(A, CA), CM), CB) -> add(mul(A, CM), CM*CA+CB).
3111 // This can help if the inner add has multiple uses.
3112 APInt CM, CA;
3113 if (ConstantSDNode *CB = dyn_cast<ConstantSDNode>(Val&: N1)) {
3114 if (VT.getScalarSizeInBits() <= 64) {
3115 if (sd_match(N: N0, P: m_OneUse(P: m_Mul(L: m_Add(L: m_Value(N&: A), R: m_ConstInt(V&: CA)),
3116 R: m_ConstInt(V&: CM)))) &&
3117 TLI.isLegalAddImmediate(
3118 (CA * CM + CB->getAPIntValue()).getSExtValue())) {
3119 SDNodeFlags Flags;
3120 // If all the inputs are nuw, the outputs can be nuw. If all the input
3121 // are _also_ nsw the outputs can be too.
3122 if (N->getFlags().hasNoUnsignedWrap() &&
3123 N0->getFlags().hasNoUnsignedWrap() &&
3124 N0.getOperand(i: 0)->getFlags().hasNoUnsignedWrap()) {
3125 Flags |= SDNodeFlags::NoUnsignedWrap;
3126 if (N->getFlags().hasNoSignedWrap() &&
3127 N0->getFlags().hasNoSignedWrap() &&
3128 N0.getOperand(i: 0)->getFlags().hasNoSignedWrap())
3129 Flags |= SDNodeFlags::NoSignedWrap;
3130 }
3131 SDValue Mul = DAG.getNode(Opcode: ISD::MUL, DL: SDLoc(N1), VT, N1: A,
3132 N2: DAG.getConstant(Val: CM, DL, VT), Flags);
3133 return DAG.getNode(
3134 Opcode: ISD::ADD, DL, VT, N1: Mul,
3135 N2: DAG.getConstant(Val: CA * CM + CB->getAPIntValue(), DL, VT), Flags);
3136 }
3137 // Also look in case there is an intermediate add.
3138 if (sd_match(N: N0, P: m_OneUse(P: m_Add(
3139 L: m_OneUse(P: m_Mul(L: m_Add(L: m_Value(N&: A), R: m_ConstInt(V&: CA)),
3140 R: m_ConstInt(V&: CM))),
3141 R: m_Value(N&: B)))) &&
3142 TLI.isLegalAddImmediate(
3143 (CA * CM + CB->getAPIntValue()).getSExtValue())) {
3144 SDNodeFlags Flags;
3145 // If all the inputs are nuw, the outputs can be nuw. If all the input
3146 // are _also_ nsw the outputs can be too.
3147 SDValue OMul =
3148 N0.getOperand(i: 0) == B ? N0.getOperand(i: 1) : N0.getOperand(i: 0);
3149 if (N->getFlags().hasNoUnsignedWrap() &&
3150 N0->getFlags().hasNoUnsignedWrap() &&
3151 OMul->getFlags().hasNoUnsignedWrap() &&
3152 OMul.getOperand(i: 0)->getFlags().hasNoUnsignedWrap()) {
3153 Flags |= SDNodeFlags::NoUnsignedWrap;
3154 if (N->getFlags().hasNoSignedWrap() &&
3155 N0->getFlags().hasNoSignedWrap() &&
3156 OMul->getFlags().hasNoSignedWrap() &&
3157 OMul.getOperand(i: 0)->getFlags().hasNoSignedWrap())
3158 Flags |= SDNodeFlags::NoSignedWrap;
3159 }
3160 SDValue Mul = DAG.getNode(Opcode: ISD::MUL, DL: SDLoc(N1), VT, N1: A,
3161 N2: DAG.getConstant(Val: CM, DL, VT), Flags);
3162 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL: SDLoc(N1), VT, N1: Mul, N2: B, Flags);
3163 return DAG.getNode(
3164 Opcode: ISD::ADD, DL, VT, N1: Add,
3165 N2: DAG.getConstant(Val: CA * CM + CB->getAPIntValue(), DL, VT), Flags);
3166 }
3167 }
3168 }
3169
3170 if (SDValue Combined = visitADDLikeCommutative(N0, N1, LocReference: N))
3171 return Combined;
3172
3173 if (SDValue Combined = visitADDLikeCommutative(N0: N1, N1: N0, LocReference: N))
3174 return Combined;
3175
3176 return SDValue();
3177}
3178
3179// Attempt to form avgfloor(A, B) from (A & B) + ((A ^ B) >> 1)
3180// Attempt to form avgfloor(A, B) from ((A >> 1) + (B >> 1)) + (A & B & 1)
3181// Attempt to form avgceil(A, B) from ((A >> 1) + (B >> 1)) + ((A | B) & 1)
3182SDValue DAGCombiner::foldAddToAvg(SDNode *N, const SDLoc &DL) {
3183 SDValue N0 = N->getOperand(Num: 0);
3184 EVT VT = N0.getValueType();
3185 SDValue A, B;
3186
3187 if ((!LegalOperations || hasOperation(Opcode: ISD::AVGFLOORU, VT)) &&
3188 (sd_match(N,
3189 P: m_Add(L: m_And(L: m_Value(N&: A), R: m_Value(N&: B)),
3190 R: m_Srl(L: m_Xor(L: m_Deferred(V&: A), R: m_Deferred(V&: B)), R: m_One()))) ||
3191 sd_match(N, P: m_ReassociatableAdd(
3192 Patterns: m_ReassociatableAnd(Patterns: m_Value(N&: A), Patterns: m_Value(N&: B), Patterns: m_One()),
3193 Patterns: m_Srl(L: m_Deferred(V&: A), R: m_One()),
3194 Patterns: m_Srl(L: m_Deferred(V&: B), R: m_One()))))) {
3195 return DAG.getNode(Opcode: ISD::AVGFLOORU, DL, VT, N1: A, N2: B);
3196 }
3197 if ((!LegalOperations || hasOperation(Opcode: ISD::AVGFLOORS, VT)) &&
3198 (sd_match(N,
3199 P: m_Add(L: m_And(L: m_Value(N&: A), R: m_Value(N&: B)),
3200 R: m_Sra(L: m_Xor(L: m_Deferred(V&: A), R: m_Deferred(V&: B)), R: m_One()))) ||
3201 sd_match(N, P: m_ReassociatableAdd(
3202 Patterns: m_ReassociatableAnd(Patterns: m_Value(N&: A), Patterns: m_Value(N&: B), Patterns: m_One()),
3203 Patterns: m_Sra(L: m_Deferred(V&: A), R: m_One()),
3204 Patterns: m_Sra(L: m_Deferred(V&: B), R: m_One()))))) {
3205 return DAG.getNode(Opcode: ISD::AVGFLOORS, DL, VT, N1: A, N2: B);
3206 }
3207
3208 if ((!LegalOperations || hasOperation(Opcode: ISD::AVGCEILU, VT)) &&
3209 sd_match(N,
3210 P: m_ReassociatableAdd(Patterns: m_And(L: m_Or(L: m_Value(N&: A), R: m_Value(N&: B)), R: m_One()),
3211 Patterns: m_Srl(L: m_Deferred(V&: A), R: m_One()),
3212 Patterns: m_Srl(L: m_Deferred(V&: B), R: m_One())))) {
3213 return DAG.getNode(Opcode: ISD::AVGCEILU, DL, VT, N1: A, N2: B);
3214 }
3215 if ((!LegalOperations || hasOperation(Opcode: ISD::AVGCEILS, VT)) &&
3216 sd_match(N,
3217 P: m_ReassociatableAdd(Patterns: m_And(L: m_Or(L: m_Value(N&: A), R: m_Value(N&: B)), R: m_One()),
3218 Patterns: m_Sra(L: m_Deferred(V&: A), R: m_One()),
3219 Patterns: m_Sra(L: m_Deferred(V&: B), R: m_One())))) {
3220 return DAG.getNode(Opcode: ISD::AVGCEILS, DL, VT, N1: A, N2: B);
3221 }
3222
3223 return SDValue();
3224}
3225
3226SDValue DAGCombiner::visitADD(SDNode *N) {
3227 SDValue N0 = N->getOperand(Num: 0);
3228 SDValue N1 = N->getOperand(Num: 1);
3229 EVT VT = N0.getValueType();
3230 SDLoc DL(N);
3231
3232 if (SDValue Combined = visitADDLike(N))
3233 return Combined;
3234
3235 if (SDValue V = foldAddSubBoolOfMaskedVal(N, DL, DAG))
3236 return V;
3237
3238 if (SDValue V = foldAddSubOfSignBit(N, DL, DAG))
3239 return V;
3240
3241 if (SDValue V = MatchRotate(LHS: N0, RHS: N1, DL: SDLoc(N), /*FromAdd=*/true))
3242 return V;
3243
3244 // Try to match AVGFLOOR fixedwidth pattern
3245 if (SDValue V = foldAddToAvg(N, DL))
3246 return V;
3247
3248 // fold (a+b) -> (a|b) iff a and b share no bits.
3249 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::OR, VT)) &&
3250 DAG.haveNoCommonBitsSet(A: N0, B: N1))
3251 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: N0, N2: N1, Flags: SDNodeFlags::Disjoint);
3252
3253 // Fold (add (vscale * C0), (vscale * C1)) to (vscale * (C0 + C1)).
3254 if (N0.getOpcode() == ISD::VSCALE && N1.getOpcode() == ISD::VSCALE) {
3255 const APInt &C0 = N0->getConstantOperandAPInt(Num: 0);
3256 const APInt &C1 = N1->getConstantOperandAPInt(Num: 0);
3257 return DAG.getVScale(DL, VT, MulImm: C0 + C1);
3258 }
3259
3260 // fold a+vscale(c1)+vscale(c2) -> a+vscale(c1+c2)
3261 if (N0.getOpcode() == ISD::ADD &&
3262 N0.getOperand(i: 1).getOpcode() == ISD::VSCALE &&
3263 N1.getOpcode() == ISD::VSCALE) {
3264 const APInt &VS0 = N0.getOperand(i: 1)->getConstantOperandAPInt(Num: 0);
3265 const APInt &VS1 = N1->getConstantOperandAPInt(Num: 0);
3266 SDValue VS = DAG.getVScale(DL, VT, MulImm: VS0 + VS1);
3267 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 0), N2: VS);
3268 }
3269
3270 // Fold (add step_vector(c1), step_vector(c2) to step_vector(c1+c2))
3271 if (N0.getOpcode() == ISD::STEP_VECTOR &&
3272 N1.getOpcode() == ISD::STEP_VECTOR) {
3273 const APInt &C0 = N0->getConstantOperandAPInt(Num: 0);
3274 const APInt &C1 = N1->getConstantOperandAPInt(Num: 0);
3275 APInt NewStep = C0 + C1;
3276 return DAG.getStepVector(DL, ResVT: VT, StepVal: NewStep);
3277 }
3278
3279 // Fold a + step_vector(c1) + step_vector(c2) to a + step_vector(c1+c2)
3280 if (N0.getOpcode() == ISD::ADD &&
3281 N0.getOperand(i: 1).getOpcode() == ISD::STEP_VECTOR &&
3282 N1.getOpcode() == ISD::STEP_VECTOR) {
3283 const APInt &SV0 = N0.getOperand(i: 1)->getConstantOperandAPInt(Num: 0);
3284 const APInt &SV1 = N1->getConstantOperandAPInt(Num: 0);
3285 APInt NewStep = SV0 + SV1;
3286 SDValue SV = DAG.getStepVector(DL, ResVT: VT, StepVal: NewStep);
3287 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 0), N2: SV);
3288 }
3289
3290 return SDValue();
3291}
3292
3293SDValue DAGCombiner::visitADDSAT(SDNode *N) {
3294 unsigned Opcode = N->getOpcode();
3295 SDValue N0 = N->getOperand(Num: 0);
3296 SDValue N1 = N->getOperand(Num: 1);
3297 EVT VT = N0.getValueType();
3298 bool IsSigned = Opcode == ISD::SADDSAT;
3299 SDLoc DL(N);
3300
3301 // fold (add_sat x, undef) -> -1
3302 if (N0.isUndef() || N1.isUndef())
3303 return DAG.getAllOnesConstant(DL, VT);
3304
3305 // fold (add_sat c1, c2) -> c3
3306 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, Ops: {N0, N1}))
3307 return C;
3308
3309 // canonicalize constant to RHS
3310 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
3311 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
3312 return DAG.getNode(Opcode, DL, VT, N1, N2: N0);
3313
3314 // fold vector ops
3315 if (VT.isVector()) {
3316 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
3317 return FoldedVOp;
3318
3319 // fold (add_sat x, 0) -> x, vector edition
3320 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
3321 return N0;
3322 }
3323
3324 // fold (add_sat x, 0) -> x
3325 if (isNullConstant(V: N1))
3326 return N0;
3327
3328 // If it cannot overflow, transform into an add.
3329 if (DAG.willNotOverflowAdd(IsSigned, N0, N1))
3330 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1);
3331
3332 return SDValue();
3333}
3334
3335static SDValue getAsCarry(const TargetLowering &TLI, SDValue V,
3336 bool ForceCarryReconstruction = false) {
3337 bool Masked = false;
3338
3339 // First, peel away TRUNCATE/ZERO_EXTEND/AND nodes due to legalization.
3340 while (true) {
3341 if (ForceCarryReconstruction && V.getValueType() == MVT::i1)
3342 return V;
3343
3344 if (V.getOpcode() == ISD::TRUNCATE || V.getOpcode() == ISD::ZERO_EXTEND) {
3345 V = V.getOperand(i: 0);
3346 continue;
3347 }
3348
3349 if (V.getOpcode() == ISD::AND && isOneConstant(V: V.getOperand(i: 1))) {
3350 if (ForceCarryReconstruction)
3351 return V;
3352
3353 Masked = true;
3354 V = V.getOperand(i: 0);
3355 continue;
3356 }
3357
3358 break;
3359 }
3360
3361 // If this is not a carry, return.
3362 if (V.getResNo() != 1)
3363 return SDValue();
3364
3365 if (V.getOpcode() != ISD::UADDO_CARRY && V.getOpcode() != ISD::USUBO_CARRY &&
3366 V.getOpcode() != ISD::UADDO && V.getOpcode() != ISD::USUBO)
3367 return SDValue();
3368
3369 EVT VT = V->getValueType(ResNo: 0);
3370 if (!TLI.isOperationLegalOrCustom(Op: V.getOpcode(), VT))
3371 return SDValue();
3372
3373 // If the result is masked, then no matter what kind of bool it is we can
3374 // return. If it isn't, then we need to make sure the bool type is either 0 or
3375 // 1 and not other values.
3376 if (Masked ||
3377 TLI.getBooleanContents(Type: V.getValueType()) ==
3378 TargetLoweringBase::ZeroOrOneBooleanContent)
3379 return V;
3380
3381 return SDValue();
3382}
3383
3384/// Given the operands of an add/sub operation, see if the 2nd operand is a
3385/// masked 0/1 whose source operand is actually known to be 0/-1. If so, invert
3386/// the opcode and bypass the mask operation.
3387static SDValue foldAddSubMasked1(bool IsAdd, SDValue N0, SDValue N1,
3388 SelectionDAG &DAG, const SDLoc &DL) {
3389 if (N1.getOpcode() == ISD::ZERO_EXTEND)
3390 N1 = N1.getOperand(i: 0);
3391
3392 if (N1.getOpcode() != ISD::AND || !isOneOrOneSplat(V: N1->getOperand(Num: 1)))
3393 return SDValue();
3394
3395 EVT VT = N0.getValueType();
3396 SDValue N10 = N1.getOperand(i: 0);
3397 if (N10.getValueType() != VT && N10.getOpcode() == ISD::TRUNCATE)
3398 N10 = N10.getOperand(i: 0);
3399
3400 if (N10.getValueType() != VT)
3401 return SDValue();
3402
3403 if (DAG.ComputeNumSignBits(Op: N10) != VT.getScalarSizeInBits())
3404 return SDValue();
3405
3406 // add N0, (and (AssertSext X, i1), 1) --> sub N0, X
3407 // sub N0, (and (AssertSext X, i1), 1) --> add N0, X
3408 return DAG.getNode(Opcode: IsAdd ? ISD::SUB : ISD::ADD, DL, VT, N1: N0, N2: N10);
3409}
3410
3411/// Helper for doing combines based on N0 and N1 being added to each other.
3412SDValue DAGCombiner::visitADDLikeCommutative(SDValue N0, SDValue N1,
3413 SDNode *LocReference) {
3414 EVT VT = N0.getValueType();
3415 SDLoc DL(LocReference);
3416
3417 // fold (add x, shl(0 - y, n)) -> sub(x, shl(y, n))
3418 SDValue Y, N;
3419 if (sd_match(N: N1, P: m_Shl(L: m_Neg(V: m_Value(N&: Y)), R: m_Value(N))))
3420 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0,
3421 N2: DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Y, N2: N));
3422
3423 if (SDValue V = foldAddSubMasked1(IsAdd: true, N0, N1, DAG, DL))
3424 return V;
3425
3426 // Look for:
3427 // add (add x, 1), y
3428 // And if the target does not like this form then turn into:
3429 // sub y, (xor x, -1)
3430 if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.getOpcode() == ISD::ADD &&
3431 N0.hasOneUse() && isOneOrOneSplat(V: N0.getOperand(i: 1)) &&
3432 // Limit this to after legalization if the add has wrap flags
3433 (Level >= AfterLegalizeDAG || (!N0->getFlags().hasNoUnsignedWrap() &&
3434 !N0->getFlags().hasNoSignedWrap()))) {
3435 SDValue Not = DAG.getNOT(DL, Val: N0.getOperand(i: 0), VT);
3436 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1, N2: Not);
3437 }
3438
3439 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse()) {
3440 // Hoist one-use subtraction by non-opaque constant:
3441 // (x - C) + y -> (x + y) - C
3442 // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
3443 if (isConstantOrConstantVector(N: N0.getOperand(i: 1), /*NoOpaques=*/true)) {
3444 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 0), N2: N1);
3445 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Add, N2: N0.getOperand(i: 1));
3446 }
3447 // Hoist one-use subtraction from non-opaque constant:
3448 // (C - x) + y -> (y - x) + C
3449 if (isConstantOrConstantVector(N: N0.getOperand(i: 0), /*NoOpaques=*/true)) {
3450 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1, N2: N0.getOperand(i: 1));
3451 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Sub, N2: N0.getOperand(i: 0));
3452 }
3453 }
3454
3455 // add (mul x, C), x -> mul x, C+1
3456 if (N0.getOpcode() == ISD::MUL && N0.getOperand(i: 0) == N1 &&
3457 isConstantOrConstantVector(N: N0.getOperand(i: 1), /*NoOpaques=*/true) &&
3458 N0.hasOneUse()) {
3459 SDValue NewC = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 1),
3460 N2: DAG.getConstant(Val: 1, DL, VT));
3461 return DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: N0.getOperand(i: 0), N2: NewC);
3462 }
3463
3464 // If the target's bool is represented as 0/1, prefer to make this 'sub 0/1'
3465 // rather than 'add 0/-1' (the zext should get folded).
3466 // add (sext i1 Y), X --> sub X, (zext i1 Y)
3467 if (N0.getOpcode() == ISD::SIGN_EXTEND &&
3468 N0.getOperand(i: 0).getScalarValueSizeInBits() == 1 &&
3469 TLI.getBooleanContents(Type: VT) == TargetLowering::ZeroOrOneBooleanContent) {
3470 SDValue ZExt = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: N0.getOperand(i: 0));
3471 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1, N2: ZExt);
3472 }
3473
3474 // add X, (sextinreg Y i1) -> sub X, (and Y 1)
3475 if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
3476 VTSDNode *TN = cast<VTSDNode>(Val: N1.getOperand(i: 1));
3477 if (TN->getVT() == MVT::i1) {
3478 SDValue ZExt = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N1.getOperand(i: 0),
3479 N2: DAG.getConstant(Val: 1, DL, VT));
3480 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: ZExt);
3481 }
3482 }
3483
3484 // (add X, (uaddo_carry Y, 0, Carry)) -> (uaddo_carry X, Y, Carry)
3485 if (N1.getOpcode() == ISD::UADDO_CARRY && isNullConstant(V: N1.getOperand(i: 1)) &&
3486 N1.getResNo() == 0)
3487 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL, VTList: N1->getVTList(),
3488 N1: N0, N2: N1.getOperand(i: 0), N3: N1.getOperand(i: 2));
3489
3490 // (add X, Carry) -> (uaddo_carry X, 0, Carry)
3491 if (TLI.isOperationLegalOrCustom(Op: ISD::UADDO_CARRY, VT))
3492 if (SDValue Carry = getAsCarry(TLI, V: N1))
3493 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL,
3494 VTList: DAG.getVTList(VT1: VT, VT2: Carry.getValueType()), N1: N0,
3495 N2: DAG.getConstant(Val: 0, DL, VT), N3: Carry);
3496
3497 return SDValue();
3498}
3499
3500SDValue DAGCombiner::visitADDC(SDNode *N) {
3501 SDValue N0 = N->getOperand(Num: 0);
3502 SDValue N1 = N->getOperand(Num: 1);
3503 EVT VT = N0.getValueType();
3504 SDLoc DL(N);
3505
3506 // If the flag result is dead, turn this into an ADD.
3507 if (!N->hasAnyUseOfValue(Value: 1))
3508 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1),
3509 Res1: DAG.getNode(Opcode: ISD::CARRY_FALSE, DL, VT: MVT::Glue));
3510
3511 // canonicalize constant to RHS.
3512 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(Val&: N0);
3513 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(Val&: N1);
3514 if (N0C && !N1C)
3515 return DAG.getNode(Opcode: ISD::ADDC, DL, VTList: N->getVTList(), N1, N2: N0);
3516
3517 // fold (addc x, 0) -> x + no carry out
3518 if (isNullConstant(V: N1))
3519 return CombineTo(N, Res0: N0, Res1: DAG.getNode(Opcode: ISD::CARRY_FALSE,
3520 DL, VT: MVT::Glue));
3521
3522 // If it cannot overflow, transform into an add.
3523 if (DAG.computeOverflowForUnsignedAdd(N0, N1) == SelectionDAG::OFK_Never)
3524 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1),
3525 Res1: DAG.getNode(Opcode: ISD::CARRY_FALSE, DL, VT: MVT::Glue));
3526
3527 return SDValue();
3528}
3529
3530/**
3531 * Flips a boolean if it is cheaper to compute. If the Force parameters is set,
3532 * then the flip also occurs if computing the inverse is the same cost.
3533 * This function returns an empty SDValue in case it cannot flip the boolean
3534 * without increasing the cost of the computation. If you want to flip a boolean
3535 * no matter what, use DAG.getLogicalNOT.
3536 */
3537static SDValue extractBooleanFlip(SDValue V, SelectionDAG &DAG,
3538 const TargetLowering &TLI,
3539 bool Force) {
3540 if (Force && isa<ConstantSDNode>(Val: V))
3541 return DAG.getLogicalNOT(DL: SDLoc(V), Val: V, VT: V.getValueType());
3542
3543 if (V.getOpcode() != ISD::XOR)
3544 return SDValue();
3545
3546 if (DAG.isBoolConstant(N: V.getOperand(i: 1)) == true)
3547 return V.getOperand(i: 0);
3548 if (Force && isConstOrConstSplat(N: V.getOperand(i: 1), AllowUndefs: false))
3549 return DAG.getLogicalNOT(DL: SDLoc(V), Val: V, VT: V.getValueType());
3550 return SDValue();
3551}
3552
3553SDValue DAGCombiner::visitADDO(SDNode *N) {
3554 SDValue N0 = N->getOperand(Num: 0);
3555 SDValue N1 = N->getOperand(Num: 1);
3556 EVT VT = N0.getValueType();
3557 bool IsSigned = (ISD::SADDO == N->getOpcode());
3558
3559 EVT CarryVT = N->getValueType(ResNo: 1);
3560 SDLoc DL(N);
3561
3562 // If the flag result is dead, turn this into an ADD.
3563 if (!N->hasAnyUseOfValue(Value: 1))
3564 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1),
3565 Res1: DAG.getUNDEF(VT: CarryVT));
3566
3567 // canonicalize constant to RHS.
3568 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
3569 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
3570 return DAG.getNode(Opcode: N->getOpcode(), DL, VTList: N->getVTList(), N1, N2: N0);
3571
3572 // fold (addo x, 0) -> x + no carry out
3573 if (isNullOrNullSplat(V: N1))
3574 return CombineTo(N, Res0: N0, Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
3575
3576 // If it cannot overflow, transform into an add.
3577 if (DAG.willNotOverflowAdd(IsSigned, N0, N1))
3578 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1),
3579 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
3580
3581 if (IsSigned) {
3582 // fold (saddo (xor a, -1), 1) -> (ssub 0, a).
3583 if (isBitwiseNot(V: N0) && isOneOrOneSplat(V: N1))
3584 return DAG.getNode(Opcode: ISD::SSUBO, DL, VTList: N->getVTList(),
3585 N1: DAG.getConstant(Val: 0, DL, VT), N2: N0.getOperand(i: 0));
3586 } else {
3587 // fold (uaddo (xor a, -1), 1) -> (usub 0, a) and flip carry.
3588 if (isBitwiseNot(V: N0) && isOneOrOneSplat(V: N1)) {
3589 SDValue Sub = DAG.getNode(Opcode: ISD::USUBO, DL, VTList: N->getVTList(),
3590 N1: DAG.getConstant(Val: 0, DL, VT), N2: N0.getOperand(i: 0));
3591 return CombineTo(
3592 N, Res0: Sub, Res1: DAG.getLogicalNOT(DL, Val: Sub.getValue(R: 1), VT: Sub->getValueType(ResNo: 1)));
3593 }
3594
3595 if (SDValue Combined = visitUADDOLike(N0, N1, N))
3596 return Combined;
3597
3598 if (SDValue Combined = visitUADDOLike(N0: N1, N1: N0, N))
3599 return Combined;
3600 }
3601
3602 return SDValue();
3603}
3604
3605SDValue DAGCombiner::visitUADDOLike(SDValue N0, SDValue N1, SDNode *N) {
3606 EVT VT = N0.getValueType();
3607 if (VT.isVector())
3608 return SDValue();
3609
3610 // (uaddo X, (uaddo_carry Y, 0, Carry)) -> (uaddo_carry X, Y, Carry)
3611 // If Y + 1 cannot overflow.
3612 if (N1.getOpcode() == ISD::UADDO_CARRY && isNullConstant(V: N1.getOperand(i: 1))) {
3613 SDValue Y = N1.getOperand(i: 0);
3614 SDValue One = DAG.getConstant(Val: 1, DL: SDLoc(N), VT: Y.getValueType());
3615 if (DAG.computeOverflowForUnsignedAdd(N0: Y, N1: One) == SelectionDAG::OFK_Never)
3616 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL: SDLoc(N), VTList: N->getVTList(), N1: N0, N2: Y,
3617 N3: N1.getOperand(i: 2));
3618 }
3619
3620 // (uaddo X, Carry) -> (uaddo_carry X, 0, Carry)
3621 if (TLI.isOperationLegalOrCustom(Op: ISD::UADDO_CARRY, VT))
3622 if (SDValue Carry = getAsCarry(TLI, V: N1))
3623 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL: SDLoc(N), VTList: N->getVTList(), N1: N0,
3624 N2: DAG.getConstant(Val: 0, DL: SDLoc(N), VT), N3: Carry);
3625
3626 return SDValue();
3627}
3628
3629SDValue DAGCombiner::visitADDE(SDNode *N) {
3630 SDValue N0 = N->getOperand(Num: 0);
3631 SDValue N1 = N->getOperand(Num: 1);
3632 SDValue CarryIn = N->getOperand(Num: 2);
3633
3634 // canonicalize constant to RHS
3635 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(Val&: N0);
3636 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(Val&: N1);
3637 if (N0C && !N1C)
3638 return DAG.getNode(Opcode: ISD::ADDE, DL: SDLoc(N), VTList: N->getVTList(),
3639 N1, N2: N0, N3: CarryIn);
3640
3641 // fold (adde x, y, false) -> (addc x, y)
3642 if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
3643 return DAG.getNode(Opcode: ISD::ADDC, DL: SDLoc(N), VTList: N->getVTList(), N1: N0, N2: N1);
3644
3645 return SDValue();
3646}
3647
3648SDValue DAGCombiner::visitUADDO_CARRY(SDNode *N) {
3649 SDValue N0 = N->getOperand(Num: 0);
3650 SDValue N1 = N->getOperand(Num: 1);
3651 SDValue CarryIn = N->getOperand(Num: 2);
3652 SDLoc DL(N);
3653
3654 // canonicalize constant to RHS
3655 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(Val&: N0);
3656 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(Val&: N1);
3657 if (N0C && !N1C)
3658 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL, VTList: N->getVTList(), N1, N2: N0, N3: CarryIn);
3659
3660 // fold (uaddo_carry x, y, false) -> (uaddo x, y)
3661 if (isNullConstant(V: CarryIn)) {
3662 if (!LegalOperations ||
3663 TLI.isOperationLegalOrCustom(Op: ISD::UADDO, VT: N->getValueType(ResNo: 0)))
3664 return DAG.getNode(Opcode: ISD::UADDO, DL, VTList: N->getVTList(), N1: N0, N2: N1);
3665 }
3666
3667 // fold (uaddo_carry 0, 0, X) -> (and (ext/trunc X), 1) and no carry.
3668 if (isNullConstant(V: N0) && isNullConstant(V: N1)) {
3669 EVT VT = N0.getValueType();
3670 EVT CarryVT = CarryIn.getValueType();
3671 SDValue CarryExt = DAG.getBoolExtOrTrunc(Op: CarryIn, SL: DL, VT, OpVT: CarryVT);
3672 AddToWorklist(N: CarryExt.getNode());
3673 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::AND, DL, VT, N1: CarryExt,
3674 N2: DAG.getConstant(Val: 1, DL, VT)),
3675 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
3676 }
3677
3678 if (SDValue Combined = visitUADDO_CARRYLike(N0, N1, CarryIn, N))
3679 return Combined;
3680
3681 if (SDValue Combined = visitUADDO_CARRYLike(N0: N1, N1: N0, CarryIn, N))
3682 return Combined;
3683
3684 // We want to avoid useless duplication.
3685 // TODO: This is done automatically for binary operations. As UADDO_CARRY is
3686 // not a binary operation, this is not really possible to leverage this
3687 // existing mechanism for it. However, if more operations require the same
3688 // deduplication logic, then it may be worth generalize.
3689 SDValue Ops[] = {N1, N0, CarryIn};
3690 SDNode *CSENode =
3691 DAG.getNodeIfExists(Opcode: ISD::UADDO_CARRY, VTList: N->getVTList(), Ops, Flags: N->getFlags());
3692 if (CSENode)
3693 return SDValue(CSENode, 0);
3694
3695 return SDValue();
3696}
3697
3698/**
3699 * If we are facing some sort of diamond carry propagation pattern try to
3700 * break it up to generate something like:
3701 * (uaddo_carry X, 0, (uaddo_carry A, B, Z):Carry)
3702 *
3703 * The end result is usually an increase in operation required, but because the
3704 * carry is now linearized, other transforms can kick in and optimize the DAG.
3705 *
3706 * Patterns typically look something like
3707 * (uaddo A, B)
3708 * / \
3709 * Carry Sum
3710 * | \
3711 * | (uaddo_carry *, 0, Z)
3712 * | /
3713 * \ Carry
3714 * | /
3715 * (uaddo_carry X, *, *)
3716 *
3717 * But numerous variation exist. Our goal is to identify A, B, X and Z and
3718 * produce a combine with a single path for carry propagation.
3719 */
3720static SDValue combineUADDO_CARRYDiamond(DAGCombiner &Combiner,
3721 SelectionDAG &DAG, SDValue X,
3722 SDValue Carry0, SDValue Carry1,
3723 SDNode *N) {
3724 if (Carry1.getResNo() != 1 || Carry0.getResNo() != 1)
3725 return SDValue();
3726 if (Carry1.getOpcode() != ISD::UADDO)
3727 return SDValue();
3728
3729 SDValue Z;
3730
3731 /**
3732 * First look for a suitable Z. It will present itself in the form of
3733 * (uaddo_carry Y, 0, Z) or its equivalent (uaddo Y, 1) for Z=true
3734 */
3735 if (Carry0.getOpcode() == ISD::UADDO_CARRY &&
3736 isNullConstant(V: Carry0.getOperand(i: 1))) {
3737 Z = Carry0.getOperand(i: 2);
3738 } else if (Carry0.getOpcode() == ISD::UADDO &&
3739 isOneConstant(V: Carry0.getOperand(i: 1))) {
3740 EVT VT = Carry0->getValueType(ResNo: 1);
3741 Z = DAG.getConstant(Val: 1, DL: SDLoc(Carry0.getOperand(i: 1)), VT);
3742 } else {
3743 // We couldn't find a suitable Z.
3744 return SDValue();
3745 }
3746
3747
3748 auto cancelDiamond = [&](SDValue A,SDValue B) {
3749 SDLoc DL(N);
3750 SDValue NewY =
3751 DAG.getNode(Opcode: ISD::UADDO_CARRY, DL, VTList: Carry0->getVTList(), N1: A, N2: B, N3: Z);
3752 Combiner.AddToWorklist(N: NewY.getNode());
3753 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL, VTList: N->getVTList(), N1: X,
3754 N2: DAG.getConstant(Val: 0, DL, VT: X.getValueType()),
3755 N3: NewY.getValue(R: 1));
3756 };
3757
3758 /**
3759 * (uaddo A, B)
3760 * |
3761 * Sum
3762 * |
3763 * (uaddo_carry *, 0, Z)
3764 */
3765 if (Carry0.getOperand(i: 0) == Carry1.getValue(R: 0)) {
3766 return cancelDiamond(Carry1.getOperand(i: 0), Carry1.getOperand(i: 1));
3767 }
3768
3769 /**
3770 * (uaddo_carry A, 0, Z)
3771 * |
3772 * Sum
3773 * |
3774 * (uaddo *, B)
3775 */
3776 if (Carry1.getOperand(i: 0) == Carry0.getValue(R: 0)) {
3777 return cancelDiamond(Carry0.getOperand(i: 0), Carry1.getOperand(i: 1));
3778 }
3779
3780 if (Carry1.getOperand(i: 1) == Carry0.getValue(R: 0)) {
3781 return cancelDiamond(Carry1.getOperand(i: 0), Carry0.getOperand(i: 0));
3782 }
3783
3784 return SDValue();
3785}
3786
3787// If we are facing some sort of diamond carry/borrow in/out pattern try to
3788// match patterns like:
3789//
3790// (uaddo A, B) CarryIn
3791// | \ |
3792// | \ |
3793// PartialSum PartialCarryOutX /
3794// | | /
3795// | ____|____________/
3796// | / |
3797// (uaddo *, *) \________
3798// | \ \
3799// | \ |
3800// | PartialCarryOutY |
3801// | \ |
3802// | \ /
3803// AddCarrySum | ______/
3804// | /
3805// CarryOut = (or *, *)
3806//
3807// And generate UADDO_CARRY (or USUBO_CARRY) with two result values:
3808//
3809// {AddCarrySum, CarryOut} = (uaddo_carry A, B, CarryIn)
3810//
3811// Our goal is to identify A, B, and CarryIn and produce UADDO_CARRY/USUBO_CARRY
3812// with a single path for carry/borrow out propagation.
3813static SDValue combineCarryDiamond(SelectionDAG &DAG, const TargetLowering &TLI,
3814 SDValue N0, SDValue N1, SDNode *N) {
3815 SDValue Carry0 = getAsCarry(TLI, V: N0);
3816 if (!Carry0)
3817 return SDValue();
3818 SDValue Carry1 = getAsCarry(TLI, V: N1);
3819 if (!Carry1)
3820 return SDValue();
3821
3822 unsigned Opcode = Carry0.getOpcode();
3823 if (Opcode != Carry1.getOpcode())
3824 return SDValue();
3825 if (Opcode != ISD::UADDO && Opcode != ISD::USUBO)
3826 return SDValue();
3827 // Guarantee identical type of CarryOut
3828 EVT CarryOutType = N->getValueType(ResNo: 0);
3829 if (CarryOutType != Carry0.getValue(R: 1).getValueType() ||
3830 CarryOutType != Carry1.getValue(R: 1).getValueType())
3831 return SDValue();
3832
3833 // Canonicalize the add/sub of A and B (the top node in the above ASCII art)
3834 // as Carry0 and the add/sub of the carry in as Carry1 (the middle node).
3835 if (Carry1.getNode()->isOperandOf(N: Carry0.getNode()))
3836 std::swap(a&: Carry0, b&: Carry1);
3837
3838 // Check if nodes are connected in expected way.
3839 if (Carry1.getOperand(i: 0) != Carry0.getValue(R: 0) &&
3840 Carry1.getOperand(i: 1) != Carry0.getValue(R: 0))
3841 return SDValue();
3842
3843 // The carry in value must be on the righthand side for subtraction.
3844 unsigned CarryInOperandNum =
3845 Carry1.getOperand(i: 0) == Carry0.getValue(R: 0) ? 1 : 0;
3846 if (Opcode == ISD::USUBO && CarryInOperandNum != 1)
3847 return SDValue();
3848 SDValue CarryIn = Carry1.getOperand(i: CarryInOperandNum);
3849
3850 unsigned NewOp = Opcode == ISD::UADDO ? ISD::UADDO_CARRY : ISD::USUBO_CARRY;
3851 if (!TLI.isOperationLegalOrCustom(Op: NewOp, VT: Carry0.getValue(R: 0).getValueType()))
3852 return SDValue();
3853
3854 // Verify that the carry/borrow in is plausibly a carry/borrow bit.
3855 CarryIn = getAsCarry(TLI, V: CarryIn, ForceCarryReconstruction: true);
3856 if (!CarryIn)
3857 return SDValue();
3858
3859 SDLoc DL(N);
3860 CarryIn = DAG.getBoolExtOrTrunc(Op: CarryIn, SL: DL, VT: Carry1->getValueType(ResNo: 1),
3861 OpVT: Carry1->getValueType(ResNo: 0));
3862 SDValue Merged =
3863 DAG.getNode(Opcode: NewOp, DL, VTList: Carry1->getVTList(), N1: Carry0.getOperand(i: 0),
3864 N2: Carry0.getOperand(i: 1), N3: CarryIn);
3865
3866 // Please note that because we have proven that the result of the UADDO/USUBO
3867 // of A and B feeds into the UADDO/USUBO that does the carry/borrow in, we can
3868 // therefore prove that if the first UADDO/USUBO overflows, the second
3869 // UADDO/USUBO cannot. For example consider 8-bit numbers where 0xFF is the
3870 // maximum value.
3871 //
3872 // 0xFF + 0xFF == 0xFE with carry but 0xFE + 1 does not carry
3873 // 0x00 - 0xFF == 1 with a carry/borrow but 1 - 1 == 0 (no carry/borrow)
3874 //
3875 // This is important because it means that OR and XOR can be used to merge
3876 // carry flags; and that AND can return a constant zero.
3877 //
3878 // TODO: match other operations that can merge flags (ADD, etc)
3879 DAG.ReplaceAllUsesOfValueWith(From: Carry1.getValue(R: 0), To: Merged.getValue(R: 0));
3880 if (N->getOpcode() == ISD::AND)
3881 return DAG.getConstant(Val: 0, DL, VT: CarryOutType);
3882 return Merged.getValue(R: 1);
3883}
3884
3885SDValue DAGCombiner::visitUADDO_CARRYLike(SDValue N0, SDValue N1,
3886 SDValue CarryIn, SDNode *N) {
3887 // fold (uaddo_carry (xor a, -1), b, c) -> (usubo_carry b, a, !c) and flip
3888 // carry.
3889 if (isBitwiseNot(V: N0))
3890 if (SDValue NotC = extractBooleanFlip(V: CarryIn, DAG, TLI, Force: true)) {
3891 SDLoc DL(N);
3892 SDValue Sub = DAG.getNode(Opcode: ISD::USUBO_CARRY, DL, VTList: N->getVTList(), N1,
3893 N2: N0.getOperand(i: 0), N3: NotC);
3894 return CombineTo(
3895 N, Res0: Sub, Res1: DAG.getLogicalNOT(DL, Val: Sub.getValue(R: 1), VT: Sub->getValueType(ResNo: 1)));
3896 }
3897
3898 // Iff the flag result is dead:
3899 // (uaddo_carry (add|uaddo X, Y), 0, Carry) -> (uaddo_carry X, Y, Carry)
3900 // Don't do this if the Carry comes from the uaddo. It won't remove the uaddo
3901 // or the dependency between the instructions.
3902 if ((N0.getOpcode() == ISD::ADD ||
3903 (N0.getOpcode() == ISD::UADDO && N0.getResNo() == 0 &&
3904 N0.getValue(R: 1) != CarryIn)) &&
3905 isNullConstant(V: N1) && !N->hasAnyUseOfValue(Value: 1))
3906 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL: SDLoc(N), VTList: N->getVTList(),
3907 N1: N0.getOperand(i: 0), N2: N0.getOperand(i: 1), N3: CarryIn);
3908
3909 /**
3910 * When one of the uaddo_carry argument is itself a carry, we may be facing
3911 * a diamond carry propagation. In which case we try to transform the DAG
3912 * to ensure linear carry propagation if that is possible.
3913 */
3914 if (auto Y = getAsCarry(TLI, V: N1)) {
3915 // Because both are carries, Y and Z can be swapped.
3916 if (auto R = combineUADDO_CARRYDiamond(Combiner&: *this, DAG, X: N0, Carry0: Y, Carry1: CarryIn, N))
3917 return R;
3918 if (auto R = combineUADDO_CARRYDiamond(Combiner&: *this, DAG, X: N0, Carry0: CarryIn, Carry1: Y, N))
3919 return R;
3920 }
3921
3922 return SDValue();
3923}
3924
3925SDValue DAGCombiner::visitSADDO_CARRYLike(SDValue N0, SDValue N1,
3926 SDValue CarryIn, SDNode *N) {
3927 // fold (saddo_carry (xor a, -1), b, c) -> (ssubo_carry b, a, !c)
3928 if (isBitwiseNot(V: N0)) {
3929 if (SDValue NotC = extractBooleanFlip(V: CarryIn, DAG, TLI, Force: true))
3930 return DAG.getNode(Opcode: ISD::SSUBO_CARRY, DL: SDLoc(N), VTList: N->getVTList(), N1,
3931 N2: N0.getOperand(i: 0), N3: NotC);
3932 }
3933
3934 return SDValue();
3935}
3936
3937SDValue DAGCombiner::visitSADDO_CARRY(SDNode *N) {
3938 SDValue N0 = N->getOperand(Num: 0);
3939 SDValue N1 = N->getOperand(Num: 1);
3940 SDValue CarryIn = N->getOperand(Num: 2);
3941 SDLoc DL(N);
3942
3943 // canonicalize constant to RHS
3944 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(Val&: N0);
3945 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(Val&: N1);
3946 if (N0C && !N1C)
3947 return DAG.getNode(Opcode: ISD::SADDO_CARRY, DL, VTList: N->getVTList(), N1, N2: N0, N3: CarryIn);
3948
3949 // fold (saddo_carry x, y, false) -> (saddo x, y)
3950 if (isNullConstant(V: CarryIn)) {
3951 if (!LegalOperations ||
3952 TLI.isOperationLegalOrCustom(Op: ISD::SADDO, VT: N->getValueType(ResNo: 0)))
3953 return DAG.getNode(Opcode: ISD::SADDO, DL, VTList: N->getVTList(), N1: N0, N2: N1);
3954 }
3955
3956 if (SDValue Combined = visitSADDO_CARRYLike(N0, N1, CarryIn, N))
3957 return Combined;
3958
3959 if (SDValue Combined = visitSADDO_CARRYLike(N0: N1, N1: N0, CarryIn, N))
3960 return Combined;
3961
3962 return SDValue();
3963}
3964
3965// Attempt to create a USUBSAT(LHS, RHS) node with DstVT, performing a
3966// clamp/truncation if necessary.
3967static SDValue getTruncatedUSUBSAT(EVT DstVT, EVT SrcVT, SDValue LHS,
3968 SDValue RHS, SelectionDAG &DAG,
3969 const SDLoc &DL) {
3970 assert(DstVT.getScalarSizeInBits() <= SrcVT.getScalarSizeInBits() &&
3971 "Illegal truncation");
3972
3973 if (DstVT == SrcVT)
3974 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT: DstVT, N1: LHS, N2: RHS);
3975
3976 // If the LHS is zero-extended then we can perform the USUBSAT as DstVT by
3977 // clamping RHS.
3978 APInt UpperBits = APInt::getBitsSetFrom(numBits: SrcVT.getScalarSizeInBits(),
3979 loBit: DstVT.getScalarSizeInBits());
3980 if (!DAG.MaskedValueIsZero(Op: LHS, Mask: UpperBits))
3981 return SDValue();
3982
3983 SDValue SatLimit =
3984 DAG.getConstant(Val: APInt::getLowBitsSet(numBits: SrcVT.getScalarSizeInBits(),
3985 loBitsSet: DstVT.getScalarSizeInBits()),
3986 DL, VT: SrcVT);
3987 RHS = DAG.getNode(Opcode: ISD::UMIN, DL, VT: SrcVT, N1: RHS, N2: SatLimit);
3988 RHS = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: DstVT, Operand: RHS);
3989 LHS = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: DstVT, Operand: LHS);
3990 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT: DstVT, N1: LHS, N2: RHS);
3991}
3992
3993// Try to find umax(a,b) - b or a - umin(a,b) patterns that may be converted to
3994// usubsat(a,b), optionally as a truncated type.
3995SDValue DAGCombiner::foldSubToUSubSat(EVT DstVT, SDNode *N, const SDLoc &DL) {
3996 if (N->getOpcode() != ISD::SUB ||
3997 !(!LegalOperations || hasOperation(Opcode: ISD::USUBSAT, VT: DstVT)))
3998 return SDValue();
3999
4000 EVT SubVT = N->getValueType(ResNo: 0);
4001 SDValue Op0 = N->getOperand(Num: 0);
4002 SDValue Op1 = N->getOperand(Num: 1);
4003
4004 // Try to find umax(a,b) - b or a - umin(a,b) patterns
4005 // they may be converted to usubsat(a,b).
4006 if (Op0.getOpcode() == ISD::UMAX && Op0.hasOneUse()) {
4007 SDValue MaxLHS = Op0.getOperand(i: 0);
4008 SDValue MaxRHS = Op0.getOperand(i: 1);
4009 if (MaxLHS == Op1)
4010 return getTruncatedUSUBSAT(DstVT, SrcVT: SubVT, LHS: MaxRHS, RHS: Op1, DAG, DL);
4011 if (MaxRHS == Op1)
4012 return getTruncatedUSUBSAT(DstVT, SrcVT: SubVT, LHS: MaxLHS, RHS: Op1, DAG, DL);
4013 }
4014
4015 if (Op1.getOpcode() == ISD::UMIN && Op1.hasOneUse()) {
4016 SDValue MinLHS = Op1.getOperand(i: 0);
4017 SDValue MinRHS = Op1.getOperand(i: 1);
4018 if (MinLHS == Op0)
4019 return getTruncatedUSUBSAT(DstVT, SrcVT: SubVT, LHS: Op0, RHS: MinRHS, DAG, DL);
4020 if (MinRHS == Op0)
4021 return getTruncatedUSUBSAT(DstVT, SrcVT: SubVT, LHS: Op0, RHS: MinLHS, DAG, DL);
4022 }
4023
4024 // sub(a,trunc(umin(zext(a),b))) -> usubsat(a,trunc(umin(b,SatLimit)))
4025 if (Op1.getOpcode() == ISD::TRUNCATE &&
4026 Op1.getOperand(i: 0).getOpcode() == ISD::UMIN &&
4027 Op1.getOperand(i: 0).hasOneUse()) {
4028 SDValue MinLHS = Op1.getOperand(i: 0).getOperand(i: 0);
4029 SDValue MinRHS = Op1.getOperand(i: 0).getOperand(i: 1);
4030 if (MinLHS.getOpcode() == ISD::ZERO_EXTEND && MinLHS.getOperand(i: 0) == Op0)
4031 return getTruncatedUSUBSAT(DstVT, SrcVT: MinLHS.getValueType(), LHS: MinLHS, RHS: MinRHS,
4032 DAG, DL);
4033 if (MinRHS.getOpcode() == ISD::ZERO_EXTEND && MinRHS.getOperand(i: 0) == Op0)
4034 return getTruncatedUSUBSAT(DstVT, SrcVT: MinLHS.getValueType(), LHS: MinRHS, RHS: MinLHS,
4035 DAG, DL);
4036 }
4037
4038 return SDValue();
4039}
4040
4041// Refinement of DAG/Type Legalisation (promotion) when CTLZ is used for
4042// counting leading ones. Broadly, it replaces the substraction with a left
4043// shift.
4044//
4045// * DAG Legalisation Pattern:
4046//
4047// (sub (ctlz (zeroextend (not Src)))
4048// BitWidthDiff)
4049//
4050// if BitWidthDiff == BitWidth(Node) - BitWidth(Src)
4051// -->
4052//
4053// (ctlz_zero_undef (not (shl (anyextend Src)
4054// BitWidthDiff)))
4055//
4056// * Type Legalisation Pattern:
4057//
4058// (sub (ctlz (and (xor Src XorMask)
4059// AndMask))
4060// BitWidthDiff)
4061//
4062// if AndMask has only trailing ones
4063// and MaskBitWidth(AndMask) == BitWidth(Node) - BitWidthDiff
4064// and XorMask has more trailing ones than AndMask
4065// -->
4066//
4067// (ctlz_zero_undef (not (shl Src BitWidthDiff)))
4068template <class MatchContextClass>
4069static SDValue foldSubCtlzNot(SDNode *N, SelectionDAG &DAG) {
4070 const SDLoc DL(N);
4071 SDValue N0 = N->getOperand(Num: 0);
4072 EVT VT = N0.getValueType();
4073 unsigned BitWidth = VT.getScalarSizeInBits();
4074
4075 MatchContextClass Matcher(DAG, DAG.getTargetLoweringInfo(), N);
4076
4077 APInt AndMask;
4078 APInt XorMask;
4079 APInt BitWidthDiff;
4080
4081 SDValue CtlzOp;
4082 SDValue Src;
4083
4084 if (!sd_context_match(
4085 N, Matcher, m_Sub(L: m_Ctlz(Op: m_Value(N&: CtlzOp)), R: m_ConstInt(V&: BitWidthDiff))))
4086 return SDValue();
4087
4088 if (sd_context_match(CtlzOp, Matcher, m_ZExt(Op: m_Not(V: m_Value(N&: Src))))) {
4089 // DAG Legalisation Pattern:
4090 // (sub (ctlz (zero_extend (not Op)) BitWidthDiff))
4091 if ((BitWidth - Src.getValueType().getScalarSizeInBits()) != BitWidthDiff)
4092 return SDValue();
4093
4094 Src = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT, Operand: Src);
4095 } else if (sd_context_match(CtlzOp, Matcher,
4096 m_And(L: m_Xor(L: m_Value(N&: Src), R: m_ConstInt(V&: XorMask)),
4097 R: m_ConstInt(V&: AndMask)))) {
4098 // Type Legalisation Pattern:
4099 // (sub (ctlz (and (xor Op XorMask) AndMask)) BitWidthDiff)
4100 if (BitWidthDiff.getZExtValue() >= BitWidth)
4101 return SDValue();
4102 unsigned AndMaskWidth = BitWidth - BitWidthDiff.getZExtValue();
4103 if (!(AndMask.isMask(numBits: AndMaskWidth) && XorMask.countr_one() >= AndMaskWidth))
4104 return SDValue();
4105 } else
4106 return SDValue();
4107
4108 SDValue ShiftConst = DAG.getShiftAmountConstant(Val: BitWidthDiff, VT, DL);
4109 SDValue LShift = Matcher.getNode(ISD::SHL, DL, VT, Src, ShiftConst);
4110 SDValue Not =
4111 Matcher.getNode(ISD::XOR, DL, VT, LShift, DAG.getAllOnesConstant(DL, VT));
4112
4113 return Matcher.getNode(ISD::CTLZ_ZERO_UNDEF, DL, VT, Not);
4114}
4115
4116// Fold sub(x, mul(divrem(x,y)[0], y)) to divrem(x, y)[1]
4117static SDValue foldRemainderIdiom(SDNode *N, SelectionDAG &DAG,
4118 const SDLoc &DL) {
4119 assert(N->getOpcode() == ISD::SUB && "Node must be a SUB");
4120 SDValue Sub0 = N->getOperand(Num: 0);
4121 SDValue Sub1 = N->getOperand(Num: 1);
4122
4123 auto CheckAndFoldMulCase = [&](SDValue DivRem, SDValue MaybeY) -> SDValue {
4124 if ((DivRem.getOpcode() == ISD::SDIVREM ||
4125 DivRem.getOpcode() == ISD::UDIVREM) &&
4126 DivRem.getResNo() == 0 && DivRem.getOperand(i: 0) == Sub0 &&
4127 DivRem.getOperand(i: 1) == MaybeY) {
4128 return SDValue(DivRem.getNode(), 1);
4129 }
4130 return SDValue();
4131 };
4132
4133 if (Sub1.getOpcode() == ISD::MUL) {
4134 // (sub x, (mul divrem(x,y)[0], y))
4135 SDValue Mul0 = Sub1.getOperand(i: 0);
4136 SDValue Mul1 = Sub1.getOperand(i: 1);
4137
4138 if (SDValue Res = CheckAndFoldMulCase(Mul0, Mul1))
4139 return Res;
4140
4141 if (SDValue Res = CheckAndFoldMulCase(Mul1, Mul0))
4142 return Res;
4143
4144 } else if (Sub1.getOpcode() == ISD::SHL) {
4145 // Handle (sub x, (shl divrem(x,y)[0], C)) where y = 1 << C
4146 SDValue Shl0 = Sub1.getOperand(i: 0);
4147 SDValue Shl1 = Sub1.getOperand(i: 1);
4148 // Check if Shl0 is divrem(x, Y)[0]
4149 if ((Shl0.getOpcode() == ISD::SDIVREM ||
4150 Shl0.getOpcode() == ISD::UDIVREM) &&
4151 Shl0.getResNo() == 0 && Shl0.getOperand(i: 0) == Sub0) {
4152
4153 SDValue Divisor = Shl0.getOperand(i: 1);
4154
4155 ConstantSDNode *DivC = isConstOrConstSplat(N: Divisor);
4156 ConstantSDNode *ShC = isConstOrConstSplat(N: Shl1);
4157 if (!DivC || !ShC)
4158 return SDValue();
4159
4160 if (DivC->getAPIntValue().isPowerOf2() &&
4161 DivC->getAPIntValue().logBase2() == ShC->getAPIntValue())
4162 return SDValue(Shl0.getNode(), 1);
4163 }
4164 }
4165 return SDValue();
4166}
4167
4168// Since it may not be valid to emit a fold to zero for vector initializers
4169// check if we can before folding.
4170static SDValue tryFoldToZero(const SDLoc &DL, const TargetLowering &TLI, EVT VT,
4171 SelectionDAG &DAG, bool LegalOperations) {
4172 if (!VT.isVector())
4173 return DAG.getConstant(Val: 0, DL, VT);
4174 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::BUILD_VECTOR, VT))
4175 return DAG.getConstant(Val: 0, DL, VT);
4176 return SDValue();
4177}
4178
4179SDValue DAGCombiner::visitSUB(SDNode *N) {
4180 SDValue N0 = N->getOperand(Num: 0);
4181 SDValue N1 = N->getOperand(Num: 1);
4182 EVT VT = N0.getValueType();
4183 unsigned BitWidth = VT.getScalarSizeInBits();
4184 SDLoc DL(N);
4185
4186 if (SDValue V = foldSubCtlzNot<EmptyMatchContext>(N, DAG))
4187 return V;
4188
4189 // fold (sub x, x) -> 0
4190 if (N0 == N1)
4191 return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
4192
4193 // fold (sub c1, c2) -> c3
4194 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::SUB, DL, VT, Ops: {N0, N1}))
4195 return C;
4196
4197 // fold vector ops
4198 if (VT.isVector()) {
4199 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4200 return FoldedVOp;
4201
4202 // fold (sub x, 0) -> x, vector edition
4203 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
4204 return N0;
4205 }
4206
4207 // (sub x, ([v]select (ult x, y), 0, y)) -> (umin x, (sub x, y))
4208 // (sub x, ([v]select (uge x, y), y, 0)) -> (umin x, (sub x, y))
4209 if (N1.hasOneUse() && hasUMin(VT)) {
4210 SDValue Y;
4211 auto MS0 = m_Specific(N: N0);
4212 auto MVY = m_Value(N&: Y);
4213 auto MZ = m_Zero();
4214 auto MCC1 = m_SpecificCondCode(CC: ISD::SETULT);
4215 auto MCC2 = m_SpecificCondCode(CC: ISD::SETUGE);
4216
4217 if (sd_match(N: N1, P: m_SelectCCLike(L: MS0, R: MVY, T: MZ, F: m_Deferred(V&: Y), CC: MCC1)) ||
4218 sd_match(N: N1, P: m_SelectCCLike(L: MS0, R: MVY, T: m_Deferred(V&: Y), F: MZ, CC: MCC2)) ||
4219 sd_match(N: N1, P: m_VSelect(Cond: m_SetCC(LHS: MS0, RHS: MVY, CC: MCC1), T: MZ, F: m_Deferred(V&: Y))) ||
4220 sd_match(N: N1, P: m_VSelect(Cond: m_SetCC(LHS: MS0, RHS: MVY, CC: MCC2), T: m_Deferred(V&: Y), F: MZ)))
4221
4222 return DAG.getNode(Opcode: ISD::UMIN, DL, VT, N1: N0,
4223 N2: DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: Y));
4224 }
4225
4226 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
4227 return NewSel;
4228
4229 // fold (sub x, c) -> (add x, -c)
4230 if (ConstantSDNode *N1C = getAsNonOpaqueConstant(N: N1))
4231 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0,
4232 N2: DAG.getConstant(Val: -N1C->getAPIntValue(), DL, VT));
4233
4234 if (isNullOrNullSplat(V: N0)) {
4235 // Right-shifting everything out but the sign bit followed by negation is
4236 // the same as flipping arithmetic/logical shift type without the negation:
4237 // -(X >>u 31) -> (X >>s 31)
4238 // -(X >>s 31) -> (X >>u 31)
4239 if (N1->getOpcode() == ISD::SRA || N1->getOpcode() == ISD::SRL) {
4240 ConstantSDNode *ShiftAmt = isConstOrConstSplat(N: N1.getOperand(i: 1));
4241 if (ShiftAmt && ShiftAmt->getAPIntValue() == (BitWidth - 1)) {
4242 auto NewSh = N1->getOpcode() == ISD::SRA ? ISD::SRL : ISD::SRA;
4243 if (!LegalOperations || TLI.isOperationLegal(Op: NewSh, VT))
4244 return DAG.getNode(Opcode: NewSh, DL, VT, N1: N1.getOperand(i: 0), N2: N1.getOperand(i: 1));
4245 }
4246 }
4247
4248 // 0 - X --> 0 if the sub is NUW.
4249 if (N->getFlags().hasNoUnsignedWrap())
4250 return N0;
4251
4252 if (DAG.MaskedValueIsZero(Op: N1, Mask: ~APInt::getSignMask(BitWidth))) {
4253 // N1 is either 0 or the minimum signed value. If the sub is NSW, then
4254 // N1 must be 0 because negating the minimum signed value is undefined.
4255 if (N->getFlags().hasNoSignedWrap())
4256 return N0;
4257
4258 // 0 - X --> X if X is 0 or the minimum signed value.
4259 return N1;
4260 }
4261
4262 // Convert 0 - abs(x).
4263 if (N1.getOpcode() == ISD::ABS && N1.hasOneUse() &&
4264 !TLI.isOperationLegalOrCustom(Op: ISD::ABS, VT))
4265 if (SDValue Result = TLI.expandABS(N: N1.getNode(), DAG, IsNegative: true))
4266 return Result;
4267
4268 // Similar to the previous rule, but this time targeting an expanded abs.
4269 // (sub 0, (max X, (sub 0, X))) --> (min X, (sub 0, X))
4270 // as well as
4271 // (sub 0, (min X, (sub 0, X))) --> (max X, (sub 0, X))
4272 // Note that these two are applicable to both signed and unsigned min/max.
4273 SDValue X;
4274 SDValue S0;
4275 auto NegPat = m_AllOf(preds: m_Neg(V: m_Deferred(V&: X)), preds: m_Value(N&: S0));
4276 if (sd_match(N: N1, P: m_OneUse(P: m_AnyOf(preds: m_SMax(L: m_Value(N&: X), R: NegPat),
4277 preds: m_UMax(L: m_Value(N&: X), R: NegPat),
4278 preds: m_SMin(L: m_Value(N&: X), R: NegPat),
4279 preds: m_UMin(L: m_Value(N&: X), R: NegPat))))) {
4280 unsigned NewOpc = ISD::getInverseMinMaxOpcode(MinMaxOpc: N1->getOpcode());
4281 if (hasOperation(Opcode: NewOpc, VT))
4282 return DAG.getNode(Opcode: NewOpc, DL, VT, N1: X, N2: S0);
4283 }
4284
4285 // Fold neg(splat(neg(x)) -> splat(x)
4286 if (VT.isVector()) {
4287 SDValue N1S = DAG.getSplatValue(V: N1, LegalTypes: true);
4288 if (N1S && N1S.getOpcode() == ISD::SUB &&
4289 isNullConstant(V: N1S.getOperand(i: 0)))
4290 return DAG.getSplat(VT, DL, Op: N1S.getOperand(i: 1));
4291 }
4292
4293 // sub 0, (and x, 1) --> SIGN_EXTEND_INREG x, i1
4294 if (N1.getOpcode() == ISD::AND && N1.hasOneUse() &&
4295 isOneOrOneSplat(V: N1->getOperand(Num: 1))) {
4296 EVT ExtVT = VT.changeElementType(Context&: *DAG.getContext(), EltVT: MVT::i1);
4297 if (TLI.getOperationAction(Op: ISD::SIGN_EXTEND_INREG, VT: ExtVT) ==
4298 TargetLowering::Legal) {
4299 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL, VT, N1: N1->getOperand(Num: 0),
4300 N2: DAG.getValueType(ExtVT));
4301 }
4302 }
4303 }
4304
4305 // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1)
4306 if (isAllOnesOrAllOnesSplat(V: N0))
4307 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1, N2: N0);
4308
4309 // fold (A - (0-B)) -> A+B
4310 if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(V: N1.getOperand(i: 0)))
4311 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1.getOperand(i: 1));
4312
4313 // fold A-(A-B) -> B
4314 if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(i: 0))
4315 return N1.getOperand(i: 1);
4316
4317 // fold (A+B)-A -> B
4318 if (N0.getOpcode() == ISD::ADD && N0.getOperand(i: 0) == N1)
4319 return N0.getOperand(i: 1);
4320
4321 // fold (A+B)-B -> A
4322 if (N0.getOpcode() == ISD::ADD && N0.getOperand(i: 1) == N1)
4323 return N0.getOperand(i: 0);
4324
4325 // fold (A+C1)-C2 -> A+(C1-C2)
4326 if (N0.getOpcode() == ISD::ADD) {
4327 SDValue N01 = N0.getOperand(i: 1);
4328 if (SDValue NewC = DAG.FoldConstantArithmetic(Opcode: ISD::SUB, DL, VT, Ops: {N01, N1}))
4329 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 0), N2: NewC);
4330 }
4331
4332 // fold C2-(A+C1) -> (C2-C1)-A
4333 if (N1.getOpcode() == ISD::ADD) {
4334 SDValue N11 = N1.getOperand(i: 1);
4335 if (SDValue NewC = DAG.FoldConstantArithmetic(Opcode: ISD::SUB, DL, VT, Ops: {N0, N11}))
4336 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: NewC, N2: N1.getOperand(i: 0));
4337 }
4338
4339 // fold (A-C1)-C2 -> A-(C1+C2)
4340 if (N0.getOpcode() == ISD::SUB) {
4341 SDValue N01 = N0.getOperand(i: 1);
4342 if (SDValue NewC = DAG.FoldConstantArithmetic(Opcode: ISD::ADD, DL, VT, Ops: {N01, N1}))
4343 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0.getOperand(i: 0), N2: NewC);
4344 }
4345
4346 // fold (c1-A)-c2 -> (c1-c2)-A
4347 if (N0.getOpcode() == ISD::SUB) {
4348 SDValue N00 = N0.getOperand(i: 0);
4349 if (SDValue NewC = DAG.FoldConstantArithmetic(Opcode: ISD::SUB, DL, VT, Ops: {N00, N1}))
4350 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: NewC, N2: N0.getOperand(i: 1));
4351 }
4352
4353 SDValue A, B, C;
4354
4355 // fold ((A+(B+C))-B) -> A+C
4356 if (sd_match(N: N0, P: m_Add(L: m_Value(N&: A), R: m_Add(L: m_Specific(N: N1), R: m_Value(N&: C)))))
4357 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: A, N2: C);
4358
4359 // fold ((A+(B-C))-B) -> A-C
4360 if (sd_match(N: N0, P: m_Add(L: m_Value(N&: A), R: m_Sub(L: m_Specific(N: N1), R: m_Value(N&: C)))))
4361 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: A, N2: C);
4362
4363 // fold ((A-(B-C))-C) -> A-B
4364 if (sd_match(N: N0, P: m_Sub(L: m_Value(N&: A), R: m_Sub(L: m_Value(N&: B), R: m_Specific(N: N1)))))
4365 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: A, N2: B);
4366
4367 // fold (A-(B-C)) -> A+(C-B)
4368 if (sd_match(N: N1, P: m_OneUse(P: m_Sub(L: m_Value(N&: B), R: m_Value(N&: C)))))
4369 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0,
4370 N2: DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: C, N2: B));
4371
4372 // A - (A & B) -> A & (~B)
4373 if (sd_match(N: N1, P: m_And(L: m_Specific(N: N0), R: m_Value(N&: B))) &&
4374 (N1.hasOneUse() || isConstantOrConstantVector(N: B, /*NoOpaques=*/true)))
4375 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0, N2: DAG.getNOT(DL, Val: B, VT));
4376
4377 // fold (A - (-B * C)) -> (A + (B * C))
4378 if (sd_match(N: N1, P: m_OneUse(P: m_Mul(L: m_Neg(V: m_Value(N&: B)), R: m_Value(N&: C)))))
4379 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0,
4380 N2: DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: B, N2: C));
4381
4382 // If either operand of a sub is undef, the result is undef
4383 if (N0.isUndef())
4384 return N0;
4385 if (N1.isUndef())
4386 return N1;
4387
4388 if (SDValue V = foldAddSubBoolOfMaskedVal(N, DL, DAG))
4389 return V;
4390
4391 if (SDValue V = foldAddSubOfSignBit(N, DL, DAG))
4392 return V;
4393
4394 // Try to match AVGCEIL fixedwidth pattern
4395 if (SDValue V = foldSubToAvg(N, DL))
4396 return V;
4397
4398 if (SDValue V = foldAddSubMasked1(IsAdd: false, N0, N1, DAG, DL))
4399 return V;
4400
4401 if (SDValue V = foldSubToUSubSat(DstVT: VT, N, DL))
4402 return V;
4403
4404 if (SDValue V = foldRemainderIdiom(N, DAG, DL))
4405 return V;
4406
4407 // (A - B) - 1 -> add (xor B, -1), A
4408 if (sd_match(N, P: m_Sub(L: m_OneUse(P: m_Sub(L: m_Value(N&: A), R: m_Value(N&: B))),
4409 R: m_One(/*AllowUndefs=*/true))))
4410 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: A, N2: DAG.getNOT(DL, Val: B, VT));
4411
4412 // Look for:
4413 // sub y, (xor x, -1)
4414 // And if the target does not like this form then turn into:
4415 // add (add x, y), 1
4416 if (TLI.preferIncOfAddToSubOfNot(VT) && N1.hasOneUse() && isBitwiseNot(V: N1)) {
4417 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1.getOperand(i: 0));
4418 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Add, N2: DAG.getConstant(Val: 1, DL, VT));
4419 }
4420
4421 // Hoist one-use addition by non-opaque constant:
4422 // (x + C) - y -> (x - y) + C
4423 if (!reassociationCanBreakAddressingModePattern(Opc: ISD::SUB, DL, N, N0, N1) &&
4424 N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
4425 isConstantOrConstantVector(N: N0.getOperand(i: 1), /*NoOpaques=*/true)) {
4426 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0.getOperand(i: 0), N2: N1);
4427 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Sub, N2: N0.getOperand(i: 1));
4428 }
4429 // y - (x + C) -> (y - x) - C
4430 if (N1.getOpcode() == ISD::ADD && N1.hasOneUse() &&
4431 isConstantOrConstantVector(N: N1.getOperand(i: 1), /*NoOpaques=*/true)) {
4432 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: N1.getOperand(i: 0));
4433 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Sub, N2: N1.getOperand(i: 1));
4434 }
4435 // (x - C) - y -> (x - y) - C
4436 // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
4437 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
4438 isConstantOrConstantVector(N: N0.getOperand(i: 1), /*NoOpaques=*/true)) {
4439 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0.getOperand(i: 0), N2: N1);
4440 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Sub, N2: N0.getOperand(i: 1));
4441 }
4442 // (C - x) - y -> C - (x + y)
4443 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
4444 isConstantOrConstantVector(N: N0.getOperand(i: 0), /*NoOpaques=*/true)) {
4445 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 1), N2: N1);
4446 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0.getOperand(i: 0), N2: Add);
4447 }
4448
4449 // If the target's bool is represented as 0/-1, prefer to make this 'add 0/-1'
4450 // rather than 'sub 0/1' (the sext should get folded).
4451 // sub X, (zext i1 Y) --> add X, (sext i1 Y)
4452 if (N1.getOpcode() == ISD::ZERO_EXTEND &&
4453 N1.getOperand(i: 0).getScalarValueSizeInBits() == 1 &&
4454 TLI.getBooleanContents(Type: VT) ==
4455 TargetLowering::ZeroOrNegativeOneBooleanContent) {
4456 SDValue SExt = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: N1.getOperand(i: 0));
4457 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: SExt);
4458 }
4459
4460 // fold B = sra (A, size(A)-1); sub (xor (A, B), B) -> (abs A)
4461 if ((!LegalOperations || hasOperation(Opcode: ISD::ABS, VT)) &&
4462 sd_match(N: N1, P: m_Sra(L: m_Value(N&: A), R: m_SpecificInt(V: BitWidth - 1))) &&
4463 sd_match(N: N0, P: m_Xor(L: m_Specific(N: A), R: m_Specific(N: N1))))
4464 return DAG.getNode(Opcode: ISD::ABS, DL, VT, Operand: A);
4465
4466 // If the relocation model supports it, consider symbol offsets.
4467 if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(Val&: N0))
4468 if (!LegalOperations && TLI.isOffsetFoldingLegal(GA)) {
4469 // fold (sub Sym+c1, Sym+c2) -> c1-c2
4470 if (GlobalAddressSDNode *GB = dyn_cast<GlobalAddressSDNode>(Val&: N1))
4471 if (GA->getGlobal() == GB->getGlobal())
4472 return DAG.getConstant(Val: (uint64_t)GA->getOffset() - GB->getOffset(),
4473 DL, VT);
4474 }
4475
4476 // sub X, (sextinreg Y i1) -> add X, (and Y 1)
4477 if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
4478 VTSDNode *TN = cast<VTSDNode>(Val: N1.getOperand(i: 1));
4479 if (TN->getVT() == MVT::i1) {
4480 SDValue ZExt = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N1.getOperand(i: 0),
4481 N2: DAG.getConstant(Val: 1, DL, VT));
4482 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: ZExt);
4483 }
4484 }
4485
4486 // canonicalize (sub X, (vscale * C)) to (add X, (vscale * -C))
4487 if (N1.getOpcode() == ISD::VSCALE && N1.hasOneUse()) {
4488 const APInt &IntVal = N1.getConstantOperandAPInt(i: 0);
4489 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: DAG.getVScale(DL, VT, MulImm: -IntVal));
4490 }
4491
4492 // canonicalize (sub X, step_vector(C)) to (add X, step_vector(-C))
4493 if (N1.getOpcode() == ISD::STEP_VECTOR && N1.hasOneUse()) {
4494 APInt NewStep = -N1.getConstantOperandAPInt(i: 0);
4495 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0,
4496 N2: DAG.getStepVector(DL, ResVT: VT, StepVal: NewStep));
4497 }
4498
4499 // Prefer an add for more folding potential and possibly better codegen:
4500 // sub N0, (lshr N10, width-1) --> add N0, (ashr N10, width-1)
4501 if (!LegalOperations && N1.getOpcode() == ISD::SRL && N1.hasOneUse()) {
4502 SDValue ShAmt = N1.getOperand(i: 1);
4503 ConstantSDNode *ShAmtC = isConstOrConstSplat(N: ShAmt);
4504 if (ShAmtC && ShAmtC->getAPIntValue() == (BitWidth - 1)) {
4505 SDValue SRA = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: N1.getOperand(i: 0), N2: ShAmt);
4506 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: SRA);
4507 }
4508 }
4509
4510 // As with the previous fold, prefer add for more folding potential.
4511 // Subtracting SMIN/0 is the same as adding SMIN/0:
4512 // N0 - (X << BW-1) --> N0 + (X << BW-1)
4513 if (N1.getOpcode() == ISD::SHL) {
4514 ConstantSDNode *ShlC = isConstOrConstSplat(N: N1.getOperand(i: 1));
4515 if (ShlC && ShlC->getAPIntValue() == (BitWidth - 1))
4516 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1, N2: N0);
4517 }
4518
4519 // (sub (usubo_carry X, 0, Carry), Y) -> (usubo_carry X, Y, Carry)
4520 if (N0.getOpcode() == ISD::USUBO_CARRY && isNullConstant(V: N0.getOperand(i: 1)) &&
4521 N0.getResNo() == 0 && N0.hasOneUse())
4522 return DAG.getNode(Opcode: ISD::USUBO_CARRY, DL, VTList: N0->getVTList(),
4523 N1: N0.getOperand(i: 0), N2: N1, N3: N0.getOperand(i: 2));
4524
4525 if (TLI.isOperationLegalOrCustom(Op: ISD::UADDO_CARRY, VT)) {
4526 // (sub Carry, X) -> (uaddo_carry (sub 0, X), 0, Carry)
4527 if (SDValue Carry = getAsCarry(TLI, V: N0)) {
4528 SDValue X = N1;
4529 SDValue Zero = DAG.getConstant(Val: 0, DL, VT);
4530 SDValue NegX = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Zero, N2: X);
4531 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL,
4532 VTList: DAG.getVTList(VT1: VT, VT2: Carry.getValueType()), N1: NegX, N2: Zero,
4533 N3: Carry);
4534 }
4535 }
4536
4537 // If there's no chance of borrowing from adjacent bits, then sub is xor:
4538 // sub C0, X --> xor X, C0
4539 if (ConstantSDNode *C0 = isConstOrConstSplat(N: N0)) {
4540 if (!C0->isOpaque()) {
4541 const APInt &C0Val = C0->getAPIntValue();
4542 const APInt &MaybeOnes = ~DAG.computeKnownBits(Op: N1).Zero;
4543 if ((C0Val - MaybeOnes) == (C0Val ^ MaybeOnes))
4544 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1, N2: N0);
4545 }
4546 }
4547
4548 // smax(a,b) - smin(a,b) --> abds(a,b)
4549 if ((!LegalOperations || hasOperation(Opcode: ISD::ABDS, VT)) &&
4550 sd_match(N: N0, DAG: &DAG, P: m_SMaxLike(L: m_Value(N&: A), R: m_Value(N&: B))) &&
4551 sd_match(N: N1, DAG: &DAG, P: m_SMinLike(L: m_Specific(N: A), R: m_Specific(N: B))))
4552 return DAG.getNode(Opcode: ISD::ABDS, DL, VT, N1: A, N2: B);
4553
4554 // smin(a,b) - smax(a,b) --> neg(abds(a,b))
4555 if (hasOperation(Opcode: ISD::ABDS, VT) &&
4556 sd_match(N: N0, DAG: &DAG, P: m_SMinLike(L: m_Value(N&: A), R: m_Value(N&: B))) &&
4557 sd_match(N: N1, DAG: &DAG, P: m_SMaxLike(L: m_Specific(N: A), R: m_Specific(N: B))))
4558 return DAG.getNegative(Val: DAG.getNode(Opcode: ISD::ABDS, DL, VT, N1: A, N2: B), DL, VT);
4559
4560 // umax(a,b) - umin(a,b) --> abdu(a,b)
4561 if ((!LegalOperations || hasOperation(Opcode: ISD::ABDU, VT)) &&
4562 sd_match(N: N0, DAG: &DAG, P: m_UMaxLike(L: m_Value(N&: A), R: m_Value(N&: B))) &&
4563 sd_match(N: N1, DAG: &DAG, P: m_UMinLike(L: m_Specific(N: A), R: m_Specific(N: B))))
4564 return DAG.getNode(Opcode: ISD::ABDU, DL, VT, N1: A, N2: B);
4565
4566 // umin(a,b) - umax(a,b) --> neg(abdu(a,b))
4567 if (hasOperation(Opcode: ISD::ABDU, VT) &&
4568 sd_match(N: N0, DAG: &DAG, P: m_UMinLike(L: m_Value(N&: A), R: m_Value(N&: B))) &&
4569 sd_match(N: N1, DAG: &DAG, P: m_UMaxLike(L: m_Specific(N: A), R: m_Specific(N: B))))
4570 return DAG.getNegative(Val: DAG.getNode(Opcode: ISD::ABDU, DL, VT, N1: A, N2: B), DL, VT);
4571
4572 return SDValue();
4573}
4574
4575SDValue DAGCombiner::visitSUBSAT(SDNode *N) {
4576 unsigned Opcode = N->getOpcode();
4577 SDValue N0 = N->getOperand(Num: 0);
4578 SDValue N1 = N->getOperand(Num: 1);
4579 EVT VT = N0.getValueType();
4580 bool IsSigned = Opcode == ISD::SSUBSAT;
4581 SDLoc DL(N);
4582
4583 // fold (sub_sat x, undef) -> 0
4584 if (N0.isUndef() || N1.isUndef())
4585 return DAG.getConstant(Val: 0, DL, VT);
4586
4587 // fold (sub_sat x, x) -> 0
4588 if (N0 == N1)
4589 return DAG.getConstant(Val: 0, DL, VT);
4590
4591 // fold (sub_sat c1, c2) -> c3
4592 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, Ops: {N0, N1}))
4593 return C;
4594
4595 // fold vector ops
4596 if (VT.isVector()) {
4597 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4598 return FoldedVOp;
4599
4600 // fold (sub_sat x, 0) -> x, vector edition
4601 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
4602 return N0;
4603 }
4604
4605 // fold (sub_sat x, 0) -> x
4606 if (isNullConstant(V: N1))
4607 return N0;
4608
4609 // If it cannot overflow, transform into an sub.
4610 if (DAG.willNotOverflowSub(IsSigned, N0, N1))
4611 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: N1);
4612
4613 return SDValue();
4614}
4615
4616SDValue DAGCombiner::visitSUBC(SDNode *N) {
4617 SDValue N0 = N->getOperand(Num: 0);
4618 SDValue N1 = N->getOperand(Num: 1);
4619 EVT VT = N0.getValueType();
4620 SDLoc DL(N);
4621
4622 // If the flag result is dead, turn this into an SUB.
4623 if (!N->hasAnyUseOfValue(Value: 1))
4624 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: N1),
4625 Res1: DAG.getNode(Opcode: ISD::CARRY_FALSE, DL, VT: MVT::Glue));
4626
4627 // fold (subc x, x) -> 0 + no borrow
4628 if (N0 == N1)
4629 return CombineTo(N, Res0: DAG.getConstant(Val: 0, DL, VT),
4630 Res1: DAG.getNode(Opcode: ISD::CARRY_FALSE, DL, VT: MVT::Glue));
4631
4632 // fold (subc x, 0) -> x + no borrow
4633 if (isNullConstant(V: N1))
4634 return CombineTo(N, Res0: N0, Res1: DAG.getNode(Opcode: ISD::CARRY_FALSE, DL, VT: MVT::Glue));
4635
4636 // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1) + no borrow
4637 if (isAllOnesConstant(V: N0))
4638 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::XOR, DL, VT, N1, N2: N0),
4639 Res1: DAG.getNode(Opcode: ISD::CARRY_FALSE, DL, VT: MVT::Glue));
4640
4641 return SDValue();
4642}
4643
4644SDValue DAGCombiner::visitSUBO(SDNode *N) {
4645 SDValue N0 = N->getOperand(Num: 0);
4646 SDValue N1 = N->getOperand(Num: 1);
4647 EVT VT = N0.getValueType();
4648 bool IsSigned = (ISD::SSUBO == N->getOpcode());
4649
4650 EVT CarryVT = N->getValueType(ResNo: 1);
4651 SDLoc DL(N);
4652
4653 // If the flag result is dead, turn this into an SUB.
4654 if (!N->hasAnyUseOfValue(Value: 1))
4655 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: N1),
4656 Res1: DAG.getUNDEF(VT: CarryVT));
4657
4658 // fold (subo x, x) -> 0 + no borrow
4659 if (N0 == N1)
4660 return CombineTo(N, Res0: DAG.getConstant(Val: 0, DL, VT),
4661 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
4662
4663 // fold (subox, c) -> (addo x, -c)
4664 if (ConstantSDNode *N1C = getAsNonOpaqueConstant(N: N1))
4665 if (IsSigned && !N1C->isMinSignedValue())
4666 return DAG.getNode(Opcode: ISD::SADDO, DL, VTList: N->getVTList(), N1: N0,
4667 N2: DAG.getConstant(Val: -N1C->getAPIntValue(), DL, VT));
4668
4669 // fold (subo x, 0) -> x + no borrow
4670 if (isNullOrNullSplat(V: N1))
4671 return CombineTo(N, Res0: N0, Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
4672
4673 // If it cannot overflow, transform into an sub.
4674 if (DAG.willNotOverflowSub(IsSigned, N0, N1))
4675 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: N1),
4676 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
4677
4678 // Canonicalize (usubo -1, x) -> ~x, i.e. (xor x, -1) + no borrow
4679 if (!IsSigned && isAllOnesOrAllOnesSplat(V: N0))
4680 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::XOR, DL, VT, N1, N2: N0),
4681 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
4682
4683 return SDValue();
4684}
4685
4686SDValue DAGCombiner::visitSUBE(SDNode *N) {
4687 SDValue N0 = N->getOperand(Num: 0);
4688 SDValue N1 = N->getOperand(Num: 1);
4689 SDValue CarryIn = N->getOperand(Num: 2);
4690
4691 // fold (sube x, y, false) -> (subc x, y)
4692 if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
4693 return DAG.getNode(Opcode: ISD::SUBC, DL: SDLoc(N), VTList: N->getVTList(), N1: N0, N2: N1);
4694
4695 return SDValue();
4696}
4697
4698SDValue DAGCombiner::visitUSUBO_CARRY(SDNode *N) {
4699 SDValue N0 = N->getOperand(Num: 0);
4700 SDValue N1 = N->getOperand(Num: 1);
4701 SDValue CarryIn = N->getOperand(Num: 2);
4702
4703 // fold (usubo_carry x, y, false) -> (usubo x, y)
4704 if (isNullConstant(V: CarryIn)) {
4705 if (!LegalOperations ||
4706 TLI.isOperationLegalOrCustom(Op: ISD::USUBO, VT: N->getValueType(ResNo: 0)))
4707 return DAG.getNode(Opcode: ISD::USUBO, DL: SDLoc(N), VTList: N->getVTList(), N1: N0, N2: N1);
4708 }
4709
4710 return SDValue();
4711}
4712
4713SDValue DAGCombiner::visitSSUBO_CARRY(SDNode *N) {
4714 SDValue N0 = N->getOperand(Num: 0);
4715 SDValue N1 = N->getOperand(Num: 1);
4716 SDValue CarryIn = N->getOperand(Num: 2);
4717
4718 // fold (ssubo_carry x, y, false) -> (ssubo x, y)
4719 if (isNullConstant(V: CarryIn)) {
4720 if (!LegalOperations ||
4721 TLI.isOperationLegalOrCustom(Op: ISD::SSUBO, VT: N->getValueType(ResNo: 0)))
4722 return DAG.getNode(Opcode: ISD::SSUBO, DL: SDLoc(N), VTList: N->getVTList(), N1: N0, N2: N1);
4723 }
4724
4725 return SDValue();
4726}
4727
4728// Notice that "mulfix" can be any of SMULFIX, SMULFIXSAT, UMULFIX and
4729// UMULFIXSAT here.
4730SDValue DAGCombiner::visitMULFIX(SDNode *N) {
4731 SDValue N0 = N->getOperand(Num: 0);
4732 SDValue N1 = N->getOperand(Num: 1);
4733 SDValue Scale = N->getOperand(Num: 2);
4734 EVT VT = N0.getValueType();
4735
4736 // fold (mulfix x, undef, scale) -> 0
4737 if (N0.isUndef() || N1.isUndef())
4738 return DAG.getConstant(Val: 0, DL: SDLoc(N), VT);
4739
4740 // Canonicalize constant to RHS (vector doesn't have to splat)
4741 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
4742 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
4743 return DAG.getNode(Opcode: N->getOpcode(), DL: SDLoc(N), VT, N1, N2: N0, N3: Scale);
4744
4745 // fold (mulfix x, 0, scale) -> 0
4746 if (isNullConstant(V: N1))
4747 return DAG.getConstant(Val: 0, DL: SDLoc(N), VT);
4748
4749 return SDValue();
4750}
4751
4752template <class MatchContextClass> SDValue DAGCombiner::visitMUL(SDNode *N) {
4753 SDValue N0 = N->getOperand(Num: 0);
4754 SDValue N1 = N->getOperand(Num: 1);
4755 EVT VT = N0.getValueType();
4756 unsigned BitWidth = VT.getScalarSizeInBits();
4757 SDLoc DL(N);
4758 bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
4759 MatchContextClass Matcher(DAG, TLI, N);
4760
4761 // fold (mul x, undef) -> 0
4762 if (N0.isUndef() || N1.isUndef())
4763 return DAG.getConstant(Val: 0, DL, VT);
4764
4765 // fold (mul c1, c2) -> c1*c2
4766 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::MUL, DL, VT, Ops: {N0, N1}))
4767 return C;
4768
4769 // canonicalize constant to RHS (vector doesn't have to splat)
4770 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
4771 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
4772 return Matcher.getNode(ISD::MUL, DL, VT, N1, N0);
4773
4774 bool N1IsConst = false;
4775 bool N1IsOpaqueConst = false;
4776 APInt ConstValue1;
4777
4778 // fold vector ops
4779 if (VT.isVector()) {
4780 // TODO: Change this to use SimplifyVBinOp when it supports VP op.
4781 if (!UseVP)
4782 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4783 return FoldedVOp;
4784
4785 N1IsConst = ISD::isConstantSplatVector(N: N1.getNode(), SplatValue&: ConstValue1);
4786 assert((!N1IsConst || ConstValue1.getBitWidth() == BitWidth) &&
4787 "Splat APInt should be element width");
4788 } else {
4789 N1IsConst = isa<ConstantSDNode>(Val: N1);
4790 if (N1IsConst) {
4791 ConstValue1 = N1->getAsAPIntVal();
4792 N1IsOpaqueConst = cast<ConstantSDNode>(Val&: N1)->isOpaque();
4793 }
4794 }
4795
4796 // fold (mul x, 0) -> 0
4797 if (N1IsConst && ConstValue1.isZero())
4798 return N1;
4799
4800 // fold (mul x, 1) -> x
4801 if (N1IsConst && ConstValue1.isOne())
4802 return N0;
4803
4804 if (!UseVP)
4805 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
4806 return NewSel;
4807
4808 // fold (mul x, -1) -> 0-x
4809 if (N1IsConst && ConstValue1.isAllOnes())
4810 return Matcher.getNode(ISD::SUB, DL, VT, DAG.getConstant(Val: 0, DL, VT), N0);
4811
4812 // fold (mul x, (1 << c)) -> x << c
4813 if (isConstantOrConstantVector(N: N1, /*NoOpaques*/ true) &&
4814 (!VT.isVector() || Level <= AfterLegalizeVectorOps)) {
4815 if (SDValue LogBase2 = BuildLogBase2(V: N1, DL)) {
4816 EVT ShiftVT = getShiftAmountTy(LHSTy: N0.getValueType());
4817 SDValue Trunc = DAG.getZExtOrTrunc(Op: LogBase2, DL, VT: ShiftVT);
4818 SDNodeFlags Flags;
4819 Flags.setNoUnsignedWrap(N->getFlags().hasNoUnsignedWrap());
4820 // TODO: Preserve setNoSignedWrap if LogBase2 isn't BitWidth - 1.
4821 return Matcher.getNode(ISD::SHL, DL, VT, N0, Trunc, Flags);
4822 }
4823 }
4824
4825 // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c
4826 if (N1IsConst && !N1IsOpaqueConst && ConstValue1.isNegatedPowerOf2()) {
4827 unsigned Log2Val = (-ConstValue1).logBase2();
4828
4829 // FIXME: If the input is something that is easily negated (e.g. a
4830 // single-use add), we should put the negate there.
4831 return Matcher.getNode(
4832 ISD::SUB, DL, VT, DAG.getConstant(Val: 0, DL, VT),
4833 Matcher.getNode(ISD::SHL, DL, VT, N0,
4834 DAG.getShiftAmountConstant(Val: Log2Val, VT, DL)));
4835 }
4836
4837 // Attempt to reuse an existing umul_lohi/smul_lohi node, but only if the
4838 // hi result is in use in case we hit this mid-legalization.
4839 if (!UseVP) {
4840 for (unsigned LoHiOpc : {ISD::UMUL_LOHI, ISD::SMUL_LOHI}) {
4841 if (!LegalOperations || TLI.isOperationLegalOrCustom(Op: LoHiOpc, VT)) {
4842 SDVTList LoHiVT = DAG.getVTList(VT1: VT, VT2: VT);
4843 // TODO: Can we match commutable operands with getNodeIfExists?
4844 if (SDNode *LoHi = DAG.getNodeIfExists(Opcode: LoHiOpc, VTList: LoHiVT, Ops: {N0, N1}))
4845 if (LoHi->hasAnyUseOfValue(Value: 1))
4846 return SDValue(LoHi, 0);
4847 if (SDNode *LoHi = DAG.getNodeIfExists(Opcode: LoHiOpc, VTList: LoHiVT, Ops: {N1, N0}))
4848 if (LoHi->hasAnyUseOfValue(Value: 1))
4849 return SDValue(LoHi, 0);
4850 }
4851 }
4852 }
4853
4854 // Try to transform:
4855 // (1) multiply-by-(power-of-2 +/- 1) into shift and add/sub.
4856 // mul x, (2^N + 1) --> add (shl x, N), x
4857 // mul x, (2^N - 1) --> sub (shl x, N), x
4858 // Examples: x * 33 --> (x << 5) + x
4859 // x * 15 --> (x << 4) - x
4860 // x * -33 --> -((x << 5) + x)
4861 // x * -15 --> -((x << 4) - x) ; this reduces --> x - (x << 4)
4862 // (2) multiply-by-(power-of-2 +/- power-of-2) into shifts and add/sub.
4863 // mul x, (2^N + 2^M) --> (add (shl x, N), (shl x, M))
4864 // mul x, (2^N - 2^M) --> (sub (shl x, N), (shl x, M))
4865 // Examples: x * 0x8800 --> (x << 15) + (x << 11)
4866 // x * 0xf800 --> (x << 16) - (x << 11)
4867 // x * -0x8800 --> -((x << 15) + (x << 11))
4868 // x * -0xf800 --> -((x << 16) - (x << 11)) ; (x << 11) - (x << 16)
4869 if (!UseVP && N1IsConst &&
4870 TLI.decomposeMulByConstant(Context&: *DAG.getContext(), VT, C: N1)) {
4871 // TODO: We could handle more general decomposition of any constant by
4872 // having the target set a limit on number of ops and making a
4873 // callback to determine that sequence (similar to sqrt expansion).
4874 unsigned MathOp = ISD::DELETED_NODE;
4875 APInt MulC = ConstValue1.abs();
4876 // The constant `2` should be treated as (2^0 + 1).
4877 unsigned TZeros = MulC == 2 ? 0 : MulC.countr_zero();
4878 MulC.lshrInPlace(ShiftAmt: TZeros);
4879 if ((MulC - 1).isPowerOf2())
4880 MathOp = ISD::ADD;
4881 else if ((MulC + 1).isPowerOf2())
4882 MathOp = ISD::SUB;
4883
4884 if (MathOp != ISD::DELETED_NODE) {
4885 unsigned ShAmt =
4886 MathOp == ISD::ADD ? (MulC - 1).logBase2() : (MulC + 1).logBase2();
4887 ShAmt += TZeros;
4888 assert(ShAmt < BitWidth &&
4889 "multiply-by-constant generated out of bounds shift");
4890 SDValue Shl =
4891 DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0, N2: DAG.getConstant(Val: ShAmt, DL, VT));
4892 SDValue R =
4893 TZeros ? DAG.getNode(Opcode: MathOp, DL, VT, N1: Shl,
4894 N2: DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0,
4895 N2: DAG.getConstant(Val: TZeros, DL, VT)))
4896 : DAG.getNode(Opcode: MathOp, DL, VT, N1: Shl, N2: N0);
4897 if (ConstValue1.isNegative())
4898 R = DAG.getNegative(Val: R, DL, VT);
4899 return R;
4900 }
4901 }
4902
4903 // (mul (shl X, c1), c2) -> (mul X, c2 << c1)
4904 if (sd_context_match(N0, Matcher, m_Opc(Opcode: ISD::SHL))) {
4905 SDValue N01 = N0.getOperand(i: 1);
4906 if (SDValue C3 = DAG.FoldConstantArithmetic(Opcode: ISD::SHL, DL, VT, Ops: {N1, N01}))
4907 return DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: N0.getOperand(i: 0), N2: C3);
4908 }
4909
4910 // Change (mul (shl X, C), Y) -> (shl (mul X, Y), C) when the shift has one
4911 // use.
4912 {
4913 SDValue Sh, Y;
4914
4915 // Check for both (mul (shl X, C), Y) and (mul Y, (shl X, C)).
4916 if (sd_context_match(N0, Matcher, m_OneUse(P: m_Opc(Opcode: ISD::SHL))) &&
4917 isConstantOrConstantVector(N: N0.getOperand(i: 1))) {
4918 Sh = N0; Y = N1;
4919 } else if (sd_context_match(N1, Matcher, m_OneUse(P: m_Opc(Opcode: ISD::SHL))) &&
4920 isConstantOrConstantVector(N: N1.getOperand(i: 1))) {
4921 Sh = N1; Y = N0;
4922 }
4923
4924 if (Sh.getNode()) {
4925 SDValue Mul = Matcher.getNode(ISD::MUL, DL, VT, Sh.getOperand(i: 0), Y);
4926 return Matcher.getNode(ISD::SHL, DL, VT, Mul, Sh.getOperand(i: 1));
4927 }
4928 }
4929
4930 // fold (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2)
4931 if (sd_context_match(N0, Matcher, m_Opc(Opcode: ISD::ADD)) &&
4932 isConstantOrConstantVector(N: N1) &&
4933 isConstantOrConstantVector(N: N0.getOperand(i: 1)) &&
4934 isMulAddWithConstProfitable(MulNode: N, AddNode: N0, ConstNode: N1))
4935 return Matcher.getNode(
4936 ISD::ADD, DL, VT,
4937 Matcher.getNode(ISD::MUL, SDLoc(N0), VT, N0.getOperand(i: 0), N1),
4938 Matcher.getNode(ISD::MUL, SDLoc(N1), VT, N0.getOperand(i: 1), N1));
4939
4940 // Fold (mul (vscale * C0), C1) to (vscale * (C0 * C1)).
4941 ConstantSDNode *NC1 = isConstOrConstSplat(N: N1);
4942 if (!UseVP && N0.getOpcode() == ISD::VSCALE && NC1) {
4943 const APInt &C0 = N0.getConstantOperandAPInt(i: 0);
4944 const APInt &C1 = NC1->getAPIntValue();
4945 return DAG.getVScale(DL, VT, MulImm: C0 * C1);
4946 }
4947
4948 // Fold (mul step_vector(C0), C1) to (step_vector(C0 * C1)).
4949 APInt MulVal;
4950 if (!UseVP && N0.getOpcode() == ISD::STEP_VECTOR &&
4951 ISD::isConstantSplatVector(N: N1.getNode(), SplatValue&: MulVal)) {
4952 const APInt &C0 = N0.getConstantOperandAPInt(i: 0);
4953 APInt NewStep = C0 * MulVal;
4954 return DAG.getStepVector(DL, ResVT: VT, StepVal: NewStep);
4955 }
4956
4957 // Fold Y = sra (X, size(X)-1); mul (or (Y, 1), X) -> (abs X)
4958 SDValue X;
4959 if (!UseVP && (!LegalOperations || hasOperation(Opcode: ISD::ABS, VT)) &&
4960 sd_context_match(
4961 N, Matcher,
4962 m_Mul(L: m_Or(L: m_Sra(L: m_Value(N&: X), R: m_SpecificInt(V: BitWidth - 1)), R: m_One()),
4963 R: m_Deferred(V&: X)))) {
4964 return Matcher.getNode(ISD::ABS, DL, VT, X);
4965 }
4966
4967 // Fold ((mul x, 0/undef) -> 0,
4968 // (mul x, 1) -> x) -> x)
4969 // -> and(x, mask)
4970 // We can replace vectors with '0' and '1' factors with a clearing mask.
4971 if (VT.isFixedLengthVector()) {
4972 unsigned NumElts = VT.getVectorNumElements();
4973 SmallBitVector ClearMask;
4974 ClearMask.reserve(N: NumElts);
4975 auto IsClearMask = [&ClearMask](ConstantSDNode *V) {
4976 if (!V || V->isZero()) {
4977 ClearMask.push_back(Val: true);
4978 return true;
4979 }
4980 ClearMask.push_back(Val: false);
4981 return V->isOne();
4982 };
4983 if ((!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::AND, VT)) &&
4984 ISD::matchUnaryPredicate(Op: N1, Match: IsClearMask, /*AllowUndefs*/ true)) {
4985 assert(N1.getOpcode() == ISD::BUILD_VECTOR && "Unknown constant vector");
4986 EVT LegalSVT = N1.getOperand(i: 0).getValueType();
4987 SDValue Zero = DAG.getConstant(Val: 0, DL, VT: LegalSVT);
4988 SDValue AllOnes = DAG.getAllOnesConstant(DL, VT: LegalSVT);
4989 SmallVector<SDValue, 16> Mask(NumElts, AllOnes);
4990 for (unsigned I = 0; I != NumElts; ++I)
4991 if (ClearMask[I])
4992 Mask[I] = Zero;
4993 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0, N2: DAG.getBuildVector(VT, DL, Ops: Mask));
4994 }
4995 }
4996
4997 // reassociate mul
4998 // TODO: Change reassociateOps to support vp ops.
4999 if (!UseVP)
5000 if (SDValue RMUL = reassociateOps(Opc: ISD::MUL, DL, N0, N1, Flags: N->getFlags()))
5001 return RMUL;
5002
5003 // Fold mul(vecreduce(x), vecreduce(y)) -> vecreduce(mul(x, y))
5004 // TODO: Change reassociateReduction to support vp ops.
5005 if (!UseVP)
5006 if (SDValue SD =
5007 reassociateReduction(RedOpc: ISD::VECREDUCE_MUL, Opc: ISD::MUL, DL, VT, N0, N1))
5008 return SD;
5009
5010 // Simplify the operands using demanded-bits information.
5011 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
5012 return SDValue(N, 0);
5013
5014 return SDValue();
5015}
5016
5017/// Return true if divmod libcall is available.
5018static bool isDivRemLibcallAvailable(SDNode *Node, bool isSigned,
5019 const SelectionDAG &DAG) {
5020 RTLIB::Libcall LC;
5021 EVT NodeType = Node->getValueType(ResNo: 0);
5022 if (!NodeType.isSimple())
5023 return false;
5024 switch (NodeType.getSimpleVT().SimpleTy) {
5025 default: return false; // No libcall for vector types.
5026 case MVT::i8: LC= isSigned ? RTLIB::SDIVREM_I8 : RTLIB::UDIVREM_I8; break;
5027 case MVT::i16: LC= isSigned ? RTLIB::SDIVREM_I16 : RTLIB::UDIVREM_I16; break;
5028 case MVT::i32: LC= isSigned ? RTLIB::SDIVREM_I32 : RTLIB::UDIVREM_I32; break;
5029 case MVT::i64: LC= isSigned ? RTLIB::SDIVREM_I64 : RTLIB::UDIVREM_I64; break;
5030 case MVT::i128: LC= isSigned ? RTLIB::SDIVREM_I128:RTLIB::UDIVREM_I128; break;
5031 }
5032
5033 return DAG.getLibcalls().getLibcallImpl(Call: LC) != RTLIB::Unsupported;
5034}
5035
5036/// Issue divrem if both quotient and remainder are needed.
5037SDValue DAGCombiner::useDivRem(SDNode *Node) {
5038 if (Node->use_empty())
5039 return SDValue(); // This is a dead node, leave it alone.
5040
5041 unsigned Opcode = Node->getOpcode();
5042 bool isSigned = (Opcode == ISD::SDIV) || (Opcode == ISD::SREM);
5043 unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM;
5044
5045 // DivMod lib calls can still work on non-legal types if using lib-calls.
5046 EVT VT = Node->getValueType(ResNo: 0);
5047 if (VT.isVector() || !VT.isInteger())
5048 return SDValue();
5049
5050 if (!TLI.isTypeLegal(VT) && !TLI.isOperationCustom(Op: DivRemOpc, VT))
5051 return SDValue();
5052
5053 // If DIVREM is going to get expanded into a libcall,
5054 // but there is no libcall available, then don't combine.
5055 if (!TLI.isOperationLegalOrCustom(Op: DivRemOpc, VT) &&
5056 !isDivRemLibcallAvailable(Node, isSigned, DAG))
5057 return SDValue();
5058
5059 // If div is legal, it's better to do the normal expansion
5060 unsigned OtherOpcode = 0;
5061 if ((Opcode == ISD::SDIV) || (Opcode == ISD::UDIV)) {
5062 OtherOpcode = isSigned ? ISD::SREM : ISD::UREM;
5063 if (TLI.isOperationLegalOrCustom(Op: Opcode, VT))
5064 return SDValue();
5065 } else {
5066 OtherOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
5067 if (TLI.isOperationLegalOrCustom(Op: OtherOpcode, VT))
5068 return SDValue();
5069 }
5070
5071 SDValue Op0 = Node->getOperand(Num: 0);
5072 SDValue Op1 = Node->getOperand(Num: 1);
5073 SDValue combined;
5074 for (SDNode *User : Op0->users()) {
5075 if (User == Node || User->getOpcode() == ISD::DELETED_NODE ||
5076 User->use_empty())
5077 continue;
5078 // Convert the other matching node(s), too;
5079 // otherwise, the DIVREM may get target-legalized into something
5080 // target-specific that we won't be able to recognize.
5081 unsigned UserOpc = User->getOpcode();
5082 if ((UserOpc == Opcode || UserOpc == OtherOpcode || UserOpc == DivRemOpc) &&
5083 User->getOperand(Num: 0) == Op0 &&
5084 User->getOperand(Num: 1) == Op1) {
5085 if (!combined) {
5086 if (UserOpc == OtherOpcode) {
5087 SDVTList VTs = DAG.getVTList(VT1: VT, VT2: VT);
5088 combined = DAG.getNode(Opcode: DivRemOpc, DL: SDLoc(Node), VTList: VTs, N1: Op0, N2: Op1);
5089 } else if (UserOpc == DivRemOpc) {
5090 combined = SDValue(User, 0);
5091 } else {
5092 assert(UserOpc == Opcode);
5093 continue;
5094 }
5095 }
5096 if (UserOpc == ISD::SDIV || UserOpc == ISD::UDIV)
5097 CombineTo(N: User, Res: combined);
5098 else if (UserOpc == ISD::SREM || UserOpc == ISD::UREM)
5099 CombineTo(N: User, Res: combined.getValue(R: 1));
5100 }
5101 }
5102 return combined;
5103}
5104
5105static SDValue simplifyDivRem(SDNode *N, SelectionDAG &DAG) {
5106 SDValue N0 = N->getOperand(Num: 0);
5107 SDValue N1 = N->getOperand(Num: 1);
5108 EVT VT = N->getValueType(ResNo: 0);
5109 SDLoc DL(N);
5110
5111 unsigned Opc = N->getOpcode();
5112 bool IsDiv = (ISD::SDIV == Opc) || (ISD::UDIV == Opc);
5113
5114 // X / undef -> undef
5115 // X % undef -> undef
5116 // X / 0 -> undef
5117 // X % 0 -> undef
5118 // NOTE: This includes vectors where any divisor element is zero/undef.
5119 if (DAG.isUndef(Opcode: Opc, Ops: {N0, N1}))
5120 return DAG.getUNDEF(VT);
5121
5122 // undef / X -> 0
5123 // undef % X -> 0
5124 if (N0.isUndef())
5125 return DAG.getConstant(Val: 0, DL, VT);
5126
5127 // 0 / X -> 0
5128 // 0 % X -> 0
5129 ConstantSDNode *N0C = isConstOrConstSplat(N: N0);
5130 if (N0C && N0C->isZero())
5131 return N0;
5132
5133 // X / X -> 1
5134 // X % X -> 0
5135 if (N0 == N1)
5136 return DAG.getConstant(Val: IsDiv ? 1 : 0, DL, VT);
5137
5138 // X / 1 -> X
5139 // X % 1 -> 0
5140 // If this is a boolean op (single-bit element type), we can't have
5141 // division-by-zero or remainder-by-zero, so assume the divisor is 1.
5142 // TODO: Similarly, if we're zero-extending a boolean divisor, then assume
5143 // it's a 1.
5144 if (isOneOrOneSplat(V: N1) || (VT.getScalarType() == MVT::i1))
5145 return IsDiv ? N0 : DAG.getConstant(Val: 0, DL, VT);
5146
5147 return SDValue();
5148}
5149
5150SDValue DAGCombiner::visitSDIV(SDNode *N) {
5151 SDValue N0 = N->getOperand(Num: 0);
5152 SDValue N1 = N->getOperand(Num: 1);
5153 EVT VT = N->getValueType(ResNo: 0);
5154 EVT CCVT = getSetCCResultType(VT);
5155 SDLoc DL(N);
5156
5157 // fold (sdiv c1, c2) -> c1/c2
5158 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::SDIV, DL, VT, Ops: {N0, N1}))
5159 return C;
5160
5161 // fold vector ops
5162 if (VT.isVector())
5163 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5164 return FoldedVOp;
5165
5166 // fold (sdiv X, -1) -> 0-X
5167 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
5168 if (N1C && N1C->isAllOnes())
5169 return DAG.getNegative(Val: N0, DL, VT);
5170
5171 // fold (sdiv X, MIN_SIGNED) -> select(X == MIN_SIGNED, 1, 0)
5172 if (N1C && N1C->isMinSignedValue())
5173 return DAG.getSelect(DL, VT, Cond: DAG.getSetCC(DL, VT: CCVT, LHS: N0, RHS: N1, Cond: ISD::SETEQ),
5174 LHS: DAG.getConstant(Val: 1, DL, VT),
5175 RHS: DAG.getConstant(Val: 0, DL, VT));
5176
5177 if (SDValue V = simplifyDivRem(N, DAG))
5178 return V;
5179
5180 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
5181 return NewSel;
5182
5183 // If we know the sign bits of both operands are zero, strength reduce to a
5184 // udiv instead. Handles (X&15) /s 4 -> X&15 >> 2
5185 if (DAG.SignBitIsZero(Op: N1) && DAG.SignBitIsZero(Op: N0))
5186 return DAG.getNode(Opcode: ISD::UDIV, DL, VT: N1.getValueType(), N1: N0, N2: N1);
5187
5188 if (SDValue V = visitSDIVLike(N0, N1, N)) {
5189 // If the corresponding remainder node exists, update its users with
5190 // (Dividend - (Quotient * Divisor).
5191 if (SDNode *RemNode = DAG.getNodeIfExists(Opcode: ISD::SREM, VTList: N->getVTList(),
5192 Ops: { N0, N1 })) {
5193 // If the sdiv has the exact flag we shouldn't propagate it to the
5194 // remainder node.
5195 if (!N->getFlags().hasExact()) {
5196 SDValue Mul = DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: V, N2: N1);
5197 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: Mul);
5198 AddToWorklist(N: Mul.getNode());
5199 AddToWorklist(N: Sub.getNode());
5200 CombineTo(N: RemNode, Res: Sub);
5201 }
5202 }
5203 return V;
5204 }
5205
5206 // sdiv, srem -> sdivrem
5207 // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
5208 // true. Otherwise, we break the simplification logic in visitREM().
5209 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
5210 if (!N1C || TLI.isIntDivCheap(VT: N->getValueType(ResNo: 0), Attr))
5211 if (SDValue DivRem = useDivRem(Node: N))
5212 return DivRem;
5213
5214 return SDValue();
5215}
5216
5217static bool isDivisorPowerOfTwo(SDValue Divisor) {
5218 // Helper for determining whether a value is a power-2 constant scalar or a
5219 // vector of such elements.
5220 auto IsPowerOfTwo = [](ConstantSDNode *C) {
5221 if (C->isZero() || C->isOpaque())
5222 return false;
5223 if (C->getAPIntValue().isPowerOf2())
5224 return true;
5225 if (C->getAPIntValue().isNegatedPowerOf2())
5226 return true;
5227 return false;
5228 };
5229
5230 return ISD::matchUnaryPredicate(Op: Divisor, Match: IsPowerOfTwo, /*AllowUndefs=*/false,
5231 /*AllowTruncation=*/true);
5232}
5233
5234SDValue DAGCombiner::visitSDIVLike(SDValue N0, SDValue N1, SDNode *N) {
5235 SDLoc DL(N);
5236 EVT VT = N->getValueType(ResNo: 0);
5237 EVT CCVT = getSetCCResultType(VT);
5238 unsigned BitWidth = VT.getScalarSizeInBits();
5239
5240 // fold (sdiv X, pow2) -> simple ops after legalize
5241 // FIXME: We check for the exact bit here because the generic lowering gives
5242 // better results in that case. The target-specific lowering should learn how
5243 // to handle exact sdivs efficiently.
5244 if (!N->getFlags().hasExact() && isDivisorPowerOfTwo(Divisor: N1)) {
5245 // Target-specific implementation of sdiv x, pow2.
5246 if (SDValue Res = BuildSDIVPow2(N))
5247 return Res;
5248
5249 // Create constants that are functions of the shift amount value.
5250 EVT ShiftAmtTy = getShiftAmountTy(LHSTy: N0.getValueType());
5251 SDValue Bits = DAG.getConstant(Val: BitWidth, DL, VT: ShiftAmtTy);
5252 SDValue C1 = DAG.getNode(Opcode: ISD::CTTZ, DL, VT, Operand: N1);
5253 C1 = DAG.getZExtOrTrunc(Op: C1, DL, VT: ShiftAmtTy);
5254 SDValue Inexact = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftAmtTy, N1: Bits, N2: C1);
5255 if (!isConstantOrConstantVector(N: Inexact))
5256 return SDValue();
5257
5258 // Splat the sign bit into the register
5259 SDValue Sign = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: N0,
5260 N2: DAG.getConstant(Val: BitWidth - 1, DL, VT: ShiftAmtTy));
5261 AddToWorklist(N: Sign.getNode());
5262
5263 // Add (N0 < 0) ? abs2 - 1 : 0;
5264 SDValue Srl = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Sign, N2: Inexact);
5265 AddToWorklist(N: Srl.getNode());
5266 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: Srl);
5267 AddToWorklist(N: Add.getNode());
5268 SDValue Sra = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: Add, N2: C1);
5269 AddToWorklist(N: Sra.getNode());
5270
5271 // Special case: (sdiv X, 1) -> X
5272 // Special Case: (sdiv X, -1) -> 0-X
5273 SDValue One = DAG.getConstant(Val: 1, DL, VT);
5274 SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
5275 SDValue IsOne = DAG.getSetCC(DL, VT: CCVT, LHS: N1, RHS: One, Cond: ISD::SETEQ);
5276 SDValue IsAllOnes = DAG.getSetCC(DL, VT: CCVT, LHS: N1, RHS: AllOnes, Cond: ISD::SETEQ);
5277 SDValue IsOneOrAllOnes = DAG.getNode(Opcode: ISD::OR, DL, VT: CCVT, N1: IsOne, N2: IsAllOnes);
5278 Sra = DAG.getSelect(DL, VT, Cond: IsOneOrAllOnes, LHS: N0, RHS: Sra);
5279
5280 // If dividing by a positive value, we're done. Otherwise, the result must
5281 // be negated.
5282 SDValue Zero = DAG.getConstant(Val: 0, DL, VT);
5283 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Zero, N2: Sra);
5284
5285 // FIXME: Use SELECT_CC once we improve SELECT_CC constant-folding.
5286 SDValue IsNeg = DAG.getSetCC(DL, VT: CCVT, LHS: N1, RHS: Zero, Cond: ISD::SETLT);
5287 SDValue Res = DAG.getSelect(DL, VT, Cond: IsNeg, LHS: Sub, RHS: Sra);
5288 return Res;
5289 }
5290
5291 // If integer divide is expensive and we satisfy the requirements, emit an
5292 // alternate sequence. Targets may check function attributes for size/speed
5293 // trade-offs.
5294 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
5295 if (isConstantOrConstantVector(N: N1, /*NoOpaques=*/false,
5296 /*AllowTruncation=*/true) &&
5297 !TLI.isIntDivCheap(VT: N->getValueType(ResNo: 0), Attr))
5298 if (SDValue Op = BuildSDIV(N))
5299 return Op;
5300
5301 return SDValue();
5302}
5303
5304SDValue DAGCombiner::visitUDIV(SDNode *N) {
5305 SDValue N0 = N->getOperand(Num: 0);
5306 SDValue N1 = N->getOperand(Num: 1);
5307 EVT VT = N->getValueType(ResNo: 0);
5308 EVT CCVT = getSetCCResultType(VT);
5309 SDLoc DL(N);
5310
5311 // fold (udiv c1, c2) -> c1/c2
5312 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::UDIV, DL, VT, Ops: {N0, N1}))
5313 return C;
5314
5315 // fold vector ops
5316 if (VT.isVector())
5317 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5318 return FoldedVOp;
5319
5320 // fold (udiv X, -1) -> select(X == -1, 1, 0)
5321 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
5322 if (N1C && N1C->isAllOnes() && CCVT.isVector() == VT.isVector()) {
5323 return DAG.getSelect(DL, VT, Cond: DAG.getSetCC(DL, VT: CCVT, LHS: N0, RHS: N1, Cond: ISD::SETEQ),
5324 LHS: DAG.getConstant(Val: 1, DL, VT),
5325 RHS: DAG.getConstant(Val: 0, DL, VT));
5326 }
5327
5328 if (SDValue V = simplifyDivRem(N, DAG))
5329 return V;
5330
5331 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
5332 return NewSel;
5333
5334 if (SDValue V = visitUDIVLike(N0, N1, N)) {
5335 // If the corresponding remainder node exists, update its users with
5336 // (Dividend - (Quotient * Divisor).
5337 if (SDNode *RemNode = DAG.getNodeIfExists(Opcode: ISD::UREM, VTList: N->getVTList(),
5338 Ops: { N0, N1 })) {
5339 // If the udiv has the exact flag we shouldn't propagate it to the
5340 // remainder node.
5341 if (!N->getFlags().hasExact()) {
5342 SDValue Mul = DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: V, N2: N1);
5343 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: Mul);
5344 AddToWorklist(N: Mul.getNode());
5345 AddToWorklist(N: Sub.getNode());
5346 CombineTo(N: RemNode, Res: Sub);
5347 }
5348 }
5349 return V;
5350 }
5351
5352 // sdiv, srem -> sdivrem
5353 // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
5354 // true. Otherwise, we break the simplification logic in visitREM().
5355 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
5356 if (!N1C || TLI.isIntDivCheap(VT: N->getValueType(ResNo: 0), Attr))
5357 if (SDValue DivRem = useDivRem(Node: N))
5358 return DivRem;
5359
5360 // Simplify the operands using demanded-bits information.
5361 // We don't have demanded bits support for UDIV so this just enables constant
5362 // folding based on known bits.
5363 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
5364 return SDValue(N, 0);
5365
5366 return SDValue();
5367}
5368
5369SDValue DAGCombiner::visitUDIVLike(SDValue N0, SDValue N1, SDNode *N) {
5370 SDLoc DL(N);
5371 EVT VT = N->getValueType(ResNo: 0);
5372
5373 // fold (udiv x, (1 << c)) -> x >>u c
5374 if (isConstantOrConstantVector(N: N1, /*NoOpaques=*/true,
5375 /*AllowTruncation=*/true)) {
5376 if (SDValue LogBase2 = BuildLogBase2(V: N1, DL)) {
5377 AddToWorklist(N: LogBase2.getNode());
5378
5379 EVT ShiftVT = getShiftAmountTy(LHSTy: N0.getValueType());
5380 SDValue Trunc = DAG.getZExtOrTrunc(Op: LogBase2, DL, VT: ShiftVT);
5381 AddToWorklist(N: Trunc.getNode());
5382 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0, N2: Trunc);
5383 }
5384 }
5385
5386 // fold (udiv x, (shl c, y)) -> x >>u (log2(c)+y) iff c is power of 2
5387 if (N1.getOpcode() == ISD::SHL) {
5388 SDValue N10 = N1.getOperand(i: 0);
5389 if (isConstantOrConstantVector(N: N10, /*NoOpaques=*/true,
5390 /*AllowTruncation=*/true)) {
5391 if (SDValue LogBase2 = BuildLogBase2(V: N10, DL)) {
5392 AddToWorklist(N: LogBase2.getNode());
5393
5394 EVT ADDVT = N1.getOperand(i: 1).getValueType();
5395 SDValue Trunc = DAG.getZExtOrTrunc(Op: LogBase2, DL, VT: ADDVT);
5396 AddToWorklist(N: Trunc.getNode());
5397 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT: ADDVT, N1: N1.getOperand(i: 1), N2: Trunc);
5398 AddToWorklist(N: Add.getNode());
5399 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0, N2: Add);
5400 }
5401 }
5402 }
5403
5404 // fold (udiv x, c) -> alternate
5405 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
5406 if (isConstantOrConstantVector(N: N1, /*NoOpaques=*/false,
5407 /*AllowTruncation=*/true) &&
5408 !TLI.isIntDivCheap(VT: N->getValueType(ResNo: 0), Attr))
5409 if (SDValue Op = BuildUDIV(N))
5410 return Op;
5411
5412 return SDValue();
5413}
5414
5415SDValue DAGCombiner::buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N) {
5416 if (!N->getFlags().hasExact() && isDivisorPowerOfTwo(Divisor: N1) &&
5417 !DAG.doesNodeExist(Opcode: ISD::SDIV, VTList: N->getVTList(), Ops: {N0, N1})) {
5418 // Target-specific implementation of srem x, pow2.
5419 if (SDValue Res = BuildSREMPow2(N))
5420 return Res;
5421 }
5422 return SDValue();
5423}
5424
5425// handles ISD::SREM and ISD::UREM
5426SDValue DAGCombiner::visitREM(SDNode *N) {
5427 unsigned Opcode = N->getOpcode();
5428 SDValue N0 = N->getOperand(Num: 0);
5429 SDValue N1 = N->getOperand(Num: 1);
5430 EVT VT = N->getValueType(ResNo: 0);
5431 EVT CCVT = getSetCCResultType(VT);
5432
5433 bool isSigned = (Opcode == ISD::SREM);
5434 SDLoc DL(N);
5435
5436 // fold (rem c1, c2) -> c1%c2
5437 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, Ops: {N0, N1}))
5438 return C;
5439
5440 // fold (urem X, -1) -> select(FX == -1, 0, FX)
5441 // Freeze the numerator to avoid a miscompile with an undefined value.
5442 if (!isSigned && llvm::isAllOnesOrAllOnesSplat(V: N1, /*AllowUndefs*/ false) &&
5443 CCVT.isVector() == VT.isVector()) {
5444 SDValue F0 = DAG.getFreeze(V: N0);
5445 SDValue EqualsNeg1 = DAG.getSetCC(DL, VT: CCVT, LHS: F0, RHS: N1, Cond: ISD::SETEQ);
5446 return DAG.getSelect(DL, VT, Cond: EqualsNeg1, LHS: DAG.getConstant(Val: 0, DL, VT), RHS: F0);
5447 }
5448
5449 if (SDValue V = simplifyDivRem(N, DAG))
5450 return V;
5451
5452 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
5453 return NewSel;
5454
5455 if (isSigned) {
5456 // If we know the sign bits of both operands are zero, strength reduce to a
5457 // urem instead. Handles (X & 0x0FFFFFFF) %s 16 -> X&15
5458 if (DAG.SignBitIsZero(Op: N1) && DAG.SignBitIsZero(Op: N0))
5459 return DAG.getNode(Opcode: ISD::UREM, DL, VT, N1: N0, N2: N1);
5460 } else {
5461 if (DAG.isKnownToBeAPowerOfTwo(Val: N1)) {
5462 // fold (urem x, pow2) -> (and x, pow2-1)
5463 SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
5464 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1, N2: NegOne);
5465 AddToWorklist(N: Add.getNode());
5466 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0, N2: Add);
5467 }
5468 // fold (urem x, (shl pow2, y)) -> (and x, (add (shl pow2, y), -1))
5469 // fold (urem x, (lshr pow2, y)) -> (and x, (add (lshr pow2, y), -1))
5470 // TODO: We should sink the following into isKnownToBePowerOfTwo
5471 // using a OrZero parameter analogous to our handling in ValueTracking.
5472 if ((N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) &&
5473 DAG.isKnownToBeAPowerOfTwo(Val: N1.getOperand(i: 0))) {
5474 SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
5475 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1, N2: NegOne);
5476 AddToWorklist(N: Add.getNode());
5477 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0, N2: Add);
5478 }
5479 }
5480
5481 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
5482
5483 // If X/C can be simplified by the division-by-constant logic, lower
5484 // X%C to the equivalent of X-X/C*C.
5485 // Reuse the SDIVLike/UDIVLike combines - to avoid mangling nodes, the
5486 // speculative DIV must not cause a DIVREM conversion. We guard against this
5487 // by skipping the simplification if isIntDivCheap(). When div is not cheap,
5488 // combine will not return a DIVREM. Regardless, checking cheapness here
5489 // makes sense since the simplification results in fatter code.
5490 if (DAG.isKnownNeverZero(Op: N1) && !TLI.isIntDivCheap(VT, Attr)) {
5491 if (isSigned) {
5492 // check if we can build faster implementation for srem
5493 if (SDValue OptimizedRem = buildOptimizedSREM(N0, N1, N))
5494 return OptimizedRem;
5495 }
5496
5497 SDValue OptimizedDiv =
5498 isSigned ? visitSDIVLike(N0, N1, N) : visitUDIVLike(N0, N1, N);
5499 if (OptimizedDiv.getNode() && OptimizedDiv.getNode() != N) {
5500 // If the equivalent Div node also exists, update its users.
5501 unsigned DivOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
5502 if (SDNode *DivNode = DAG.getNodeIfExists(Opcode: DivOpcode, VTList: N->getVTList(),
5503 Ops: { N0, N1 }))
5504 CombineTo(N: DivNode, Res: OptimizedDiv);
5505 SDValue Mul = DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: OptimizedDiv, N2: N1);
5506 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: Mul);
5507 AddToWorklist(N: OptimizedDiv.getNode());
5508 AddToWorklist(N: Mul.getNode());
5509 return Sub;
5510 }
5511 }
5512
5513 // sdiv, srem -> sdivrem
5514 if (SDValue DivRem = useDivRem(Node: N))
5515 return DivRem.getValue(R: 1);
5516
5517 // fold urem(urem(A, BCst), Op1Cst) -> urem(A, Op1Cst)
5518 // iff urem(BCst, Op1Cst) == 0
5519 SDValue A;
5520 APInt Op1Cst, BCst;
5521 if (sd_match(N, P: m_URem(L: m_URem(L: m_Value(N&: A), R: m_ConstInt(V&: BCst)),
5522 R: m_ConstInt(V&: Op1Cst))) &&
5523 BCst.urem(RHS: Op1Cst).isZero()) {
5524 return DAG.getNode(Opcode: ISD::UREM, DL, VT, N1: A, N2: DAG.getConstant(Val: Op1Cst, DL, VT));
5525 }
5526
5527 // fold srem(srem(A, BCst), Op1Cst) -> srem(A, Op1Cst)
5528 // iff srem(BCst, Op1Cst) == 0 && Op1Cst != 1
5529 if (sd_match(N, P: m_SRem(L: m_SRem(L: m_Value(N&: A), R: m_ConstInt(V&: BCst)),
5530 R: m_ConstInt(V&: Op1Cst))) &&
5531 BCst.srem(RHS: Op1Cst).isZero() && !Op1Cst.isAllOnes()) {
5532 return DAG.getNode(Opcode: ISD::SREM, DL, VT, N1: A, N2: DAG.getConstant(Val: Op1Cst, DL, VT));
5533 }
5534
5535 return SDValue();
5536}
5537
5538SDValue DAGCombiner::visitMULHS(SDNode *N) {
5539 SDValue N0 = N->getOperand(Num: 0);
5540 SDValue N1 = N->getOperand(Num: 1);
5541 EVT VT = N->getValueType(ResNo: 0);
5542 SDLoc DL(N);
5543
5544 // fold (mulhs c1, c2)
5545 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::MULHS, DL, VT, Ops: {N0, N1}))
5546 return C;
5547
5548 // canonicalize constant to RHS.
5549 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5550 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5551 return DAG.getNode(Opcode: ISD::MULHS, DL, VTList: N->getVTList(), N1, N2: N0);
5552
5553 if (VT.isVector()) {
5554 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5555 return FoldedVOp;
5556
5557 // fold (mulhs x, 0) -> 0
5558 // do not return N1, because undef node may exist.
5559 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
5560 return DAG.getConstant(Val: 0, DL, VT);
5561 }
5562
5563 // fold (mulhs x, 0) -> 0
5564 if (isNullConstant(V: N1))
5565 return N1;
5566
5567 // fold (mulhs x, 1) -> (sra x, size(x)-1)
5568 if (isOneConstant(V: N1))
5569 return DAG.getNode(
5570 Opcode: ISD::SRA, DL, VT, N1: N0,
5571 N2: DAG.getShiftAmountConstant(Val: N0.getScalarValueSizeInBits() - 1, VT, DL));
5572
5573 // fold (mulhs x, undef) -> 0
5574 if (N0.isUndef() || N1.isUndef())
5575 return DAG.getConstant(Val: 0, DL, VT);
5576
5577 // If the type twice as wide is legal, transform the mulhs to a wider multiply
5578 // plus a shift.
5579 if (!TLI.isOperationLegalOrCustom(Op: ISD::MULHS, VT) && VT.isSimple() &&
5580 !VT.isVector()) {
5581 MVT Simple = VT.getSimpleVT();
5582 unsigned SimpleSize = Simple.getSizeInBits();
5583 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: SimpleSize*2);
5584 if (TLI.isOperationLegal(Op: ISD::MUL, VT: NewVT)) {
5585 N0 = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT: NewVT, Operand: N0);
5586 N1 = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT: NewVT, Operand: N1);
5587 N1 = DAG.getNode(Opcode: ISD::MUL, DL, VT: NewVT, N1: N0, N2: N1);
5588 N1 = DAG.getNode(Opcode: ISD::SRL, DL, VT: NewVT, N1,
5589 N2: DAG.getShiftAmountConstant(Val: SimpleSize, VT: NewVT, DL));
5590 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N1);
5591 }
5592 }
5593
5594 return SDValue();
5595}
5596
5597SDValue DAGCombiner::visitMULHU(SDNode *N) {
5598 SDValue N0 = N->getOperand(Num: 0);
5599 SDValue N1 = N->getOperand(Num: 1);
5600 EVT VT = N->getValueType(ResNo: 0);
5601 SDLoc DL(N);
5602
5603 // fold (mulhu c1, c2)
5604 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::MULHU, DL, VT, Ops: {N0, N1}))
5605 return C;
5606
5607 // canonicalize constant to RHS.
5608 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5609 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5610 return DAG.getNode(Opcode: ISD::MULHU, DL, VTList: N->getVTList(), N1, N2: N0);
5611
5612 if (VT.isVector()) {
5613 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5614 return FoldedVOp;
5615
5616 // fold (mulhu x, 0) -> 0
5617 // do not return N1, because undef node may exist.
5618 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
5619 return DAG.getConstant(Val: 0, DL, VT);
5620 }
5621
5622 // fold (mulhu x, 0) -> 0
5623 if (isNullConstant(V: N1))
5624 return N1;
5625
5626 // fold (mulhu x, 1) -> 0
5627 if (isOneConstant(V: N1))
5628 return DAG.getConstant(Val: 0, DL, VT);
5629
5630 // fold (mulhu x, undef) -> 0
5631 if (N0.isUndef() || N1.isUndef())
5632 return DAG.getConstant(Val: 0, DL, VT);
5633
5634 // fold (mulhu x, (1 << c)) -> x >> (bitwidth - c)
5635 if (isConstantOrConstantVector(N: N1, /*NoOpaques=*/true,
5636 /*AllowTruncation=*/true) &&
5637 hasOperation(Opcode: ISD::SRL, VT)) {
5638 if (SDValue LogBase2 = BuildLogBase2(V: N1, DL)) {
5639 unsigned NumEltBits = VT.getScalarSizeInBits();
5640 SDValue SRLAmt = DAG.getNode(
5641 Opcode: ISD::SUB, DL, VT, N1: DAG.getConstant(Val: NumEltBits, DL, VT), N2: LogBase2);
5642 EVT ShiftVT = getShiftAmountTy(LHSTy: N0.getValueType());
5643 SDValue Trunc = DAG.getZExtOrTrunc(Op: SRLAmt, DL, VT: ShiftVT);
5644 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0, N2: Trunc);
5645 }
5646 }
5647
5648 // If the type twice as wide is legal, transform the mulhu to a wider multiply
5649 // plus a shift.
5650 if (!TLI.isOperationLegalOrCustom(Op: ISD::MULHU, VT) && VT.isSimple() &&
5651 !VT.isVector()) {
5652 MVT Simple = VT.getSimpleVT();
5653 unsigned SimpleSize = Simple.getSizeInBits();
5654 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: SimpleSize*2);
5655 if (TLI.isOperationLegal(Op: ISD::MUL, VT: NewVT)) {
5656 N0 = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: NewVT, Operand: N0);
5657 N1 = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: NewVT, Operand: N1);
5658 N1 = DAG.getNode(Opcode: ISD::MUL, DL, VT: NewVT, N1: N0, N2: N1);
5659 N1 = DAG.getNode(Opcode: ISD::SRL, DL, VT: NewVT, N1,
5660 N2: DAG.getShiftAmountConstant(Val: SimpleSize, VT: NewVT, DL));
5661 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N1);
5662 }
5663 }
5664
5665 // Simplify the operands using demanded-bits information.
5666 // We don't have demanded bits support for MULHU so this just enables constant
5667 // folding based on known bits.
5668 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
5669 return SDValue(N, 0);
5670
5671 return SDValue();
5672}
5673
5674SDValue DAGCombiner::visitAVG(SDNode *N) {
5675 unsigned Opcode = N->getOpcode();
5676 SDValue N0 = N->getOperand(Num: 0);
5677 SDValue N1 = N->getOperand(Num: 1);
5678 EVT VT = N->getValueType(ResNo: 0);
5679 SDLoc DL(N);
5680 bool IsSigned = Opcode == ISD::AVGCEILS || Opcode == ISD::AVGFLOORS;
5681
5682 // fold (avg c1, c2)
5683 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, Ops: {N0, N1}))
5684 return C;
5685
5686 // canonicalize constant to RHS.
5687 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5688 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5689 return DAG.getNode(Opcode, DL, VTList: N->getVTList(), N1, N2: N0);
5690
5691 if (VT.isVector())
5692 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5693 return FoldedVOp;
5694
5695 // fold (avg x, undef) -> x
5696 if (N0.isUndef())
5697 return N1;
5698 if (N1.isUndef())
5699 return N0;
5700
5701 // fold (avg x, x) --> x
5702 if (N0 == N1 && Level >= AfterLegalizeTypes)
5703 return N0;
5704
5705 // fold (avgfloor x, 0) -> x >> 1
5706 SDValue X, Y;
5707 if (sd_match(N, P: m_c_BinOp(Opc: ISD::AVGFLOORS, L: m_Value(N&: X), R: m_Zero())))
5708 return DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: X,
5709 N2: DAG.getShiftAmountConstant(Val: 1, VT, DL));
5710 if (sd_match(N, P: m_c_BinOp(Opc: ISD::AVGFLOORU, L: m_Value(N&: X), R: m_Zero())))
5711 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: X,
5712 N2: DAG.getShiftAmountConstant(Val: 1, VT, DL));
5713
5714 // fold avgu(zext(x), zext(y)) -> zext(avgu(x, y))
5715 // fold avgs(sext(x), sext(y)) -> sext(avgs(x, y))
5716 if (!IsSigned &&
5717 sd_match(N, P: m_BinOp(Opc: Opcode, L: m_ZExt(Op: m_Value(N&: X)), R: m_ZExt(Op: m_Value(N&: Y)))) &&
5718 X.getValueType() == Y.getValueType() &&
5719 hasOperation(Opcode, VT: X.getValueType())) {
5720 SDValue AvgU = DAG.getNode(Opcode, DL, VT: X.getValueType(), N1: X, N2: Y);
5721 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: AvgU);
5722 }
5723 if (IsSigned &&
5724 sd_match(N, P: m_BinOp(Opc: Opcode, L: m_SExt(Op: m_Value(N&: X)), R: m_SExt(Op: m_Value(N&: Y)))) &&
5725 X.getValueType() == Y.getValueType() &&
5726 hasOperation(Opcode, VT: X.getValueType())) {
5727 SDValue AvgS = DAG.getNode(Opcode, DL, VT: X.getValueType(), N1: X, N2: Y);
5728 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: AvgS);
5729 }
5730
5731 // Fold avgflooru(x,y) -> avgceilu(x,y-1) iff y != 0
5732 // Fold avgflooru(x,y) -> avgceilu(x-1,y) iff x != 0
5733 // Check if avgflooru isn't legal/custom but avgceilu is.
5734 if (Opcode == ISD::AVGFLOORU && !hasOperation(Opcode: ISD::AVGFLOORU, VT) &&
5735 (!LegalOperations || hasOperation(Opcode: ISD::AVGCEILU, VT))) {
5736 if (DAG.isKnownNeverZero(Op: N1))
5737 return DAG.getNode(
5738 Opcode: ISD::AVGCEILU, DL, VT, N1: N0,
5739 N2: DAG.getNode(Opcode: ISD::ADD, DL, VT, N1, N2: DAG.getAllOnesConstant(DL, VT)));
5740 if (DAG.isKnownNeverZero(Op: N0))
5741 return DAG.getNode(
5742 Opcode: ISD::AVGCEILU, DL, VT, N1,
5743 N2: DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: DAG.getAllOnesConstant(DL, VT)));
5744 }
5745
5746 // Fold avgfloor((add nw x,y), 1) -> avgceil(x,y)
5747 // Fold avgfloor((add nw x,1), y) -> avgceil(x,y)
5748 if ((Opcode == ISD::AVGFLOORU && hasOperation(Opcode: ISD::AVGCEILU, VT)) ||
5749 (Opcode == ISD::AVGFLOORS && hasOperation(Opcode: ISD::AVGCEILS, VT))) {
5750 SDValue Add;
5751 if (sd_match(N,
5752 P: m_c_BinOp(Opc: Opcode,
5753 L: m_AllOf(preds: m_Value(N&: Add), preds: m_Add(L: m_Value(N&: X), R: m_Value(N&: Y))),
5754 R: m_One())) ||
5755 sd_match(N, P: m_c_BinOp(Opc: Opcode,
5756 L: m_AllOf(preds: m_Value(N&: Add), preds: m_Add(L: m_Value(N&: X), R: m_One())),
5757 R: m_Value(N&: Y)))) {
5758
5759 if (IsSigned && Add->getFlags().hasNoSignedWrap())
5760 return DAG.getNode(Opcode: ISD::AVGCEILS, DL, VT, N1: X, N2: Y);
5761
5762 if (!IsSigned && Add->getFlags().hasNoUnsignedWrap())
5763 return DAG.getNode(Opcode: ISD::AVGCEILU, DL, VT, N1: X, N2: Y);
5764 }
5765 }
5766
5767 // Fold avgfloors(x,y) -> avgflooru(x,y) if both x and y are non-negative
5768 if (Opcode == ISD::AVGFLOORS && hasOperation(Opcode: ISD::AVGFLOORU, VT)) {
5769 if (DAG.SignBitIsZero(Op: N0) && DAG.SignBitIsZero(Op: N1))
5770 return DAG.getNode(Opcode: ISD::AVGFLOORU, DL, VT, N1: N0, N2: N1);
5771 }
5772
5773 return SDValue();
5774}
5775
5776SDValue DAGCombiner::visitABD(SDNode *N) {
5777 unsigned Opcode = N->getOpcode();
5778 SDValue N0 = N->getOperand(Num: 0);
5779 SDValue N1 = N->getOperand(Num: 1);
5780 EVT VT = N->getValueType(ResNo: 0);
5781 SDLoc DL(N);
5782
5783 // fold (abd c1, c2)
5784 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, Ops: {N0, N1}))
5785 return C;
5786
5787 // canonicalize constant to RHS.
5788 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5789 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5790 return DAG.getNode(Opcode, DL, VTList: N->getVTList(), N1, N2: N0);
5791
5792 if (VT.isVector())
5793 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5794 return FoldedVOp;
5795
5796 // fold (abd x, undef) -> 0
5797 if (N0.isUndef() || N1.isUndef())
5798 return DAG.getConstant(Val: 0, DL, VT);
5799
5800 // fold (abd x, x) -> 0
5801 if (N0 == N1)
5802 return DAG.getConstant(Val: 0, DL, VT);
5803
5804 SDValue X, Y;
5805
5806 // fold (abds x, 0) -> abs x
5807 if (sd_match(N, P: m_c_BinOp(Opc: ISD::ABDS, L: m_Value(N&: X), R: m_Zero())) &&
5808 (!LegalOperations || hasOperation(Opcode: ISD::ABS, VT)))
5809 return DAG.getNode(Opcode: ISD::ABS, DL, VT, Operand: X);
5810
5811 // fold (abdu x, 0) -> x
5812 if (sd_match(N, P: m_c_BinOp(Opc: ISD::ABDU, L: m_Value(N&: X), R: m_Zero())))
5813 return X;
5814
5815 // fold (abds x, y) -> (abdu x, y) iff both args are known positive
5816 if (Opcode == ISD::ABDS && hasOperation(Opcode: ISD::ABDU, VT) &&
5817 DAG.SignBitIsZero(Op: N0) && DAG.SignBitIsZero(Op: N1))
5818 return DAG.getNode(Opcode: ISD::ABDU, DL, VT, N1, N2: N0);
5819
5820 // fold (abd? (?ext x), (?ext y)) -> (zext (abd? x, y))
5821 if (sd_match(N, P: m_BinOp(Opc: ISD::ABDU, L: m_ZExt(Op: m_Value(N&: X)), R: m_ZExt(Op: m_Value(N&: Y)))) ||
5822 sd_match(N, P: m_BinOp(Opc: ISD::ABDS, L: m_SExt(Op: m_Value(N&: X)), R: m_SExt(Op: m_Value(N&: Y))))) {
5823 EVT SmallVT = X.getScalarValueSizeInBits() > Y.getScalarValueSizeInBits()
5824 ? X.getValueType()
5825 : Y.getValueType();
5826 if (!LegalOperations || hasOperation(Opcode, VT: SmallVT)) {
5827 SDValue ExtedX = DAG.getExtOrTrunc(Op: X, DL: SDLoc(X), VT: SmallVT, Opcode: N0->getOpcode());
5828 SDValue ExtedY = DAG.getExtOrTrunc(Op: Y, DL: SDLoc(Y), VT: SmallVT, Opcode: N0->getOpcode());
5829 SDValue SmallABD = DAG.getNode(Opcode, DL, VT: SmallVT, Ops: {ExtedX, ExtedY});
5830 SDValue ZExted = DAG.getZExtOrTrunc(Op: SmallABD, DL, VT);
5831 return ZExted;
5832 }
5833 }
5834
5835 // fold (abd? (?ext ty:x), small_const:c) -> (zext (abd? x, c))
5836 if (sd_match(N, P: m_c_BinOp(Opc: ISD::ABDU, L: m_ZExt(Op: m_Value(N&: X)), R: m_Value(N&: Y))) ||
5837 sd_match(N, P: m_c_BinOp(Opc: ISD::ABDS, L: m_SExt(Op: m_Value(N&: X)), R: m_Value(N&: Y)))) {
5838 EVT SmallVT = X.getValueType();
5839 if (!LegalOperations || hasOperation(Opcode, VT: SmallVT)) {
5840 uint64_t Bits = SmallVT.getScalarSizeInBits();
5841 unsigned RelevantBits =
5842 (Opcode == ISD::ABDS) ? DAG.ComputeMaxSignificantBits(Op: Y)
5843 : DAG.computeKnownBits(Op: Y).countMaxActiveBits();
5844 bool TruncatingYIsCheap = TLI.isTruncateFree(Val: Y, VT2: SmallVT) ||
5845 ISD::matchUnaryPredicate(
5846 Op: Y,
5847 Match: [&](auto *C) {
5848 const APInt &YConst = C->getAsAPIntVal();
5849 return (Opcode == ISD::ABDS)
5850 ? YConst.isSignedIntN(N: Bits)
5851 : YConst.isIntN(N: Bits);
5852 },
5853 /*AllowUndefs=*/true);
5854
5855 if (RelevantBits <= Bits && TruncatingYIsCheap) {
5856 SDValue NewY = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(Y), VT: SmallVT, Operand: Y);
5857 SDValue SmallABD = DAG.getNode(Opcode, DL, VT: SmallVT, Ops: {X, NewY});
5858 return DAG.getZExtOrTrunc(Op: SmallABD, DL, VT);
5859 }
5860 }
5861 }
5862
5863 return SDValue();
5864}
5865
5866/// Perform optimizations common to nodes that compute two values. LoOp and HiOp
5867/// give the opcodes for the two computations that are being performed. Return
5868/// true if a simplification was made.
5869SDValue DAGCombiner::SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
5870 unsigned HiOp) {
5871 // If the high half is not needed, just compute the low half.
5872 bool HiExists = N->hasAnyUseOfValue(Value: 1);
5873 if (!HiExists && (!LegalOperations ||
5874 TLI.isOperationLegalOrCustom(Op: LoOp, VT: N->getValueType(ResNo: 0)))) {
5875 SDValue Res = DAG.getNode(Opcode: LoOp, DL: SDLoc(N), VT: N->getValueType(ResNo: 0), Ops: N->ops());
5876 return CombineTo(N, Res0: Res, Res1: Res);
5877 }
5878
5879 // If the low half is not needed, just compute the high half.
5880 bool LoExists = N->hasAnyUseOfValue(Value: 0);
5881 if (!LoExists && (!LegalOperations ||
5882 TLI.isOperationLegalOrCustom(Op: HiOp, VT: N->getValueType(ResNo: 1)))) {
5883 SDValue Res = DAG.getNode(Opcode: HiOp, DL: SDLoc(N), VT: N->getValueType(ResNo: 1), Ops: N->ops());
5884 return CombineTo(N, Res0: Res, Res1: Res);
5885 }
5886
5887 // If both halves are used, return as it is.
5888 if (LoExists && HiExists)
5889 return SDValue();
5890
5891 // If the two computed results can be simplified separately, separate them.
5892 if (LoExists) {
5893 SDValue Lo = DAG.getNode(Opcode: LoOp, DL: SDLoc(N), VT: N->getValueType(ResNo: 0), Ops: N->ops());
5894 AddToWorklist(N: Lo.getNode());
5895 SDValue LoOpt = combine(N: Lo.getNode());
5896 if (LoOpt.getNode() && LoOpt.getNode() != Lo.getNode() &&
5897 (!LegalOperations ||
5898 TLI.isOperationLegalOrCustom(Op: LoOpt.getOpcode(), VT: LoOpt.getValueType())))
5899 return CombineTo(N, Res0: LoOpt, Res1: LoOpt);
5900 }
5901
5902 if (HiExists) {
5903 SDValue Hi = DAG.getNode(Opcode: HiOp, DL: SDLoc(N), VT: N->getValueType(ResNo: 1), Ops: N->ops());
5904 AddToWorklist(N: Hi.getNode());
5905 SDValue HiOpt = combine(N: Hi.getNode());
5906 if (HiOpt.getNode() && HiOpt != Hi &&
5907 (!LegalOperations ||
5908 TLI.isOperationLegalOrCustom(Op: HiOpt.getOpcode(), VT: HiOpt.getValueType())))
5909 return CombineTo(N, Res0: HiOpt, Res1: HiOpt);
5910 }
5911
5912 return SDValue();
5913}
5914
5915SDValue DAGCombiner::visitSMUL_LOHI(SDNode *N) {
5916 if (SDValue Res = SimplifyNodeWithTwoResults(N, LoOp: ISD::MUL, HiOp: ISD::MULHS))
5917 return Res;
5918
5919 SDValue N0 = N->getOperand(Num: 0);
5920 SDValue N1 = N->getOperand(Num: 1);
5921 EVT VT = N->getValueType(ResNo: 0);
5922 SDLoc DL(N);
5923
5924 // Constant fold.
5925 if (isa<ConstantSDNode>(Val: N0) && isa<ConstantSDNode>(Val: N1))
5926 return DAG.getNode(Opcode: ISD::SMUL_LOHI, DL, VTList: N->getVTList(), N1: N0, N2: N1);
5927
5928 // canonicalize constant to RHS (vector doesn't have to splat)
5929 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5930 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5931 return DAG.getNode(Opcode: ISD::SMUL_LOHI, DL, VTList: N->getVTList(), N1, N2: N0);
5932
5933 // If the type is twice as wide is legal, transform the mulhu to a wider
5934 // multiply plus a shift.
5935 if (VT.isSimple() && !VT.isVector()) {
5936 MVT Simple = VT.getSimpleVT();
5937 unsigned SimpleSize = Simple.getSizeInBits();
5938 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: SimpleSize*2);
5939 if (TLI.isOperationLegal(Op: ISD::MUL, VT: NewVT)) {
5940 SDValue Lo = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT: NewVT, Operand: N0);
5941 SDValue Hi = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT: NewVT, Operand: N1);
5942 Lo = DAG.getNode(Opcode: ISD::MUL, DL, VT: NewVT, N1: Lo, N2: Hi);
5943 // Compute the high part as N1.
5944 Hi = DAG.getNode(Opcode: ISD::SRL, DL, VT: NewVT, N1: Lo,
5945 N2: DAG.getShiftAmountConstant(Val: SimpleSize, VT: NewVT, DL));
5946 Hi = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: Hi);
5947 // Compute the low part as N0.
5948 Lo = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: Lo);
5949 return CombineTo(N, Res0: Lo, Res1: Hi);
5950 }
5951 }
5952
5953 return SDValue();
5954}
5955
5956SDValue DAGCombiner::visitUMUL_LOHI(SDNode *N) {
5957 if (SDValue Res = SimplifyNodeWithTwoResults(N, LoOp: ISD::MUL, HiOp: ISD::MULHU))
5958 return Res;
5959
5960 SDValue N0 = N->getOperand(Num: 0);
5961 SDValue N1 = N->getOperand(Num: 1);
5962 EVT VT = N->getValueType(ResNo: 0);
5963 SDLoc DL(N);
5964
5965 // Constant fold.
5966 if (isa<ConstantSDNode>(Val: N0) && isa<ConstantSDNode>(Val: N1))
5967 return DAG.getNode(Opcode: ISD::UMUL_LOHI, DL, VTList: N->getVTList(), N1: N0, N2: N1);
5968
5969 // canonicalize constant to RHS (vector doesn't have to splat)
5970 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5971 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5972 return DAG.getNode(Opcode: ISD::UMUL_LOHI, DL, VTList: N->getVTList(), N1, N2: N0);
5973
5974 // (umul_lohi N0, 0) -> (0, 0)
5975 if (isNullConstant(V: N1)) {
5976 SDValue Zero = DAG.getConstant(Val: 0, DL, VT);
5977 return CombineTo(N, Res0: Zero, Res1: Zero);
5978 }
5979
5980 // (umul_lohi N0, 1) -> (N0, 0)
5981 if (isOneConstant(V: N1)) {
5982 SDValue Zero = DAG.getConstant(Val: 0, DL, VT);
5983 return CombineTo(N, Res0: N0, Res1: Zero);
5984 }
5985
5986 // If the type is twice as wide is legal, transform the mulhu to a wider
5987 // multiply plus a shift.
5988 if (VT.isSimple() && !VT.isVector()) {
5989 MVT Simple = VT.getSimpleVT();
5990 unsigned SimpleSize = Simple.getSizeInBits();
5991 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: SimpleSize*2);
5992 if (TLI.isOperationLegal(Op: ISD::MUL, VT: NewVT)) {
5993 SDValue Lo = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: NewVT, Operand: N0);
5994 SDValue Hi = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: NewVT, Operand: N1);
5995 Lo = DAG.getNode(Opcode: ISD::MUL, DL, VT: NewVT, N1: Lo, N2: Hi);
5996 // Compute the high part as N1.
5997 Hi = DAG.getNode(Opcode: ISD::SRL, DL, VT: NewVT, N1: Lo,
5998 N2: DAG.getShiftAmountConstant(Val: SimpleSize, VT: NewVT, DL));
5999 Hi = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: Hi);
6000 // Compute the low part as N0.
6001 Lo = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: Lo);
6002 return CombineTo(N, Res0: Lo, Res1: Hi);
6003 }
6004 }
6005
6006 return SDValue();
6007}
6008
6009SDValue DAGCombiner::visitMULO(SDNode *N) {
6010 SDValue N0 = N->getOperand(Num: 0);
6011 SDValue N1 = N->getOperand(Num: 1);
6012 EVT VT = N0.getValueType();
6013 bool IsSigned = (ISD::SMULO == N->getOpcode());
6014
6015 EVT CarryVT = N->getValueType(ResNo: 1);
6016 SDLoc DL(N);
6017
6018 ConstantSDNode *N0C = isConstOrConstSplat(N: N0);
6019 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
6020
6021 // fold operation with constant operands.
6022 // TODO: Move this to FoldConstantArithmetic when it supports nodes with
6023 // multiple results.
6024 if (N0C && N1C) {
6025 bool Overflow;
6026 APInt Result =
6027 IsSigned ? N0C->getAPIntValue().smul_ov(RHS: N1C->getAPIntValue(), Overflow)
6028 : N0C->getAPIntValue().umul_ov(RHS: N1C->getAPIntValue(), Overflow);
6029 return CombineTo(N, Res0: DAG.getConstant(Val: Result, DL, VT),
6030 Res1: DAG.getBoolConstant(V: Overflow, DL, VT: CarryVT, OpVT: CarryVT));
6031 }
6032
6033 // canonicalize constant to RHS.
6034 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
6035 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
6036 return DAG.getNode(Opcode: N->getOpcode(), DL, VTList: N->getVTList(), N1, N2: N0);
6037
6038 // fold (mulo x, 0) -> 0 + no carry out
6039 if (isNullOrNullSplat(V: N1))
6040 return CombineTo(N, Res0: DAG.getConstant(Val: 0, DL, VT),
6041 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
6042
6043 // (mulo x, 2) -> (addo x, x)
6044 // FIXME: This needs a freeze.
6045 if (N1C && N1C->getAPIntValue() == 2 &&
6046 (!IsSigned || VT.getScalarSizeInBits() > 2))
6047 return DAG.getNode(Opcode: IsSigned ? ISD::SADDO : ISD::UADDO, DL,
6048 VTList: N->getVTList(), N1: N0, N2: N0);
6049
6050 // A 1 bit SMULO overflows if both inputs are 1.
6051 if (IsSigned && VT.getScalarSizeInBits() == 1) {
6052 SDValue And = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0, N2: N1);
6053 SDValue Cmp = DAG.getSetCC(DL, VT: CarryVT, LHS: And,
6054 RHS: DAG.getConstant(Val: 0, DL, VT), Cond: ISD::SETNE);
6055 return CombineTo(N, Res0: And, Res1: Cmp);
6056 }
6057
6058 // If it cannot overflow, transform into a mul.
6059 if (DAG.willNotOverflowMul(IsSigned, N0, N1))
6060 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: N0, N2: N1),
6061 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
6062 return SDValue();
6063}
6064
6065// Function to calculate whether the Min/Max pair of SDNodes (potentially
6066// swapped around) make a signed saturate pattern, clamping to between a signed
6067// saturate of -2^(BW-1) and 2^(BW-1)-1, or an unsigned saturate of 0 and 2^BW.
6068// Returns the node being clamped and the bitwidth of the clamp in BW. Should
6069// work with both SMIN/SMAX nodes and setcc/select combo. The operands are the
6070// same as SimplifySelectCC. N0<N1 ? N2 : N3.
6071static SDValue isSaturatingMinMax(SDValue N0, SDValue N1, SDValue N2,
6072 SDValue N3, ISD::CondCode CC, unsigned &BW,
6073 bool &Unsigned, SelectionDAG &DAG) {
6074 auto isSignedMinMax = [&](SDValue N0, SDValue N1, SDValue N2, SDValue N3,
6075 ISD::CondCode CC) {
6076 // The compare and select operand should be the same or the select operands
6077 // should be truncated versions of the comparison.
6078 if (N0 != N2 && (N2.getOpcode() != ISD::TRUNCATE || N0 != N2.getOperand(i: 0)))
6079 return 0;
6080 // The constants need to be the same or a truncated version of each other.
6081 ConstantSDNode *N1C = isConstOrConstSplat(N: peekThroughTruncates(V: N1));
6082 ConstantSDNode *N3C = isConstOrConstSplat(N: peekThroughTruncates(V: N3));
6083 if (!N1C || !N3C)
6084 return 0;
6085 const APInt &C1 = N1C->getAPIntValue().trunc(width: N1.getScalarValueSizeInBits());
6086 const APInt &C2 = N3C->getAPIntValue().trunc(width: N3.getScalarValueSizeInBits());
6087 if (C1.getBitWidth() < C2.getBitWidth() || C1 != C2.sext(width: C1.getBitWidth()))
6088 return 0;
6089 return CC == ISD::SETLT ? ISD::SMIN : (CC == ISD::SETGT ? ISD::SMAX : 0);
6090 };
6091
6092 // Check the initial value is a SMIN/SMAX equivalent.
6093 unsigned Opcode0 = isSignedMinMax(N0, N1, N2, N3, CC);
6094 if (!Opcode0)
6095 return SDValue();
6096
6097 // We could only need one range check, if the fptosi could never produce
6098 // the upper value.
6099 if (N0.getOpcode() == ISD::FP_TO_SINT && Opcode0 == ISD::SMAX) {
6100 if (isNullOrNullSplat(V: N3)) {
6101 EVT IntVT = N0.getValueType().getScalarType();
6102 EVT FPVT = N0.getOperand(i: 0).getValueType().getScalarType();
6103 if (FPVT.isSimple()) {
6104 Type *InputTy = FPVT.getTypeForEVT(Context&: *DAG.getContext());
6105 const fltSemantics &Semantics = InputTy->getFltSemantics();
6106 uint32_t MinBitWidth =
6107 APFloatBase::semanticsIntSizeInBits(Semantics, /*isSigned*/ true);
6108 if (IntVT.getSizeInBits() >= MinBitWidth) {
6109 Unsigned = true;
6110 BW = PowerOf2Ceil(A: MinBitWidth);
6111 return N0;
6112 }
6113 }
6114 }
6115 }
6116
6117 SDValue N00, N01, N02, N03;
6118 ISD::CondCode N0CC;
6119 switch (N0.getOpcode()) {
6120 case ISD::SMIN:
6121 case ISD::SMAX:
6122 N00 = N02 = N0.getOperand(i: 0);
6123 N01 = N03 = N0.getOperand(i: 1);
6124 N0CC = N0.getOpcode() == ISD::SMIN ? ISD::SETLT : ISD::SETGT;
6125 break;
6126 case ISD::SELECT_CC:
6127 N00 = N0.getOperand(i: 0);
6128 N01 = N0.getOperand(i: 1);
6129 N02 = N0.getOperand(i: 2);
6130 N03 = N0.getOperand(i: 3);
6131 N0CC = cast<CondCodeSDNode>(Val: N0.getOperand(i: 4))->get();
6132 break;
6133 case ISD::SELECT:
6134 case ISD::VSELECT:
6135 if (N0.getOperand(i: 0).getOpcode() != ISD::SETCC)
6136 return SDValue();
6137 N00 = N0.getOperand(i: 0).getOperand(i: 0);
6138 N01 = N0.getOperand(i: 0).getOperand(i: 1);
6139 N02 = N0.getOperand(i: 1);
6140 N03 = N0.getOperand(i: 2);
6141 N0CC = cast<CondCodeSDNode>(Val: N0.getOperand(i: 0).getOperand(i: 2))->get();
6142 break;
6143 default:
6144 return SDValue();
6145 }
6146
6147 unsigned Opcode1 = isSignedMinMax(N00, N01, N02, N03, N0CC);
6148 if (!Opcode1 || Opcode0 == Opcode1)
6149 return SDValue();
6150
6151 ConstantSDNode *MinCOp = isConstOrConstSplat(N: Opcode0 == ISD::SMIN ? N1 : N01);
6152 ConstantSDNode *MaxCOp = isConstOrConstSplat(N: Opcode0 == ISD::SMIN ? N01 : N1);
6153 if (!MinCOp || !MaxCOp || MinCOp->getValueType(ResNo: 0) != MaxCOp->getValueType(ResNo: 0))
6154 return SDValue();
6155
6156 const APInt &MinC = MinCOp->getAPIntValue();
6157 const APInt &MaxC = MaxCOp->getAPIntValue();
6158 APInt MinCPlus1 = MinC + 1;
6159 if (-MaxC == MinCPlus1 && MinCPlus1.isPowerOf2()) {
6160 BW = MinCPlus1.exactLogBase2() + 1;
6161 Unsigned = false;
6162 return N02;
6163 }
6164
6165 if (MaxC == 0 && MinC != 0 && MinCPlus1.isPowerOf2()) {
6166 BW = MinCPlus1.exactLogBase2();
6167 Unsigned = true;
6168 return N02;
6169 }
6170
6171 return SDValue();
6172}
6173
6174static SDValue PerformMinMaxFpToSatCombine(SDValue N0, SDValue N1, SDValue N2,
6175 SDValue N3, ISD::CondCode CC,
6176 SelectionDAG &DAG) {
6177 unsigned BW;
6178 bool Unsigned;
6179 SDValue Fp = isSaturatingMinMax(N0, N1, N2, N3, CC, BW, Unsigned, DAG);
6180 if (!Fp || Fp.getOpcode() != ISD::FP_TO_SINT)
6181 return SDValue();
6182 EVT FPVT = Fp.getOperand(i: 0).getValueType();
6183 EVT NewVT = FPVT.changeElementType(Context&: *DAG.getContext(),
6184 EltVT: EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: BW));
6185 unsigned NewOpc = Unsigned ? ISD::FP_TO_UINT_SAT : ISD::FP_TO_SINT_SAT;
6186 if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(Op: NewOpc, FPVT, VT: NewVT))
6187 return SDValue();
6188 SDLoc DL(Fp);
6189 SDValue Sat = DAG.getNode(Opcode: NewOpc, DL, VT: NewVT, N1: Fp.getOperand(i: 0),
6190 N2: DAG.getValueType(NewVT.getScalarType()));
6191 return DAG.getExtOrTrunc(IsSigned: !Unsigned, Op: Sat, DL, VT: N2->getValueType(ResNo: 0));
6192}
6193
6194static SDValue PerformUMinFpToSatCombine(SDValue N0, SDValue N1, SDValue N2,
6195 SDValue N3, ISD::CondCode CC,
6196 SelectionDAG &DAG) {
6197 // We are looking for UMIN(FPTOUI(X), (2^n)-1), which may have come via a
6198 // select/vselect/select_cc. The two operands pairs for the select (N2/N3) may
6199 // be truncated versions of the setcc (N0/N1).
6200 if ((N0 != N2 &&
6201 (N2.getOpcode() != ISD::TRUNCATE || N0 != N2.getOperand(i: 0))) ||
6202 N0.getOpcode() != ISD::FP_TO_UINT || CC != ISD::SETULT)
6203 return SDValue();
6204 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
6205 ConstantSDNode *N3C = isConstOrConstSplat(N: N3);
6206 if (!N1C || !N3C)
6207 return SDValue();
6208 const APInt &C1 = N1C->getAPIntValue();
6209 const APInt &C3 = N3C->getAPIntValue();
6210 if (!(C1 + 1).isPowerOf2() || C1.getBitWidth() < C3.getBitWidth() ||
6211 C1 != C3.zext(width: C1.getBitWidth()))
6212 return SDValue();
6213
6214 unsigned BW = (C1 + 1).exactLogBase2();
6215 EVT FPVT = N0.getOperand(i: 0).getValueType();
6216 EVT NewVT = FPVT.changeElementType(Context&: *DAG.getContext(),
6217 EltVT: EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: BW));
6218 if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(Op: ISD::FP_TO_UINT_SAT,
6219 FPVT, VT: NewVT))
6220 return SDValue();
6221
6222 SDValue Sat =
6223 DAG.getNode(Opcode: ISD::FP_TO_UINT_SAT, DL: SDLoc(N0), VT: NewVT, N1: N0.getOperand(i: 0),
6224 N2: DAG.getValueType(NewVT.getScalarType()));
6225 return DAG.getZExtOrTrunc(Op: Sat, DL: SDLoc(N0), VT: N3.getValueType());
6226}
6227
6228SDValue DAGCombiner::visitIMINMAX(SDNode *N) {
6229 SDValue N0 = N->getOperand(Num: 0);
6230 SDValue N1 = N->getOperand(Num: 1);
6231 EVT VT = N0.getValueType();
6232 unsigned Opcode = N->getOpcode();
6233 SDLoc DL(N);
6234
6235 // fold operation with constant operands.
6236 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, Ops: {N0, N1}))
6237 return C;
6238
6239 // If the operands are the same, this is a no-op.
6240 if (N0 == N1)
6241 return N0;
6242
6243 // canonicalize constant to RHS
6244 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
6245 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
6246 return DAG.getNode(Opcode, DL, VT, N1, N2: N0);
6247
6248 // fold vector ops
6249 if (VT.isVector())
6250 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
6251 return FoldedVOp;
6252
6253 // reassociate minmax
6254 if (SDValue RMINMAX = reassociateOps(Opc: Opcode, DL, N0, N1, Flags: N->getFlags()))
6255 return RMINMAX;
6256
6257 // If both operands are known to have the same sign (both non-negative or both
6258 // negative), flip between UMIN/UMAX and SMIN/SMAX.
6259 // Only do this if:
6260 // 1. The current op isn't legal and the flipped is.
6261 // 2. The saturation pattern is broken by canonicalization in InstCombine.
6262 bool IsOpIllegal = !TLI.isOperationLegal(Op: Opcode, VT);
6263 bool IsSatBroken = Opcode == ISD::UMIN && N0.getOpcode() == ISD::SMAX;
6264
6265 if (IsSatBroken || IsOpIllegal) {
6266 auto HasKnownSameSign = [&](SDValue A, SDValue B) {
6267 if (A.isUndef() || B.isUndef())
6268 return true;
6269
6270 KnownBits KA = DAG.computeKnownBits(Op: A);
6271 if (!KA.isNonNegative() && !KA.isNegative())
6272 return false;
6273
6274 KnownBits KB = DAG.computeKnownBits(Op: B);
6275 if (KA.isNonNegative())
6276 return KB.isNonNegative();
6277 return KB.isNegative();
6278 };
6279
6280 if (HasKnownSameSign(N0, N1)) {
6281 unsigned AltOpcode = ISD::getOppositeSignednessMinMaxOpcode(MinMaxOpc: Opcode);
6282 if ((IsSatBroken && IsOpIllegal) || TLI.isOperationLegal(Op: AltOpcode, VT))
6283 return DAG.getNode(Opcode: AltOpcode, DL, VT, N1: N0, N2: N1);
6284 }
6285 }
6286
6287 if (Opcode == ISD::SMIN || Opcode == ISD::SMAX)
6288 if (SDValue S = PerformMinMaxFpToSatCombine(
6289 N0, N1, N2: N0, N3: N1, CC: Opcode == ISD::SMIN ? ISD::SETLT : ISD::SETGT, DAG))
6290 return S;
6291 if (Opcode == ISD::UMIN)
6292 if (SDValue S = PerformUMinFpToSatCombine(N0, N1, N2: N0, N3: N1, CC: ISD::SETULT, DAG))
6293 return S;
6294
6295 // Fold min/max(vecreduce(x), vecreduce(y)) -> vecreduce(min/max(x, y))
6296 auto ReductionOpcode = [](unsigned Opcode) {
6297 switch (Opcode) {
6298 case ISD::SMIN:
6299 return ISD::VECREDUCE_SMIN;
6300 case ISD::SMAX:
6301 return ISD::VECREDUCE_SMAX;
6302 case ISD::UMIN:
6303 return ISD::VECREDUCE_UMIN;
6304 case ISD::UMAX:
6305 return ISD::VECREDUCE_UMAX;
6306 default:
6307 llvm_unreachable("Unexpected opcode");
6308 }
6309 };
6310 if (SDValue SD = reassociateReduction(RedOpc: ReductionOpcode(Opcode), Opc: Opcode,
6311 DL: SDLoc(N), VT, N0, N1))
6312 return SD;
6313
6314 // Fold operation with vscale operands.
6315 if (N0.getOpcode() == ISD::VSCALE && N1.getOpcode() == ISD::VSCALE) {
6316 uint64_t C0 = N0->getConstantOperandVal(Num: 0);
6317 uint64_t C1 = N1->getConstantOperandVal(Num: 0);
6318 if (Opcode == ISD::UMAX)
6319 return C0 > C1 ? N0 : N1;
6320 else if (Opcode == ISD::UMIN)
6321 return C0 > C1 ? N1 : N0;
6322 }
6323
6324 // If we know the range of vscale, see if we can fold it given a constant.
6325 // TODO: Generalize this to other nodes by adding computeConstantRange
6326 if (N0.getOpcode() == ISD::VSCALE) {
6327 if (auto *C1 = dyn_cast<ConstantSDNode>(Val&: N1)) {
6328 const Function &F = DAG.getMachineFunction().getFunction();
6329 ConstantRange Range =
6330 getVScaleRange(F: &F, BitWidth: VT.getScalarSizeInBits())
6331 .multiply(Other: ConstantRange(N0.getConstantOperandAPInt(i: 0)));
6332
6333 const APInt &C1V = C1->getAPIntValue();
6334 if ((Opcode == ISD::UMAX && Range.getUnsignedMax().ule(RHS: C1V)) ||
6335 (Opcode == ISD::UMIN && Range.getUnsignedMin().uge(RHS: C1V)) ||
6336 (Opcode == ISD::SMAX && Range.getSignedMax().sle(RHS: C1V)) ||
6337 (Opcode == ISD::SMIN && Range.getSignedMin().sge(RHS: C1V))) {
6338 return N1;
6339 }
6340 }
6341 }
6342
6343 // Simplify the operands using demanded-bits information.
6344 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
6345 return SDValue(N, 0);
6346
6347 return SDValue();
6348}
6349
6350/// If this is a bitwise logic instruction and both operands have the same
6351/// opcode, try to sink the other opcode after the logic instruction.
6352SDValue DAGCombiner::hoistLogicOpWithSameOpcodeHands(SDNode *N) {
6353 SDValue N0 = N->getOperand(Num: 0), N1 = N->getOperand(Num: 1);
6354 EVT VT = N0.getValueType();
6355 unsigned LogicOpcode = N->getOpcode();
6356 unsigned HandOpcode = N0.getOpcode();
6357 assert(ISD::isBitwiseLogicOp(LogicOpcode) && "Expected logic opcode");
6358 assert(HandOpcode == N1.getOpcode() && "Bad input!");
6359
6360 // Bail early if none of these transforms apply.
6361 if (N0.getNumOperands() == 0)
6362 return SDValue();
6363
6364 // FIXME: We should check number of uses of the operands to not increase
6365 // the instruction count for all transforms.
6366
6367 // Handle size-changing casts (or sign_extend_inreg).
6368 SDValue X = N0.getOperand(i: 0);
6369 SDValue Y = N1.getOperand(i: 0);
6370 EVT XVT = X.getValueType();
6371 SDLoc DL(N);
6372 if (ISD::isExtOpcode(Opcode: HandOpcode) || ISD::isExtVecInRegOpcode(Opcode: HandOpcode) ||
6373 (HandOpcode == ISD::SIGN_EXTEND_INREG &&
6374 N0.getOperand(i: 1) == N1.getOperand(i: 1))) {
6375 // If both operands have other uses, this transform would create extra
6376 // instructions without eliminating anything.
6377 if (!N0.hasOneUse() && !N1.hasOneUse())
6378 return SDValue();
6379 // We need matching integer source types.
6380 if (XVT != Y.getValueType())
6381 return SDValue();
6382 // Don't create an illegal op during or after legalization. Don't ever
6383 // create an unsupported vector op.
6384 if ((VT.isVector() || LegalOperations) &&
6385 !TLI.isOperationLegalOrCustom(Op: LogicOpcode, VT: XVT))
6386 return SDValue();
6387 // Avoid infinite looping with PromoteIntBinOp.
6388 // TODO: Should we apply desirable/legal constraints to all opcodes?
6389 if ((HandOpcode == ISD::ANY_EXTEND ||
6390 HandOpcode == ISD::ANY_EXTEND_VECTOR_INREG) &&
6391 LegalTypes && !TLI.isTypeDesirableForOp(LogicOpcode, VT: XVT))
6392 return SDValue();
6393 // logic_op (hand_op X), (hand_op Y) --> hand_op (logic_op X, Y)
6394 SDNodeFlags LogicFlags;
6395 LogicFlags.setDisjoint(N->getFlags().hasDisjoint() &&
6396 ISD::isExtOpcode(Opcode: HandOpcode));
6397 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT: XVT, N1: X, N2: Y, Flags: LogicFlags);
6398 if (HandOpcode == ISD::SIGN_EXTEND_INREG)
6399 return DAG.getNode(Opcode: HandOpcode, DL, VT, N1: Logic, N2: N0.getOperand(i: 1));
6400 return DAG.getNode(Opcode: HandOpcode, DL, VT, Operand: Logic);
6401 }
6402
6403 // logic_op (truncate x), (truncate y) --> truncate (logic_op x, y)
6404 if (HandOpcode == ISD::TRUNCATE) {
6405 // If both operands have other uses, this transform would create extra
6406 // instructions without eliminating anything.
6407 if (!N0.hasOneUse() && !N1.hasOneUse())
6408 return SDValue();
6409 // We need matching source types.
6410 if (XVT != Y.getValueType())
6411 return SDValue();
6412 // Don't create an illegal op during or after legalization.
6413 if (LegalOperations && !TLI.isOperationLegal(Op: LogicOpcode, VT: XVT))
6414 return SDValue();
6415 // Be extra careful sinking truncate. If it's free, there's no benefit in
6416 // widening a binop. Also, don't create a logic op on an illegal type.
6417 if (TLI.isZExtFree(FromTy: VT, ToTy: XVT) && TLI.isTruncateFree(FromVT: XVT, ToVT: VT))
6418 return SDValue();
6419 if (!TLI.isTypeLegal(VT: XVT))
6420 return SDValue();
6421 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT: XVT, N1: X, N2: Y);
6422 return DAG.getNode(Opcode: HandOpcode, DL, VT, Operand: Logic);
6423 }
6424
6425 // For binops SHL/SRL/SRA/AND:
6426 // logic_op (OP x, z), (OP y, z) --> OP (logic_op x, y), z
6427 if ((HandOpcode == ISD::SHL || HandOpcode == ISD::SRL ||
6428 HandOpcode == ISD::SRA || HandOpcode == ISD::AND) &&
6429 N0.getOperand(i: 1) == N1.getOperand(i: 1)) {
6430 // If either operand has other uses, this transform is not an improvement.
6431 if (!N0.hasOneUse() || !N1.hasOneUse())
6432 return SDValue();
6433 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT: XVT, N1: X, N2: Y);
6434 return DAG.getNode(Opcode: HandOpcode, DL, VT, N1: Logic, N2: N0.getOperand(i: 1));
6435 }
6436
6437 // Unary ops: logic_op (bswap x), (bswap y) --> bswap (logic_op x, y)
6438 if (HandOpcode == ISD::BSWAP) {
6439 // If either operand has other uses, this transform is not an improvement.
6440 if (!N0.hasOneUse() || !N1.hasOneUse())
6441 return SDValue();
6442 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT: XVT, N1: X, N2: Y);
6443 return DAG.getNode(Opcode: HandOpcode, DL, VT, Operand: Logic);
6444 }
6445
6446 // For funnel shifts FSHL/FSHR:
6447 // logic_op (OP x, x1, s), (OP y, y1, s) -->
6448 // --> OP (logic_op x, y), (logic_op, x1, y1), s
6449 if ((HandOpcode == ISD::FSHL || HandOpcode == ISD::FSHR) &&
6450 N0.getOperand(i: 2) == N1.getOperand(i: 2)) {
6451 if (!N0.hasOneUse() || !N1.hasOneUse())
6452 return SDValue();
6453 SDValue X1 = N0.getOperand(i: 1);
6454 SDValue Y1 = N1.getOperand(i: 1);
6455 SDValue S = N0.getOperand(i: 2);
6456 SDValue Logic0 = DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: X, N2: Y);
6457 SDValue Logic1 = DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: X1, N2: Y1);
6458 return DAG.getNode(Opcode: HandOpcode, DL, VT, N1: Logic0, N2: Logic1, N3: S);
6459 }
6460
6461 // Simplify xor/and/or (bitcast(A), bitcast(B)) -> bitcast(op (A,B))
6462 // Only perform this optimization up until type legalization, before
6463 // LegalizeVectorOprs. LegalizeVectorOprs promotes vector operations by
6464 // adding bitcasts. For example (xor v4i32) is promoted to (v2i64), and
6465 // we don't want to undo this promotion.
6466 // We also handle SCALAR_TO_VECTOR because xor/or/and operations are cheaper
6467 // on scalars.
6468 if ((HandOpcode == ISD::BITCAST || HandOpcode == ISD::SCALAR_TO_VECTOR) &&
6469 Level <= AfterLegalizeTypes) {
6470 // Input types must be integer and the same.
6471 if (XVT.isInteger() && XVT == Y.getValueType() &&
6472 !(VT.isVector() && TLI.isTypeLegal(VT) &&
6473 !XVT.isVector() && !TLI.isTypeLegal(VT: XVT))) {
6474 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT: XVT, N1: X, N2: Y);
6475 return DAG.getNode(Opcode: HandOpcode, DL, VT, Operand: Logic);
6476 }
6477 }
6478
6479 // Xor/and/or are indifferent to the swizzle operation (shuffle of one value).
6480 // Simplify xor/and/or (shuff(A), shuff(B)) -> shuff(op (A,B))
6481 // If both shuffles use the same mask, and both shuffle within a single
6482 // vector, then it is worthwhile to move the swizzle after the operation.
6483 // The type-legalizer generates this pattern when loading illegal
6484 // vector types from memory. In many cases this allows additional shuffle
6485 // optimizations.
6486 // There are other cases where moving the shuffle after the xor/and/or
6487 // is profitable even if shuffles don't perform a swizzle.
6488 // If both shuffles use the same mask, and both shuffles have the same first
6489 // or second operand, then it might still be profitable to move the shuffle
6490 // after the xor/and/or operation.
6491 if (HandOpcode == ISD::VECTOR_SHUFFLE && Level < AfterLegalizeDAG) {
6492 auto *SVN0 = cast<ShuffleVectorSDNode>(Val&: N0);
6493 auto *SVN1 = cast<ShuffleVectorSDNode>(Val&: N1);
6494 assert(X.getValueType() == Y.getValueType() &&
6495 "Inputs to shuffles are not the same type");
6496
6497 // Check that both shuffles use the same mask. The masks are known to be of
6498 // the same length because the result vector type is the same.
6499 // Check also that shuffles have only one use to avoid introducing extra
6500 // instructions.
6501 if (!SVN0->hasOneUse() || !SVN1->hasOneUse() ||
6502 !SVN0->getMask().equals(RHS: SVN1->getMask()))
6503 return SDValue();
6504
6505 // Don't try to fold this node if it requires introducing a
6506 // build vector of all zeros that might be illegal at this stage.
6507 SDValue ShOp = N0.getOperand(i: 1);
6508 if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
6509 ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
6510
6511 // (logic_op (shuf (A, C), shuf (B, C))) --> shuf (logic_op (A, B), C)
6512 if (N0.getOperand(i: 1) == N1.getOperand(i: 1) && ShOp.getNode()) {
6513 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT,
6514 N1: N0.getOperand(i: 0), N2: N1.getOperand(i: 0));
6515 return DAG.getVectorShuffle(VT, dl: DL, N1: Logic, N2: ShOp, Mask: SVN0->getMask());
6516 }
6517
6518 // Don't try to fold this node if it requires introducing a
6519 // build vector of all zeros that might be illegal at this stage.
6520 ShOp = N0.getOperand(i: 0);
6521 if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
6522 ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
6523
6524 // (logic_op (shuf (C, A), shuf (C, B))) --> shuf (C, logic_op (A, B))
6525 if (N0.getOperand(i: 0) == N1.getOperand(i: 0) && ShOp.getNode()) {
6526 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: N0.getOperand(i: 1),
6527 N2: N1.getOperand(i: 1));
6528 return DAG.getVectorShuffle(VT, dl: DL, N1: ShOp, N2: Logic, Mask: SVN0->getMask());
6529 }
6530 }
6531
6532 return SDValue();
6533}
6534
6535/// Try to make (and/or setcc (LL, LR), setcc (RL, RR)) more efficient.
6536SDValue DAGCombiner::foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
6537 const SDLoc &DL) {
6538 SDValue LL, LR, RL, RR, N0CC, N1CC;
6539 if (!isSetCCEquivalent(N: N0, LHS&: LL, RHS&: LR, CC&: N0CC) ||
6540 !isSetCCEquivalent(N: N1, LHS&: RL, RHS&: RR, CC&: N1CC))
6541 return SDValue();
6542
6543 assert(N0.getValueType() == N1.getValueType() &&
6544 "Unexpected operand types for bitwise logic op");
6545 assert(LL.getValueType() == LR.getValueType() &&
6546 RL.getValueType() == RR.getValueType() &&
6547 "Unexpected operand types for setcc");
6548
6549 // If we're here post-legalization or the logic op type is not i1, the logic
6550 // op type must match a setcc result type. Also, all folds require new
6551 // operations on the left and right operands, so those types must match.
6552 EVT VT = N0.getValueType();
6553 EVT OpVT = LL.getValueType();
6554 if (LegalOperations || VT.getScalarType() != MVT::i1)
6555 if (VT != getSetCCResultType(VT: OpVT))
6556 return SDValue();
6557 if (OpVT != RL.getValueType())
6558 return SDValue();
6559
6560 ISD::CondCode CC0 = cast<CondCodeSDNode>(Val&: N0CC)->get();
6561 ISD::CondCode CC1 = cast<CondCodeSDNode>(Val&: N1CC)->get();
6562 bool IsInteger = OpVT.isInteger();
6563 if (LR == RR && CC0 == CC1 && IsInteger) {
6564 bool IsZero = isNullOrNullSplat(V: LR);
6565 bool IsNeg1 = isAllOnesOrAllOnesSplat(V: LR);
6566
6567 // All bits clear?
6568 bool AndEqZero = IsAnd && CC1 == ISD::SETEQ && IsZero;
6569 // All sign bits clear?
6570 bool AndGtNeg1 = IsAnd && CC1 == ISD::SETGT && IsNeg1;
6571 // Any bits set?
6572 bool OrNeZero = !IsAnd && CC1 == ISD::SETNE && IsZero;
6573 // Any sign bits set?
6574 bool OrLtZero = !IsAnd && CC1 == ISD::SETLT && IsZero;
6575
6576 // (and (seteq X, 0), (seteq Y, 0)) --> (seteq (or X, Y), 0)
6577 // (and (setgt X, -1), (setgt Y, -1)) --> (setgt (or X, Y), -1)
6578 // (or (setne X, 0), (setne Y, 0)) --> (setne (or X, Y), 0)
6579 // (or (setlt X, 0), (setlt Y, 0)) --> (setlt (or X, Y), 0)
6580 if (AndEqZero || AndGtNeg1 || OrNeZero || OrLtZero) {
6581 SDValue Or = DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N0), VT: OpVT, N1: LL, N2: RL);
6582 AddToWorklist(N: Or.getNode());
6583 return DAG.getSetCC(DL, VT, LHS: Or, RHS: LR, Cond: CC1);
6584 }
6585
6586 // All bits set?
6587 bool AndEqNeg1 = IsAnd && CC1 == ISD::SETEQ && IsNeg1;
6588 // All sign bits set?
6589 bool AndLtZero = IsAnd && CC1 == ISD::SETLT && IsZero;
6590 // Any bits clear?
6591 bool OrNeNeg1 = !IsAnd && CC1 == ISD::SETNE && IsNeg1;
6592 // Any sign bits clear?
6593 bool OrGtNeg1 = !IsAnd && CC1 == ISD::SETGT && IsNeg1;
6594
6595 // (and (seteq X, -1), (seteq Y, -1)) --> (seteq (and X, Y), -1)
6596 // (and (setlt X, 0), (setlt Y, 0)) --> (setlt (and X, Y), 0)
6597 // (or (setne X, -1), (setne Y, -1)) --> (setne (and X, Y), -1)
6598 // (or (setgt X, -1), (setgt Y -1)) --> (setgt (and X, Y), -1)
6599 if (AndEqNeg1 || AndLtZero || OrNeNeg1 || OrGtNeg1) {
6600 SDValue And = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(N0), VT: OpVT, N1: LL, N2: RL);
6601 AddToWorklist(N: And.getNode());
6602 return DAG.getSetCC(DL, VT, LHS: And, RHS: LR, Cond: CC1);
6603 }
6604 }
6605
6606 // TODO: What is the 'or' equivalent of this fold?
6607 // (and (setne X, 0), (setne X, -1)) --> (setuge (add X, 1), 2)
6608 if (IsAnd && LL == RL && CC0 == CC1 && OpVT.getScalarSizeInBits() > 1 &&
6609 IsInteger && CC0 == ISD::SETNE &&
6610 ((isNullConstant(V: LR) && isAllOnesConstant(V: RR)) ||
6611 (isAllOnesConstant(V: LR) && isNullConstant(V: RR)))) {
6612 SDValue One = DAG.getConstant(Val: 1, DL, VT: OpVT);
6613 SDValue Two = DAG.getConstant(Val: 2, DL, VT: OpVT);
6614 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL: SDLoc(N0), VT: OpVT, N1: LL, N2: One);
6615 AddToWorklist(N: Add.getNode());
6616 return DAG.getSetCC(DL, VT, LHS: Add, RHS: Two, Cond: ISD::SETUGE);
6617 }
6618
6619 // Try more general transforms if the predicates match and the only user of
6620 // the compares is the 'and' or 'or'.
6621 if (IsInteger && TLI.convertSetCCLogicToBitwiseLogic(VT: OpVT) && CC0 == CC1 &&
6622 N0.hasOneUse() && N1.hasOneUse()) {
6623 // and (seteq A, B), (seteq C, D) --> seteq (or (xor A, B), (xor C, D)), 0
6624 // or (setne A, B), (setne C, D) --> setne (or (xor A, B), (xor C, D)), 0
6625 if ((IsAnd && CC1 == ISD::SETEQ) || (!IsAnd && CC1 == ISD::SETNE)) {
6626 SDValue XorL = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N0), VT: OpVT, N1: LL, N2: LR);
6627 SDValue XorR = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N1), VT: OpVT, N1: RL, N2: RR);
6628 SDValue Or = DAG.getNode(Opcode: ISD::OR, DL, VT: OpVT, N1: XorL, N2: XorR);
6629 SDValue Zero = DAG.getConstant(Val: 0, DL, VT: OpVT);
6630 return DAG.getSetCC(DL, VT, LHS: Or, RHS: Zero, Cond: CC1);
6631 }
6632
6633 // Turn compare of constants whose difference is 1 bit into add+and+setcc.
6634 if ((IsAnd && CC1 == ISD::SETNE) || (!IsAnd && CC1 == ISD::SETEQ)) {
6635 // Match a shared variable operand and 2 non-opaque constant operands.
6636 auto MatchDiffPow2 = [&](ConstantSDNode *C0, ConstantSDNode *C1) {
6637 // The difference of the constants must be a single bit.
6638 const APInt &CMax =
6639 APIntOps::umax(A: C0->getAPIntValue(), B: C1->getAPIntValue());
6640 const APInt &CMin =
6641 APIntOps::umin(A: C0->getAPIntValue(), B: C1->getAPIntValue());
6642 return !C0->isOpaque() && !C1->isOpaque() && (CMax - CMin).isPowerOf2();
6643 };
6644 if (LL == RL && ISD::matchBinaryPredicate(LHS: LR, RHS: RR, Match: MatchDiffPow2)) {
6645 // and/or (setcc X, CMax, ne), (setcc X, CMin, ne/eq) -->
6646 // setcc ((sub X, CMin), ~(CMax - CMin)), 0, ne/eq
6647 SDValue Max = DAG.getNode(Opcode: ISD::UMAX, DL, VT: OpVT, N1: LR, N2: RR);
6648 SDValue Min = DAG.getNode(Opcode: ISD::UMIN, DL, VT: OpVT, N1: LR, N2: RR);
6649 SDValue Offset = DAG.getNode(Opcode: ISD::SUB, DL, VT: OpVT, N1: LL, N2: Min);
6650 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: OpVT, N1: Max, N2: Min);
6651 SDValue Mask = DAG.getNOT(DL, Val: Diff, VT: OpVT);
6652 SDValue And = DAG.getNode(Opcode: ISD::AND, DL, VT: OpVT, N1: Offset, N2: Mask);
6653 SDValue Zero = DAG.getConstant(Val: 0, DL, VT: OpVT);
6654 return DAG.getSetCC(DL, VT, LHS: And, RHS: Zero, Cond: CC0);
6655 }
6656 }
6657 }
6658
6659 // Canonicalize equivalent operands to LL == RL.
6660 if (LL == RR && LR == RL) {
6661 CC1 = ISD::getSetCCSwappedOperands(Operation: CC1);
6662 std::swap(a&: RL, b&: RR);
6663 }
6664
6665 // (and (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
6666 // (or (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
6667 if (LL == RL && LR == RR) {
6668 ISD::CondCode NewCC = IsAnd ? ISD::getSetCCAndOperation(Op1: CC0, Op2: CC1, Type: OpVT)
6669 : ISD::getSetCCOrOperation(Op1: CC0, Op2: CC1, Type: OpVT);
6670 if (NewCC != ISD::SETCC_INVALID &&
6671 (!LegalOperations ||
6672 (TLI.isCondCodeLegal(CC: NewCC, VT: LL.getSimpleValueType()) &&
6673 TLI.isOperationLegal(Op: ISD::SETCC, VT: OpVT))))
6674 return DAG.getSetCC(DL, VT, LHS: LL, RHS: LR, Cond: NewCC);
6675 }
6676
6677 return SDValue();
6678}
6679
6680static bool arebothOperandsNotSNan(SDValue Operand1, SDValue Operand2,
6681 SelectionDAG &DAG) {
6682 return DAG.isKnownNeverSNaN(Op: Operand2) && DAG.isKnownNeverSNaN(Op: Operand1);
6683}
6684
6685static bool arebothOperandsNotNan(SDValue Operand1, SDValue Operand2,
6686 SelectionDAG &DAG) {
6687 return DAG.isKnownNeverNaN(Op: Operand2) && DAG.isKnownNeverNaN(Op: Operand1);
6688}
6689
6690/// Returns an appropriate FP min/max opcode for clamping operations.
6691static unsigned getMinMaxOpcodeForClamp(bool IsMin, SDValue Operand1,
6692 SDValue Operand2, SelectionDAG &DAG,
6693 const TargetLowering &TLI) {
6694 EVT VT = Operand1.getValueType();
6695 unsigned IEEEOp = IsMin ? ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE;
6696 if (TLI.isOperationLegalOrCustom(Op: IEEEOp, VT) &&
6697 arebothOperandsNotNan(Operand1, Operand2, DAG))
6698 return IEEEOp;
6699 unsigned PreferredOp = IsMin ? ISD::FMINNUM : ISD::FMAXNUM;
6700 if (TLI.isOperationLegalOrCustom(Op: PreferredOp, VT))
6701 return PreferredOp;
6702 return ISD::DELETED_NODE;
6703}
6704
6705// FIXME: use FMINIMUMNUM if possible, such as for RISC-V.
6706static unsigned getMinMaxOpcodeForCompareFold(
6707 SDValue Operand1, SDValue Operand2, ISD::CondCode CC, unsigned OrAndOpcode,
6708 SelectionDAG &DAG, bool isFMAXNUMFMINNUM_IEEE, bool isFMAXNUMFMINNUM) {
6709 // The optimization cannot be applied for all the predicates because
6710 // of the way FMINNUM/FMAXNUM and FMINNUM_IEEE/FMAXNUM_IEEE handle
6711 // NaNs. For FMINNUM_IEEE/FMAXNUM_IEEE, the optimization cannot be
6712 // applied at all if one of the operands is a signaling NaN.
6713
6714 // It is safe to use FMINNUM_IEEE/FMAXNUM_IEEE if all the operands
6715 // are non NaN values.
6716 if (((CC == ISD::SETLT || CC == ISD::SETLE) && (OrAndOpcode == ISD::OR)) ||
6717 ((CC == ISD::SETGT || CC == ISD::SETGE) && (OrAndOpcode == ISD::AND))) {
6718 return arebothOperandsNotNan(Operand1, Operand2, DAG) &&
6719 isFMAXNUMFMINNUM_IEEE
6720 ? ISD::FMINNUM_IEEE
6721 : ISD::DELETED_NODE;
6722 }
6723
6724 if (((CC == ISD::SETGT || CC == ISD::SETGE) && (OrAndOpcode == ISD::OR)) ||
6725 ((CC == ISD::SETLT || CC == ISD::SETLE) && (OrAndOpcode == ISD::AND))) {
6726 return arebothOperandsNotNan(Operand1, Operand2, DAG) &&
6727 isFMAXNUMFMINNUM_IEEE
6728 ? ISD::FMAXNUM_IEEE
6729 : ISD::DELETED_NODE;
6730 }
6731
6732 // Both FMINNUM/FMAXNUM and FMINNUM_IEEE/FMAXNUM_IEEE handle quiet
6733 // NaNs in the same way. But, FMINNUM/FMAXNUM and FMINNUM_IEEE/
6734 // FMAXNUM_IEEE handle signaling NaNs differently. If we cannot prove
6735 // that there are not any sNaNs, then the optimization is not valid
6736 // for FMINNUM_IEEE/FMAXNUM_IEEE. In the presence of sNaNs, we apply
6737 // the optimization using FMINNUM/FMAXNUM for the following cases. If
6738 // we can prove that we do not have any sNaNs, then we can do the
6739 // optimization using FMINNUM_IEEE/FMAXNUM_IEEE for the following
6740 // cases.
6741 if (((CC == ISD::SETOLT || CC == ISD::SETOLE) && (OrAndOpcode == ISD::OR)) ||
6742 ((CC == ISD::SETUGT || CC == ISD::SETUGE) && (OrAndOpcode == ISD::AND))) {
6743 return isFMAXNUMFMINNUM ? ISD::FMINNUM
6744 : arebothOperandsNotSNan(Operand1, Operand2, DAG) &&
6745 isFMAXNUMFMINNUM_IEEE
6746 ? ISD::FMINNUM_IEEE
6747 : ISD::DELETED_NODE;
6748 }
6749
6750 if (((CC == ISD::SETOGT || CC == ISD::SETOGE) && (OrAndOpcode == ISD::OR)) ||
6751 ((CC == ISD::SETULT || CC == ISD::SETULE) && (OrAndOpcode == ISD::AND))) {
6752 return isFMAXNUMFMINNUM ? ISD::FMAXNUM
6753 : arebothOperandsNotSNan(Operand1, Operand2, DAG) &&
6754 isFMAXNUMFMINNUM_IEEE
6755 ? ISD::FMAXNUM_IEEE
6756 : ISD::DELETED_NODE;
6757 }
6758
6759 return ISD::DELETED_NODE;
6760}
6761
6762static SDValue foldAndOrOfSETCC(SDNode *LogicOp, SelectionDAG &DAG) {
6763 using AndOrSETCCFoldKind = TargetLowering::AndOrSETCCFoldKind;
6764 assert(
6765 (LogicOp->getOpcode() == ISD::AND || LogicOp->getOpcode() == ISD::OR) &&
6766 "Invalid Op to combine SETCC with");
6767
6768 // TODO: Search past casts/truncates.
6769 SDValue LHS = LogicOp->getOperand(Num: 0);
6770 SDValue RHS = LogicOp->getOperand(Num: 1);
6771 if (LHS->getOpcode() != ISD::SETCC || RHS->getOpcode() != ISD::SETCC ||
6772 !LHS->hasOneUse() || !RHS->hasOneUse())
6773 return SDValue();
6774
6775 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6776 AndOrSETCCFoldKind TargetPreference = TLI.isDesirableToCombineLogicOpOfSETCC(
6777 LogicOp, SETCC0: LHS.getNode(), SETCC1: RHS.getNode());
6778
6779 SDValue LHS0 = LHS->getOperand(Num: 0);
6780 SDValue RHS0 = RHS->getOperand(Num: 0);
6781 SDValue LHS1 = LHS->getOperand(Num: 1);
6782 SDValue RHS1 = RHS->getOperand(Num: 1);
6783 // TODO: We don't actually need a splat here, for vectors we just need the
6784 // invariants to hold for each element.
6785 auto *LHS1C = isConstOrConstSplat(N: LHS1);
6786 auto *RHS1C = isConstOrConstSplat(N: RHS1);
6787 ISD::CondCode CCL = cast<CondCodeSDNode>(Val: LHS.getOperand(i: 2))->get();
6788 ISD::CondCode CCR = cast<CondCodeSDNode>(Val: RHS.getOperand(i: 2))->get();
6789 EVT VT = LogicOp->getValueType(ResNo: 0);
6790 EVT OpVT = LHS0.getValueType();
6791 SDLoc DL(LogicOp);
6792
6793 // Check if the operands of an and/or operation are comparisons and if they
6794 // compare against the same value. Replace the and/or-cmp-cmp sequence with
6795 // min/max cmp sequence. If LHS1 is equal to RHS1, then the or-cmp-cmp
6796 // sequence will be replaced with min-cmp sequence:
6797 // (LHS0 < LHS1) | (RHS0 < RHS1) -> min(LHS0, RHS0) < LHS1
6798 // and and-cmp-cmp will be replaced with max-cmp sequence:
6799 // (LHS0 < LHS1) & (RHS0 < RHS1) -> max(LHS0, RHS0) < LHS1
6800 // The optimization does not work for `==` or `!=` .
6801 // The two comparisons should have either the same predicate or the
6802 // predicate of one of the comparisons is the opposite of the other one.
6803 bool isFMAXNUMFMINNUM_IEEE = TLI.isOperationLegal(Op: ISD::FMAXNUM_IEEE, VT: OpVT) &&
6804 TLI.isOperationLegal(Op: ISD::FMINNUM_IEEE, VT: OpVT);
6805 bool isFMAXNUMFMINNUM = TLI.isOperationLegalOrCustom(Op: ISD::FMAXNUM, VT: OpVT) &&
6806 TLI.isOperationLegalOrCustom(Op: ISD::FMINNUM, VT: OpVT);
6807 if (((OpVT.isInteger() && TLI.isOperationLegal(Op: ISD::UMAX, VT: OpVT) &&
6808 TLI.isOperationLegal(Op: ISD::SMAX, VT: OpVT) &&
6809 TLI.isOperationLegal(Op: ISD::UMIN, VT: OpVT) &&
6810 TLI.isOperationLegal(Op: ISD::SMIN, VT: OpVT)) ||
6811 (OpVT.isFloatingPoint() &&
6812 (isFMAXNUMFMINNUM_IEEE || isFMAXNUMFMINNUM))) &&
6813 !ISD::isIntEqualitySetCC(Code: CCL) && !ISD::isFPEqualitySetCC(Code: CCL) &&
6814 CCL != ISD::SETFALSE && CCL != ISD::SETO && CCL != ISD::SETUO &&
6815 CCL != ISD::SETTRUE &&
6816 (CCL == CCR || CCL == ISD::getSetCCSwappedOperands(Operation: CCR))) {
6817
6818 SDValue CommonValue, Operand1, Operand2;
6819 ISD::CondCode CC = ISD::SETCC_INVALID;
6820 if (CCL == CCR) {
6821 if (LHS0 == RHS0) {
6822 CommonValue = LHS0;
6823 Operand1 = LHS1;
6824 Operand2 = RHS1;
6825 CC = ISD::getSetCCSwappedOperands(Operation: CCL);
6826 } else if (LHS1 == RHS1) {
6827 CommonValue = LHS1;
6828 Operand1 = LHS0;
6829 Operand2 = RHS0;
6830 CC = CCL;
6831 }
6832 } else {
6833 assert(CCL == ISD::getSetCCSwappedOperands(CCR) && "Unexpected CC");
6834 if (LHS0 == RHS1) {
6835 CommonValue = LHS0;
6836 Operand1 = LHS1;
6837 Operand2 = RHS0;
6838 CC = CCR;
6839 } else if (RHS0 == LHS1) {
6840 CommonValue = LHS1;
6841 Operand1 = LHS0;
6842 Operand2 = RHS1;
6843 CC = CCL;
6844 }
6845 }
6846
6847 // Don't do this transform for sign bit tests. Let foldLogicOfSetCCs
6848 // handle it using OR/AND.
6849 if (CC == ISD::SETLT && isNullOrNullSplat(V: CommonValue))
6850 CC = ISD::SETCC_INVALID;
6851 else if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(V: CommonValue))
6852 CC = ISD::SETCC_INVALID;
6853
6854 if (CC != ISD::SETCC_INVALID) {
6855 unsigned NewOpcode = ISD::DELETED_NODE;
6856 bool IsSigned = isSignedIntSetCC(Code: CC);
6857 if (OpVT.isInteger()) {
6858 bool IsLess = (CC == ISD::SETLE || CC == ISD::SETULE ||
6859 CC == ISD::SETLT || CC == ISD::SETULT);
6860 bool IsOr = (LogicOp->getOpcode() == ISD::OR);
6861 if (IsLess == IsOr)
6862 NewOpcode = IsSigned ? ISD::SMIN : ISD::UMIN;
6863 else
6864 NewOpcode = IsSigned ? ISD::SMAX : ISD::UMAX;
6865 } else if (OpVT.isFloatingPoint())
6866 NewOpcode = getMinMaxOpcodeForCompareFold(
6867 Operand1, Operand2, CC, OrAndOpcode: LogicOp->getOpcode(), DAG,
6868 isFMAXNUMFMINNUM_IEEE, isFMAXNUMFMINNUM);
6869
6870 if (NewOpcode != ISD::DELETED_NODE) {
6871 SDValue MinMaxValue =
6872 DAG.getNode(Opcode: NewOpcode, DL, VT: OpVT, N1: Operand1, N2: Operand2);
6873 return DAG.getSetCC(DL, VT, LHS: MinMaxValue, RHS: CommonValue, Cond: CC);
6874 }
6875 }
6876 }
6877
6878 if (LHS0 == LHS1 && RHS0 == RHS1 && CCL == CCR &&
6879 LHS0.getValueType() == RHS0.getValueType() &&
6880 ((LogicOp->getOpcode() == ISD::AND && CCL == ISD::SETO) ||
6881 (LogicOp->getOpcode() == ISD::OR && CCL == ISD::SETUO)))
6882 return DAG.getSetCC(DL, VT, LHS: LHS0, RHS: RHS0, Cond: CCL);
6883
6884 if (TargetPreference == AndOrSETCCFoldKind::None)
6885 return SDValue();
6886
6887 if (CCL == CCR &&
6888 CCL == (LogicOp->getOpcode() == ISD::AND ? ISD::SETNE : ISD::SETEQ) &&
6889 LHS0 == RHS0 && LHS1C && RHS1C && OpVT.isInteger()) {
6890 const APInt &APLhs = LHS1C->getAPIntValue();
6891 const APInt &APRhs = RHS1C->getAPIntValue();
6892
6893 // Preference is to use ISD::ABS or we already have an ISD::ABS (in which
6894 // case this is just a compare).
6895 if (APLhs == (-APRhs) &&
6896 ((TargetPreference & AndOrSETCCFoldKind::ABS) ||
6897 DAG.doesNodeExist(Opcode: ISD::ABS, VTList: DAG.getVTList(VT: OpVT), Ops: {LHS0}))) {
6898 const APInt &C = APLhs.isNegative() ? APRhs : APLhs;
6899 // (icmp eq A, C) | (icmp eq A, -C)
6900 // -> (icmp eq Abs(A), C)
6901 // (icmp ne A, C) & (icmp ne A, -C)
6902 // -> (icmp ne Abs(A), C)
6903 SDValue AbsOp = DAG.getNode(Opcode: ISD::ABS, DL, VT: OpVT, Operand: LHS0);
6904 return DAG.getNode(Opcode: ISD::SETCC, DL, VT, N1: AbsOp,
6905 N2: DAG.getConstant(Val: C, DL, VT: OpVT), N3: LHS.getOperand(i: 2));
6906 } else if (TargetPreference &
6907 (AndOrSETCCFoldKind::AddAnd | AndOrSETCCFoldKind::NotAnd)) {
6908
6909 // AndOrSETCCFoldKind::AddAnd:
6910 // A == C0 | A == C1
6911 // IF IsPow2(smax(C0, C1)-smin(C0, C1))
6912 // -> ((A - smin(C0, C1)) & ~(smax(C0, C1)-smin(C0, C1))) == 0
6913 // A != C0 & A != C1
6914 // IF IsPow2(smax(C0, C1)-smin(C0, C1))
6915 // -> ((A - smin(C0, C1)) & ~(smax(C0, C1)-smin(C0, C1))) != 0
6916
6917 // AndOrSETCCFoldKind::NotAnd:
6918 // A == C0 | A == C1
6919 // IF smax(C0, C1) == -1 AND IsPow2(smax(C0, C1) - smin(C0, C1))
6920 // -> ~A & smin(C0, C1) == 0
6921 // A != C0 & A != C1
6922 // IF smax(C0, C1) == -1 AND IsPow2(smax(C0, C1) - smin(C0, C1))
6923 // -> ~A & smin(C0, C1) != 0
6924
6925 const APInt &MaxC = APIntOps::smax(A: APRhs, B: APLhs);
6926 const APInt &MinC = APIntOps::smin(A: APRhs, B: APLhs);
6927 APInt Dif = MaxC - MinC;
6928 if (!Dif.isZero() && Dif.isPowerOf2()) {
6929 if (MaxC.isAllOnes() &&
6930 (TargetPreference & AndOrSETCCFoldKind::NotAnd)) {
6931 SDValue NotOp = DAG.getNOT(DL, Val: LHS0, VT: OpVT);
6932 SDValue AndOp = DAG.getNode(Opcode: ISD::AND, DL, VT: OpVT, N1: NotOp,
6933 N2: DAG.getConstant(Val: MinC, DL, VT: OpVT));
6934 return DAG.getNode(Opcode: ISD::SETCC, DL, VT, N1: AndOp,
6935 N2: DAG.getConstant(Val: 0, DL, VT: OpVT), N3: LHS.getOperand(i: 2));
6936 } else if (TargetPreference & AndOrSETCCFoldKind::AddAnd) {
6937
6938 SDValue AddOp = DAG.getNode(Opcode: ISD::ADD, DL, VT: OpVT, N1: LHS0,
6939 N2: DAG.getConstant(Val: -MinC, DL, VT: OpVT));
6940 SDValue AndOp = DAG.getNode(Opcode: ISD::AND, DL, VT: OpVT, N1: AddOp,
6941 N2: DAG.getConstant(Val: ~Dif, DL, VT: OpVT));
6942 return DAG.getNode(Opcode: ISD::SETCC, DL, VT, N1: AndOp,
6943 N2: DAG.getConstant(Val: 0, DL, VT: OpVT), N3: LHS.getOperand(i: 2));
6944 }
6945 }
6946 }
6947 }
6948
6949 return SDValue();
6950}
6951
6952// Combine `(select c, (X & 1), 0)` -> `(and (zext c), X)`.
6953// We canonicalize to the `select` form in the middle end, but the `and` form
6954// gets better codegen and all tested targets (arm, x86, riscv)
6955static SDValue combineSelectAsExtAnd(SDValue Cond, SDValue T, SDValue F,
6956 const SDLoc &DL, SelectionDAG &DAG) {
6957 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6958 if (!isNullConstant(V: F))
6959 return SDValue();
6960
6961 EVT CondVT = Cond.getValueType();
6962 if (TLI.getBooleanContents(Type: CondVT) !=
6963 TargetLoweringBase::ZeroOrOneBooleanContent)
6964 return SDValue();
6965
6966 if (T.getOpcode() != ISD::AND)
6967 return SDValue();
6968
6969 if (!isOneConstant(V: T.getOperand(i: 1)))
6970 return SDValue();
6971
6972 EVT OpVT = T.getValueType();
6973
6974 SDValue CondMask =
6975 OpVT == CondVT ? Cond : DAG.getBoolExtOrTrunc(Op: Cond, SL: DL, VT: OpVT, OpVT: CondVT);
6976 return DAG.getNode(Opcode: ISD::AND, DL, VT: OpVT, N1: CondMask, N2: T.getOperand(i: 0));
6977}
6978
6979/// This contains all DAGCombine rules which reduce two values combined by
6980/// an And operation to a single value. This makes them reusable in the context
6981/// of visitSELECT(). Rules involving constants are not included as
6982/// visitSELECT() already handles those cases.
6983SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1, SDNode *N) {
6984 EVT VT = N1.getValueType();
6985 SDLoc DL(N);
6986
6987 // fold (and x, undef) -> 0
6988 if (N0.isUndef() || N1.isUndef())
6989 return DAG.getConstant(Val: 0, DL, VT);
6990
6991 if (SDValue V = foldLogicOfSetCCs(IsAnd: true, N0, N1, DL))
6992 return V;
6993
6994 // Canonicalize:
6995 // and(x, add) -> and(add, x)
6996 if (N1.getOpcode() == ISD::ADD)
6997 std::swap(a&: N0, b&: N1);
6998
6999 // TODO: Rewrite this to return a new 'AND' instead of using CombineTo.
7000 if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::SRL &&
7001 VT.isScalarInteger() && VT.getSizeInBits() <= 64 && N0->hasOneUse()) {
7002 if (ConstantSDNode *ADDI = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1))) {
7003 if (ConstantSDNode *SRLI = dyn_cast<ConstantSDNode>(Val: N1.getOperand(i: 1))) {
7004 // Look for (and (add x, c1), (lshr y, c2)). If C1 wasn't a legal
7005 // immediate for an add, but it is legal if its top c2 bits are set,
7006 // transform the ADD so the immediate doesn't need to be materialized
7007 // in a register.
7008 APInt ADDC = ADDI->getAPIntValue();
7009 APInt SRLC = SRLI->getAPIntValue();
7010 if (ADDC.getSignificantBits() <= 64 && SRLC.ult(RHS: VT.getSizeInBits()) &&
7011 !TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
7012 APInt Mask = APInt::getHighBitsSet(numBits: VT.getSizeInBits(),
7013 hiBitsSet: SRLC.getZExtValue());
7014 if (DAG.MaskedValueIsZero(Op: N0.getOperand(i: 1), Mask)) {
7015 ADDC |= Mask;
7016 if (TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
7017 SDLoc DL0(N0);
7018 SDValue NewAdd =
7019 DAG.getNode(Opcode: ISD::ADD, DL: DL0, VT,
7020 N1: N0.getOperand(i: 0), N2: DAG.getConstant(Val: ADDC, DL, VT));
7021 CombineTo(N: N0.getNode(), Res: NewAdd);
7022 // Return N so it doesn't get rechecked!
7023 return SDValue(N, 0);
7024 }
7025 }
7026 }
7027 }
7028 }
7029 }
7030
7031 return SDValue();
7032}
7033
7034bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
7035 EVT LoadResultTy, EVT &ExtVT) {
7036 if (!AndC->getAPIntValue().isMask())
7037 return false;
7038
7039 unsigned ActiveBits = AndC->getAPIntValue().countr_one();
7040
7041 ExtVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: ActiveBits);
7042 EVT LoadedVT = LoadN->getMemoryVT();
7043
7044 if (ExtVT == LoadedVT &&
7045 (!LegalOperations ||
7046 TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: LoadResultTy, MemVT: ExtVT))) {
7047 // ZEXTLOAD will match without needing to change the size of the value being
7048 // loaded.
7049 return true;
7050 }
7051
7052 // Do not change the width of a volatile or atomic loads.
7053 if (!LoadN->isSimple())
7054 return false;
7055
7056 // Do not generate loads of non-round integer types since these can
7057 // be expensive (and would be wrong if the type is not byte sized).
7058 if (!LoadedVT.bitsGT(VT: ExtVT) || !ExtVT.isRound())
7059 return false;
7060
7061 if (LegalOperations &&
7062 !TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: LoadResultTy, MemVT: ExtVT))
7063 return false;
7064
7065 if (!TLI.shouldReduceLoadWidth(Load: LoadN, ExtTy: ISD::ZEXTLOAD, NewVT: ExtVT, /*ByteOffset=*/0))
7066 return false;
7067
7068 return true;
7069}
7070
7071bool DAGCombiner::isLegalNarrowLdSt(LSBaseSDNode *LDST,
7072 ISD::LoadExtType ExtType, EVT &MemVT,
7073 unsigned ShAmt) {
7074 if (!LDST)
7075 return false;
7076
7077 // Only allow byte offsets.
7078 if (ShAmt % 8)
7079 return false;
7080 const unsigned ByteShAmt = ShAmt / 8;
7081
7082 // Do not generate loads of non-round integer types since these can
7083 // be expensive (and would be wrong if the type is not byte sized).
7084 if (!MemVT.isRound())
7085 return false;
7086
7087 // Don't change the width of a volatile or atomic loads.
7088 if (!LDST->isSimple())
7089 return false;
7090
7091 EVT LdStMemVT = LDST->getMemoryVT();
7092
7093 // Bail out when changing the scalable property, since we can't be sure that
7094 // we're actually narrowing here.
7095 if (LdStMemVT.isScalableVector() != MemVT.isScalableVector())
7096 return false;
7097
7098 // Verify that we are actually reducing a load width here.
7099 if (LdStMemVT.bitsLT(VT: MemVT))
7100 return false;
7101
7102 // Ensure that this isn't going to produce an unsupported memory access.
7103 if (ShAmt) {
7104 const Align LDSTAlign = LDST->getAlign();
7105 const Align NarrowAlign = commonAlignment(A: LDSTAlign, Offset: ByteShAmt);
7106 if (!TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT: MemVT,
7107 AddrSpace: LDST->getAddressSpace(), Alignment: NarrowAlign,
7108 Flags: LDST->getMemOperand()->getFlags()))
7109 return false;
7110 }
7111
7112 // It's not possible to generate a constant of extended or untyped type.
7113 EVT PtrType = LDST->getBasePtr().getValueType();
7114 if (PtrType == MVT::Untyped || PtrType.isExtended())
7115 return false;
7116
7117 if (isa<LoadSDNode>(Val: LDST)) {
7118 LoadSDNode *Load = cast<LoadSDNode>(Val: LDST);
7119 // Don't transform one with multiple uses, this would require adding a new
7120 // load.
7121 if (!SDValue(Load, 0).hasOneUse())
7122 return false;
7123
7124 if (LegalOperations &&
7125 !TLI.isLoadExtLegal(ExtType, ValVT: Load->getValueType(ResNo: 0), MemVT))
7126 return false;
7127
7128 // For the transform to be legal, the load must produce only two values
7129 // (the value loaded and the chain). Don't transform a pre-increment
7130 // load, for example, which produces an extra value. Otherwise the
7131 // transformation is not equivalent, and the downstream logic to replace
7132 // uses gets things wrong.
7133 if (Load->getNumValues() > 2)
7134 return false;
7135
7136 // If the load that we're shrinking is an extload and we're not just
7137 // discarding the extension we can't simply shrink the load. Bail.
7138 // TODO: It would be possible to merge the extensions in some cases.
7139 if (Load->getExtensionType() != ISD::NON_EXTLOAD &&
7140 Load->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
7141 return false;
7142
7143 if (!TLI.shouldReduceLoadWidth(Load, ExtTy: ExtType, NewVT: MemVT, ByteOffset: ByteShAmt))
7144 return false;
7145 } else {
7146 assert(isa<StoreSDNode>(LDST) && "It is not a Load nor a Store SDNode");
7147 StoreSDNode *Store = cast<StoreSDNode>(Val: LDST);
7148 // Can't write outside the original store
7149 if (Store->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
7150 return false;
7151
7152 if (LegalOperations &&
7153 !TLI.isTruncStoreLegal(ValVT: Store->getValue().getValueType(), MemVT))
7154 return false;
7155 }
7156 return true;
7157}
7158
7159bool DAGCombiner::SearchForAndLoads(SDNode *N,
7160 SmallVectorImpl<LoadSDNode*> &Loads,
7161 SmallPtrSetImpl<SDNode*> &NodesWithConsts,
7162 ConstantSDNode *Mask,
7163 SDNode *&NodeToMask) {
7164 // Recursively search for the operands, looking for loads which can be
7165 // narrowed.
7166 for (SDValue Op : N->op_values()) {
7167 if (Op.getValueType().isVector())
7168 return false;
7169
7170 // Some constants may need fixing up later if they are too large.
7171 if (auto *C = dyn_cast<ConstantSDNode>(Val&: Op)) {
7172 assert(ISD::isBitwiseLogicOp(N->getOpcode()) &&
7173 "Expected bitwise logic operation");
7174 if (!C->getAPIntValue().isSubsetOf(RHS: Mask->getAPIntValue()))
7175 NodesWithConsts.insert(Ptr: N);
7176 continue;
7177 }
7178
7179 if (!Op.hasOneUse())
7180 return false;
7181
7182 switch(Op.getOpcode()) {
7183 case ISD::LOAD: {
7184 auto *Load = cast<LoadSDNode>(Val&: Op);
7185 EVT ExtVT;
7186 if (isAndLoadExtLoad(AndC: Mask, LoadN: Load, LoadResultTy: Load->getValueType(ResNo: 0), ExtVT) &&
7187 isLegalNarrowLdSt(LDST: Load, ExtType: ISD::ZEXTLOAD, MemVT&: ExtVT)) {
7188
7189 // ZEXTLOAD is already small enough.
7190 if (Load->getExtensionType() == ISD::ZEXTLOAD &&
7191 ExtVT.bitsGE(VT: Load->getMemoryVT()))
7192 continue;
7193
7194 // Use LE to convert equal sized loads to zext.
7195 if (ExtVT.bitsLE(VT: Load->getMemoryVT()))
7196 Loads.push_back(Elt: Load);
7197
7198 continue;
7199 }
7200 return false;
7201 }
7202 case ISD::ZERO_EXTEND:
7203 case ISD::AssertZext: {
7204 unsigned ActiveBits = Mask->getAPIntValue().countr_one();
7205 EVT ExtVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: ActiveBits);
7206 EVT VT = Op.getOpcode() == ISD::AssertZext ?
7207 cast<VTSDNode>(Val: Op.getOperand(i: 1))->getVT() :
7208 Op.getOperand(i: 0).getValueType();
7209
7210 // We can accept extending nodes if the mask is wider or an equal
7211 // width to the original type.
7212 if (ExtVT.bitsGE(VT))
7213 continue;
7214 break;
7215 }
7216 case ISD::OR:
7217 case ISD::XOR:
7218 case ISD::AND:
7219 if (!SearchForAndLoads(N: Op.getNode(), Loads, NodesWithConsts, Mask,
7220 NodeToMask))
7221 return false;
7222 continue;
7223 }
7224
7225 // Allow one node which will masked along with any loads found.
7226 if (NodeToMask)
7227 return false;
7228
7229 // Also ensure that the node to be masked only produces one data result.
7230 NodeToMask = Op.getNode();
7231 if (NodeToMask->getNumValues() > 1) {
7232 bool HasValue = false;
7233 for (unsigned i = 0, e = NodeToMask->getNumValues(); i < e; ++i) {
7234 MVT VT = SDValue(NodeToMask, i).getSimpleValueType();
7235 if (VT != MVT::Glue && VT != MVT::Other) {
7236 if (HasValue) {
7237 NodeToMask = nullptr;
7238 return false;
7239 }
7240 HasValue = true;
7241 }
7242 }
7243 assert(HasValue && "Node to be masked has no data result?");
7244 }
7245 }
7246 return true;
7247}
7248
7249bool DAGCombiner::BackwardsPropagateMask(SDNode *N) {
7250 auto *Mask = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
7251 if (!Mask)
7252 return false;
7253
7254 if (!Mask->getAPIntValue().isMask())
7255 return false;
7256
7257 // No need to do anything if the and directly uses a load.
7258 if (isa<LoadSDNode>(Val: N->getOperand(Num: 0)))
7259 return false;
7260
7261 SmallVector<LoadSDNode*, 8> Loads;
7262 SmallPtrSet<SDNode*, 2> NodesWithConsts;
7263 SDNode *FixupNode = nullptr;
7264 if (SearchForAndLoads(N, Loads, NodesWithConsts, Mask, NodeToMask&: FixupNode)) {
7265 if (Loads.empty())
7266 return false;
7267
7268 LLVM_DEBUG(dbgs() << "Backwards propagate AND: "; N->dump());
7269 SDValue MaskOp = N->getOperand(Num: 1);
7270
7271 // If it exists, fixup the single node we allow in the tree that needs
7272 // masking.
7273 if (FixupNode) {
7274 LLVM_DEBUG(dbgs() << "First, need to fix up: "; FixupNode->dump());
7275 SDValue And = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(FixupNode),
7276 VT: FixupNode->getValueType(ResNo: 0),
7277 N1: SDValue(FixupNode, 0), N2: MaskOp);
7278 DAG.ReplaceAllUsesOfValueWith(From: SDValue(FixupNode, 0), To: And);
7279 if (And.getOpcode() == ISD ::AND)
7280 DAG.UpdateNodeOperands(N: And.getNode(), Op1: SDValue(FixupNode, 0), Op2: MaskOp);
7281 }
7282
7283 // Narrow any constants that need it.
7284 for (auto *LogicN : NodesWithConsts) {
7285 SDValue Op0 = LogicN->getOperand(Num: 0);
7286 SDValue Op1 = LogicN->getOperand(Num: 1);
7287
7288 // We only need to fix AND if both inputs are constants. And we only need
7289 // to fix one of the constants.
7290 if (LogicN->getOpcode() == ISD::AND &&
7291 (!isa<ConstantSDNode>(Val: Op0) || !isa<ConstantSDNode>(Val: Op1)))
7292 continue;
7293
7294 if (isa<ConstantSDNode>(Val: Op0) && LogicN->getOpcode() != ISD::AND)
7295 Op0 =
7296 DAG.getNode(Opcode: ISD::AND, DL: SDLoc(Op0), VT: Op0.getValueType(), N1: Op0, N2: MaskOp);
7297
7298 if (isa<ConstantSDNode>(Val: Op1))
7299 Op1 =
7300 DAG.getNode(Opcode: ISD::AND, DL: SDLoc(Op1), VT: Op1.getValueType(), N1: Op1, N2: MaskOp);
7301
7302 if (isa<ConstantSDNode>(Val: Op0) && !isa<ConstantSDNode>(Val: Op1))
7303 std::swap(a&: Op0, b&: Op1);
7304
7305 DAG.UpdateNodeOperands(N: LogicN, Op1: Op0, Op2: Op1);
7306 }
7307
7308 // Create narrow loads.
7309 for (auto *Load : Loads) {
7310 LLVM_DEBUG(dbgs() << "Propagate AND back to: "; Load->dump());
7311 SDValue And = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(Load), VT: Load->getValueType(ResNo: 0),
7312 N1: SDValue(Load, 0), N2: MaskOp);
7313 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Load, 0), To: And);
7314 if (And.getOpcode() == ISD ::AND)
7315 And = SDValue(
7316 DAG.UpdateNodeOperands(N: And.getNode(), Op1: SDValue(Load, 0), Op2: MaskOp), 0);
7317 SDValue NewLoad = reduceLoadWidth(N: And.getNode());
7318 assert(NewLoad &&
7319 "Shouldn't be masking the load if it can't be narrowed");
7320 CombineTo(N: Load, Res0: NewLoad, Res1: NewLoad.getValue(R: 1));
7321 }
7322 DAG.ReplaceAllUsesWith(From: N, To: N->getOperand(Num: 0).getNode());
7323 return true;
7324 }
7325 return false;
7326}
7327
7328// Unfold
7329// x & (-1 'logical shift' y)
7330// To
7331// (x 'opposite logical shift' y) 'logical shift' y
7332// if it is better for performance.
7333SDValue DAGCombiner::unfoldExtremeBitClearingToShifts(SDNode *N) {
7334 assert(N->getOpcode() == ISD::AND);
7335
7336 SDValue N0 = N->getOperand(Num: 0);
7337 SDValue N1 = N->getOperand(Num: 1);
7338
7339 // Do we actually prefer shifts over mask?
7340 if (!TLI.shouldFoldMaskToVariableShiftPair(X: N0))
7341 return SDValue();
7342
7343 // Try to match (-1 '[outer] logical shift' y)
7344 unsigned OuterShift;
7345 unsigned InnerShift; // The opposite direction to the OuterShift.
7346 SDValue Y; // Shift amount.
7347 auto matchMask = [&OuterShift, &InnerShift, &Y](SDValue M) -> bool {
7348 if (!M.hasOneUse())
7349 return false;
7350 OuterShift = M->getOpcode();
7351 if (OuterShift == ISD::SHL)
7352 InnerShift = ISD::SRL;
7353 else if (OuterShift == ISD::SRL)
7354 InnerShift = ISD::SHL;
7355 else
7356 return false;
7357 if (!isAllOnesConstant(V: M->getOperand(Num: 0)))
7358 return false;
7359 Y = M->getOperand(Num: 1);
7360 return true;
7361 };
7362
7363 SDValue X;
7364 if (matchMask(N1))
7365 X = N0;
7366 else if (matchMask(N0))
7367 X = N1;
7368 else
7369 return SDValue();
7370
7371 SDLoc DL(N);
7372 EVT VT = N->getValueType(ResNo: 0);
7373
7374 // tmp = x 'opposite logical shift' y
7375 SDValue T0 = DAG.getNode(Opcode: InnerShift, DL, VT, N1: X, N2: Y);
7376 // ret = tmp 'logical shift' y
7377 SDValue T1 = DAG.getNode(Opcode: OuterShift, DL, VT, N1: T0, N2: Y);
7378
7379 return T1;
7380}
7381
7382/// Try to replace shift/logic that tests if a bit is clear with mask + setcc.
7383/// For a target with a bit test, this is expected to become test + set and save
7384/// at least 1 instruction.
7385static SDValue combineShiftAnd1ToBitTest(SDNode *And, SelectionDAG &DAG) {
7386 assert(And->getOpcode() == ISD::AND && "Expected an 'and' op");
7387
7388 // Look through an optional extension.
7389 SDValue And0 = And->getOperand(Num: 0), And1 = And->getOperand(Num: 1);
7390 if (And0.getOpcode() == ISD::ANY_EXTEND && And0.hasOneUse())
7391 And0 = And0.getOperand(i: 0);
7392 if (!isOneConstant(V: And1) || !And0.hasOneUse())
7393 return SDValue();
7394
7395 SDValue Src = And0;
7396
7397 // Attempt to find a 'not' op.
7398 // TODO: Should we favor test+set even without the 'not' op?
7399 bool FoundNot = false;
7400 if (isBitwiseNot(V: Src)) {
7401 FoundNot = true;
7402 Src = Src.getOperand(i: 0);
7403
7404 // Look though an optional truncation. The source operand may not be the
7405 // same type as the original 'and', but that is ok because we are masking
7406 // off everything but the low bit.
7407 if (Src.getOpcode() == ISD::TRUNCATE && Src.hasOneUse())
7408 Src = Src.getOperand(i: 0);
7409 }
7410
7411 // Match a shift-right by constant.
7412 if (Src.getOpcode() != ISD::SRL || !Src.hasOneUse())
7413 return SDValue();
7414
7415 // This is probably not worthwhile without a supported type.
7416 EVT SrcVT = Src.getValueType();
7417 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
7418 if (!TLI.isTypeLegal(VT: SrcVT))
7419 return SDValue();
7420
7421 // We might have looked through casts that make this transform invalid.
7422 unsigned BitWidth = SrcVT.getScalarSizeInBits();
7423 SDValue ShiftAmt = Src.getOperand(i: 1);
7424 auto *ShiftAmtC = dyn_cast<ConstantSDNode>(Val&: ShiftAmt);
7425 if (!ShiftAmtC || !ShiftAmtC->getAPIntValue().ult(RHS: BitWidth))
7426 return SDValue();
7427
7428 // Set source to shift source.
7429 Src = Src.getOperand(i: 0);
7430
7431 // Try again to find a 'not' op.
7432 // TODO: Should we favor test+set even with two 'not' ops?
7433 if (!FoundNot) {
7434 if (!isBitwiseNot(V: Src))
7435 return SDValue();
7436 Src = Src.getOperand(i: 0);
7437 }
7438
7439 if (!TLI.hasBitTest(X: Src, Y: ShiftAmt))
7440 return SDValue();
7441
7442 // Turn this into a bit-test pattern using mask op + setcc:
7443 // and (not (srl X, C)), 1 --> (and X, 1<<C) == 0
7444 // and (srl (not X), C)), 1 --> (and X, 1<<C) == 0
7445 SDLoc DL(And);
7446 SDValue X = DAG.getZExtOrTrunc(Op: Src, DL, VT: SrcVT);
7447 EVT CCVT =
7448 TLI.getSetCCResultType(DL: DAG.getDataLayout(), Context&: *DAG.getContext(), VT: SrcVT);
7449 SDValue Mask = DAG.getConstant(
7450 Val: APInt::getOneBitSet(numBits: BitWidth, BitNo: ShiftAmtC->getZExtValue()), DL, VT: SrcVT);
7451 SDValue NewAnd = DAG.getNode(Opcode: ISD::AND, DL, VT: SrcVT, N1: X, N2: Mask);
7452 SDValue Zero = DAG.getConstant(Val: 0, DL, VT: SrcVT);
7453 SDValue Setcc = DAG.getSetCC(DL, VT: CCVT, LHS: NewAnd, RHS: Zero, Cond: ISD::SETEQ);
7454 return DAG.getZExtOrTrunc(Op: Setcc, DL, VT: And->getValueType(ResNo: 0));
7455}
7456
7457/// For targets that support usubsat, match a bit-hack form of that operation
7458/// that ends in 'and' and convert it.
7459static SDValue foldAndToUsubsat(SDNode *N, SelectionDAG &DAG, const SDLoc &DL) {
7460 EVT VT = N->getValueType(ResNo: 0);
7461 unsigned BitWidth = VT.getScalarSizeInBits();
7462 APInt SignMask = APInt::getSignMask(BitWidth);
7463
7464 // (i8 X ^ 128) & (i8 X s>> 7) --> usubsat X, 128
7465 // (i8 X + 128) & (i8 X s>> 7) --> usubsat X, 128
7466 // xor/add with SMIN (signmask) are logically equivalent.
7467 SDValue X;
7468 if (!sd_match(N, P: m_And(L: m_OneUse(P: m_Xor(L: m_Value(N&: X), R: m_SpecificInt(V: SignMask))),
7469 R: m_OneUse(P: m_Sra(L: m_Deferred(V&: X),
7470 R: m_SpecificInt(V: BitWidth - 1))))) &&
7471 !sd_match(N, P: m_And(L: m_OneUse(P: m_Add(L: m_Value(N&: X), R: m_SpecificInt(V: SignMask))),
7472 R: m_OneUse(P: m_Sra(L: m_Deferred(V&: X),
7473 R: m_SpecificInt(V: BitWidth - 1))))))
7474 return SDValue();
7475
7476 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT, N1: X,
7477 N2: DAG.getConstant(Val: SignMask, DL, VT));
7478}
7479
7480/// Given a bitwise logic operation N with a matching bitwise logic operand,
7481/// fold a pattern where 2 of the source operands are identically shifted
7482/// values. For example:
7483/// ((X0 << Y) | Z) | (X1 << Y) --> ((X0 | X1) << Y) | Z
7484static SDValue foldLogicOfShifts(SDNode *N, SDValue LogicOp, SDValue ShiftOp,
7485 SelectionDAG &DAG) {
7486 unsigned LogicOpcode = N->getOpcode();
7487 assert(ISD::isBitwiseLogicOp(LogicOpcode) &&
7488 "Expected bitwise logic operation");
7489
7490 if (!LogicOp.hasOneUse() || !ShiftOp.hasOneUse())
7491 return SDValue();
7492
7493 // Match another bitwise logic op and a shift.
7494 unsigned ShiftOpcode = ShiftOp.getOpcode();
7495 if (LogicOp.getOpcode() != LogicOpcode ||
7496 !(ShiftOpcode == ISD::SHL || ShiftOpcode == ISD::SRL ||
7497 ShiftOpcode == ISD::SRA))
7498 return SDValue();
7499
7500 // Match another shift op inside the first logic operand. Handle both commuted
7501 // possibilities.
7502 // LOGIC (LOGIC (SH X0, Y), Z), (SH X1, Y) --> LOGIC (SH (LOGIC X0, X1), Y), Z
7503 // LOGIC (LOGIC Z, (SH X0, Y)), (SH X1, Y) --> LOGIC (SH (LOGIC X0, X1), Y), Z
7504 SDValue X1 = ShiftOp.getOperand(i: 0);
7505 SDValue Y = ShiftOp.getOperand(i: 1);
7506 SDValue X0, Z;
7507 if (LogicOp.getOperand(i: 0).getOpcode() == ShiftOpcode &&
7508 LogicOp.getOperand(i: 0).getOperand(i: 1) == Y) {
7509 X0 = LogicOp.getOperand(i: 0).getOperand(i: 0);
7510 Z = LogicOp.getOperand(i: 1);
7511 } else if (LogicOp.getOperand(i: 1).getOpcode() == ShiftOpcode &&
7512 LogicOp.getOperand(i: 1).getOperand(i: 1) == Y) {
7513 X0 = LogicOp.getOperand(i: 1).getOperand(i: 0);
7514 Z = LogicOp.getOperand(i: 0);
7515 } else {
7516 return SDValue();
7517 }
7518
7519 EVT VT = N->getValueType(ResNo: 0);
7520 SDLoc DL(N);
7521 SDValue LogicX = DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: X0, N2: X1);
7522 SDValue NewShift = DAG.getNode(Opcode: ShiftOpcode, DL, VT, N1: LogicX, N2: Y);
7523 return DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: NewShift, N2: Z);
7524}
7525
7526/// Given a tree of logic operations with shape like
7527/// (LOGIC (LOGIC (X, Y), LOGIC (Z, Y)))
7528/// try to match and fold shift operations with the same shift amount.
7529/// For example:
7530/// LOGIC (LOGIC (SH X0, Y), Z), (LOGIC (SH X1, Y), W) -->
7531/// --> LOGIC (SH (LOGIC X0, X1), Y), (LOGIC Z, W)
7532static SDValue foldLogicTreeOfShifts(SDNode *N, SDValue LeftHand,
7533 SDValue RightHand, SelectionDAG &DAG) {
7534 unsigned LogicOpcode = N->getOpcode();
7535 assert(ISD::isBitwiseLogicOp(LogicOpcode) &&
7536 "Expected bitwise logic operation");
7537 if (LeftHand.getOpcode() != LogicOpcode ||
7538 RightHand.getOpcode() != LogicOpcode)
7539 return SDValue();
7540 if (!LeftHand.hasOneUse() || !RightHand.hasOneUse())
7541 return SDValue();
7542
7543 // Try to match one of following patterns:
7544 // LOGIC (LOGIC (SH X0, Y), Z), (LOGIC (SH X1, Y), W)
7545 // LOGIC (LOGIC (SH X0, Y), Z), (LOGIC W, (SH X1, Y))
7546 // Note that foldLogicOfShifts will handle commuted versions of the left hand
7547 // itself.
7548 SDValue CombinedShifts, W;
7549 SDValue R0 = RightHand.getOperand(i: 0);
7550 SDValue R1 = RightHand.getOperand(i: 1);
7551 if ((CombinedShifts = foldLogicOfShifts(N, LogicOp: LeftHand, ShiftOp: R0, DAG)))
7552 W = R1;
7553 else if ((CombinedShifts = foldLogicOfShifts(N, LogicOp: LeftHand, ShiftOp: R1, DAG)))
7554 W = R0;
7555 else
7556 return SDValue();
7557
7558 EVT VT = N->getValueType(ResNo: 0);
7559 SDLoc DL(N);
7560 return DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: CombinedShifts, N2: W);
7561}
7562
7563/// Fold "masked merge" expressions like `(m & x) | (~m & y)` and its DeMorgan
7564/// variant `(~m | x) & (m | y)` into the equivalent `((x ^ y) & m) ^ y)`
7565/// pattern. This is typically a better representation for targets without a
7566/// fused "and-not" operation.
7567static SDValue foldMaskedMerge(SDNode *Node, SelectionDAG &DAG,
7568 const TargetLowering &TLI, const SDLoc &DL) {
7569 // Note that masked-merge variants using XOR or ADD expressions are
7570 // normalized to OR by InstCombine so we only check for OR or AND.
7571 assert((Node->getOpcode() == ISD::OR || Node->getOpcode() == ISD::AND) &&
7572 "Must be called with ISD::OR or ISD::AND node");
7573
7574 // If the target supports and-not, don't fold this.
7575 if (TLI.hasAndNot(X: SDValue(Node, 0)))
7576 return SDValue();
7577
7578 SDValue M, X, Y;
7579
7580 if (sd_match(N: Node,
7581 P: m_Or(L: m_OneUse(P: m_And(L: m_OneUse(P: m_Not(V: m_Value(N&: M))), R: m_Value(N&: Y))),
7582 R: m_OneUse(P: m_And(L: m_Deferred(V&: M), R: m_Value(N&: X))))) ||
7583 sd_match(N: Node,
7584 P: m_And(L: m_OneUse(P: m_Or(L: m_OneUse(P: m_Not(V: m_Value(N&: M))), R: m_Value(N&: X))),
7585 R: m_OneUse(P: m_Or(L: m_Deferred(V&: M), R: m_Value(N&: Y)))))) {
7586 EVT VT = M.getValueType();
7587 SDValue Xor = DAG.getNode(Opcode: ISD::XOR, DL, VT, N1: X, N2: Y);
7588 SDValue And = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Xor, N2: M);
7589 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1: And, N2: Y);
7590 }
7591 return SDValue();
7592}
7593
7594SDValue DAGCombiner::visitAND(SDNode *N) {
7595 SDValue N0 = N->getOperand(Num: 0);
7596 SDValue N1 = N->getOperand(Num: 1);
7597 EVT VT = N1.getValueType();
7598 SDLoc DL(N);
7599
7600 // x & x --> x
7601 if (N0 == N1)
7602 return N0;
7603
7604 // fold (and c1, c2) -> c1&c2
7605 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::AND, DL, VT, Ops: {N0, N1}))
7606 return C;
7607
7608 // canonicalize constant to RHS
7609 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
7610 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
7611 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1, N2: N0);
7612
7613 if (areBitwiseNotOfEachother(Op0: N0, Op1: N1))
7614 return DAG.getConstant(Val: APInt::getZero(numBits: VT.getScalarSizeInBits()), DL, VT);
7615
7616 // fold vector ops
7617 if (VT.isVector()) {
7618 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
7619 return FoldedVOp;
7620
7621 // fold (and x, 0) -> 0, vector edition
7622 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
7623 // do not return N1, because undef node may exist in N1
7624 return DAG.getConstant(Val: APInt::getZero(numBits: N1.getScalarValueSizeInBits()), DL,
7625 VT: N1.getValueType());
7626
7627 // fold (and x, -1) -> x, vector edition
7628 if (ISD::isConstantSplatVectorAllOnes(N: N1.getNode()))
7629 return N0;
7630
7631 // fold (and (masked_load) (splat_vec (x, ...))) to zext_masked_load
7632 bool Frozen = N0.getOpcode() == ISD::FREEZE;
7633 auto *MLoad = dyn_cast<MaskedLoadSDNode>(Val: Frozen ? N0.getOperand(i: 0) : N0);
7634 ConstantSDNode *Splat = isConstOrConstSplat(N: N1, AllowUndefs: true, AllowTruncation: true);
7635 if (MLoad && MLoad->getExtensionType() == ISD::EXTLOAD && Splat) {
7636 EVT MemVT = MLoad->getMemoryVT();
7637 if (TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: VT, MemVT)) {
7638 // For this AND to be a zero extension of the masked load the elements
7639 // of the BuildVec must mask the bottom bits of the extended element
7640 // type
7641 if (Splat->getAPIntValue().isMask(numBits: MemVT.getScalarSizeInBits())) {
7642 SDValue NewLoad = DAG.getMaskedLoad(
7643 VT, dl: DL, Chain: MLoad->getChain(), Base: MLoad->getBasePtr(),
7644 Offset: MLoad->getOffset(), Mask: MLoad->getMask(), Src0: MLoad->getPassThru(), MemVT,
7645 MMO: MLoad->getMemOperand(), AM: MLoad->getAddressingMode(), ISD::ZEXTLOAD,
7646 IsExpanding: MLoad->isExpandingLoad());
7647 CombineTo(N, Res: Frozen ? N0 : NewLoad);
7648 CombineTo(N: MLoad, Res0: NewLoad, Res1: NewLoad.getValue(R: 1));
7649 return SDValue(N, 0);
7650 }
7651 }
7652 }
7653 }
7654
7655 // fold (and x, -1) -> x
7656 if (isAllOnesConstant(V: N1))
7657 return N0;
7658
7659 // if (and x, c) is known to be zero, return 0
7660 unsigned BitWidth = VT.getScalarSizeInBits();
7661 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
7662 if (N1C && DAG.MaskedValueIsZero(Op: SDValue(N, 0), Mask: APInt::getAllOnes(numBits: BitWidth)))
7663 return DAG.getConstant(Val: 0, DL, VT);
7664
7665 if (SDValue R = foldAndOrOfSETCC(LogicOp: N, DAG))
7666 return R;
7667
7668 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
7669 return NewSel;
7670
7671 // reassociate and
7672 if (SDValue RAND = reassociateOps(Opc: ISD::AND, DL, N0, N1, Flags: N->getFlags()))
7673 return RAND;
7674
7675 // Fold and(vecreduce(x), vecreduce(y)) -> vecreduce(and(x, y))
7676 if (SDValue SD =
7677 reassociateReduction(RedOpc: ISD::VECREDUCE_AND, Opc: ISD::AND, DL, VT, N0, N1))
7678 return SD;
7679
7680 // fold (and (or x, C), D) -> D if (C & D) == D
7681 auto MatchSubset = [](ConstantSDNode *LHS, ConstantSDNode *RHS) {
7682 return RHS->getAPIntValue().isSubsetOf(RHS: LHS->getAPIntValue());
7683 };
7684 if (N0.getOpcode() == ISD::OR &&
7685 ISD::matchBinaryPredicate(LHS: N0.getOperand(i: 1), RHS: N1, Match: MatchSubset))
7686 return N1;
7687
7688 if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
7689 SDValue N0Op0 = N0.getOperand(i: 0);
7690 EVT SrcVT = N0Op0.getValueType();
7691 unsigned SrcBitWidth = SrcVT.getScalarSizeInBits();
7692 APInt Mask = ~N1C->getAPIntValue();
7693 Mask = Mask.trunc(width: SrcBitWidth);
7694
7695 // fold (and (any_ext V), c) -> (zero_ext V) if 'and' only clears top bits.
7696 if (DAG.MaskedValueIsZero(Op: N0Op0, Mask))
7697 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: N0Op0);
7698
7699 // fold (and (any_ext V), c) -> (zero_ext (and (trunc V), c)) if profitable.
7700 if (N1C->getAPIntValue().countLeadingZeros() >= (BitWidth - SrcBitWidth) &&
7701 TLI.isTruncateFree(FromVT: VT, ToVT: SrcVT) && TLI.isZExtFree(FromTy: SrcVT, ToTy: VT) &&
7702 TLI.isTypeDesirableForOp(ISD::AND, VT: SrcVT) &&
7703 TLI.isNarrowingProfitable(N, SrcVT: VT, DestVT: SrcVT))
7704 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT,
7705 Operand: DAG.getNode(Opcode: ISD::AND, DL, VT: SrcVT, N1: N0Op0,
7706 N2: DAG.getZExtOrTrunc(Op: N1, DL, VT: SrcVT)));
7707 }
7708
7709 // fold (and (ext (and V, c1)), c2) -> (and (ext V), (and c1, (ext c2)))
7710 if (ISD::isExtOpcode(Opcode: N0.getOpcode())) {
7711 unsigned ExtOpc = N0.getOpcode();
7712 SDValue N0Op0 = N0.getOperand(i: 0);
7713 if (N0Op0.getOpcode() == ISD::AND &&
7714 (ExtOpc != ISD::ZERO_EXTEND || !TLI.isZExtFree(Val: N0Op0, VT2: VT)) &&
7715 N0->hasOneUse() && N0Op0->hasOneUse()) {
7716 if (SDValue NewExt = DAG.FoldConstantArithmetic(Opcode: ExtOpc, DL, VT,
7717 Ops: {N0Op0.getOperand(i: 1)})) {
7718 if (SDValue NewMask =
7719 DAG.FoldConstantArithmetic(Opcode: ISD::AND, DL, VT, Ops: {N1, NewExt})) {
7720 return DAG.getNode(Opcode: ISD::AND, DL, VT,
7721 N1: DAG.getNode(Opcode: ExtOpc, DL, VT, Operand: N0Op0.getOperand(i: 0)),
7722 N2: NewMask);
7723 }
7724 }
7725 }
7726 }
7727
7728 // similarly fold (and (X (load ([non_ext|any_ext|zero_ext] V))), c) ->
7729 // (X (load ([non_ext|zero_ext] V))) if 'and' only clears top bits which must
7730 // already be zero by virtue of the width of the base type of the load.
7731 //
7732 // the 'X' node here can either be nothing or an extract_vector_elt to catch
7733 // more cases.
7734 if ((N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
7735 N0.getValueSizeInBits() == N0.getOperand(i: 0).getScalarValueSizeInBits() &&
7736 N0.getOperand(i: 0).getOpcode() == ISD::LOAD &&
7737 N0.getOperand(i: 0).getResNo() == 0) ||
7738 (N0.getOpcode() == ISD::LOAD && N0.getResNo() == 0)) {
7739 auto *Load =
7740 cast<LoadSDNode>(Val: (N0.getOpcode() == ISD::LOAD) ? N0 : N0.getOperand(i: 0));
7741
7742 // Get the constant (if applicable) the zero'th operand is being ANDed with.
7743 // This can be a pure constant or a vector splat, in which case we treat the
7744 // vector as a scalar and use the splat value.
7745 APInt Constant = APInt::getZero(numBits: 1);
7746 if (const ConstantSDNode *C = isConstOrConstSplat(
7747 N: N1, /*AllowUndefs=*/false, /*AllowTruncation=*/true)) {
7748 Constant = C->getAPIntValue();
7749 } else if (BuildVectorSDNode *Vector = dyn_cast<BuildVectorSDNode>(Val&: N1)) {
7750 unsigned EltBitWidth = Vector->getValueType(ResNo: 0).getScalarSizeInBits();
7751 APInt SplatValue, SplatUndef;
7752 unsigned SplatBitSize;
7753 bool HasAnyUndefs;
7754 // Endianness should not matter here. Code below makes sure that we only
7755 // use the result if the SplatBitSize is a multiple of the vector element
7756 // size. And after that we AND all element sized parts of the splat
7757 // together. So the end result should be the same regardless of in which
7758 // order we do those operations.
7759 const bool IsBigEndian = false;
7760 bool IsSplat =
7761 Vector->isConstantSplat(SplatValue, SplatUndef, SplatBitSize,
7762 HasAnyUndefs, MinSplatBits: EltBitWidth, isBigEndian: IsBigEndian);
7763
7764 // Make sure that variable 'Constant' is only set if 'SplatBitSize' is a
7765 // multiple of 'BitWidth'. Otherwise, we could propagate a wrong value.
7766 if (IsSplat && (SplatBitSize % EltBitWidth) == 0) {
7767 // Undef bits can contribute to a possible optimisation if set, so
7768 // set them.
7769 SplatValue |= SplatUndef;
7770
7771 // The splat value may be something like "0x00FFFFFF", which means 0 for
7772 // the first vector value and FF for the rest, repeating. We need a mask
7773 // that will apply equally to all members of the vector, so AND all the
7774 // lanes of the constant together.
7775 Constant = APInt::getAllOnes(numBits: EltBitWidth);
7776 for (unsigned i = 0, n = (SplatBitSize / EltBitWidth); i < n; ++i)
7777 Constant &= SplatValue.extractBits(numBits: EltBitWidth, bitPosition: i * EltBitWidth);
7778 }
7779 }
7780
7781 // If we want to change an EXTLOAD to a ZEXTLOAD, ensure a ZEXTLOAD is
7782 // actually legal and isn't going to get expanded, else this is a false
7783 // optimisation.
7784 bool CanZextLoadProfitably = TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD,
7785 ValVT: Load->getValueType(ResNo: 0),
7786 MemVT: Load->getMemoryVT());
7787
7788 // Resize the constant to the same size as the original memory access before
7789 // extension. If it is still the AllOnesValue then this AND is completely
7790 // unneeded.
7791 Constant = Constant.zextOrTrunc(width: Load->getMemoryVT().getScalarSizeInBits());
7792
7793 bool B;
7794 switch (Load->getExtensionType()) {
7795 default: B = false; break;
7796 case ISD::EXTLOAD: B = CanZextLoadProfitably; break;
7797 case ISD::ZEXTLOAD:
7798 case ISD::NON_EXTLOAD: B = true; break;
7799 }
7800
7801 if (B && Constant.isAllOnes()) {
7802 // If the load type was an EXTLOAD, convert to ZEXTLOAD in order to
7803 // preserve semantics once we get rid of the AND.
7804 SDValue NewLoad(Load, 0);
7805
7806 // Fold the AND away. NewLoad may get replaced immediately.
7807 CombineTo(N, Res: (N0.getNode() == Load) ? NewLoad : N0);
7808
7809 if (Load->getExtensionType() == ISD::EXTLOAD) {
7810 NewLoad = DAG.getLoad(AM: Load->getAddressingMode(), ExtType: ISD::ZEXTLOAD,
7811 VT: Load->getValueType(ResNo: 0), dl: SDLoc(Load),
7812 Chain: Load->getChain(), Ptr: Load->getBasePtr(),
7813 Offset: Load->getOffset(), MemVT: Load->getMemoryVT(),
7814 MMO: Load->getMemOperand());
7815 // Replace uses of the EXTLOAD with the new ZEXTLOAD.
7816 if (Load->getNumValues() == 3) {
7817 // PRE/POST_INC loads have 3 values.
7818 SDValue To[] = { NewLoad.getValue(R: 0), NewLoad.getValue(R: 1),
7819 NewLoad.getValue(R: 2) };
7820 CombineTo(N: Load, To, NumTo: 3, AddTo: true);
7821 } else {
7822 CombineTo(N: Load, Res0: NewLoad.getValue(R: 0), Res1: NewLoad.getValue(R: 1));
7823 }
7824 }
7825
7826 return SDValue(N, 0); // Return N so it doesn't get rechecked!
7827 }
7828 }
7829
7830 // Try to convert a constant mask AND into a shuffle clear mask.
7831 if (VT.isVector())
7832 if (SDValue Shuffle = XformToShuffleWithZero(N))
7833 return Shuffle;
7834
7835 if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
7836 return Combined;
7837
7838 if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR && N0.hasOneUse() && N1C &&
7839 ISD::isExtOpcode(Opcode: N0.getOperand(i: 0).getOpcode())) {
7840 SDValue Ext = N0.getOperand(i: 0);
7841 EVT ExtVT = Ext->getValueType(ResNo: 0);
7842 SDValue Extendee = Ext->getOperand(Num: 0);
7843
7844 unsigned ScalarWidth = Extendee.getValueType().getScalarSizeInBits();
7845 if (N1C->getAPIntValue().isMask(numBits: ScalarWidth) &&
7846 (!LegalOperations || TLI.isOperationLegal(Op: ISD::ZERO_EXTEND, VT: ExtVT))) {
7847 // (and (extract_subvector (zext|anyext|sext v) _) iN_mask)
7848 // => (extract_subvector (iN_zeroext v))
7849 SDValue ZeroExtExtendee =
7850 DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: ExtVT, Operand: Extendee);
7851
7852 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT, N1: ZeroExtExtendee,
7853 N2: N0.getOperand(i: 1));
7854 }
7855 }
7856
7857 // fold (and (masked_gather x)) -> (zext_masked_gather x)
7858 if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(Val&: N0)) {
7859 EVT MemVT = GN0->getMemoryVT();
7860 EVT ScalarVT = MemVT.getScalarType();
7861
7862 if (SDValue(GN0, 0).hasOneUse() &&
7863 isConstantSplatVectorMaskForType(N: N1.getNode(), ScalarTy: ScalarVT) &&
7864 TLI.isVectorLoadExtDesirable(ExtVal: SDValue(N, 0))) {
7865 SDValue Ops[] = {GN0->getChain(), GN0->getPassThru(), GN0->getMask(),
7866 GN0->getBasePtr(), GN0->getIndex(), GN0->getScale()};
7867
7868 SDValue ZExtLoad = DAG.getMaskedGather(
7869 VTs: DAG.getVTList(VT1: VT, VT2: MVT::Other), MemVT, dl: DL, Ops, MMO: GN0->getMemOperand(),
7870 IndexType: GN0->getIndexType(), ExtTy: ISD::ZEXTLOAD);
7871
7872 CombineTo(N, Res: ZExtLoad);
7873 AddToWorklist(N: ZExtLoad.getNode());
7874 // Avoid recheck of N.
7875 return SDValue(N, 0);
7876 }
7877 }
7878
7879 // fold (and (load x), 255) -> (zextload x, i8)
7880 // fold (and (extload x, i16), 255) -> (zextload x, i8)
7881 if (N1C && N0.getOpcode() == ISD::LOAD && !VT.isVector())
7882 if (SDValue Res = reduceLoadWidth(N))
7883 return Res;
7884
7885 if (LegalTypes) {
7886 // Attempt to propagate the AND back up to the leaves which, if they're
7887 // loads, can be combined to narrow loads and the AND node can be removed.
7888 // Perform after legalization so that extend nodes will already be
7889 // combined into the loads.
7890 if (BackwardsPropagateMask(N))
7891 return SDValue(N, 0);
7892 }
7893
7894 if (SDValue Combined = visitANDLike(N0, N1, N))
7895 return Combined;
7896
7897 // Simplify: (and (op x...), (op y...)) -> (op (and x, y))
7898 if (N0.getOpcode() == N1.getOpcode())
7899 if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
7900 return V;
7901
7902 if (SDValue R = foldLogicOfShifts(N, LogicOp: N0, ShiftOp: N1, DAG))
7903 return R;
7904 if (SDValue R = foldLogicOfShifts(N, LogicOp: N1, ShiftOp: N0, DAG))
7905 return R;
7906
7907 // Fold (and X, (bswap (not Y))) -> (and X, (not (bswap Y)))
7908 // Fold (and X, (bitreverse (not Y))) -> (and X, (not (bitreverse Y)))
7909 SDValue X, Y, Z, NotY;
7910 for (unsigned Opc : {ISD::BSWAP, ISD::BITREVERSE})
7911 if (sd_match(N,
7912 P: m_And(L: m_Value(N&: X), R: m_OneUse(P: m_UnaryOp(Opc, Op: m_Value(N&: NotY))))) &&
7913 sd_match(N: NotY, P: m_Not(V: m_Value(N&: Y))) &&
7914 (TLI.hasAndNot(X: SDValue(N, 0)) || NotY->hasOneUse()))
7915 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: X,
7916 N2: DAG.getNOT(DL, Val: DAG.getNode(Opcode: Opc, DL, VT, Operand: Y), VT));
7917
7918 // Fold (and X, (rot (not Y), Z)) -> (and X, (not (rot Y, Z)))
7919 for (unsigned Opc : {ISD::ROTL, ISD::ROTR})
7920 if (sd_match(N, P: m_And(L: m_Value(N&: X),
7921 R: m_OneUse(P: m_BinOp(Opc, L: m_Value(N&: NotY), R: m_Value(N&: Z))))) &&
7922 sd_match(N: NotY, P: m_Not(V: m_Value(N&: Y))) &&
7923 (TLI.hasAndNot(X: SDValue(N, 0)) || NotY->hasOneUse()))
7924 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: X,
7925 N2: DAG.getNOT(DL, Val: DAG.getNode(Opcode: Opc, DL, VT, N1: Y, N2: Z), VT));
7926
7927 // Fold (and X, (add (not Y), Z)) -> (and X, (not (sub Y, Z)))
7928 // Fold (and X, (sub (not Y), Z)) -> (and X, (not (add Y, Z)))
7929 if (TLI.hasAndNot(X: SDValue(N, 0)))
7930 if (SDValue Folded = foldBitwiseOpWithNeg(N, DL, VT))
7931 return Folded;
7932
7933 // Fold (and (srl X, C), 1) -> (srl X, BW-1) for signbit extraction
7934 // If we are shifting down an extended sign bit, see if we can simplify
7935 // this to shifting the MSB directly to expose further simplifications.
7936 // This pattern often appears after sext_inreg legalization.
7937 APInt Amt;
7938 if (sd_match(N, P: m_And(L: m_Srl(L: m_Value(N&: X), R: m_ConstInt(V&: Amt)), R: m_One())) &&
7939 Amt.ult(RHS: BitWidth - 1) && Amt.uge(RHS: BitWidth - DAG.ComputeNumSignBits(Op: X)))
7940 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: X,
7941 N2: DAG.getShiftAmountConstant(Val: BitWidth - 1, VT, DL));
7942
7943 // Masking the negated extension of a boolean is just the zero-extended
7944 // boolean:
7945 // and (sub 0, zext(bool X)), 1 --> zext(bool X)
7946 // and (sub 0, sext(bool X)), 1 --> zext(bool X)
7947 //
7948 // Note: the SimplifyDemandedBits fold below can make an information-losing
7949 // transform, and then we have no way to find this better fold.
7950 if (sd_match(N, P: m_And(L: m_Sub(L: m_Zero(), R: m_Value(N&: X)), R: m_One()))) {
7951 if (X.getOpcode() == ISD::ZERO_EXTEND &&
7952 X.getOperand(i: 0).getScalarValueSizeInBits() == 1)
7953 return X;
7954 if (X.getOpcode() == ISD::SIGN_EXTEND &&
7955 X.getOperand(i: 0).getScalarValueSizeInBits() == 1)
7956 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: X.getOperand(i: 0));
7957 }
7958
7959 // fold (and (sign_extend_inreg x, i16 to i32), 1) -> (and x, 1)
7960 // fold (and (sra)) -> (and (srl)) when possible.
7961 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
7962 return SDValue(N, 0);
7963
7964 // fold (zext_inreg (extload x)) -> (zextload x)
7965 // fold (zext_inreg (sextload x)) -> (zextload x) iff load has one use
7966 if (ISD::isUNINDEXEDLoad(N: N0.getNode()) &&
7967 (ISD::isEXTLoad(N: N0.getNode()) ||
7968 (ISD::isSEXTLoad(N: N0.getNode()) && N0.hasOneUse()))) {
7969 auto *LN0 = cast<LoadSDNode>(Val&: N0);
7970 EVT MemVT = LN0->getMemoryVT();
7971 // If we zero all the possible extended bits, then we can turn this into
7972 // a zextload if we are running before legalize or the operation is legal.
7973 unsigned ExtBitSize = N1.getScalarValueSizeInBits();
7974 unsigned MemBitSize = MemVT.getScalarSizeInBits();
7975 APInt ExtBits = APInt::getHighBitsSet(numBits: ExtBitSize, hiBitsSet: ExtBitSize - MemBitSize);
7976 if (DAG.MaskedValueIsZero(Op: N1, Mask: ExtBits) &&
7977 ((!LegalOperations && LN0->isSimple()) ||
7978 TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: VT, MemVT))) {
7979 SDValue ExtLoad =
7980 DAG.getExtLoad(ExtType: ISD::ZEXTLOAD, dl: SDLoc(N0), VT, Chain: LN0->getChain(),
7981 Ptr: LN0->getBasePtr(), MemVT, MMO: LN0->getMemOperand());
7982 AddToWorklist(N);
7983 CombineTo(N: N0.getNode(), Res0: ExtLoad, Res1: ExtLoad.getValue(R: 1));
7984 return SDValue(N, 0); // Return N so it doesn't get rechecked!
7985 }
7986 }
7987
7988 // fold (and (or (srl N, 8), (shl N, 8)), 0xffff) -> (srl (bswap N), const)
7989 if (N1C && N1C->getAPIntValue() == 0xffff && N0.getOpcode() == ISD::OR) {
7990 if (SDValue BSwap = MatchBSwapHWordLow(N: N0.getNode(), N0: N0.getOperand(i: 0),
7991 N1: N0.getOperand(i: 1), DemandHighBits: false))
7992 return BSwap;
7993 }
7994
7995 if (SDValue Shifts = unfoldExtremeBitClearingToShifts(N))
7996 return Shifts;
7997
7998 if (SDValue V = combineShiftAnd1ToBitTest(And: N, DAG))
7999 return V;
8000
8001 // Recognize the following pattern:
8002 //
8003 // AndVT = (and (sign_extend NarrowVT to AndVT) #bitmask)
8004 //
8005 // where bitmask is a mask that clears the upper bits of AndVT. The
8006 // number of bits in bitmask must be a power of two.
8007 auto IsAndZeroExtMask = [](SDValue LHS, SDValue RHS) {
8008 if (LHS->getOpcode() != ISD::SIGN_EXTEND)
8009 return false;
8010
8011 auto *C = dyn_cast<ConstantSDNode>(Val&: RHS);
8012 if (!C)
8013 return false;
8014
8015 if (!C->getAPIntValue().isMask(
8016 numBits: LHS.getOperand(i: 0).getValueType().getFixedSizeInBits()))
8017 return false;
8018
8019 return true;
8020 };
8021
8022 // Replace (and (sign_extend ...) #bitmask) with (zero_extend ...).
8023 if (IsAndZeroExtMask(N0, N1))
8024 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: N0.getOperand(i: 0));
8025
8026 if (hasOperation(Opcode: ISD::USUBSAT, VT))
8027 if (SDValue V = foldAndToUsubsat(N, DAG, DL))
8028 return V;
8029
8030 // Postpone until legalization completed to avoid interference with bswap
8031 // folding
8032 if (LegalOperations || VT.isVector())
8033 if (SDValue R = foldLogicTreeOfShifts(N, LeftHand: N0, RightHand: N1, DAG))
8034 return R;
8035
8036 if (VT.isScalarInteger() && VT != MVT::i1)
8037 if (SDValue R = foldMaskedMerge(Node: N, DAG, TLI, DL))
8038 return R;
8039
8040 return SDValue();
8041}
8042
8043/// Match (a >> 8) | (a << 8) as (bswap a) >> 16.
8044SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
8045 bool DemandHighBits) {
8046 if (!LegalOperations)
8047 return SDValue();
8048
8049 EVT VT = N->getValueType(ResNo: 0);
8050 if (VT != MVT::i64 && VT != MVT::i32 && VT != MVT::i16)
8051 return SDValue();
8052 if (!TLI.isOperationLegalOrCustom(Op: ISD::BSWAP, VT))
8053 return SDValue();
8054
8055 // Recognize (and (shl a, 8), 0xff00), (and (srl a, 8), 0xff)
8056 bool LookPassAnd0 = false;
8057 bool LookPassAnd1 = false;
8058 if (N0.getOpcode() == ISD::AND && N0.getOperand(i: 0).getOpcode() == ISD::SRL)
8059 std::swap(a&: N0, b&: N1);
8060 if (N1.getOpcode() == ISD::AND && N1.getOperand(i: 0).getOpcode() == ISD::SHL)
8061 std::swap(a&: N0, b&: N1);
8062 if (N0.getOpcode() == ISD::AND) {
8063 if (!N0->hasOneUse())
8064 return SDValue();
8065 ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
8066 // Also handle 0xffff since the LHS is guaranteed to have zeros there.
8067 // This is needed for X86.
8068 if (!N01C || (N01C->getZExtValue() != 0xFF00 &&
8069 N01C->getZExtValue() != 0xFFFF))
8070 return SDValue();
8071 N0 = N0.getOperand(i: 0);
8072 LookPassAnd0 = true;
8073 }
8074
8075 if (N1.getOpcode() == ISD::AND) {
8076 if (!N1->hasOneUse())
8077 return SDValue();
8078 ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(Val: N1.getOperand(i: 1));
8079 if (!N11C || N11C->getZExtValue() != 0xFF)
8080 return SDValue();
8081 N1 = N1.getOperand(i: 0);
8082 LookPassAnd1 = true;
8083 }
8084
8085 if (N0.getOpcode() == ISD::SRL && N1.getOpcode() == ISD::SHL)
8086 std::swap(a&: N0, b&: N1);
8087 if (N0.getOpcode() != ISD::SHL || N1.getOpcode() != ISD::SRL)
8088 return SDValue();
8089 if (!N0->hasOneUse() || !N1->hasOneUse())
8090 return SDValue();
8091
8092 ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
8093 ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(Val: N1.getOperand(i: 1));
8094 if (!N01C || !N11C)
8095 return SDValue();
8096 if (N01C->getZExtValue() != 8 || N11C->getZExtValue() != 8)
8097 return SDValue();
8098
8099 // Look for (shl (and a, 0xff), 8), (srl (and a, 0xff00), 8)
8100 SDValue N00 = N0->getOperand(Num: 0);
8101 if (!LookPassAnd0 && N00.getOpcode() == ISD::AND) {
8102 if (!N00->hasOneUse())
8103 return SDValue();
8104 ConstantSDNode *N001C = dyn_cast<ConstantSDNode>(Val: N00.getOperand(i: 1));
8105 if (!N001C || N001C->getZExtValue() != 0xFF)
8106 return SDValue();
8107 N00 = N00.getOperand(i: 0);
8108 LookPassAnd0 = true;
8109 }
8110
8111 SDValue N10 = N1->getOperand(Num: 0);
8112 if (!LookPassAnd1 && N10.getOpcode() == ISD::AND) {
8113 if (!N10->hasOneUse())
8114 return SDValue();
8115 ConstantSDNode *N101C = dyn_cast<ConstantSDNode>(Val: N10.getOperand(i: 1));
8116 // Also allow 0xFFFF since the bits will be shifted out. This is needed
8117 // for X86.
8118 if (!N101C || (N101C->getZExtValue() != 0xFF00 &&
8119 N101C->getZExtValue() != 0xFFFF))
8120 return SDValue();
8121 N10 = N10.getOperand(i: 0);
8122 LookPassAnd1 = true;
8123 }
8124
8125 if (N00 != N10)
8126 return SDValue();
8127
8128 // Make sure everything beyond the low halfword gets set to zero since the SRL
8129 // 16 will clear the top bits.
8130 unsigned OpSizeInBits = VT.getSizeInBits();
8131 if (OpSizeInBits > 16) {
8132 // If the left-shift isn't masked out then the only way this is a bswap is
8133 // if all bits beyond the low 8 are 0. In that case the entire pattern
8134 // reduces to a left shift anyway: leave it for other parts of the combiner.
8135 if (DemandHighBits && !LookPassAnd0)
8136 return SDValue();
8137
8138 // However, if the right shift isn't masked out then it might be because
8139 // it's not needed. See if we can spot that too. If the high bits aren't
8140 // demanded, we only need bits 23:16 to be zero. Otherwise, we need all
8141 // upper bits to be zero.
8142 if (!LookPassAnd1) {
8143 unsigned HighBit = DemandHighBits ? OpSizeInBits : 24;
8144 if (!DAG.MaskedValueIsZero(Op: N10,
8145 Mask: APInt::getBitsSet(numBits: OpSizeInBits, loBit: 16, hiBit: HighBit)))
8146 return SDValue();
8147 }
8148 }
8149
8150 SDValue Res = DAG.getNode(Opcode: ISD::BSWAP, DL: SDLoc(N), VT, Operand: N00);
8151 if (OpSizeInBits > 16) {
8152 SDLoc DL(N);
8153 Res = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Res,
8154 N2: DAG.getShiftAmountConstant(Val: OpSizeInBits - 16, VT, DL));
8155 }
8156 return Res;
8157}
8158
8159/// Return true if the specified node is an element that makes up a 32-bit
8160/// packed halfword byteswap.
8161/// ((x & 0x000000ff) << 8) |
8162/// ((x & 0x0000ff00) >> 8) |
8163/// ((x & 0x00ff0000) << 8) |
8164/// ((x & 0xff000000) >> 8)
8165static bool isBSwapHWordElement(SDValue N, MutableArrayRef<SDNode *> Parts) {
8166 if (!N->hasOneUse())
8167 return false;
8168
8169 unsigned Opc = N.getOpcode();
8170 if (Opc != ISD::AND && Opc != ISD::SHL && Opc != ISD::SRL)
8171 return false;
8172
8173 SDValue N0 = N.getOperand(i: 0);
8174 unsigned Opc0 = N0.getOpcode();
8175 if (Opc0 != ISD::AND && Opc0 != ISD::SHL && Opc0 != ISD::SRL)
8176 return false;
8177
8178 ConstantSDNode *N1C = nullptr;
8179 // SHL or SRL: look upstream for AND mask operand
8180 if (Opc == ISD::AND)
8181 N1C = dyn_cast<ConstantSDNode>(Val: N.getOperand(i: 1));
8182 else if (Opc0 == ISD::AND)
8183 N1C = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
8184 if (!N1C)
8185 return false;
8186
8187 unsigned MaskByteOffset;
8188 switch (N1C->getZExtValue()) {
8189 default:
8190 return false;
8191 case 0xFF: MaskByteOffset = 0; break;
8192 case 0xFF00: MaskByteOffset = 1; break;
8193 case 0xFFFF:
8194 // In case demanded bits didn't clear the bits that will be shifted out.
8195 // This is needed for X86.
8196 if (Opc == ISD::SRL || (Opc == ISD::AND && Opc0 == ISD::SHL)) {
8197 MaskByteOffset = 1;
8198 break;
8199 }
8200 return false;
8201 case 0xFF0000: MaskByteOffset = 2; break;
8202 case 0xFF000000: MaskByteOffset = 3; break;
8203 }
8204
8205 // Look for (x & 0xff) << 8 as well as ((x << 8) & 0xff00).
8206 if (Opc == ISD::AND) {
8207 if (MaskByteOffset == 0 || MaskByteOffset == 2) {
8208 // (x >> 8) & 0xff
8209 // (x >> 8) & 0xff0000
8210 if (Opc0 != ISD::SRL)
8211 return false;
8212 ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
8213 if (!C || C->getZExtValue() != 8)
8214 return false;
8215 } else {
8216 // (x << 8) & 0xff00
8217 // (x << 8) & 0xff000000
8218 if (Opc0 != ISD::SHL)
8219 return false;
8220 ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
8221 if (!C || C->getZExtValue() != 8)
8222 return false;
8223 }
8224 } else if (Opc == ISD::SHL) {
8225 // (x & 0xff) << 8
8226 // (x & 0xff0000) << 8
8227 if (MaskByteOffset != 0 && MaskByteOffset != 2)
8228 return false;
8229 ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val: N.getOperand(i: 1));
8230 if (!C || C->getZExtValue() != 8)
8231 return false;
8232 } else { // Opc == ISD::SRL
8233 // (x & 0xff00) >> 8
8234 // (x & 0xff000000) >> 8
8235 if (MaskByteOffset != 1 && MaskByteOffset != 3)
8236 return false;
8237 ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val: N.getOperand(i: 1));
8238 if (!C || C->getZExtValue() != 8)
8239 return false;
8240 }
8241
8242 if (Parts[MaskByteOffset])
8243 return false;
8244
8245 Parts[MaskByteOffset] = N0.getOperand(i: 0).getNode();
8246 return true;
8247}
8248
8249// Match 2 elements of a packed halfword bswap.
8250static bool isBSwapHWordPair(SDValue N, MutableArrayRef<SDNode *> Parts) {
8251 if (N.getOpcode() == ISD::OR)
8252 return isBSwapHWordElement(N: N.getOperand(i: 0), Parts) &&
8253 isBSwapHWordElement(N: N.getOperand(i: 1), Parts);
8254
8255 if (N.getOpcode() == ISD::SRL && N.getOperand(i: 0).getOpcode() == ISD::BSWAP) {
8256 ConstantSDNode *C = isConstOrConstSplat(N: N.getOperand(i: 1));
8257 if (!C || C->getAPIntValue() != 16)
8258 return false;
8259 Parts[0] = Parts[1] = N.getOperand(i: 0).getOperand(i: 0).getNode();
8260 return true;
8261 }
8262
8263 return false;
8264}
8265
8266// Match this pattern:
8267// (or (and (shl (A, 8)), 0xff00ff00), (and (srl (A, 8)), 0x00ff00ff))
8268// And rewrite this to:
8269// (rotr (bswap A), 16)
8270static SDValue matchBSwapHWordOrAndAnd(const TargetLowering &TLI,
8271 SelectionDAG &DAG, SDNode *N, SDValue N0,
8272 SDValue N1, EVT VT) {
8273 assert(N->getOpcode() == ISD::OR && VT == MVT::i32 &&
8274 "MatchBSwapHWordOrAndAnd: expecting i32");
8275 if (!TLI.isOperationLegalOrCustom(Op: ISD::ROTR, VT))
8276 return SDValue();
8277 if (N0.getOpcode() != ISD::AND || N1.getOpcode() != ISD::AND)
8278 return SDValue();
8279 // TODO: this is too restrictive; lifting this restriction requires more tests
8280 if (!N0->hasOneUse() || !N1->hasOneUse())
8281 return SDValue();
8282 ConstantSDNode *Mask0 = isConstOrConstSplat(N: N0.getOperand(i: 1));
8283 ConstantSDNode *Mask1 = isConstOrConstSplat(N: N1.getOperand(i: 1));
8284 if (!Mask0 || !Mask1)
8285 return SDValue();
8286 if (Mask0->getAPIntValue() != 0xff00ff00 ||
8287 Mask1->getAPIntValue() != 0x00ff00ff)
8288 return SDValue();
8289 SDValue Shift0 = N0.getOperand(i: 0);
8290 SDValue Shift1 = N1.getOperand(i: 0);
8291 if (Shift0.getOpcode() != ISD::SHL || Shift1.getOpcode() != ISD::SRL)
8292 return SDValue();
8293 ConstantSDNode *ShiftAmt0 = isConstOrConstSplat(N: Shift0.getOperand(i: 1));
8294 ConstantSDNode *ShiftAmt1 = isConstOrConstSplat(N: Shift1.getOperand(i: 1));
8295 if (!ShiftAmt0 || !ShiftAmt1)
8296 return SDValue();
8297 if (ShiftAmt0->getAPIntValue() != 8 || ShiftAmt1->getAPIntValue() != 8)
8298 return SDValue();
8299 if (Shift0.getOperand(i: 0) != Shift1.getOperand(i: 0))
8300 return SDValue();
8301
8302 SDLoc DL(N);
8303 SDValue BSwap = DAG.getNode(Opcode: ISD::BSWAP, DL, VT, Operand: Shift0.getOperand(i: 0));
8304 SDValue ShAmt = DAG.getShiftAmountConstant(Val: 16, VT, DL);
8305 return DAG.getNode(Opcode: ISD::ROTR, DL, VT, N1: BSwap, N2: ShAmt);
8306}
8307
8308/// Match a 32-bit packed halfword bswap. That is
8309/// ((x & 0x000000ff) << 8) |
8310/// ((x & 0x0000ff00) >> 8) |
8311/// ((x & 0x00ff0000) << 8) |
8312/// ((x & 0xff000000) >> 8)
8313/// => (rotl (bswap x), 16)
8314SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
8315 if (!LegalOperations)
8316 return SDValue();
8317
8318 EVT VT = N->getValueType(ResNo: 0);
8319 if (VT != MVT::i32)
8320 return SDValue();
8321 if (!TLI.isOperationLegalOrCustom(Op: ISD::BSWAP, VT))
8322 return SDValue();
8323
8324 if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N0, N1, VT))
8325 return BSwap;
8326
8327 // Try again with commuted operands.
8328 if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N0: N1, N1: N0, VT))
8329 return BSwap;
8330
8331
8332 // Look for either
8333 // (or (bswaphpair), (bswaphpair))
8334 // (or (or (bswaphpair), (and)), (and))
8335 // (or (or (and), (bswaphpair)), (and))
8336 SDNode *Parts[4] = {};
8337
8338 if (isBSwapHWordPair(N: N0, Parts)) {
8339 // (or (or (and), (and)), (or (and), (and)))
8340 if (!isBSwapHWordPair(N: N1, Parts))
8341 return SDValue();
8342 } else if (N0.getOpcode() == ISD::OR) {
8343 // (or (or (or (and), (and)), (and)), (and))
8344 if (!isBSwapHWordElement(N: N1, Parts))
8345 return SDValue();
8346 SDValue N00 = N0.getOperand(i: 0);
8347 SDValue N01 = N0.getOperand(i: 1);
8348 if (!(isBSwapHWordElement(N: N01, Parts) && isBSwapHWordPair(N: N00, Parts)) &&
8349 !(isBSwapHWordElement(N: N00, Parts) && isBSwapHWordPair(N: N01, Parts)))
8350 return SDValue();
8351 } else {
8352 return SDValue();
8353 }
8354
8355 // Make sure the parts are all coming from the same node.
8356 if (Parts[0] != Parts[1] || Parts[0] != Parts[2] || Parts[0] != Parts[3])
8357 return SDValue();
8358
8359 SDLoc DL(N);
8360 SDValue BSwap = DAG.getNode(Opcode: ISD::BSWAP, DL, VT,
8361 Operand: SDValue(Parts[0], 0));
8362
8363 // Result of the bswap should be rotated by 16. If it's not legal, then
8364 // do (x << 16) | (x >> 16).
8365 SDValue ShAmt = DAG.getShiftAmountConstant(Val: 16, VT, DL);
8366 if (TLI.isOperationLegalOrCustom(Op: ISD::ROTL, VT))
8367 return DAG.getNode(Opcode: ISD::ROTL, DL, VT, N1: BSwap, N2: ShAmt);
8368 if (TLI.isOperationLegalOrCustom(Op: ISD::ROTR, VT))
8369 return DAG.getNode(Opcode: ISD::ROTR, DL, VT, N1: BSwap, N2: ShAmt);
8370 return DAG.getNode(Opcode: ISD::OR, DL, VT,
8371 N1: DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: BSwap, N2: ShAmt),
8372 N2: DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: BSwap, N2: ShAmt));
8373}
8374
8375/// This contains all DAGCombine rules which reduce two values combined by
8376/// an Or operation to a single value \see visitANDLike().
8377SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, const SDLoc &DL) {
8378 EVT VT = N1.getValueType();
8379
8380 // fold (or x, undef) -> -1
8381 if (!LegalOperations && (N0.isUndef() || N1.isUndef()))
8382 return DAG.getAllOnesConstant(DL, VT);
8383
8384 if (SDValue V = foldLogicOfSetCCs(IsAnd: false, N0, N1, DL))
8385 return V;
8386
8387 // (or (and X, C1), (and Y, C2)) -> (and (or X, Y), C3) if possible.
8388 if (N0.getOpcode() == ISD::AND && N1.getOpcode() == ISD::AND &&
8389 // Don't increase # computations.
8390 (N0->hasOneUse() || N1->hasOneUse())) {
8391 // We can only do this xform if we know that bits from X that are set in C2
8392 // but not in C1 are already zero. Likewise for Y.
8393 if (const ConstantSDNode *N0O1C =
8394 getAsNonOpaqueConstant(N: N0.getOperand(i: 1))) {
8395 if (const ConstantSDNode *N1O1C =
8396 getAsNonOpaqueConstant(N: N1.getOperand(i: 1))) {
8397 // We can only do this xform if we know that bits from X that are set in
8398 // C2 but not in C1 are already zero. Likewise for Y.
8399 const APInt &LHSMask = N0O1C->getAPIntValue();
8400 const APInt &RHSMask = N1O1C->getAPIntValue();
8401
8402 if (DAG.MaskedValueIsZero(Op: N0.getOperand(i: 0), Mask: RHSMask&~LHSMask) &&
8403 DAG.MaskedValueIsZero(Op: N1.getOperand(i: 0), Mask: LHSMask&~RHSMask)) {
8404 SDValue X = DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N0), VT,
8405 N1: N0.getOperand(i: 0), N2: N1.getOperand(i: 0));
8406 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: X,
8407 N2: DAG.getConstant(Val: LHSMask | RHSMask, DL, VT));
8408 }
8409 }
8410 }
8411 }
8412
8413 // (or (and X, M), (and X, N)) -> (and X, (or M, N))
8414 if (N0.getOpcode() == ISD::AND &&
8415 N1.getOpcode() == ISD::AND &&
8416 N0.getOperand(i: 0) == N1.getOperand(i: 0) &&
8417 // Don't increase # computations.
8418 (N0->hasOneUse() || N1->hasOneUse())) {
8419 SDValue X = DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N0), VT,
8420 N1: N0.getOperand(i: 1), N2: N1.getOperand(i: 1));
8421 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0.getOperand(i: 0), N2: X);
8422 }
8423
8424 return SDValue();
8425}
8426
8427/// OR combines for which the commuted variant will be tried as well.
8428static SDValue visitORCommutative(SelectionDAG &DAG, SDValue N0, SDValue N1,
8429 SDNode *N) {
8430 EVT VT = N0.getValueType();
8431 unsigned BW = VT.getScalarSizeInBits();
8432 SDLoc DL(N);
8433
8434 auto peekThroughResize = [](SDValue V) {
8435 if (V->getOpcode() == ISD::ZERO_EXTEND || V->getOpcode() == ISD::TRUNCATE)
8436 return V->getOperand(Num: 0);
8437 return V;
8438 };
8439
8440 SDValue N0Resized = peekThroughResize(N0);
8441 if (N0Resized.getOpcode() == ISD::AND) {
8442 SDValue N1Resized = peekThroughResize(N1);
8443 SDValue N00 = N0Resized.getOperand(i: 0);
8444 SDValue N01 = N0Resized.getOperand(i: 1);
8445
8446 // fold or (and x, y), x --> x
8447 if (N00 == N1Resized || N01 == N1Resized)
8448 return N1;
8449
8450 // fold (or (and X, (xor Y, -1)), Y) -> (or X, Y)
8451 // TODO: Set AllowUndefs = true.
8452 if (SDValue NotOperand = getBitwiseNotOperand(V: N01, Mask: N00,
8453 /* AllowUndefs */ false)) {
8454 if (peekThroughResize(NotOperand) == N1Resized)
8455 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: DAG.getZExtOrTrunc(Op: N00, DL, VT),
8456 N2: N1);
8457 }
8458
8459 // fold (or (and (xor Y, -1), X), Y) -> (or X, Y)
8460 if (SDValue NotOperand = getBitwiseNotOperand(V: N00, Mask: N01,
8461 /* AllowUndefs */ false)) {
8462 if (peekThroughResize(NotOperand) == N1Resized)
8463 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: DAG.getZExtOrTrunc(Op: N01, DL, VT),
8464 N2: N1);
8465 }
8466 }
8467
8468 SDValue X, Y;
8469
8470 // fold or (xor X, N1), N1 --> or X, N1
8471 if (sd_match(N: N0, P: m_Xor(L: m_Value(N&: X), R: m_Specific(N: N1))))
8472 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: X, N2: N1);
8473
8474 // fold or (xor x, y), (x and/or y) --> or x, y
8475 if (sd_match(N: N0, P: m_Xor(L: m_Value(N&: X), R: m_Value(N&: Y))) &&
8476 (sd_match(N: N1, P: m_And(L: m_Specific(N: X), R: m_Specific(N: Y))) ||
8477 sd_match(N: N1, P: m_Or(L: m_Specific(N: X), R: m_Specific(N: Y)))))
8478 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: X, N2: Y);
8479
8480 if (SDValue R = foldLogicOfShifts(N, LogicOp: N0, ShiftOp: N1, DAG))
8481 return R;
8482
8483 auto peekThroughZext = [](SDValue V) {
8484 if (V->getOpcode() == ISD::ZERO_EXTEND)
8485 return V->getOperand(Num: 0);
8486 return V;
8487 };
8488
8489 // (fshl X, ?, Y) | (shl X, Y) --> fshl X, ?, Y
8490 if (N0.getOpcode() == ISD::FSHL && N1.getOpcode() == ISD::SHL &&
8491 N0.getOperand(i: 0) == N1.getOperand(i: 0) &&
8492 peekThroughZext(N0.getOperand(i: 2)) == peekThroughZext(N1.getOperand(i: 1)))
8493 return N0;
8494
8495 // (fshr ?, X, Y) | (srl X, Y) --> fshr ?, X, Y
8496 if (N0.getOpcode() == ISD::FSHR && N1.getOpcode() == ISD::SRL &&
8497 N0.getOperand(i: 1) == N1.getOperand(i: 0) &&
8498 peekThroughZext(N0.getOperand(i: 2)) == peekThroughZext(N1.getOperand(i: 1)))
8499 return N0;
8500
8501 // Attempt to match a legalized build_pair-esque pattern:
8502 // or(shl(aext(Hi),BW/2),zext(Lo))
8503 SDValue Lo, Hi;
8504 if (sd_match(N: N0,
8505 P: m_OneUse(P: m_Shl(L: m_AnyExt(Op: m_Value(N&: Hi)), R: m_SpecificInt(V: BW / 2)))) &&
8506 sd_match(N: N1, P: m_ZExt(Op: m_Value(N&: Lo))) &&
8507 Lo.getScalarValueSizeInBits() == (BW / 2) &&
8508 Lo.getValueType() == Hi.getValueType()) {
8509 // Fold build_pair(not(Lo),not(Hi)) -> not(build_pair(Lo,Hi)).
8510 SDValue NotLo, NotHi;
8511 if (sd_match(N: Lo, P: m_OneUse(P: m_Not(V: m_Value(N&: NotLo)))) &&
8512 sd_match(N: Hi, P: m_OneUse(P: m_Not(V: m_Value(N&: NotHi))))) {
8513 Lo = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: NotLo);
8514 Hi = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT, Operand: NotHi);
8515 Hi = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Hi,
8516 N2: DAG.getShiftAmountConstant(Val: BW / 2, VT, DL));
8517 return DAG.getNOT(DL, Val: DAG.getNode(Opcode: ISD::OR, DL, VT, N1: Lo, N2: Hi), VT);
8518 }
8519 }
8520
8521 return SDValue();
8522}
8523
8524SDValue DAGCombiner::visitOR(SDNode *N) {
8525 SDValue N0 = N->getOperand(Num: 0);
8526 SDValue N1 = N->getOperand(Num: 1);
8527 EVT VT = N1.getValueType();
8528 SDLoc DL(N);
8529
8530 // x | x --> x
8531 if (N0 == N1)
8532 return N0;
8533
8534 // fold (or c1, c2) -> c1|c2
8535 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::OR, DL, VT, Ops: {N0, N1}))
8536 return C;
8537
8538 // canonicalize constant to RHS
8539 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
8540 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
8541 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1, N2: N0);
8542
8543 // fold vector ops
8544 if (VT.isVector()) {
8545 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
8546 return FoldedVOp;
8547
8548 // fold (or x, 0) -> x, vector edition
8549 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
8550 return N0;
8551
8552 // fold (or x, -1) -> -1, vector edition
8553 if (ISD::isConstantSplatVectorAllOnes(N: N1.getNode()))
8554 // do not return N1, because undef node may exist in N1
8555 return DAG.getAllOnesConstant(DL, VT: N1.getValueType());
8556
8557 // fold (or (shuf A, V_0, MA), (shuf B, V_0, MB)) -> (shuf A, B, Mask)
8558 // Do this only if the resulting type / shuffle is legal.
8559 auto *SV0 = dyn_cast<ShuffleVectorSDNode>(Val&: N0);
8560 auto *SV1 = dyn_cast<ShuffleVectorSDNode>(Val&: N1);
8561 if (SV0 && SV1 && TLI.isTypeLegal(VT)) {
8562 bool ZeroN00 = ISD::isBuildVectorAllZeros(N: N0.getOperand(i: 0).getNode());
8563 bool ZeroN01 = ISD::isBuildVectorAllZeros(N: N0.getOperand(i: 1).getNode());
8564 bool ZeroN10 = ISD::isBuildVectorAllZeros(N: N1.getOperand(i: 0).getNode());
8565 bool ZeroN11 = ISD::isBuildVectorAllZeros(N: N1.getOperand(i: 1).getNode());
8566 // Ensure both shuffles have a zero input.
8567 if ((ZeroN00 != ZeroN01) && (ZeroN10 != ZeroN11)) {
8568 assert((!ZeroN00 || !ZeroN01) && "Both inputs zero!");
8569 assert((!ZeroN10 || !ZeroN11) && "Both inputs zero!");
8570 bool CanFold = true;
8571 int NumElts = VT.getVectorNumElements();
8572 SmallVector<int, 4> Mask(NumElts, -1);
8573
8574 for (int i = 0; i != NumElts; ++i) {
8575 int M0 = SV0->getMaskElt(Idx: i);
8576 int M1 = SV1->getMaskElt(Idx: i);
8577
8578 // Determine if either index is pointing to a zero vector.
8579 bool M0Zero = M0 < 0 || (ZeroN00 == (M0 < NumElts));
8580 bool M1Zero = M1 < 0 || (ZeroN10 == (M1 < NumElts));
8581
8582 // If one element is zero and the otherside is undef, keep undef.
8583 // This also handles the case that both are undef.
8584 if ((M0Zero && M1 < 0) || (M1Zero && M0 < 0))
8585 continue;
8586
8587 // Make sure only one of the elements is zero.
8588 if (M0Zero == M1Zero) {
8589 CanFold = false;
8590 break;
8591 }
8592
8593 assert((M0 >= 0 || M1 >= 0) && "Undef index!");
8594
8595 // We have a zero and non-zero element. If the non-zero came from
8596 // SV0 make the index a LHS index. If it came from SV1, make it
8597 // a RHS index. We need to mod by NumElts because we don't care
8598 // which operand it came from in the original shuffles.
8599 Mask[i] = M1Zero ? M0 % NumElts : (M1 % NumElts) + NumElts;
8600 }
8601
8602 if (CanFold) {
8603 SDValue NewLHS = ZeroN00 ? N0.getOperand(i: 1) : N0.getOperand(i: 0);
8604 SDValue NewRHS = ZeroN10 ? N1.getOperand(i: 1) : N1.getOperand(i: 0);
8605 SDValue LegalShuffle =
8606 TLI.buildLegalVectorShuffle(VT, DL, N0: NewLHS, N1: NewRHS, Mask, DAG);
8607 if (LegalShuffle)
8608 return LegalShuffle;
8609 }
8610 }
8611 }
8612 }
8613
8614 // fold (or x, 0) -> x
8615 if (isNullConstant(V: N1))
8616 return N0;
8617
8618 // fold (or x, -1) -> -1
8619 if (isAllOnesConstant(V: N1))
8620 return N1;
8621
8622 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
8623 return NewSel;
8624
8625 // fold (or x, c) -> c iff (x & ~c) == 0
8626 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(Val&: N1);
8627 if (N1C && DAG.MaskedValueIsZero(Op: N0, Mask: ~N1C->getAPIntValue()))
8628 return N1;
8629
8630 if (SDValue R = foldAndOrOfSETCC(LogicOp: N, DAG))
8631 return R;
8632
8633 if (SDValue Combined = visitORLike(N0, N1, DL))
8634 return Combined;
8635
8636 if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
8637 return Combined;
8638
8639 // Recognize halfword bswaps as (bswap + rotl 16) or (bswap + shl 16)
8640 if (SDValue BSwap = MatchBSwapHWord(N, N0, N1))
8641 return BSwap;
8642 if (SDValue BSwap = MatchBSwapHWordLow(N, N0, N1))
8643 return BSwap;
8644
8645 // reassociate or
8646 if (SDValue ROR = reassociateOps(Opc: ISD::OR, DL, N0, N1, Flags: N->getFlags()))
8647 return ROR;
8648
8649 // Fold or(vecreduce(x), vecreduce(y)) -> vecreduce(or(x, y))
8650 if (SDValue SD =
8651 reassociateReduction(RedOpc: ISD::VECREDUCE_OR, Opc: ISD::OR, DL, VT, N0, N1))
8652 return SD;
8653
8654 // Canonicalize (or (and X, c1), c2) -> (and (or X, c2), c1|c2)
8655 // iff (c1 & c2) != 0 or c1/c2 are undef.
8656 auto MatchIntersect = [](ConstantSDNode *C1, ConstantSDNode *C2) {
8657 return !C1 || !C2 || C1->getAPIntValue().intersects(RHS: C2->getAPIntValue());
8658 };
8659 if (N0.getOpcode() == ISD::AND && N0->hasOneUse() &&
8660 ISD::matchBinaryPredicate(LHS: N0.getOperand(i: 1), RHS: N1, Match: MatchIntersect, AllowUndefs: true)) {
8661 if (SDValue COR = DAG.FoldConstantArithmetic(Opcode: ISD::OR, DL: SDLoc(N1), VT,
8662 Ops: {N1, N0.getOperand(i: 1)})) {
8663 SDValue IOR = DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N0), VT, N1: N0.getOperand(i: 0), N2: N1);
8664 AddToWorklist(N: IOR.getNode());
8665 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: COR, N2: IOR);
8666 }
8667 }
8668
8669 if (SDValue Combined = visitORCommutative(DAG, N0, N1, N))
8670 return Combined;
8671 if (SDValue Combined = visitORCommutative(DAG, N0: N1, N1: N0, N))
8672 return Combined;
8673
8674 // Simplify: (or (op x...), (op y...)) -> (op (or x, y))
8675 if (N0.getOpcode() == N1.getOpcode())
8676 if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
8677 return V;
8678
8679 // See if this is some rotate idiom.
8680 if (SDValue Rot = MatchRotate(LHS: N0, RHS: N1, DL, /*FromAdd=*/false))
8681 return Rot;
8682
8683 if (SDValue Load = MatchLoadCombine(N))
8684 return Load;
8685
8686 // Simplify the operands using demanded-bits information.
8687 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
8688 return SDValue(N, 0);
8689
8690 // If OR can be rewritten into ADD, try combines based on ADD.
8691 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::ADD, VT)) &&
8692 DAG.isADDLike(Op: SDValue(N, 0)))
8693 if (SDValue Combined = visitADDLike(N))
8694 return Combined;
8695
8696 // Postpone until legalization completed to avoid interference with bswap
8697 // folding
8698 if (LegalOperations || VT.isVector())
8699 if (SDValue R = foldLogicTreeOfShifts(N, LeftHand: N0, RightHand: N1, DAG))
8700 return R;
8701
8702 if (VT.isScalarInteger() && VT != MVT::i1)
8703 if (SDValue R = foldMaskedMerge(Node: N, DAG, TLI, DL))
8704 return R;
8705
8706 return SDValue();
8707}
8708
8709static SDValue stripConstantMask(const SelectionDAG &DAG, SDValue Op,
8710 SDValue &Mask) {
8711 if (Op.getOpcode() == ISD::AND &&
8712 DAG.isConstantIntBuildVectorOrConstantInt(N: Op.getOperand(i: 1))) {
8713 Mask = Op.getOperand(i: 1);
8714 return Op.getOperand(i: 0);
8715 }
8716 return Op;
8717}
8718
8719/// Match "(X shl/srl V1) & V2" where V2 may not be present.
8720static bool matchRotateHalf(const SelectionDAG &DAG, SDValue Op, SDValue &Shift,
8721 SDValue &Mask) {
8722 Op = stripConstantMask(DAG, Op, Mask);
8723 if (Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SHL) {
8724 Shift = Op;
8725 return true;
8726 }
8727 return false;
8728}
8729
8730/// Helper function for visitOR to extract the needed side of a rotate idiom
8731/// from a shl/srl/mul/udiv. This is meant to handle cases where
8732/// InstCombine merged some outside op with one of the shifts from
8733/// the rotate pattern.
8734/// \returns An empty \c SDValue if the needed shift couldn't be extracted.
8735/// Otherwise, returns an expansion of \p ExtractFrom based on the following
8736/// patterns:
8737///
8738/// (or (add v v) (shrl v bitwidth-1)):
8739/// expands (add v v) -> (shl v 1)
8740///
8741/// (or (mul v c0) (shrl (mul v c1) c2)):
8742/// expands (mul v c0) -> (shl (mul v c1) c3)
8743///
8744/// (or (udiv v c0) (shl (udiv v c1) c2)):
8745/// expands (udiv v c0) -> (shrl (udiv v c1) c3)
8746///
8747/// (or (shl v c0) (shrl (shl v c1) c2)):
8748/// expands (shl v c0) -> (shl (shl v c1) c3)
8749///
8750/// (or (shrl v c0) (shl (shrl v c1) c2)):
8751/// expands (shrl v c0) -> (shrl (shrl v c1) c3)
8752///
8753/// Such that in all cases, c3+c2==bitwidth(op v c1).
8754static SDValue extractShiftForRotate(SelectionDAG &DAG, SDValue OppShift,
8755 SDValue ExtractFrom, SDValue &Mask,
8756 const SDLoc &DL) {
8757 assert(OppShift && ExtractFrom && "Empty SDValue");
8758 if (OppShift.getOpcode() != ISD::SHL && OppShift.getOpcode() != ISD::SRL)
8759 return SDValue();
8760
8761 ExtractFrom = stripConstantMask(DAG, Op: ExtractFrom, Mask);
8762
8763 // Value and Type of the shift.
8764 SDValue OppShiftLHS = OppShift.getOperand(i: 0);
8765 EVT ShiftedVT = OppShiftLHS.getValueType();
8766
8767 // Amount of the existing shift.
8768 ConstantSDNode *OppShiftCst = isConstOrConstSplat(N: OppShift.getOperand(i: 1));
8769
8770 // (add v v) -> (shl v 1)
8771 // TODO: Should this be a general DAG canonicalization?
8772 if (OppShift.getOpcode() == ISD::SRL && OppShiftCst &&
8773 ExtractFrom.getOpcode() == ISD::ADD &&
8774 ExtractFrom.getOperand(i: 0) == ExtractFrom.getOperand(i: 1) &&
8775 ExtractFrom.getOperand(i: 0) == OppShiftLHS &&
8776 OppShiftCst->getAPIntValue() == ShiftedVT.getScalarSizeInBits() - 1)
8777 return DAG.getNode(Opcode: ISD::SHL, DL, VT: ShiftedVT, N1: OppShiftLHS,
8778 N2: DAG.getShiftAmountConstant(Val: 1, VT: ShiftedVT, DL));
8779
8780 // Preconditions:
8781 // (or (op0 v c0) (shiftl/r (op0 v c1) c2))
8782 //
8783 // Find opcode of the needed shift to be extracted from (op0 v c0).
8784 unsigned Opcode = ISD::DELETED_NODE;
8785 bool IsMulOrDiv = false;
8786 // Set Opcode and IsMulOrDiv if the extract opcode matches the needed shift
8787 // opcode or its arithmetic (mul or udiv) variant.
8788 auto SelectOpcode = [&](unsigned NeededShift, unsigned MulOrDivVariant) {
8789 IsMulOrDiv = ExtractFrom.getOpcode() == MulOrDivVariant;
8790 if (!IsMulOrDiv && ExtractFrom.getOpcode() != NeededShift)
8791 return false;
8792 Opcode = NeededShift;
8793 return true;
8794 };
8795 // op0 must be either the needed shift opcode or the mul/udiv equivalent
8796 // that the needed shift can be extracted from.
8797 if ((OppShift.getOpcode() != ISD::SRL || !SelectOpcode(ISD::SHL, ISD::MUL)) &&
8798 (OppShift.getOpcode() != ISD::SHL || !SelectOpcode(ISD::SRL, ISD::UDIV)))
8799 return SDValue();
8800
8801 // op0 must be the same opcode on both sides, have the same LHS argument,
8802 // and produce the same value type.
8803 if (OppShiftLHS.getOpcode() != ExtractFrom.getOpcode() ||
8804 OppShiftLHS.getOperand(i: 0) != ExtractFrom.getOperand(i: 0) ||
8805 ShiftedVT != ExtractFrom.getValueType())
8806 return SDValue();
8807
8808 // Constant mul/udiv/shift amount from the RHS of the shift's LHS op.
8809 ConstantSDNode *OppLHSCst = isConstOrConstSplat(N: OppShiftLHS.getOperand(i: 1));
8810 // Constant mul/udiv/shift amount from the RHS of the ExtractFrom op.
8811 ConstantSDNode *ExtractFromCst =
8812 isConstOrConstSplat(N: ExtractFrom.getOperand(i: 1));
8813 // TODO: We should be able to handle non-uniform constant vectors for these values
8814 // Check that we have constant values.
8815 if (!OppShiftCst || !OppShiftCst->getAPIntValue() ||
8816 !OppLHSCst || !OppLHSCst->getAPIntValue() ||
8817 !ExtractFromCst || !ExtractFromCst->getAPIntValue())
8818 return SDValue();
8819
8820 // Compute the shift amount we need to extract to complete the rotate.
8821 const unsigned VTWidth = ShiftedVT.getScalarSizeInBits();
8822 if (OppShiftCst->getAPIntValue().ugt(RHS: VTWidth))
8823 return SDValue();
8824 APInt NeededShiftAmt = VTWidth - OppShiftCst->getAPIntValue();
8825 // Normalize the bitwidth of the two mul/udiv/shift constant operands.
8826 APInt ExtractFromAmt = ExtractFromCst->getAPIntValue();
8827 APInt OppLHSAmt = OppLHSCst->getAPIntValue();
8828 zeroExtendToMatch(LHS&: ExtractFromAmt, RHS&: OppLHSAmt);
8829
8830 // Now try extract the needed shift from the ExtractFrom op and see if the
8831 // result matches up with the existing shift's LHS op.
8832 if (IsMulOrDiv) {
8833 // Op to extract from is a mul or udiv by a constant.
8834 // Check:
8835 // c2 / (1 << (bitwidth(op0 v c0) - c1)) == c0
8836 // c2 % (1 << (bitwidth(op0 v c0) - c1)) == 0
8837 const APInt ExtractDiv = APInt::getOneBitSet(numBits: ExtractFromAmt.getBitWidth(),
8838 BitNo: NeededShiftAmt.getZExtValue());
8839 APInt ResultAmt;
8840 APInt Rem;
8841 APInt::udivrem(LHS: ExtractFromAmt, RHS: ExtractDiv, Quotient&: ResultAmt, Remainder&: Rem);
8842 if (Rem != 0 || ResultAmt != OppLHSAmt)
8843 return SDValue();
8844 } else {
8845 // Op to extract from is a shift by a constant.
8846 // Check:
8847 // c2 - (bitwidth(op0 v c0) - c1) == c0
8848 if (OppLHSAmt != ExtractFromAmt - NeededShiftAmt.zextOrTrunc(
8849 width: ExtractFromAmt.getBitWidth()))
8850 return SDValue();
8851 }
8852
8853 // Return the expanded shift op that should allow a rotate to be formed.
8854 EVT ShiftVT = OppShift.getOperand(i: 1).getValueType();
8855 EVT ResVT = ExtractFrom.getValueType();
8856 SDValue NewShiftNode = DAG.getConstant(Val: NeededShiftAmt, DL, VT: ShiftVT);
8857 return DAG.getNode(Opcode, DL, VT: ResVT, N1: OppShiftLHS, N2: NewShiftNode);
8858}
8859
8860// Return true if we can prove that, whenever Neg and Pos are both in the
8861// range [0, EltSize), Neg == (Pos == 0 ? 0 : EltSize - Pos). This means that
8862// for two opposing shifts shift1 and shift2 and a value X with OpBits bits:
8863//
8864// (or (shift1 X, Neg), (shift2 X, Pos))
8865//
8866// reduces to a rotate in direction shift2 by Pos or (equivalently) a rotate
8867// in direction shift1 by Neg. The range [0, EltSize) means that we only need
8868// to consider shift amounts with defined behavior.
8869//
8870// The IsRotate flag should be set when the LHS of both shifts is the same.
8871// Otherwise if matching a general funnel shift, it should be clear.
8872static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
8873 SelectionDAG &DAG, bool IsRotate, bool FromAdd) {
8874 const auto &TLI = DAG.getTargetLoweringInfo();
8875 // If EltSize is a power of 2 then:
8876 //
8877 // (a) (Pos == 0 ? 0 : EltSize - Pos) == (EltSize - Pos) & (EltSize - 1)
8878 // (b) Neg == Neg & (EltSize - 1) whenever Neg is in [0, EltSize).
8879 //
8880 // So if EltSize is a power of 2 and Neg is (and Neg', EltSize-1), we check
8881 // for the stronger condition:
8882 //
8883 // Neg & (EltSize - 1) == (EltSize - Pos) & (EltSize - 1) [A]
8884 //
8885 // for all Neg and Pos. Since Neg & (EltSize - 1) == Neg' & (EltSize - 1)
8886 // we can just replace Neg with Neg' for the rest of the function.
8887 //
8888 // In other cases we check for the even stronger condition:
8889 //
8890 // Neg == EltSize - Pos [B]
8891 //
8892 // for all Neg and Pos. Note that the (or ...) then invokes undefined
8893 // behavior if Pos == 0 (and consequently Neg == EltSize).
8894 //
8895 // We could actually use [A] whenever EltSize is a power of 2, but the
8896 // only extra cases that it would match are those uninteresting ones
8897 // where Neg and Pos are never in range at the same time. E.g. for
8898 // EltSize == 32, using [A] would allow a Neg of the form (sub 64, Pos)
8899 // as well as (sub 32, Pos), but:
8900 //
8901 // (or (shift1 X, (sub 64, Pos)), (shift2 X, Pos))
8902 //
8903 // always invokes undefined behavior for 32-bit X.
8904 //
8905 // Below, Mask == EltSize - 1 when using [A] and is all-ones otherwise.
8906 // This allows us to peek through any operations that only affect Mask's
8907 // un-demanded bits.
8908 //
8909 // NOTE: We can only do this when matching operations which won't modify the
8910 // least Log2(EltSize) significant bits and not a general funnel shift.
8911 unsigned MaskLoBits = 0;
8912 if (IsRotate && !FromAdd && isPowerOf2_64(Value: EltSize)) {
8913 unsigned Bits = Log2_64(Value: EltSize);
8914 unsigned NegBits = Neg.getScalarValueSizeInBits();
8915 if (NegBits >= Bits) {
8916 APInt DemandedBits = APInt::getLowBitsSet(numBits: NegBits, loBitsSet: Bits);
8917 if (SDValue Inner =
8918 TLI.SimplifyMultipleUseDemandedBits(Op: Neg, DemandedBits, DAG)) {
8919 Neg = Inner;
8920 MaskLoBits = Bits;
8921 }
8922 }
8923 }
8924
8925 // Check whether Neg has the form (sub NegC, NegOp1) for some NegC and NegOp1.
8926 if (Neg.getOpcode() != ISD::SUB)
8927 return false;
8928 ConstantSDNode *NegC = isConstOrConstSplat(N: Neg.getOperand(i: 0));
8929 if (!NegC)
8930 return false;
8931 SDValue NegOp1 = Neg.getOperand(i: 1);
8932
8933 // On the RHS of [A], if Pos is the result of operation on Pos' that won't
8934 // affect Mask's demanded bits, just replace Pos with Pos'. These operations
8935 // are redundant for the purpose of the equality.
8936 if (MaskLoBits) {
8937 unsigned PosBits = Pos.getScalarValueSizeInBits();
8938 if (PosBits >= MaskLoBits) {
8939 APInt DemandedBits = APInt::getLowBitsSet(numBits: PosBits, loBitsSet: MaskLoBits);
8940 if (SDValue Inner =
8941 TLI.SimplifyMultipleUseDemandedBits(Op: Pos, DemandedBits, DAG)) {
8942 Pos = Inner;
8943 }
8944 }
8945 }
8946
8947 // The condition we need is now:
8948 //
8949 // (NegC - NegOp1) & Mask == (EltSize - Pos) & Mask
8950 //
8951 // If NegOp1 == Pos then we need:
8952 //
8953 // EltSize & Mask == NegC & Mask
8954 //
8955 // (because "x & Mask" is a truncation and distributes through subtraction).
8956 //
8957 // We also need to account for a potential truncation of NegOp1 if the amount
8958 // has already been legalized to a shift amount type.
8959 APInt Width;
8960 if ((Pos == NegOp1) ||
8961 (NegOp1.getOpcode() == ISD::TRUNCATE && Pos == NegOp1.getOperand(i: 0)))
8962 Width = NegC->getAPIntValue();
8963
8964 // Check for cases where Pos has the form (add NegOp1, PosC) for some PosC.
8965 // Then the condition we want to prove becomes:
8966 //
8967 // (NegC - NegOp1) & Mask == (EltSize - (NegOp1 + PosC)) & Mask
8968 //
8969 // which, again because "x & Mask" is a truncation, becomes:
8970 //
8971 // NegC & Mask == (EltSize - PosC) & Mask
8972 // EltSize & Mask == (NegC + PosC) & Mask
8973 else if (Pos.getOpcode() == ISD::ADD && Pos.getOperand(i: 0) == NegOp1) {
8974 if (ConstantSDNode *PosC = isConstOrConstSplat(N: Pos.getOperand(i: 1)))
8975 Width = PosC->getAPIntValue() + NegC->getAPIntValue();
8976 else
8977 return false;
8978 } else
8979 return false;
8980
8981 // Now we just need to check that EltSize & Mask == Width & Mask.
8982 if (MaskLoBits)
8983 // EltSize & Mask is 0 since Mask is EltSize - 1.
8984 return Width.getLoBits(numBits: MaskLoBits) == 0;
8985 return Width == EltSize;
8986}
8987
8988// A subroutine of MatchRotate used once we have found an OR of two opposite
8989// shifts of Shifted. If Neg == <operand size> - Pos then the OR reduces
8990// to both (PosOpcode Shifted, Pos) and (NegOpcode Shifted, Neg), with the
8991// former being preferred if supported. InnerPos and InnerNeg are Pos and
8992// Neg with outer conversions stripped away.
8993SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
8994 SDValue Neg, SDValue InnerPos,
8995 SDValue InnerNeg, bool FromAdd,
8996 bool HasPos, unsigned PosOpcode,
8997 unsigned NegOpcode, const SDLoc &DL) {
8998 // fold (or/add (shl x, (*ext y)),
8999 // (srl x, (*ext (sub 32, y)))) ->
9000 // (rotl x, y) or (rotr x, (sub 32, y))
9001 //
9002 // fold (or/add (shl x, (*ext (sub 32, y))),
9003 // (srl x, (*ext y))) ->
9004 // (rotr x, y) or (rotl x, (sub 32, y))
9005 EVT VT = Shifted.getValueType();
9006 if (matchRotateSub(Pos: InnerPos, Neg: InnerNeg, EltSize: VT.getScalarSizeInBits(), DAG,
9007 /*IsRotate*/ true, FromAdd))
9008 return DAG.getNode(Opcode: HasPos ? PosOpcode : NegOpcode, DL, VT, N1: Shifted,
9009 N2: HasPos ? Pos : Neg);
9010
9011 return SDValue();
9012}
9013
9014// A subroutine of MatchRotate used once we have found an OR of two opposite
9015// shifts of N0 + N1. If Neg == <operand size> - Pos then the OR reduces
9016// to both (PosOpcode N0, N1, Pos) and (NegOpcode N0, N1, Neg), with the
9017// former being preferred if supported. InnerPos and InnerNeg are Pos and
9018// Neg with outer conversions stripped away.
9019// TODO: Merge with MatchRotatePosNeg.
9020SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
9021 SDValue Neg, SDValue InnerPos,
9022 SDValue InnerNeg, bool FromAdd,
9023 bool HasPos, unsigned PosOpcode,
9024 unsigned NegOpcode, const SDLoc &DL) {
9025 EVT VT = N0.getValueType();
9026 unsigned EltBits = VT.getScalarSizeInBits();
9027
9028 // fold (or/add (shl x0, (*ext y)),
9029 // (srl x1, (*ext (sub 32, y)))) ->
9030 // (fshl x0, x1, y) or (fshr x0, x1, (sub 32, y))
9031 //
9032 // fold (or/add (shl x0, (*ext (sub 32, y))),
9033 // (srl x1, (*ext y))) ->
9034 // (fshr x0, x1, y) or (fshl x0, x1, (sub 32, y))
9035 if (matchRotateSub(Pos: InnerPos, Neg: InnerNeg, EltSize: EltBits, DAG, /*IsRotate*/ N0 == N1,
9036 FromAdd))
9037 return DAG.getNode(Opcode: HasPos ? PosOpcode : NegOpcode, DL, VT, N1: N0, N2: N1,
9038 N3: HasPos ? Pos : Neg);
9039
9040 // Matching the shift+xor cases, we can't easily use the xor'd shift amount
9041 // so for now just use the PosOpcode case if its legal.
9042 // TODO: When can we use the NegOpcode case?
9043 if (PosOpcode == ISD::FSHL && isPowerOf2_32(Value: EltBits)) {
9044 SDValue X;
9045 // fold (or/add (shl x0, y), (srl (srl x1, 1), (xor y, 31)))
9046 // -> (fshl x0, x1, y)
9047 if (sd_match(N: N1, P: m_Srl(L: m_Value(N&: X), R: m_One())) &&
9048 sd_match(N: InnerNeg,
9049 P: m_Xor(L: m_Specific(N: InnerPos), R: m_SpecificInt(V: EltBits - 1))) &&
9050 TLI.isOperationLegalOrCustom(Op: ISD::FSHL, VT)) {
9051 return DAG.getNode(Opcode: ISD::FSHL, DL, VT, N1: N0, N2: X, N3: Pos);
9052 }
9053
9054 // fold (or/add (shl (shl x0, 1), (xor y, 31)), (srl x1, y))
9055 // -> (fshr x0, x1, y)
9056 if (sd_match(N: N0, P: m_Shl(L: m_Value(N&: X), R: m_One())) &&
9057 sd_match(N: InnerPos,
9058 P: m_Xor(L: m_Specific(N: InnerNeg), R: m_SpecificInt(V: EltBits - 1))) &&
9059 TLI.isOperationLegalOrCustom(Op: ISD::FSHR, VT)) {
9060 return DAG.getNode(Opcode: ISD::FSHR, DL, VT, N1: X, N2: N1, N3: Neg);
9061 }
9062
9063 // fold (or/add (shl (add x0, x0), (xor y, 31)), (srl x1, y))
9064 // -> (fshr x0, x1, y)
9065 // TODO: Should add(x,x) -> shl(x,1) be a general DAG canonicalization?
9066 if (sd_match(N: N0, P: m_Add(L: m_Value(N&: X), R: m_Deferred(V&: X))) &&
9067 sd_match(N: InnerPos,
9068 P: m_Xor(L: m_Specific(N: InnerNeg), R: m_SpecificInt(V: EltBits - 1))) &&
9069 TLI.isOperationLegalOrCustom(Op: ISD::FSHR, VT)) {
9070 return DAG.getNode(Opcode: ISD::FSHR, DL, VT, N1: X, N2: N1, N3: Neg);
9071 }
9072 }
9073
9074 return SDValue();
9075}
9076
9077// MatchRotate - Handle an 'or' or 'add' of two operands. If this is one of the
9078// many idioms for rotate, and if the target supports rotation instructions,
9079// generate a rot[lr]. This also matches funnel shift patterns, similar to
9080// rotation but with different shifted sources.
9081SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL,
9082 bool FromAdd) {
9083 EVT VT = LHS.getValueType();
9084
9085 // The target must have at least one rotate/funnel flavor.
9086 // We still try to match rotate by constant pre-legalization.
9087 // TODO: Support pre-legalization funnel-shift by constant.
9088 bool HasROTL = hasOperation(Opcode: ISD::ROTL, VT);
9089 bool HasROTR = hasOperation(Opcode: ISD::ROTR, VT);
9090 bool HasFSHL = hasOperation(Opcode: ISD::FSHL, VT);
9091 bool HasFSHR = hasOperation(Opcode: ISD::FSHR, VT);
9092
9093 // If the type is going to be promoted and the target has enabled custom
9094 // lowering for rotate, allow matching rotate by non-constants. Only allow
9095 // this for scalar types.
9096 if (VT.isScalarInteger() && TLI.getTypeAction(Context&: *DAG.getContext(), VT) ==
9097 TargetLowering::TypePromoteInteger) {
9098 HasROTL |= TLI.getOperationAction(Op: ISD::ROTL, VT) == TargetLowering::Custom;
9099 HasROTR |= TLI.getOperationAction(Op: ISD::ROTR, VT) == TargetLowering::Custom;
9100 }
9101
9102 if (LegalOperations && !HasROTL && !HasROTR && !HasFSHL && !HasFSHR)
9103 return SDValue();
9104
9105 // Check for truncated rotate.
9106 if (LHS.getOpcode() == ISD::TRUNCATE && RHS.getOpcode() == ISD::TRUNCATE &&
9107 LHS.getOperand(i: 0).getValueType() == RHS.getOperand(i: 0).getValueType()) {
9108 assert(LHS.getValueType() == RHS.getValueType());
9109 if (SDValue Rot =
9110 MatchRotate(LHS: LHS.getOperand(i: 0), RHS: RHS.getOperand(i: 0), DL, FromAdd))
9111 return DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(LHS), VT: LHS.getValueType(), Operand: Rot);
9112 }
9113
9114 // Match "(X shl/srl V1) & V2" where V2 may not be present.
9115 SDValue LHSShift; // The shift.
9116 SDValue LHSMask; // AND value if any.
9117 matchRotateHalf(DAG, Op: LHS, Shift&: LHSShift, Mask&: LHSMask);
9118
9119 SDValue RHSShift; // The shift.
9120 SDValue RHSMask; // AND value if any.
9121 matchRotateHalf(DAG, Op: RHS, Shift&: RHSShift, Mask&: RHSMask);
9122
9123 // If neither side matched a rotate half, bail
9124 if (!LHSShift && !RHSShift)
9125 return SDValue();
9126
9127 // InstCombine may have combined a constant shl, srl, mul, or udiv with one
9128 // side of the rotate, so try to handle that here. In all cases we need to
9129 // pass the matched shift from the opposite side to compute the opcode and
9130 // needed shift amount to extract. We still want to do this if both sides
9131 // matched a rotate half because one half may be a potential overshift that
9132 // can be broken down (ie if InstCombine merged two shl or srl ops into a
9133 // single one).
9134
9135 // Have LHS side of the rotate, try to extract the needed shift from the RHS.
9136 if (LHSShift)
9137 if (SDValue NewRHSShift =
9138 extractShiftForRotate(DAG, OppShift: LHSShift, ExtractFrom: RHS, Mask&: RHSMask, DL))
9139 RHSShift = NewRHSShift;
9140 // Have RHS side of the rotate, try to extract the needed shift from the LHS.
9141 if (RHSShift)
9142 if (SDValue NewLHSShift =
9143 extractShiftForRotate(DAG, OppShift: RHSShift, ExtractFrom: LHS, Mask&: LHSMask, DL))
9144 LHSShift = NewLHSShift;
9145
9146 // If a side is still missing, nothing else we can do.
9147 if (!RHSShift || !LHSShift)
9148 return SDValue();
9149
9150 // At this point we've matched or extracted a shift op on each side.
9151
9152 if (LHSShift.getOpcode() == RHSShift.getOpcode())
9153 return SDValue(); // Shifts must disagree.
9154
9155 // Canonicalize shl to left side in a shl/srl pair.
9156 if (RHSShift.getOpcode() == ISD::SHL) {
9157 std::swap(a&: LHS, b&: RHS);
9158 std::swap(a&: LHSShift, b&: RHSShift);
9159 std::swap(a&: LHSMask, b&: RHSMask);
9160 }
9161
9162 // Something has gone wrong - we've lost the shl/srl pair - bail.
9163 if (LHSShift.getOpcode() != ISD::SHL || RHSShift.getOpcode() != ISD::SRL)
9164 return SDValue();
9165
9166 unsigned EltSizeInBits = VT.getScalarSizeInBits();
9167 SDValue LHSShiftArg = LHSShift.getOperand(i: 0);
9168 SDValue LHSShiftAmt = LHSShift.getOperand(i: 1);
9169 SDValue RHSShiftArg = RHSShift.getOperand(i: 0);
9170 SDValue RHSShiftAmt = RHSShift.getOperand(i: 1);
9171
9172 auto MatchRotateSum = [EltSizeInBits](ConstantSDNode *LHS,
9173 ConstantSDNode *RHS) {
9174 return (LHS->getAPIntValue() + RHS->getAPIntValue()) == EltSizeInBits;
9175 };
9176
9177 auto ApplyMasks = [&](SDValue Res) {
9178 // If there is an AND of either shifted operand, apply it to the result.
9179 if (LHSMask.getNode() || RHSMask.getNode()) {
9180 SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
9181 SDValue Mask = AllOnes;
9182
9183 if (LHSMask.getNode()) {
9184 SDValue RHSBits = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: AllOnes, N2: RHSShiftAmt);
9185 Mask = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Mask,
9186 N2: DAG.getNode(Opcode: ISD::OR, DL, VT, N1: LHSMask, N2: RHSBits));
9187 }
9188 if (RHSMask.getNode()) {
9189 SDValue LHSBits = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: AllOnes, N2: LHSShiftAmt);
9190 Mask = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Mask,
9191 N2: DAG.getNode(Opcode: ISD::OR, DL, VT, N1: RHSMask, N2: LHSBits));
9192 }
9193
9194 Res = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Res, N2: Mask);
9195 }
9196
9197 return Res;
9198 };
9199
9200 // TODO: Support pre-legalization funnel-shift by constant.
9201 bool IsRotate = LHSShiftArg == RHSShiftArg;
9202 if (!IsRotate && !(HasFSHL || HasFSHR)) {
9203 if (TLI.isTypeLegal(VT) && LHS.hasOneUse() && RHS.hasOneUse() &&
9204 ISD::matchBinaryPredicate(LHS: LHSShiftAmt, RHS: RHSShiftAmt, Match: MatchRotateSum)) {
9205 // Look for a disguised rotate by constant.
9206 // The common shifted operand X may be hidden inside another 'or'.
9207 SDValue X, Y;
9208 auto matchOr = [&X, &Y](SDValue Or, SDValue CommonOp) {
9209 if (!Or.hasOneUse() || Or.getOpcode() != ISD::OR)
9210 return false;
9211 if (CommonOp == Or.getOperand(i: 0)) {
9212 X = CommonOp;
9213 Y = Or.getOperand(i: 1);
9214 return true;
9215 }
9216 if (CommonOp == Or.getOperand(i: 1)) {
9217 X = CommonOp;
9218 Y = Or.getOperand(i: 0);
9219 return true;
9220 }
9221 return false;
9222 };
9223
9224 SDValue Res;
9225 if (matchOr(LHSShiftArg, RHSShiftArg)) {
9226 // (shl (X | Y), C1) | (srl X, C2) --> (rotl X, C1) | (shl Y, C1)
9227 SDValue RotX = DAG.getNode(Opcode: ISD::ROTL, DL, VT, N1: X, N2: LHSShiftAmt);
9228 SDValue ShlY = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Y, N2: LHSShiftAmt);
9229 Res = DAG.getNode(Opcode: ISD::OR, DL, VT, N1: RotX, N2: ShlY);
9230 } else if (matchOr(RHSShiftArg, LHSShiftArg)) {
9231 // (shl X, C1) | (srl (X | Y), C2) --> (rotl X, C1) | (srl Y, C2)
9232 SDValue RotX = DAG.getNode(Opcode: ISD::ROTL, DL, VT, N1: X, N2: LHSShiftAmt);
9233 SDValue SrlY = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Y, N2: RHSShiftAmt);
9234 Res = DAG.getNode(Opcode: ISD::OR, DL, VT, N1: RotX, N2: SrlY);
9235 } else {
9236 return SDValue();
9237 }
9238
9239 return ApplyMasks(Res);
9240 }
9241
9242 return SDValue(); // Requires funnel shift support.
9243 }
9244
9245 // fold (or/add (shl x, C1), (srl x, C2)) -> (rotl x, C1)
9246 // fold (or/add (shl x, C1), (srl x, C2)) -> (rotr x, C2)
9247 // fold (or/add (shl x, C1), (srl y, C2)) -> (fshl x, y, C1)
9248 // fold (or/add (shl x, C1), (srl y, C2)) -> (fshr x, y, C2)
9249 // iff C1+C2 == EltSizeInBits
9250 if (ISD::matchBinaryPredicate(LHS: LHSShiftAmt, RHS: RHSShiftAmt, Match: MatchRotateSum)) {
9251 SDValue Res;
9252 if (IsRotate && (HasROTL || HasROTR || !(HasFSHL || HasFSHR))) {
9253 bool UseROTL = !LegalOperations || HasROTL;
9254 Res = DAG.getNode(Opcode: UseROTL ? ISD::ROTL : ISD::ROTR, DL, VT, N1: LHSShiftArg,
9255 N2: UseROTL ? LHSShiftAmt : RHSShiftAmt);
9256 } else {
9257 bool UseFSHL = !LegalOperations || HasFSHL;
9258 Res = DAG.getNode(Opcode: UseFSHL ? ISD::FSHL : ISD::FSHR, DL, VT, N1: LHSShiftArg,
9259 N2: RHSShiftArg, N3: UseFSHL ? LHSShiftAmt : RHSShiftAmt);
9260 }
9261
9262 return ApplyMasks(Res);
9263 }
9264
9265 // Even pre-legalization, we can't easily rotate/funnel-shift by a variable
9266 // shift.
9267 if (!HasROTL && !HasROTR && !HasFSHL && !HasFSHR)
9268 return SDValue();
9269
9270 // If there is a mask here, and we have a variable shift, we can't be sure
9271 // that we're masking out the right stuff.
9272 if (LHSMask.getNode() || RHSMask.getNode())
9273 return SDValue();
9274
9275 // If the shift amount is sign/zext/any-extended just peel it off.
9276 SDValue LExtOp0 = LHSShiftAmt;
9277 SDValue RExtOp0 = RHSShiftAmt;
9278 if ((LHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
9279 LHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
9280 LHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
9281 LHSShiftAmt.getOpcode() == ISD::TRUNCATE) &&
9282 (RHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
9283 RHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
9284 RHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
9285 RHSShiftAmt.getOpcode() == ISD::TRUNCATE)) {
9286 LExtOp0 = LHSShiftAmt.getOperand(i: 0);
9287 RExtOp0 = RHSShiftAmt.getOperand(i: 0);
9288 }
9289
9290 if (IsRotate && (HasROTL || HasROTR)) {
9291 if (SDValue TryL = MatchRotatePosNeg(Shifted: LHSShiftArg, Pos: LHSShiftAmt, Neg: RHSShiftAmt,
9292 InnerPos: LExtOp0, InnerNeg: RExtOp0, FromAdd, HasPos: HasROTL,
9293 PosOpcode: ISD::ROTL, NegOpcode: ISD::ROTR, DL))
9294 return TryL;
9295
9296 if (SDValue TryR = MatchRotatePosNeg(Shifted: RHSShiftArg, Pos: RHSShiftAmt, Neg: LHSShiftAmt,
9297 InnerPos: RExtOp0, InnerNeg: LExtOp0, FromAdd, HasPos: HasROTR,
9298 PosOpcode: ISD::ROTR, NegOpcode: ISD::ROTL, DL))
9299 return TryR;
9300 }
9301
9302 if (SDValue TryL = MatchFunnelPosNeg(N0: LHSShiftArg, N1: RHSShiftArg, Pos: LHSShiftAmt,
9303 Neg: RHSShiftAmt, InnerPos: LExtOp0, InnerNeg: RExtOp0, FromAdd,
9304 HasPos: HasFSHL, PosOpcode: ISD::FSHL, NegOpcode: ISD::FSHR, DL))
9305 return TryL;
9306
9307 if (SDValue TryR = MatchFunnelPosNeg(N0: LHSShiftArg, N1: RHSShiftArg, Pos: RHSShiftAmt,
9308 Neg: LHSShiftAmt, InnerPos: RExtOp0, InnerNeg: LExtOp0, FromAdd,
9309 HasPos: HasFSHR, PosOpcode: ISD::FSHR, NegOpcode: ISD::FSHL, DL))
9310 return TryR;
9311
9312 return SDValue();
9313}
9314
9315/// Recursively traverses the expression calculating the origin of the requested
9316/// byte of the given value. Returns std::nullopt if the provider can't be
9317/// calculated.
9318///
9319/// For all the values except the root of the expression, we verify that the
9320/// value has exactly one use and if not then return std::nullopt. This way if
9321/// the origin of the byte is returned it's guaranteed that the values which
9322/// contribute to the byte are not used outside of this expression.
9323
9324/// However, there is a special case when dealing with vector loads -- we allow
9325/// more than one use if the load is a vector type. Since the values that
9326/// contribute to the byte ultimately come from the ExtractVectorElements of the
9327/// Load, we don't care if the Load has uses other than ExtractVectorElements,
9328/// because those operations are independent from the pattern to be combined.
9329/// For vector loads, we simply care that the ByteProviders are adjacent
9330/// positions of the same vector, and their index matches the byte that is being
9331/// provided. This is captured by the \p VectorIndex algorithm. \p VectorIndex
9332/// is the index used in an ExtractVectorElement, and \p StartingIndex is the
9333/// byte position we are trying to provide for the LoadCombine. If these do
9334/// not match, then we can not combine the vector loads. \p Index uses the
9335/// byte position we are trying to provide for and is matched against the
9336/// shl and load size. The \p Index algorithm ensures the requested byte is
9337/// provided for by the pattern, and the pattern does not over provide bytes.
9338///
9339///
9340/// The supported LoadCombine pattern for vector loads is as follows
9341/// or
9342/// / \
9343/// or shl
9344/// / \ |
9345/// or shl zext
9346/// / \ | |
9347/// shl zext zext EVE*
9348/// | | | |
9349/// zext EVE* EVE* LOAD
9350/// | | |
9351/// EVE* LOAD LOAD
9352/// |
9353/// LOAD
9354///
9355/// *ExtractVectorElement
9356using SDByteProvider = ByteProvider<SDNode *>;
9357
9358static std::optional<SDByteProvider>
9359calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
9360 std::optional<uint64_t> VectorIndex,
9361 unsigned StartingIndex = 0) {
9362
9363 // Typical i64 by i8 pattern requires recursion up to 8 calls depth
9364 if (Depth == 10)
9365 return std::nullopt;
9366
9367 // Only allow multiple uses if the instruction is a vector load (in which
9368 // case we will use the load for every ExtractVectorElement)
9369 if (Depth && !Op.hasOneUse() &&
9370 (Op.getOpcode() != ISD::LOAD || !Op.getValueType().isVector()))
9371 return std::nullopt;
9372
9373 // Fail to combine if we have encountered anything but a LOAD after handling
9374 // an ExtractVectorElement.
9375 if (Op.getOpcode() != ISD::LOAD && VectorIndex.has_value())
9376 return std::nullopt;
9377
9378 unsigned BitWidth = Op.getScalarValueSizeInBits();
9379 if (BitWidth % 8 != 0)
9380 return std::nullopt;
9381 unsigned ByteWidth = BitWidth / 8;
9382 assert(Index < ByteWidth && "invalid index requested");
9383 (void) ByteWidth;
9384
9385 switch (Op.getOpcode()) {
9386 case ISD::OR: {
9387 auto LHS =
9388 calculateByteProvider(Op: Op->getOperand(Num: 0), Index, Depth: Depth + 1, VectorIndex);
9389 if (!LHS)
9390 return std::nullopt;
9391 auto RHS =
9392 calculateByteProvider(Op: Op->getOperand(Num: 1), Index, Depth: Depth + 1, VectorIndex);
9393 if (!RHS)
9394 return std::nullopt;
9395
9396 if (LHS->isConstantZero())
9397 return RHS;
9398 if (RHS->isConstantZero())
9399 return LHS;
9400 return std::nullopt;
9401 }
9402 case ISD::SHL: {
9403 auto ShiftOp = dyn_cast<ConstantSDNode>(Val: Op->getOperand(Num: 1));
9404 if (!ShiftOp)
9405 return std::nullopt;
9406
9407 uint64_t BitShift = ShiftOp->getZExtValue();
9408
9409 if (BitShift % 8 != 0)
9410 return std::nullopt;
9411 uint64_t ByteShift = BitShift / 8;
9412
9413 // If we are shifting by an amount greater than the index we are trying to
9414 // provide, then do not provide anything. Otherwise, subtract the index by
9415 // the amount we shifted by.
9416 return Index < ByteShift
9417 ? SDByteProvider::getConstantZero()
9418 : calculateByteProvider(Op: Op->getOperand(Num: 0), Index: Index - ByteShift,
9419 Depth: Depth + 1, VectorIndex, StartingIndex: Index);
9420 }
9421 case ISD::ANY_EXTEND:
9422 case ISD::SIGN_EXTEND:
9423 case ISD::ZERO_EXTEND: {
9424 SDValue NarrowOp = Op->getOperand(Num: 0);
9425 unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
9426 if (NarrowBitWidth % 8 != 0)
9427 return std::nullopt;
9428 uint64_t NarrowByteWidth = NarrowBitWidth / 8;
9429
9430 if (Index >= NarrowByteWidth)
9431 return Op.getOpcode() == ISD::ZERO_EXTEND
9432 ? std::optional<SDByteProvider>(
9433 SDByteProvider::getConstantZero())
9434 : std::nullopt;
9435 return calculateByteProvider(Op: NarrowOp, Index, Depth: Depth + 1, VectorIndex,
9436 StartingIndex);
9437 }
9438 case ISD::BSWAP:
9439 return calculateByteProvider(Op: Op->getOperand(Num: 0), Index: ByteWidth - Index - 1,
9440 Depth: Depth + 1, VectorIndex, StartingIndex);
9441 case ISD::EXTRACT_VECTOR_ELT: {
9442 auto OffsetOp = dyn_cast<ConstantSDNode>(Val: Op->getOperand(Num: 1));
9443 if (!OffsetOp)
9444 return std::nullopt;
9445
9446 VectorIndex = OffsetOp->getZExtValue();
9447
9448 SDValue NarrowOp = Op->getOperand(Num: 0);
9449 unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
9450 if (NarrowBitWidth % 8 != 0)
9451 return std::nullopt;
9452 uint64_t NarrowByteWidth = NarrowBitWidth / 8;
9453 // EXTRACT_VECTOR_ELT can extend the element type to the width of the return
9454 // type, leaving the high bits undefined.
9455 if (Index >= NarrowByteWidth)
9456 return std::nullopt;
9457
9458 // Check to see if the position of the element in the vector corresponds
9459 // with the byte we are trying to provide for. In the case of a vector of
9460 // i8, this simply means the VectorIndex == StartingIndex. For non i8 cases,
9461 // the element will provide a range of bytes. For example, if we have a
9462 // vector of i16s, each element provides two bytes (V[1] provides byte 2 and
9463 // 3).
9464 if (*VectorIndex * NarrowByteWidth > StartingIndex)
9465 return std::nullopt;
9466 if ((*VectorIndex + 1) * NarrowByteWidth <= StartingIndex)
9467 return std::nullopt;
9468
9469 return calculateByteProvider(Op: Op->getOperand(Num: 0), Index, Depth: Depth + 1,
9470 VectorIndex, StartingIndex);
9471 }
9472 case ISD::LOAD: {
9473 auto L = cast<LoadSDNode>(Val: Op.getNode());
9474 if (!L->isSimple() || L->isIndexed())
9475 return std::nullopt;
9476
9477 unsigned NarrowBitWidth = L->getMemoryVT().getScalarSizeInBits();
9478 if (NarrowBitWidth % 8 != 0)
9479 return std::nullopt;
9480 uint64_t NarrowByteWidth = NarrowBitWidth / 8;
9481
9482 // If the width of the load does not reach byte we are trying to provide for
9483 // and it is not a ZEXTLOAD, then the load does not provide for the byte in
9484 // question
9485 if (Index >= NarrowByteWidth)
9486 return L->getExtensionType() == ISD::ZEXTLOAD
9487 ? std::optional<SDByteProvider>(
9488 SDByteProvider::getConstantZero())
9489 : std::nullopt;
9490
9491 unsigned BPVectorIndex = VectorIndex.value_or(u: 0U);
9492 return SDByteProvider::getSrc(Val: L, ByteOffset: Index, VectorOffset: BPVectorIndex);
9493 }
9494 }
9495
9496 return std::nullopt;
9497}
9498
9499static unsigned littleEndianByteAt(unsigned BW, unsigned i) {
9500 return i;
9501}
9502
9503static unsigned bigEndianByteAt(unsigned BW, unsigned i) {
9504 return BW - i - 1;
9505}
9506
9507// Check if the bytes offsets we are looking at match with either big or
9508// little endian value loaded. Return true for big endian, false for little
9509// endian, and std::nullopt if match failed.
9510static std::optional<bool> isBigEndian(ArrayRef<int64_t> ByteOffsets,
9511 int64_t FirstOffset) {
9512 // The endian can be decided only when it is 2 bytes at least.
9513 unsigned Width = ByteOffsets.size();
9514 if (Width < 2)
9515 return std::nullopt;
9516
9517 bool BigEndian = true, LittleEndian = true;
9518 for (unsigned i = 0; i < Width; i++) {
9519 int64_t CurrentByteOffset = ByteOffsets[i] - FirstOffset;
9520 LittleEndian &= CurrentByteOffset == littleEndianByteAt(BW: Width, i);
9521 BigEndian &= CurrentByteOffset == bigEndianByteAt(BW: Width, i);
9522 if (!BigEndian && !LittleEndian)
9523 return std::nullopt;
9524 }
9525
9526 assert((BigEndian != LittleEndian) && "It should be either big endian or"
9527 "little endian");
9528 return BigEndian;
9529}
9530
9531// Look through one layer of truncate or extend.
9532static SDValue stripTruncAndExt(SDValue Value) {
9533 switch (Value.getOpcode()) {
9534 case ISD::TRUNCATE:
9535 case ISD::ZERO_EXTEND:
9536 case ISD::SIGN_EXTEND:
9537 case ISD::ANY_EXTEND:
9538 return Value.getOperand(i: 0);
9539 }
9540 return SDValue();
9541}
9542
9543/// Match a pattern where a wide type scalar value is stored by several narrow
9544/// stores. Fold it into a single store or a BSWAP and a store if the targets
9545/// supports it.
9546///
9547/// Assuming little endian target:
9548/// i8 *p = ...
9549/// i32 val = ...
9550/// p[0] = (val >> 0) & 0xFF;
9551/// p[1] = (val >> 8) & 0xFF;
9552/// p[2] = (val >> 16) & 0xFF;
9553/// p[3] = (val >> 24) & 0xFF;
9554/// =>
9555/// *((i32)p) = val;
9556///
9557/// i8 *p = ...
9558/// i32 val = ...
9559/// p[0] = (val >> 24) & 0xFF;
9560/// p[1] = (val >> 16) & 0xFF;
9561/// p[2] = (val >> 8) & 0xFF;
9562/// p[3] = (val >> 0) & 0xFF;
9563/// =>
9564/// *((i32)p) = BSWAP(val);
9565SDValue DAGCombiner::mergeTruncStores(StoreSDNode *N) {
9566 // The matching looks for "store (trunc x)" patterns that appear early but are
9567 // likely to be replaced by truncating store nodes during combining.
9568 // TODO: If there is evidence that running this later would help, this
9569 // limitation could be removed. Legality checks may need to be added
9570 // for the created store and optional bswap/rotate.
9571 if (LegalOperations || OptLevel == CodeGenOptLevel::None)
9572 return SDValue();
9573
9574 // We only handle merging simple stores of 1-4 bytes.
9575 // TODO: Allow unordered atomics when wider type is legal (see D66309)
9576 EVT MemVT = N->getMemoryVT();
9577 if (!(MemVT == MVT::i8 || MemVT == MVT::i16 || MemVT == MVT::i32) ||
9578 !N->isSimple() || N->isIndexed())
9579 return SDValue();
9580
9581 // Collect all of the stores in the chain, upto the maximum store width (i64).
9582 SDValue Chain = N->getChain();
9583 SmallVector<StoreSDNode *, 8> Stores = {N};
9584 unsigned NarrowNumBits = MemVT.getScalarSizeInBits();
9585 unsigned MaxWideNumBits = 64;
9586 unsigned MaxStores = MaxWideNumBits / NarrowNumBits;
9587 while (auto *Store = dyn_cast<StoreSDNode>(Val&: Chain)) {
9588 // All stores must be the same size to ensure that we are writing all of the
9589 // bytes in the wide value.
9590 // This store should have exactly one use as a chain operand for another
9591 // store in the merging set. If there are other chain uses, then the
9592 // transform may not be safe because order of loads/stores outside of this
9593 // set may not be preserved.
9594 // TODO: We could allow multiple sizes by tracking each stored byte.
9595 if (Store->getMemoryVT() != MemVT || !Store->isSimple() ||
9596 Store->isIndexed() || !Store->hasOneUse())
9597 return SDValue();
9598 Stores.push_back(Elt: Store);
9599 Chain = Store->getChain();
9600 if (MaxStores < Stores.size())
9601 return SDValue();
9602 }
9603 // There is no reason to continue if we do not have at least a pair of stores.
9604 if (Stores.size() < 2)
9605 return SDValue();
9606
9607 // Handle simple types only.
9608 LLVMContext &Context = *DAG.getContext();
9609 unsigned NumStores = Stores.size();
9610 unsigned WideNumBits = NumStores * NarrowNumBits;
9611 if (WideNumBits != 16 && WideNumBits != 32 && WideNumBits != 64)
9612 return SDValue();
9613
9614 // Check if all bytes of the source value that we are looking at are stored
9615 // to the same base address. Collect offsets from Base address into OffsetMap.
9616 SDValue SourceValue;
9617 SmallVector<int64_t, 8> OffsetMap(NumStores, INT64_MAX);
9618 int64_t FirstOffset = INT64_MAX;
9619 StoreSDNode *FirstStore = nullptr;
9620 std::optional<BaseIndexOffset> Base;
9621 for (auto *Store : Stores) {
9622 // All the stores store different parts of the CombinedValue. A truncate is
9623 // required to get the partial value.
9624 SDValue Trunc = Store->getValue();
9625 if (Trunc.getOpcode() != ISD::TRUNCATE)
9626 return SDValue();
9627 // Other than the first/last part, a shift operation is required to get the
9628 // offset.
9629 int64_t Offset = 0;
9630 SDValue WideVal = Trunc.getOperand(i: 0);
9631 if ((WideVal.getOpcode() == ISD::SRL || WideVal.getOpcode() == ISD::SRA) &&
9632 isa<ConstantSDNode>(Val: WideVal.getOperand(i: 1))) {
9633 // The shift amount must be a constant multiple of the narrow type.
9634 // It is translated to the offset address in the wide source value "y".
9635 //
9636 // x = srl y, ShiftAmtC
9637 // i8 z = trunc x
9638 // store z, ...
9639 uint64_t ShiftAmtC = WideVal.getConstantOperandVal(i: 1);
9640 if (ShiftAmtC % NarrowNumBits != 0)
9641 return SDValue();
9642
9643 // Make sure we aren't reading bits that are shifted in.
9644 if (ShiftAmtC > WideVal.getScalarValueSizeInBits() - NarrowNumBits)
9645 return SDValue();
9646
9647 Offset = ShiftAmtC / NarrowNumBits;
9648 WideVal = WideVal.getOperand(i: 0);
9649 }
9650
9651 // Stores must share the same source value with different offsets.
9652 if (!SourceValue)
9653 SourceValue = WideVal;
9654 else if (SourceValue != WideVal) {
9655 // Truncate and extends can be stripped to see if the values are related.
9656 if (stripTruncAndExt(Value: SourceValue) != WideVal &&
9657 stripTruncAndExt(Value: WideVal) != SourceValue)
9658 return SDValue();
9659
9660 if (WideVal.getScalarValueSizeInBits() >
9661 SourceValue.getScalarValueSizeInBits())
9662 SourceValue = WideVal;
9663
9664 // Give up if the source value type is smaller than the store size.
9665 if (SourceValue.getScalarValueSizeInBits() < WideNumBits)
9666 return SDValue();
9667 }
9668
9669 // Stores must share the same base address.
9670 BaseIndexOffset Ptr = BaseIndexOffset::match(N: Store, DAG);
9671 int64_t ByteOffsetFromBase = 0;
9672 if (!Base)
9673 Base = Ptr;
9674 else if (!Base->equalBaseIndex(Other: Ptr, DAG, Off&: ByteOffsetFromBase))
9675 return SDValue();
9676
9677 // Remember the first store.
9678 if (ByteOffsetFromBase < FirstOffset) {
9679 FirstStore = Store;
9680 FirstOffset = ByteOffsetFromBase;
9681 }
9682 // Map the offset in the store and the offset in the combined value, and
9683 // early return if it has been set before.
9684 if (Offset < 0 || Offset >= NumStores || OffsetMap[Offset] != INT64_MAX)
9685 return SDValue();
9686 OffsetMap[Offset] = ByteOffsetFromBase;
9687 }
9688
9689 EVT WideVT = EVT::getIntegerVT(Context, BitWidth: WideNumBits);
9690
9691 assert(FirstOffset != INT64_MAX && "First byte offset must be set");
9692 assert(FirstStore && "First store must be set");
9693
9694 // Check that a store of the wide type is both allowed and fast on the target
9695 const DataLayout &Layout = DAG.getDataLayout();
9696 unsigned Fast = 0;
9697 bool Allowed = TLI.allowsMemoryAccess(Context, DL: Layout, VT: WideVT,
9698 MMO: *FirstStore->getMemOperand(), Fast: &Fast);
9699 if (!Allowed || !Fast)
9700 return SDValue();
9701
9702 // Check if the pieces of the value are going to the expected places in memory
9703 // to merge the stores.
9704 auto checkOffsets = [&](bool MatchLittleEndian) {
9705 if (MatchLittleEndian) {
9706 for (unsigned i = 0; i != NumStores; ++i)
9707 if (OffsetMap[i] != i * (NarrowNumBits / 8) + FirstOffset)
9708 return false;
9709 } else { // MatchBigEndian by reversing loop counter.
9710 for (unsigned i = 0, j = NumStores - 1; i != NumStores; ++i, --j)
9711 if (OffsetMap[j] != i * (NarrowNumBits / 8) + FirstOffset)
9712 return false;
9713 }
9714 return true;
9715 };
9716
9717 // Check if the offsets line up for the native data layout of this target.
9718 bool NeedBswap = false;
9719 bool NeedRotate = false;
9720 if (!checkOffsets(Layout.isLittleEndian())) {
9721 // Special-case: check if byte offsets line up for the opposite endian.
9722 if (NarrowNumBits == 8 && checkOffsets(Layout.isBigEndian()))
9723 NeedBswap = true;
9724 else if (NumStores == 2 && checkOffsets(Layout.isBigEndian()))
9725 NeedRotate = true;
9726 else
9727 return SDValue();
9728 }
9729
9730 SDLoc DL(N);
9731 if (WideVT != SourceValue.getValueType()) {
9732 assert(SourceValue.getValueType().getScalarSizeInBits() > WideNumBits &&
9733 "Unexpected store value to merge");
9734 SourceValue = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: WideVT, Operand: SourceValue);
9735 }
9736
9737 // Before legalize we can introduce illegal bswaps/rotates which will be later
9738 // converted to an explicit bswap sequence. This way we end up with a single
9739 // store and byte shuffling instead of several stores and byte shuffling.
9740 if (NeedBswap) {
9741 SourceValue = DAG.getNode(Opcode: ISD::BSWAP, DL, VT: WideVT, Operand: SourceValue);
9742 } else if (NeedRotate) {
9743 assert(WideNumBits % 2 == 0 && "Unexpected type for rotate");
9744 SDValue RotAmt = DAG.getConstant(Val: WideNumBits / 2, DL, VT: WideVT);
9745 SourceValue = DAG.getNode(Opcode: ISD::ROTR, DL, VT: WideVT, N1: SourceValue, N2: RotAmt);
9746 }
9747
9748 SDValue NewStore =
9749 DAG.getStore(Chain, dl: DL, Val: SourceValue, Ptr: FirstStore->getBasePtr(),
9750 PtrInfo: FirstStore->getPointerInfo(), Alignment: FirstStore->getAlign());
9751
9752 // Rely on other DAG combine rules to remove the other individual stores.
9753 DAG.ReplaceAllUsesWith(From: N, To: NewStore.getNode());
9754 return NewStore;
9755}
9756
9757/// Match a pattern where a wide type scalar value is loaded by several narrow
9758/// loads and combined by shifts and ors. Fold it into a single load or a load
9759/// and a BSWAP if the targets supports it.
9760///
9761/// Assuming little endian target:
9762/// i8 *a = ...
9763/// i32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24)
9764/// =>
9765/// i32 val = *((i32)a)
9766///
9767/// i8 *a = ...
9768/// i32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]
9769/// =>
9770/// i32 val = BSWAP(*((i32)a))
9771///
9772/// TODO: This rule matches complex patterns with OR node roots and doesn't
9773/// interact well with the worklist mechanism. When a part of the pattern is
9774/// updated (e.g. one of the loads) its direct users are put into the worklist,
9775/// but the root node of the pattern which triggers the load combine is not
9776/// necessarily a direct user of the changed node. For example, once the address
9777/// of t28 load is reassociated load combine won't be triggered:
9778/// t25: i32 = add t4, Constant:i32<2>
9779/// t26: i64 = sign_extend t25
9780/// t27: i64 = add t2, t26
9781/// t28: i8,ch = load<LD1[%tmp9]> t0, t27, undef:i64
9782/// t29: i32 = zero_extend t28
9783/// t32: i32 = shl t29, Constant:i8<8>
9784/// t33: i32 = or t23, t32
9785/// As a possible fix visitLoad can check if the load can be a part of a load
9786/// combine pattern and add corresponding OR roots to the worklist.
9787SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
9788 assert(N->getOpcode() == ISD::OR &&
9789 "Can only match load combining against OR nodes");
9790
9791 // Handles simple types only
9792 EVT VT = N->getValueType(ResNo: 0);
9793 if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
9794 return SDValue();
9795 unsigned ByteWidth = VT.getSizeInBits() / 8;
9796
9797 bool IsBigEndianTarget = DAG.getDataLayout().isBigEndian();
9798 auto MemoryByteOffset = [&](SDByteProvider P) {
9799 assert(P.hasSrc() && "Must be a memory byte provider");
9800 auto *Load = cast<LoadSDNode>(Val: P.Src.value());
9801
9802 unsigned LoadBitWidth = Load->getMemoryVT().getScalarSizeInBits();
9803
9804 assert(LoadBitWidth % 8 == 0 &&
9805 "can only analyze providers for individual bytes not bit");
9806 unsigned LoadByteWidth = LoadBitWidth / 8;
9807 return IsBigEndianTarget ? bigEndianByteAt(BW: LoadByteWidth, i: P.DestOffset)
9808 : littleEndianByteAt(BW: LoadByteWidth, i: P.DestOffset);
9809 };
9810
9811 std::optional<BaseIndexOffset> Base;
9812 SDValue Chain;
9813
9814 SmallPtrSet<LoadSDNode *, 8> Loads;
9815 std::optional<SDByteProvider> FirstByteProvider;
9816 int64_t FirstOffset = INT64_MAX;
9817
9818 // Check if all the bytes of the OR we are looking at are loaded from the same
9819 // base address. Collect bytes offsets from Base address in ByteOffsets.
9820 SmallVector<int64_t, 8> ByteOffsets(ByteWidth);
9821 unsigned ZeroExtendedBytes = 0;
9822 for (int i = ByteWidth - 1; i >= 0; --i) {
9823 auto P =
9824 calculateByteProvider(Op: SDValue(N, 0), Index: i, Depth: 0, /*VectorIndex*/ std::nullopt,
9825 /*StartingIndex*/ i);
9826 if (!P)
9827 return SDValue();
9828
9829 if (P->isConstantZero()) {
9830 // It's OK for the N most significant bytes to be 0, we can just
9831 // zero-extend the load.
9832 if (++ZeroExtendedBytes != (ByteWidth - static_cast<unsigned>(i)))
9833 return SDValue();
9834 continue;
9835 }
9836 assert(P->hasSrc() && "provenance should either be memory or zero");
9837 auto *L = cast<LoadSDNode>(Val: P->Src.value());
9838
9839 // All loads must share the same chain
9840 SDValue LChain = L->getChain();
9841 if (!Chain)
9842 Chain = LChain;
9843 else if (Chain != LChain)
9844 return SDValue();
9845
9846 // Loads must share the same base address
9847 BaseIndexOffset Ptr = BaseIndexOffset::match(N: L, DAG);
9848 int64_t ByteOffsetFromBase = 0;
9849
9850 // For vector loads, the expected load combine pattern will have an
9851 // ExtractElement for each index in the vector. While each of these
9852 // ExtractElements will be accessing the same base address as determined
9853 // by the load instruction, the actual bytes they interact with will differ
9854 // due to different ExtractElement indices. To accurately determine the
9855 // byte position of an ExtractElement, we offset the base load ptr with
9856 // the index multiplied by the byte size of each element in the vector.
9857 if (L->getMemoryVT().isVector()) {
9858 unsigned LoadWidthInBit = L->getMemoryVT().getScalarSizeInBits();
9859 if (LoadWidthInBit % 8 != 0)
9860 return SDValue();
9861 unsigned ByteOffsetFromVector = P->SrcOffset * LoadWidthInBit / 8;
9862 Ptr.addToOffset(VectorOff: ByteOffsetFromVector);
9863 }
9864
9865 if (!Base)
9866 Base = Ptr;
9867
9868 else if (!Base->equalBaseIndex(Other: Ptr, DAG, Off&: ByteOffsetFromBase))
9869 return SDValue();
9870
9871 // Calculate the offset of the current byte from the base address
9872 ByteOffsetFromBase += MemoryByteOffset(*P);
9873 ByteOffsets[i] = ByteOffsetFromBase;
9874
9875 // Remember the first byte load
9876 if (ByteOffsetFromBase < FirstOffset) {
9877 FirstByteProvider = P;
9878 FirstOffset = ByteOffsetFromBase;
9879 }
9880
9881 Loads.insert(Ptr: L);
9882 }
9883
9884 assert(!Loads.empty() && "All the bytes of the value must be loaded from "
9885 "memory, so there must be at least one load which produces the value");
9886 assert(Base && "Base address of the accessed memory location must be set");
9887 assert(FirstOffset != INT64_MAX && "First byte offset must be set");
9888
9889 bool NeedsZext = ZeroExtendedBytes > 0;
9890
9891 EVT MemVT =
9892 EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: (ByteWidth - ZeroExtendedBytes) * 8);
9893
9894 if (!MemVT.isSimple())
9895 return SDValue();
9896
9897 // Before legalize we can introduce too wide illegal loads which will be later
9898 // split into legal sized loads. This enables us to combine i64 load by i8
9899 // patterns to a couple of i32 loads on 32 bit targets.
9900 if (LegalOperations &&
9901 !TLI.isLoadExtLegal(ExtType: NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, ValVT: VT,
9902 MemVT))
9903 return SDValue();
9904
9905 // Check if the bytes of the OR we are looking at match with either big or
9906 // little endian value load
9907 std::optional<bool> IsBigEndian = isBigEndian(
9908 ByteOffsets: ArrayRef(ByteOffsets).drop_back(N: ZeroExtendedBytes), FirstOffset);
9909 if (!IsBigEndian)
9910 return SDValue();
9911
9912 assert(FirstByteProvider && "must be set");
9913
9914 // Ensure that the first byte is loaded from zero offset of the first load.
9915 // So the combined value can be loaded from the first load address.
9916 if (MemoryByteOffset(*FirstByteProvider) != 0)
9917 return SDValue();
9918 auto *FirstLoad = cast<LoadSDNode>(Val: FirstByteProvider->Src.value());
9919
9920 // The node we are looking at matches with the pattern, check if we can
9921 // replace it with a single (possibly zero-extended) load and bswap + shift if
9922 // needed.
9923
9924 // If the load needs byte swap check if the target supports it
9925 bool NeedsBswap = IsBigEndianTarget != *IsBigEndian;
9926
9927 // Before legalize we can introduce illegal bswaps which will be later
9928 // converted to an explicit bswap sequence. This way we end up with a single
9929 // load and byte shuffling instead of several loads and byte shuffling.
9930 // We do not introduce illegal bswaps when zero-extending as this tends to
9931 // introduce too many arithmetic instructions.
9932 if (NeedsBswap && (LegalOperations || NeedsZext) &&
9933 !TLI.isOperationLegal(Op: ISD::BSWAP, VT))
9934 return SDValue();
9935
9936 // If we need to bswap and zero extend, we have to insert a shift. Check that
9937 // it is legal.
9938 if (NeedsBswap && NeedsZext && LegalOperations &&
9939 !TLI.isOperationLegal(Op: ISD::SHL, VT))
9940 return SDValue();
9941
9942 // Check that a load of the wide type is both allowed and fast on the target
9943 unsigned Fast = 0;
9944 bool Allowed =
9945 TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT: MemVT,
9946 MMO: *FirstLoad->getMemOperand(), Fast: &Fast);
9947 if (!Allowed || !Fast)
9948 return SDValue();
9949
9950 SDValue NewLoad =
9951 DAG.getExtLoad(ExtType: NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, dl: SDLoc(N), VT,
9952 Chain, Ptr: FirstLoad->getBasePtr(),
9953 PtrInfo: FirstLoad->getPointerInfo(), MemVT, Alignment: FirstLoad->getAlign());
9954
9955 // Transfer chain users from old loads to the new load.
9956 for (LoadSDNode *L : Loads)
9957 DAG.makeEquivalentMemoryOrdering(OldLoad: L, NewMemOp: NewLoad);
9958
9959 if (!NeedsBswap)
9960 return NewLoad;
9961
9962 SDValue ShiftedLoad =
9963 NeedsZext ? DAG.getNode(Opcode: ISD::SHL, DL: SDLoc(N), VT, N1: NewLoad,
9964 N2: DAG.getShiftAmountConstant(Val: ZeroExtendedBytes * 8,
9965 VT, DL: SDLoc(N)))
9966 : NewLoad;
9967 return DAG.getNode(Opcode: ISD::BSWAP, DL: SDLoc(N), VT, Operand: ShiftedLoad);
9968}
9969
9970// If the target has andn, bsl, or a similar bit-select instruction,
9971// we want to unfold masked merge, with canonical pattern of:
9972// | A | |B|
9973// ((x ^ y) & m) ^ y
9974// | D |
9975// Into:
9976// (x & m) | (y & ~m)
9977// If y is a constant, m is not a 'not', and the 'andn' does not work with
9978// immediates, we unfold into a different pattern:
9979// ~(~x & m) & (m | y)
9980// If x is a constant, m is a 'not', and the 'andn' does not work with
9981// immediates, we unfold into a different pattern:
9982// (x | ~m) & ~(~m & ~y)
9983// NOTE: we don't unfold the pattern if 'xor' is actually a 'not', because at
9984// the very least that breaks andnpd / andnps patterns, and because those
9985// patterns are simplified in IR and shouldn't be created in the DAG
9986SDValue DAGCombiner::unfoldMaskedMerge(SDNode *N) {
9987 assert(N->getOpcode() == ISD::XOR);
9988
9989 // Don't touch 'not' (i.e. where y = -1).
9990 if (isAllOnesOrAllOnesSplat(V: N->getOperand(Num: 1)))
9991 return SDValue();
9992
9993 EVT VT = N->getValueType(ResNo: 0);
9994
9995 // There are 3 commutable operators in the pattern,
9996 // so we have to deal with 8 possible variants of the basic pattern.
9997 SDValue X, Y, M;
9998 auto matchAndXor = [&X, &Y, &M](SDValue And, unsigned XorIdx, SDValue Other) {
9999 if (And.getOpcode() != ISD::AND || !And.hasOneUse())
10000 return false;
10001 SDValue Xor = And.getOperand(i: XorIdx);
10002 if (Xor.getOpcode() != ISD::XOR || !Xor.hasOneUse())
10003 return false;
10004 SDValue Xor0 = Xor.getOperand(i: 0);
10005 SDValue Xor1 = Xor.getOperand(i: 1);
10006 // Don't touch 'not' (i.e. where y = -1).
10007 if (isAllOnesOrAllOnesSplat(V: Xor1))
10008 return false;
10009 if (Other == Xor0)
10010 std::swap(a&: Xor0, b&: Xor1);
10011 if (Other != Xor1)
10012 return false;
10013 X = Xor0;
10014 Y = Xor1;
10015 M = And.getOperand(i: XorIdx ? 0 : 1);
10016 return true;
10017 };
10018
10019 SDValue N0 = N->getOperand(Num: 0);
10020 SDValue N1 = N->getOperand(Num: 1);
10021 if (!matchAndXor(N0, 0, N1) && !matchAndXor(N0, 1, N1) &&
10022 !matchAndXor(N1, 0, N0) && !matchAndXor(N1, 1, N0))
10023 return SDValue();
10024
10025 // Don't do anything if the mask is constant. This should not be reachable.
10026 // InstCombine should have already unfolded this pattern, and DAGCombiner
10027 // probably shouldn't produce it, too.
10028 if (isa<ConstantSDNode>(Val: M.getNode()))
10029 return SDValue();
10030
10031 // We can transform if the target has AndNot
10032 if (!TLI.hasAndNot(X: M))
10033 return SDValue();
10034
10035 SDLoc DL(N);
10036
10037 // If Y is a constant, check that 'andn' works with immediates. Unless M is
10038 // a bitwise not that would already allow ANDN to be used.
10039 if (!TLI.hasAndNot(X: Y) && !isBitwiseNot(V: M)) {
10040 assert(TLI.hasAndNot(X) && "Only mask is a variable? Unreachable.");
10041 // If not, we need to do a bit more work to make sure andn is still used.
10042 SDValue NotX = DAG.getNOT(DL, Val: X, VT);
10043 SDValue LHS = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: NotX, N2: M);
10044 SDValue NotLHS = DAG.getNOT(DL, Val: LHS, VT);
10045 SDValue RHS = DAG.getNode(Opcode: ISD::OR, DL, VT, N1: M, N2: Y);
10046 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: NotLHS, N2: RHS);
10047 }
10048
10049 // If X is a constant and M is a bitwise not, check that 'andn' works with
10050 // immediates.
10051 if (!TLI.hasAndNot(X) && isBitwiseNot(V: M)) {
10052 assert(TLI.hasAndNot(Y) && "Only mask is a variable? Unreachable.");
10053 // If not, we need to do a bit more work to make sure andn is still used.
10054 SDValue NotM = M.getOperand(i: 0);
10055 SDValue LHS = DAG.getNode(Opcode: ISD::OR, DL, VT, N1: X, N2: NotM);
10056 SDValue NotY = DAG.getNOT(DL, Val: Y, VT);
10057 SDValue RHS = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: NotM, N2: NotY);
10058 SDValue NotRHS = DAG.getNOT(DL, Val: RHS, VT);
10059 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: LHS, N2: NotRHS);
10060 }
10061
10062 SDValue LHS = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: X, N2: M);
10063 SDValue NotM = DAG.getNOT(DL, Val: M, VT);
10064 SDValue RHS = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Y, N2: NotM);
10065
10066 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: LHS, N2: RHS);
10067}
10068
10069SDValue DAGCombiner::visitXOR(SDNode *N) {
10070 SDValue N0 = N->getOperand(Num: 0);
10071 SDValue N1 = N->getOperand(Num: 1);
10072 EVT VT = N0.getValueType();
10073 SDLoc DL(N);
10074
10075 // fold (xor undef, undef) -> 0. This is a common idiom (misuse).
10076 if (N0.isUndef() && N1.isUndef())
10077 return DAG.getConstant(Val: 0, DL, VT);
10078
10079 // fold (xor x, undef) -> undef
10080 if (N0.isUndef())
10081 return N0;
10082 if (N1.isUndef())
10083 return N1;
10084
10085 // fold (xor c1, c2) -> c1^c2
10086 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::XOR, DL, VT, Ops: {N0, N1}))
10087 return C;
10088
10089 // canonicalize constant to RHS
10090 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
10091 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
10092 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1, N2: N0);
10093
10094 // fold vector ops
10095 if (VT.isVector()) {
10096 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
10097 return FoldedVOp;
10098
10099 // fold (xor x, 0) -> x, vector edition
10100 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
10101 return N0;
10102 }
10103
10104 // fold (xor x, 0) -> x
10105 if (isNullConstant(V: N1))
10106 return N0;
10107
10108 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
10109 return NewSel;
10110
10111 // reassociate xor
10112 if (SDValue RXOR = reassociateOps(Opc: ISD::XOR, DL, N0, N1, Flags: N->getFlags()))
10113 return RXOR;
10114
10115 // Fold xor(vecreduce(x), vecreduce(y)) -> vecreduce(xor(x, y))
10116 if (SDValue SD =
10117 reassociateReduction(RedOpc: ISD::VECREDUCE_XOR, Opc: ISD::XOR, DL, VT, N0, N1))
10118 return SD;
10119
10120 // fold (a^b) -> (a|b) iff a and b share no bits.
10121 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::OR, VT)) &&
10122 DAG.haveNoCommonBitsSet(A: N0, B: N1))
10123 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: N0, N2: N1, Flags: SDNodeFlags::Disjoint);
10124
10125 // look for 'add-like' folds:
10126 // XOR(N0,MIN_SIGNED_VALUE) == ADD(N0,MIN_SIGNED_VALUE)
10127 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::ADD, VT)) &&
10128 isMinSignedConstant(V: N1))
10129 if (SDValue Combined = visitADDLike(N))
10130 return Combined;
10131
10132 // fold not (setcc x, y, cc) -> setcc x y !cc
10133 // Avoid breaking: and (not(setcc x, y, cc), z) -> andn for vec
10134 unsigned N0Opcode = N0.getOpcode();
10135 SDValue LHS, RHS, CC;
10136 if (TLI.isConstTrueVal(N: N1) &&
10137 isSetCCEquivalent(N: N0, LHS, RHS, CC, /*MatchStrict*/ true) &&
10138 !(VT.isVector() && TLI.hasAndNot(X: SDValue(N, 0)) && N->hasOneUse() &&
10139 N->use_begin()->getUser()->getOpcode() == ISD::AND)) {
10140 ISD::CondCode NotCC = ISD::getSetCCInverse(Operation: cast<CondCodeSDNode>(Val&: CC)->get(),
10141 Type: LHS.getValueType());
10142 if (!LegalOperations ||
10143 TLI.isCondCodeLegal(CC: NotCC, VT: LHS.getSimpleValueType())) {
10144 switch (N0Opcode) {
10145 default:
10146 llvm_unreachable("Unhandled SetCC Equivalent!");
10147 case ISD::SETCC:
10148 return DAG.getSetCC(DL: SDLoc(N0), VT, LHS, RHS, Cond: NotCC);
10149 case ISD::SELECT_CC:
10150 return DAG.getSelectCC(DL: SDLoc(N0), LHS, RHS, True: N0.getOperand(i: 2),
10151 False: N0.getOperand(i: 3), Cond: NotCC);
10152 case ISD::STRICT_FSETCC:
10153 case ISD::STRICT_FSETCCS: {
10154 if (N0.hasOneUse()) {
10155 // FIXME Can we handle multiple uses? Could we token factor the chain
10156 // results from the new/old setcc?
10157 SDValue SetCC =
10158 DAG.getSetCC(DL: SDLoc(N0), VT, LHS, RHS, Cond: NotCC,
10159 Chain: N0.getOperand(i: 0), IsSignaling: N0Opcode == ISD::STRICT_FSETCCS);
10160 CombineTo(N, Res: SetCC);
10161 DAG.ReplaceAllUsesOfValueWith(From: N0.getValue(R: 1), To: SetCC.getValue(R: 1));
10162 recursivelyDeleteUnusedNodes(N: N0.getNode());
10163 return SDValue(N, 0); // Return N so it doesn't get rechecked!
10164 }
10165 break;
10166 }
10167 }
10168 }
10169 }
10170
10171 // fold (not (zext (setcc x, y))) -> (zext (not (setcc x, y)))
10172 if (isOneConstant(V: N1) && N0Opcode == ISD::ZERO_EXTEND && N0.hasOneUse() &&
10173 isSetCCEquivalent(N: N0.getOperand(i: 0), LHS, RHS, CC)){
10174 SDValue V = N0.getOperand(i: 0);
10175 SDLoc DL0(N0);
10176 V = DAG.getNode(Opcode: ISD::XOR, DL: DL0, VT: V.getValueType(), N1: V,
10177 N2: DAG.getConstant(Val: 1, DL: DL0, VT: V.getValueType()));
10178 AddToWorklist(N: V.getNode());
10179 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: V);
10180 }
10181
10182 // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are setcc
10183 // fold (not (and x, y)) -> (or (not x), (not y)) iff x or y are setcc
10184 if (isOneConstant(V: N1) && VT == MVT::i1 && N0.hasOneUse() &&
10185 (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
10186 SDValue N00 = N0.getOperand(i: 0), N01 = N0.getOperand(i: 1);
10187 if (isOneUseSetCC(N: N01) || isOneUseSetCC(N: N00)) {
10188 unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
10189 N00 = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N00), VT, N1: N00, N2: N1); // N00 = ~N00
10190 N01 = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N01), VT, N1: N01, N2: N1); // N01 = ~N01
10191 AddToWorklist(N: N00.getNode()); AddToWorklist(N: N01.getNode());
10192 return DAG.getNode(Opcode: NewOpcode, DL, VT, N1: N00, N2: N01);
10193 }
10194 }
10195 // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are constants
10196 // fold (not (and x, y)) -> (or (not x), (not y)) iff x or y are constants
10197 if (isAllOnesConstant(V: N1) && N0.hasOneUse() &&
10198 (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
10199 SDValue N00 = N0.getOperand(i: 0), N01 = N0.getOperand(i: 1);
10200 if (isa<ConstantSDNode>(Val: N01) || isa<ConstantSDNode>(Val: N00)) {
10201 unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
10202 N00 = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N00), VT, N1: N00, N2: N1); // N00 = ~N00
10203 N01 = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N01), VT, N1: N01, N2: N1); // N01 = ~N01
10204 AddToWorklist(N: N00.getNode()); AddToWorklist(N: N01.getNode());
10205 return DAG.getNode(Opcode: NewOpcode, DL, VT, N1: N00, N2: N01);
10206 }
10207 }
10208
10209 // fold (not (sub Y, X)) -> (add X, ~Y) if Y is a constant
10210 if (N0.getOpcode() == ISD::SUB && isAllOnesConstant(V: N1)) {
10211 SDValue Y = N0.getOperand(i: 0);
10212 SDValue X = N0.getOperand(i: 1);
10213
10214 if (auto *YConst = dyn_cast<ConstantSDNode>(Val&: Y)) {
10215 APInt NotYValue = ~YConst->getAPIntValue();
10216 SDValue NotY = DAG.getConstant(Val: NotYValue, DL, VT);
10217 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: X, N2: NotY, Flags: N->getFlags());
10218 }
10219 }
10220
10221 // fold (not (add X, -1)) -> (neg X)
10222 if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() && isAllOnesConstant(V: N1) &&
10223 isAllOnesOrAllOnesSplat(V: N0.getOperand(i: 1))) {
10224 return DAG.getNegative(Val: N0.getOperand(i: 0), DL, VT);
10225 }
10226
10227 // fold (xor (and x, y), y) -> (and (not x), y)
10228 if (N0Opcode == ISD::AND && N0.hasOneUse() && N0->getOperand(Num: 1) == N1) {
10229 SDValue X = N0.getOperand(i: 0);
10230 SDValue NotX = DAG.getNOT(DL: SDLoc(X), Val: X, VT);
10231 AddToWorklist(N: NotX.getNode());
10232 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: NotX, N2: N1);
10233 }
10234
10235 // fold Y = sra (X, size(X)-1); xor (add (X, Y), Y) -> (abs X)
10236 if (!LegalOperations || hasOperation(Opcode: ISD::ABS, VT)) {
10237 SDValue A = N0Opcode == ISD::ADD ? N0 : N1;
10238 SDValue S = N0Opcode == ISD::SRA ? N0 : N1;
10239 if (A.getOpcode() == ISD::ADD && S.getOpcode() == ISD::SRA) {
10240 SDValue A0 = A.getOperand(i: 0), A1 = A.getOperand(i: 1);
10241 SDValue S0 = S.getOperand(i: 0);
10242 if ((A0 == S && A1 == S0) || (A1 == S && A0 == S0))
10243 if (ConstantSDNode *C = isConstOrConstSplat(N: S.getOperand(i: 1)))
10244 if (C->getAPIntValue() == (VT.getScalarSizeInBits() - 1))
10245 return DAG.getNode(Opcode: ISD::ABS, DL, VT, Operand: S0);
10246 }
10247 }
10248
10249 // fold (xor x, x) -> 0
10250 if (N0 == N1)
10251 return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
10252
10253 // fold (xor (shl 1, x), -1) -> (rotl ~1, x)
10254 // Here is a concrete example of this equivalence:
10255 // i16 x == 14
10256 // i16 shl == 1 << 14 == 16384 == 0b0100000000000000
10257 // i16 xor == ~(1 << 14) == 49151 == 0b1011111111111111
10258 //
10259 // =>
10260 //
10261 // i16 ~1 == 0b1111111111111110
10262 // i16 rol(~1, 14) == 0b1011111111111111
10263 //
10264 // Some additional tips to help conceptualize this transform:
10265 // - Try to see the operation as placing a single zero in a value of all ones.
10266 // - There exists no value for x which would allow the result to contain zero.
10267 // - Values of x larger than the bitwidth are undefined and do not require a
10268 // consistent result.
10269 // - Pushing the zero left requires shifting one bits in from the right.
10270 // A rotate left of ~1 is a nice way of achieving the desired result.
10271 if (TLI.isOperationLegalOrCustom(Op: ISD::ROTL, VT) && N0Opcode == ISD::SHL &&
10272 isAllOnesConstant(V: N1) && isOneConstant(V: N0.getOperand(i: 0))) {
10273 return DAG.getNode(Opcode: ISD::ROTL, DL, VT, N1: DAG.getSignedConstant(Val: ~1, DL, VT),
10274 N2: N0.getOperand(i: 1));
10275 }
10276
10277 // Simplify: xor (op x...), (op y...) -> (op (xor x, y))
10278 if (N0Opcode == N1.getOpcode())
10279 if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
10280 return V;
10281
10282 if (SDValue R = foldLogicOfShifts(N, LogicOp: N0, ShiftOp: N1, DAG))
10283 return R;
10284 if (SDValue R = foldLogicOfShifts(N, LogicOp: N1, ShiftOp: N0, DAG))
10285 return R;
10286 if (SDValue R = foldLogicTreeOfShifts(N, LeftHand: N0, RightHand: N1, DAG))
10287 return R;
10288
10289 // Unfold ((x ^ y) & m) ^ y into (x & m) | (y & ~m) if profitable
10290 if (SDValue MM = unfoldMaskedMerge(N))
10291 return MM;
10292
10293 // Simplify the expression using non-local knowledge.
10294 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
10295 return SDValue(N, 0);
10296
10297 if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
10298 return Combined;
10299
10300 // fold (xor (smin(x, C), C)) -> select (x < C), xor(x, C), 0
10301 // fold (xor (smax(x, C), C)) -> select (x > C), xor(x, C), 0
10302 // fold (xor (umin(x, C), C)) -> select (x < C), xor(x, C), 0
10303 // fold (xor (umax(x, C), C)) -> select (x > C), xor(x, C), 0
10304 SDValue Op0;
10305 if (sd_match(N: N0, P: m_OneUse(P: m_AnyOf(preds: m_SMin(L: m_Value(N&: Op0), R: m_Specific(N: N1)),
10306 preds: m_SMax(L: m_Value(N&: Op0), R: m_Specific(N: N1)),
10307 preds: m_UMin(L: m_Value(N&: Op0), R: m_Specific(N: N1)),
10308 preds: m_UMax(L: m_Value(N&: Op0), R: m_Specific(N: N1)))))) {
10309
10310 if (isa<ConstantSDNode>(Val: N1) ||
10311 ISD::isBuildVectorOfConstantSDNodes(N: N1.getNode())) {
10312 // For vectors, only optimize when the constant is zero or all-ones to
10313 // avoid generating more instructions
10314 if (VT.isVector()) {
10315 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
10316 if (!N1C || (!N1C->isZero() && !N1C->isAllOnes()))
10317 return SDValue();
10318 }
10319
10320 // Avoid the fold if the minmax operation is legal and select is expensive
10321 if (TLI.isOperationLegal(Op: N0.getOpcode(), VT) &&
10322 TLI.isPredictableSelectExpensive())
10323 return SDValue();
10324
10325 EVT CCVT = getSetCCResultType(VT);
10326 ISD::CondCode CC;
10327 switch (N0.getOpcode()) {
10328 case ISD::SMIN:
10329 CC = ISD::SETLT;
10330 break;
10331 case ISD::SMAX:
10332 CC = ISD::SETGT;
10333 break;
10334 case ISD::UMIN:
10335 CC = ISD::SETULT;
10336 break;
10337 case ISD::UMAX:
10338 CC = ISD::SETUGT;
10339 break;
10340 }
10341 SDValue FN1 = DAG.getFreeze(V: N1);
10342 SDValue Cmp = DAG.getSetCC(DL, VT: CCVT, LHS: Op0, RHS: FN1, Cond: CC);
10343 SDValue XorXC = DAG.getNode(Opcode: ISD::XOR, DL, VT, N1: Op0, N2: FN1);
10344 SDValue Zero = DAG.getConstant(Val: 0, DL, VT);
10345 return DAG.getSelect(DL, VT, Cond: Cmp, LHS: XorXC, RHS: Zero);
10346 }
10347 }
10348
10349 return SDValue();
10350}
10351
10352/// If we have a shift-by-constant of a bitwise logic op that itself has a
10353/// shift-by-constant operand with identical opcode, we may be able to convert
10354/// that into 2 independent shifts followed by the logic op. This is a
10355/// throughput improvement.
10356static SDValue combineShiftOfShiftedLogic(SDNode *Shift, SelectionDAG &DAG) {
10357 // Match a one-use bitwise logic op.
10358 SDValue LogicOp = Shift->getOperand(Num: 0);
10359 if (!LogicOp.hasOneUse())
10360 return SDValue();
10361
10362 unsigned LogicOpcode = LogicOp.getOpcode();
10363 if (LogicOpcode != ISD::AND && LogicOpcode != ISD::OR &&
10364 LogicOpcode != ISD::XOR)
10365 return SDValue();
10366
10367 // Find a matching one-use shift by constant.
10368 unsigned ShiftOpcode = Shift->getOpcode();
10369 SDValue C1 = Shift->getOperand(Num: 1);
10370 ConstantSDNode *C1Node = isConstOrConstSplat(N: C1);
10371 assert(C1Node && "Expected a shift with constant operand");
10372 const APInt &C1Val = C1Node->getAPIntValue();
10373 auto matchFirstShift = [&](SDValue V, SDValue &ShiftOp,
10374 const APInt *&ShiftAmtVal) {
10375 if (V.getOpcode() != ShiftOpcode || !V.hasOneUse())
10376 return false;
10377
10378 ConstantSDNode *ShiftCNode = isConstOrConstSplat(N: V.getOperand(i: 1));
10379 if (!ShiftCNode)
10380 return false;
10381
10382 // Capture the shifted operand and shift amount value.
10383 ShiftOp = V.getOperand(i: 0);
10384 ShiftAmtVal = &ShiftCNode->getAPIntValue();
10385
10386 // Shift amount types do not have to match their operand type, so check that
10387 // the constants are the same width.
10388 if (ShiftAmtVal->getBitWidth() != C1Val.getBitWidth())
10389 return false;
10390
10391 // The fold is not valid if the sum of the shift values doesn't fit in the
10392 // given shift amount type.
10393 bool Overflow = false;
10394 APInt NewShiftAmt = C1Val.uadd_ov(RHS: *ShiftAmtVal, Overflow);
10395 if (Overflow)
10396 return false;
10397
10398 // The fold is not valid if the sum of the shift values exceeds bitwidth.
10399 if (NewShiftAmt.uge(RHS: V.getScalarValueSizeInBits()))
10400 return false;
10401
10402 return true;
10403 };
10404
10405 // Logic ops are commutative, so check each operand for a match.
10406 SDValue X, Y;
10407 const APInt *C0Val;
10408 if (matchFirstShift(LogicOp.getOperand(i: 0), X, C0Val))
10409 Y = LogicOp.getOperand(i: 1);
10410 else if (matchFirstShift(LogicOp.getOperand(i: 1), X, C0Val))
10411 Y = LogicOp.getOperand(i: 0);
10412 else
10413 return SDValue();
10414
10415 // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1)
10416 SDLoc DL(Shift);
10417 EVT VT = Shift->getValueType(ResNo: 0);
10418 EVT ShiftAmtVT = Shift->getOperand(Num: 1).getValueType();
10419 SDValue ShiftSumC = DAG.getConstant(Val: *C0Val + C1Val, DL, VT: ShiftAmtVT);
10420 SDValue NewShift1 = DAG.getNode(Opcode: ShiftOpcode, DL, VT, N1: X, N2: ShiftSumC);
10421 SDValue NewShift2 = DAG.getNode(Opcode: ShiftOpcode, DL, VT, N1: Y, N2: C1);
10422 return DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: NewShift1, N2: NewShift2,
10423 Flags: LogicOp->getFlags());
10424}
10425
10426/// Handle transforms common to the three shifts, when the shift amount is a
10427/// constant.
10428/// We are looking for: (shift being one of shl/sra/srl)
10429/// shift (binop X, C0), C1
10430/// And want to transform into:
10431/// binop (shift X, C1), (shift C0, C1)
10432SDValue DAGCombiner::visitShiftByConstant(SDNode *N) {
10433 assert(isConstOrConstSplat(N->getOperand(1)) && "Expected constant operand");
10434
10435 // Do not turn a 'not' into a regular xor.
10436 if (isBitwiseNot(V: N->getOperand(Num: 0)))
10437 return SDValue();
10438
10439 // The inner binop must be one-use, since we want to replace it.
10440 SDValue LHS = N->getOperand(Num: 0);
10441 if (!LHS.hasOneUse() || !TLI.isDesirableToCommuteWithShift(N, Level))
10442 return SDValue();
10443
10444 // Fold shift(bitop(shift(x,c1),y), c2) -> bitop(shift(x,c1+c2),shift(y,c2)).
10445 if (SDValue R = combineShiftOfShiftedLogic(Shift: N, DAG))
10446 return R;
10447
10448 // We want to pull some binops through shifts, so that we have (and (shift))
10449 // instead of (shift (and)), likewise for add, or, xor, etc. This sort of
10450 // thing happens with address calculations, so it's important to canonicalize
10451 // it.
10452 switch (LHS.getOpcode()) {
10453 default:
10454 return SDValue();
10455 case ISD::OR:
10456 case ISD::XOR:
10457 case ISD::AND:
10458 break;
10459 case ISD::ADD:
10460 if (N->getOpcode() != ISD::SHL)
10461 return SDValue(); // only shl(add) not sr[al](add).
10462 break;
10463 }
10464
10465 // FIXME: disable this unless the input to the binop is a shift by a constant
10466 // or is copy/select. Enable this in other cases when figure out it's exactly
10467 // profitable.
10468 SDValue BinOpLHSVal = LHS.getOperand(i: 0);
10469 bool IsShiftByConstant = (BinOpLHSVal.getOpcode() == ISD::SHL ||
10470 BinOpLHSVal.getOpcode() == ISD::SRA ||
10471 BinOpLHSVal.getOpcode() == ISD::SRL) &&
10472 isa<ConstantSDNode>(Val: BinOpLHSVal.getOperand(i: 1));
10473 bool IsCopyOrSelect = BinOpLHSVal.getOpcode() == ISD::CopyFromReg ||
10474 BinOpLHSVal.getOpcode() == ISD::SELECT;
10475
10476 if (!IsShiftByConstant && !IsCopyOrSelect)
10477 return SDValue();
10478
10479 if (IsCopyOrSelect && N->hasOneUse())
10480 return SDValue();
10481
10482 // Attempt to fold the constants, shifting the binop RHS by the shift amount.
10483 SDLoc DL(N);
10484 EVT VT = N->getValueType(ResNo: 0);
10485 if (SDValue NewRHS = DAG.FoldConstantArithmetic(
10486 Opcode: N->getOpcode(), DL, VT, Ops: {LHS.getOperand(i: 1), N->getOperand(Num: 1)})) {
10487 SDValue NewShift = DAG.getNode(Opcode: N->getOpcode(), DL, VT, N1: LHS.getOperand(i: 0),
10488 N2: N->getOperand(Num: 1));
10489 return DAG.getNode(Opcode: LHS.getOpcode(), DL, VT, N1: NewShift, N2: NewRHS);
10490 }
10491
10492 return SDValue();
10493}
10494
10495SDValue DAGCombiner::distributeTruncateThroughAnd(SDNode *N) {
10496 assert(N->getOpcode() == ISD::TRUNCATE);
10497 assert(N->getOperand(0).getOpcode() == ISD::AND);
10498
10499 // (truncate:TruncVT (and N00, N01C)) -> (and (truncate:TruncVT N00), TruncC)
10500 EVT TruncVT = N->getValueType(ResNo: 0);
10501 if (N->hasOneUse() && N->getOperand(Num: 0).hasOneUse() &&
10502 TLI.isTypeDesirableForOp(ISD::AND, VT: TruncVT)) {
10503 SDValue N01 = N->getOperand(Num: 0).getOperand(i: 1);
10504 if (isConstantOrConstantVector(N: N01, /* NoOpaques */ true)) {
10505 SDLoc DL(N);
10506 SDValue N00 = N->getOperand(Num: 0).getOperand(i: 0);
10507 SDValue Trunc00 = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: TruncVT, Operand: N00);
10508 SDValue Trunc01 = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: TruncVT, Operand: N01);
10509 AddToWorklist(N: Trunc00.getNode());
10510 AddToWorklist(N: Trunc01.getNode());
10511 return DAG.getNode(Opcode: ISD::AND, DL, VT: TruncVT, N1: Trunc00, N2: Trunc01);
10512 }
10513 }
10514
10515 return SDValue();
10516}
10517
10518SDValue DAGCombiner::visitRotate(SDNode *N) {
10519 SDLoc dl(N);
10520 SDValue N0 = N->getOperand(Num: 0);
10521 SDValue N1 = N->getOperand(Num: 1);
10522 EVT VT = N->getValueType(ResNo: 0);
10523 unsigned Bitsize = VT.getScalarSizeInBits();
10524
10525 // fold (rot x, 0) -> x
10526 if (isNullOrNullSplat(V: N1))
10527 return N0;
10528
10529 // fold (rot x, c) -> x iff (c % BitSize) == 0
10530 if (isPowerOf2_32(Value: Bitsize) && Bitsize > 1) {
10531 APInt ModuloMask(N1.getScalarValueSizeInBits(), Bitsize - 1);
10532 if (DAG.MaskedValueIsZero(Op: N1, Mask: ModuloMask))
10533 return N0;
10534 }
10535
10536 // fold (rot x, c) -> (rot x, c % BitSize)
10537 bool OutOfRange = false;
10538 auto MatchOutOfRange = [Bitsize, &OutOfRange](ConstantSDNode *C) {
10539 OutOfRange |= C->getAPIntValue().uge(RHS: Bitsize);
10540 return true;
10541 };
10542 if (ISD::matchUnaryPredicate(Op: N1, Match: MatchOutOfRange) && OutOfRange) {
10543 EVT AmtVT = N1.getValueType();
10544 SDValue Bits = DAG.getConstant(Val: Bitsize, DL: dl, VT: AmtVT);
10545 if (SDValue Amt =
10546 DAG.FoldConstantArithmetic(Opcode: ISD::UREM, DL: dl, VT: AmtVT, Ops: {N1, Bits}))
10547 return DAG.getNode(Opcode: N->getOpcode(), DL: dl, VT, N1: N0, N2: Amt);
10548 }
10549
10550 // rot i16 X, 8 --> bswap X
10551 auto *RotAmtC = isConstOrConstSplat(N: N1);
10552 if (RotAmtC && RotAmtC->getAPIntValue() == 8 &&
10553 VT.getScalarSizeInBits() == 16 && hasOperation(Opcode: ISD::BSWAP, VT))
10554 return DAG.getNode(Opcode: ISD::BSWAP, DL: dl, VT, Operand: N0);
10555
10556 // Simplify the operands using demanded-bits information.
10557 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
10558 return SDValue(N, 0);
10559
10560 // fold (rot* x, (trunc (and y, c))) -> (rot* x, (and (trunc y), (trunc c))).
10561 if (N1.getOpcode() == ISD::TRUNCATE &&
10562 N1.getOperand(i: 0).getOpcode() == ISD::AND) {
10563 if (SDValue NewOp1 = distributeTruncateThroughAnd(N: N1.getNode()))
10564 return DAG.getNode(Opcode: N->getOpcode(), DL: dl, VT, N1: N0, N2: NewOp1);
10565 }
10566
10567 unsigned NextOp = N0.getOpcode();
10568
10569 // fold (rot* (rot* x, c2), c1)
10570 // -> (rot* x, ((c1 % bitsize) +- (c2 % bitsize) + bitsize) % bitsize)
10571 if (NextOp == ISD::ROTL || NextOp == ISD::ROTR) {
10572 bool C1 = DAG.isConstantIntBuildVectorOrConstantInt(N: N1);
10573 bool C2 = DAG.isConstantIntBuildVectorOrConstantInt(N: N0.getOperand(i: 1));
10574 if (C1 && C2 && N1.getValueType() == N0.getOperand(i: 1).getValueType()) {
10575 EVT ShiftVT = N1.getValueType();
10576 bool SameSide = (N->getOpcode() == NextOp);
10577 unsigned CombineOp = SameSide ? ISD::ADD : ISD::SUB;
10578 SDValue BitsizeC = DAG.getConstant(Val: Bitsize, DL: dl, VT: ShiftVT);
10579 SDValue Norm1 = DAG.FoldConstantArithmetic(Opcode: ISD::UREM, DL: dl, VT: ShiftVT,
10580 Ops: {N1, BitsizeC});
10581 SDValue Norm2 = DAG.FoldConstantArithmetic(Opcode: ISD::UREM, DL: dl, VT: ShiftVT,
10582 Ops: {N0.getOperand(i: 1), BitsizeC});
10583 if (Norm1 && Norm2)
10584 if (SDValue CombinedShift = DAG.FoldConstantArithmetic(
10585 Opcode: CombineOp, DL: dl, VT: ShiftVT, Ops: {Norm1, Norm2})) {
10586 CombinedShift = DAG.FoldConstantArithmetic(Opcode: ISD::ADD, DL: dl, VT: ShiftVT,
10587 Ops: {CombinedShift, BitsizeC});
10588 SDValue CombinedShiftNorm = DAG.FoldConstantArithmetic(
10589 Opcode: ISD::UREM, DL: dl, VT: ShiftVT, Ops: {CombinedShift, BitsizeC});
10590 return DAG.getNode(Opcode: N->getOpcode(), DL: dl, VT, N1: N0->getOperand(Num: 0),
10591 N2: CombinedShiftNorm);
10592 }
10593 }
10594 }
10595 return SDValue();
10596}
10597
10598SDValue DAGCombiner::visitSHL(SDNode *N) {
10599 SDValue N0 = N->getOperand(Num: 0);
10600 SDValue N1 = N->getOperand(Num: 1);
10601 if (SDValue V = DAG.simplifyShift(X: N0, Y: N1))
10602 return V;
10603
10604 SDLoc DL(N);
10605 EVT VT = N0.getValueType();
10606 EVT ShiftVT = N1.getValueType();
10607 unsigned OpSizeInBits = VT.getScalarSizeInBits();
10608
10609 // fold (shl c1, c2) -> c1<<c2
10610 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::SHL, DL, VT, Ops: {N0, N1}))
10611 return C;
10612
10613 // fold vector ops
10614 if (VT.isVector()) {
10615 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
10616 return FoldedVOp;
10617
10618 BuildVectorSDNode *N1CV = dyn_cast<BuildVectorSDNode>(Val&: N1);
10619 // If setcc produces all-one true value then:
10620 // (shl (and (setcc) N01CV) N1CV) -> (and (setcc) N01CV<<N1CV)
10621 if (N1CV && N1CV->isConstant()) {
10622 if (N0.getOpcode() == ISD::AND) {
10623 SDValue N00 = N0->getOperand(Num: 0);
10624 SDValue N01 = N0->getOperand(Num: 1);
10625 BuildVectorSDNode *N01CV = dyn_cast<BuildVectorSDNode>(Val&: N01);
10626
10627 if (N01CV && N01CV->isConstant() && N00.getOpcode() == ISD::SETCC &&
10628 TLI.getBooleanContents(Type: N00.getOperand(i: 0).getValueType()) ==
10629 TargetLowering::ZeroOrNegativeOneBooleanContent) {
10630 if (SDValue C =
10631 DAG.FoldConstantArithmetic(Opcode: ISD::SHL, DL, VT, Ops: {N01, N1}))
10632 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N00, N2: C);
10633 }
10634 }
10635 }
10636 }
10637
10638 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
10639 return NewSel;
10640
10641 // if (shl x, c) is known to be zero, return 0
10642 if (DAG.MaskedValueIsZero(Op: SDValue(N, 0), Mask: APInt::getAllOnes(numBits: OpSizeInBits)))
10643 return DAG.getConstant(Val: 0, DL, VT);
10644
10645 // fold (shl x, (trunc (and y, c))) -> (shl x, (and (trunc y), (trunc c))).
10646 if (N1.getOpcode() == ISD::TRUNCATE &&
10647 N1.getOperand(i: 0).getOpcode() == ISD::AND) {
10648 if (SDValue NewOp1 = distributeTruncateThroughAnd(N: N1.getNode()))
10649 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0, N2: NewOp1);
10650 }
10651
10652 // fold (shl (shl x, c1), c2) -> 0 or (shl x, (add c1, c2))
10653 if (N0.getOpcode() == ISD::SHL) {
10654 auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
10655 ConstantSDNode *RHS) {
10656 APInt c1 = LHS->getAPIntValue();
10657 APInt c2 = RHS->getAPIntValue();
10658 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
10659 return (c1 + c2).uge(RHS: OpSizeInBits);
10660 };
10661 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchOutOfRange))
10662 return DAG.getConstant(Val: 0, DL, VT);
10663
10664 auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
10665 ConstantSDNode *RHS) {
10666 APInt c1 = LHS->getAPIntValue();
10667 APInt c2 = RHS->getAPIntValue();
10668 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
10669 return (c1 + c2).ult(RHS: OpSizeInBits);
10670 };
10671 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchInRange)) {
10672 SDValue Sum = DAG.getNode(Opcode: ISD::ADD, DL, VT: ShiftVT, N1, N2: N0.getOperand(i: 1));
10673 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0.getOperand(i: 0), N2: Sum);
10674 }
10675 }
10676
10677 // fold (shl (ext (shl x, c1)), c2) -> (shl (ext x), (add c1, c2))
10678 // For this to be valid, the second form must not preserve any of the bits
10679 // that are shifted out by the inner shift in the first form. This means
10680 // the outer shift size must be >= the number of bits added by the ext.
10681 // As a corollary, we don't care what kind of ext it is.
10682 if ((N0.getOpcode() == ISD::ZERO_EXTEND ||
10683 N0.getOpcode() == ISD::ANY_EXTEND ||
10684 N0.getOpcode() == ISD::SIGN_EXTEND) &&
10685 N0.getOperand(i: 0).getOpcode() == ISD::SHL) {
10686 SDValue N0Op0 = N0.getOperand(i: 0);
10687 SDValue InnerShiftAmt = N0Op0.getOperand(i: 1);
10688 EVT InnerVT = N0Op0.getValueType();
10689 uint64_t InnerBitwidth = InnerVT.getScalarSizeInBits();
10690
10691 auto MatchOutOfRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
10692 ConstantSDNode *RHS) {
10693 APInt c1 = LHS->getAPIntValue();
10694 APInt c2 = RHS->getAPIntValue();
10695 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
10696 return c2.uge(RHS: OpSizeInBits - InnerBitwidth) &&
10697 (c1 + c2).uge(RHS: OpSizeInBits);
10698 };
10699 if (ISD::matchBinaryPredicate(LHS: InnerShiftAmt, RHS: N1, Match: MatchOutOfRange,
10700 /*AllowUndefs*/ false,
10701 /*AllowTypeMismatch*/ true))
10702 return DAG.getConstant(Val: 0, DL, VT);
10703
10704 auto MatchInRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
10705 ConstantSDNode *RHS) {
10706 APInt c1 = LHS->getAPIntValue();
10707 APInt c2 = RHS->getAPIntValue();
10708 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
10709 return c2.uge(RHS: OpSizeInBits - InnerBitwidth) &&
10710 (c1 + c2).ult(RHS: OpSizeInBits);
10711 };
10712 if (ISD::matchBinaryPredicate(LHS: InnerShiftAmt, RHS: N1, Match: MatchInRange,
10713 /*AllowUndefs*/ false,
10714 /*AllowTypeMismatch*/ true)) {
10715 SDValue Ext = DAG.getNode(Opcode: N0.getOpcode(), DL, VT, Operand: N0Op0.getOperand(i: 0));
10716 SDValue Sum = DAG.getZExtOrTrunc(Op: InnerShiftAmt, DL, VT: ShiftVT);
10717 Sum = DAG.getNode(Opcode: ISD::ADD, DL, VT: ShiftVT, N1: Sum, N2: N1);
10718 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Ext, N2: Sum);
10719 }
10720 }
10721
10722 // fold (shl (zext (srl x, C)), C) -> (zext (shl (srl x, C), C))
10723 // Only fold this if the inner zext has no other uses to avoid increasing
10724 // the total number of instructions.
10725 if (N0.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse() &&
10726 N0.getOperand(i: 0).getOpcode() == ISD::SRL) {
10727 SDValue N0Op0 = N0.getOperand(i: 0);
10728 SDValue InnerShiftAmt = N0Op0.getOperand(i: 1);
10729
10730 auto MatchEqual = [VT](ConstantSDNode *LHS, ConstantSDNode *RHS) {
10731 APInt c1 = LHS->getAPIntValue();
10732 APInt c2 = RHS->getAPIntValue();
10733 zeroExtendToMatch(LHS&: c1, RHS&: c2);
10734 return c1.ult(RHS: VT.getScalarSizeInBits()) && (c1 == c2);
10735 };
10736 if (ISD::matchBinaryPredicate(LHS: InnerShiftAmt, RHS: N1, Match: MatchEqual,
10737 /*AllowUndefs*/ false,
10738 /*AllowTypeMismatch*/ true)) {
10739 EVT InnerShiftAmtVT = N0Op0.getOperand(i: 1).getValueType();
10740 SDValue NewSHL = DAG.getZExtOrTrunc(Op: N1, DL, VT: InnerShiftAmtVT);
10741 NewSHL = DAG.getNode(Opcode: ISD::SHL, DL, VT: N0Op0.getValueType(), N1: N0Op0, N2: NewSHL);
10742 AddToWorklist(N: NewSHL.getNode());
10743 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL: SDLoc(N0), VT, Operand: NewSHL);
10744 }
10745 }
10746
10747 if (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SRA) {
10748 auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
10749 ConstantSDNode *RHS) {
10750 const APInt &LHSC = LHS->getAPIntValue();
10751 const APInt &RHSC = RHS->getAPIntValue();
10752 return LHSC.ult(RHS: OpSizeInBits) && RHSC.ult(RHS: OpSizeInBits) &&
10753 LHSC.getZExtValue() <= RHSC.getZExtValue();
10754 };
10755
10756 // fold (shl (sr[la] exact X, C1), C2) -> (shl X, (C2-C1)) if C1 <= C2
10757 // fold (shl (sr[la] exact X, C1), C2) -> (sr[la] X, (C2-C1)) if C1 >= C2
10758 if (N0->getFlags().hasExact()) {
10759 if (ISD::matchBinaryPredicate(LHS: N0.getOperand(i: 1), RHS: N1, Match: MatchShiftAmount,
10760 /*AllowUndefs*/ false,
10761 /*AllowTypeMismatch*/ true)) {
10762 SDValue N01 = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1), DL, VT: ShiftVT);
10763 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftVT, N1, N2: N01);
10764 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0.getOperand(i: 0), N2: Diff);
10765 }
10766 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchShiftAmount,
10767 /*AllowUndefs*/ false,
10768 /*AllowTypeMismatch*/ true)) {
10769 SDValue N01 = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1), DL, VT: ShiftVT);
10770 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftVT, N1: N01, N2: N1);
10771 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, N1: N0.getOperand(i: 0), N2: Diff);
10772 }
10773 }
10774
10775 // fold (shl (srl x, c1), c2) -> (and (shl x, (sub c2, c1), MASK) or
10776 // (and (srl x, (sub c1, c2), MASK)
10777 // Only fold this if the inner shift has no other uses -- if it does,
10778 // folding this will increase the total number of instructions.
10779 if (N0.getOpcode() == ISD::SRL &&
10780 (N0.getOperand(i: 1) == N1 || N0.hasOneUse()) &&
10781 TLI.shouldFoldConstantShiftPairToMask(N)) {
10782 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchShiftAmount,
10783 /*AllowUndefs*/ false,
10784 /*AllowTypeMismatch*/ true)) {
10785 SDValue N01 = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1), DL, VT: ShiftVT);
10786 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftVT, N1: N01, N2: N1);
10787 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
10788 Mask = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Mask, N2: N01);
10789 Mask = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Mask, N2: Diff);
10790 SDValue Shift = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0.getOperand(i: 0), N2: Diff);
10791 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Shift, N2: Mask);
10792 }
10793 if (ISD::matchBinaryPredicate(LHS: N0.getOperand(i: 1), RHS: N1, Match: MatchShiftAmount,
10794 /*AllowUndefs*/ false,
10795 /*AllowTypeMismatch*/ true)) {
10796 SDValue N01 = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1), DL, VT: ShiftVT);
10797 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftVT, N1, N2: N01);
10798 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
10799 Mask = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Mask, N2: N1);
10800 SDValue Shift = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0.getOperand(i: 0), N2: Diff);
10801 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Shift, N2: Mask);
10802 }
10803 }
10804 }
10805
10806 // fold (shl (sra x, c1), c1) -> (and x, (shl -1, c1))
10807 if (N0.getOpcode() == ISD::SRA && N1 == N0.getOperand(i: 1) &&
10808 isConstantOrConstantVector(N: N1, /* No Opaques */ NoOpaques: true)) {
10809 SDValue AllBits = DAG.getAllOnesConstant(DL, VT);
10810 SDValue HiBitsMask = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: AllBits, N2: N1);
10811 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0.getOperand(i: 0), N2: HiBitsMask);
10812 }
10813
10814 // fold (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2)
10815 // fold (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2)
10816 // Variant of version done on multiply, except mul by a power of 2 is turned
10817 // into a shift.
10818 if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::OR) &&
10819 TLI.isDesirableToCommuteWithShift(N, Level)) {
10820 SDValue N01 = N0.getOperand(i: 1);
10821 if (SDValue Shl1 =
10822 DAG.FoldConstantArithmetic(Opcode: ISD::SHL, DL: SDLoc(N1), VT, Ops: {N01, N1})) {
10823 SDValue Shl0 = DAG.getNode(Opcode: ISD::SHL, DL: SDLoc(N0), VT, N1: N0.getOperand(i: 0), N2: N1);
10824 AddToWorklist(N: Shl0.getNode());
10825 SDNodeFlags Flags;
10826 // Preserve the disjoint flag for Or.
10827 if (N0.getOpcode() == ISD::OR && N0->getFlags().hasDisjoint())
10828 Flags |= SDNodeFlags::Disjoint;
10829 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, N1: Shl0, N2: Shl1, Flags);
10830 }
10831 }
10832
10833 // fold (shl (sext (add_nsw x, c1)), c2) -> (add (shl (sext x), c2), c1 << c2)
10834 // TODO: Add zext/add_nuw variant with suitable test coverage
10835 // TODO: Should we limit this with isLegalAddImmediate?
10836 if (N0.getOpcode() == ISD::SIGN_EXTEND &&
10837 N0.getOperand(i: 0).getOpcode() == ISD::ADD &&
10838 N0.getOperand(i: 0)->getFlags().hasNoSignedWrap() &&
10839 TLI.isDesirableToCommuteWithShift(N, Level)) {
10840 SDValue Add = N0.getOperand(i: 0);
10841 SDLoc DL(N0);
10842 if (SDValue ExtC = DAG.FoldConstantArithmetic(Opcode: N0.getOpcode(), DL, VT,
10843 Ops: {Add.getOperand(i: 1)})) {
10844 if (SDValue ShlC =
10845 DAG.FoldConstantArithmetic(Opcode: ISD::SHL, DL, VT, Ops: {ExtC, N1})) {
10846 SDValue ExtX = DAG.getNode(Opcode: N0.getOpcode(), DL, VT, Operand: Add.getOperand(i: 0));
10847 SDValue ShlX = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: ExtX, N2: N1);
10848 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: ShlX, N2: ShlC);
10849 }
10850 }
10851 }
10852
10853 // fold (shl (mul x, c1), c2) -> (mul x, c1 << c2)
10854 if (N0.getOpcode() == ISD::MUL && N0->hasOneUse()) {
10855 SDValue N01 = N0.getOperand(i: 1);
10856 if (SDValue Shl =
10857 DAG.FoldConstantArithmetic(Opcode: ISD::SHL, DL: SDLoc(N1), VT, Ops: {N01, N1}))
10858 return DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: N0.getOperand(i: 0), N2: Shl);
10859 }
10860
10861 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
10862 if (N1C && !N1C->isOpaque())
10863 if (SDValue NewSHL = visitShiftByConstant(N))
10864 return NewSHL;
10865
10866 // fold (shl X, cttz(Y)) -> (mul (Y & -Y), X) if cttz is unsupported on the
10867 // target.
10868 if (((N1.getOpcode() == ISD::CTTZ &&
10869 VT.getScalarSizeInBits() <= ShiftVT.getScalarSizeInBits()) ||
10870 N1.getOpcode() == ISD::CTTZ_ZERO_UNDEF) &&
10871 N1.hasOneUse() && !TLI.isOperationLegalOrCustom(Op: ISD::CTTZ, VT: ShiftVT) &&
10872 TLI.isOperationLegalOrCustom(Op: ISD::MUL, VT)) {
10873 SDValue Y = N1.getOperand(i: 0);
10874 SDLoc DL(N);
10875 SDValue NegY = DAG.getNegative(Val: Y, DL, VT: ShiftVT);
10876 SDValue And =
10877 DAG.getZExtOrTrunc(Op: DAG.getNode(Opcode: ISD::AND, DL, VT: ShiftVT, N1: Y, N2: NegY), DL, VT);
10878 return DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: And, N2: N0);
10879 }
10880
10881 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
10882 return SDValue(N, 0);
10883
10884 // Fold (shl (vscale * C0), C1) to (vscale * (C0 << C1)).
10885 if (N0.getOpcode() == ISD::VSCALE && N1C) {
10886 const APInt &C0 = N0.getConstantOperandAPInt(i: 0);
10887 const APInt &C1 = N1C->getAPIntValue();
10888 return DAG.getVScale(DL, VT, MulImm: C0 << C1);
10889 }
10890
10891 SDValue X;
10892 APInt VS0;
10893
10894 // fold (shl (X * vscale(VS0)), C1) -> (X * vscale(VS0 << C1))
10895 if (N1C && sd_match(N: N0, P: m_Mul(L: m_Value(N&: X), R: m_VScale(Op: m_ConstInt(V&: VS0))))) {
10896 SDNodeFlags Flags;
10897 Flags.setNoUnsignedWrap(N->getFlags().hasNoUnsignedWrap() &&
10898 N0->getFlags().hasNoUnsignedWrap());
10899
10900 SDValue VScale = DAG.getVScale(DL, VT, MulImm: VS0 << N1C->getAPIntValue());
10901 return DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: X, N2: VScale, Flags);
10902 }
10903
10904 // Fold (shl step_vector(C0), C1) to (step_vector(C0 << C1)).
10905 APInt ShlVal;
10906 if (N0.getOpcode() == ISD::STEP_VECTOR &&
10907 ISD::isConstantSplatVector(N: N1.getNode(), SplatValue&: ShlVal)) {
10908 const APInt &C0 = N0.getConstantOperandAPInt(i: 0);
10909 if (ShlVal.ult(RHS: C0.getBitWidth())) {
10910 APInt NewStep = C0 << ShlVal;
10911 return DAG.getStepVector(DL, ResVT: VT, StepVal: NewStep);
10912 }
10913 }
10914
10915 return SDValue();
10916}
10917
10918// Transform a right shift of a multiply into a multiply-high.
10919// Examples:
10920// (srl (mul (zext i32:$a to i64), (zext i32:$a to i64)), 32) -> (mulhu $a, $b)
10921// (sra (mul (sext i32:$a to i64), (sext i32:$a to i64)), 32) -> (mulhs $a, $b)
10922static SDValue combineShiftToMULH(SDNode *N, const SDLoc &DL, SelectionDAG &DAG,
10923 const TargetLowering &TLI) {
10924 assert((N->getOpcode() == ISD::SRL || N->getOpcode() == ISD::SRA) &&
10925 "SRL or SRA node is required here!");
10926
10927 // Check the shift amount. Proceed with the transformation if the shift
10928 // amount is constant.
10929 ConstantSDNode *ShiftAmtSrc = isConstOrConstSplat(N: N->getOperand(Num: 1));
10930 if (!ShiftAmtSrc)
10931 return SDValue();
10932
10933 // The operation feeding into the shift must be a multiply.
10934 SDValue ShiftOperand = N->getOperand(Num: 0);
10935 if (ShiftOperand.getOpcode() != ISD::MUL)
10936 return SDValue();
10937
10938 // Both operands must be equivalent extend nodes.
10939 SDValue LeftOp = ShiftOperand.getOperand(i: 0);
10940 SDValue RightOp = ShiftOperand.getOperand(i: 1);
10941
10942 bool IsSignExt = LeftOp.getOpcode() == ISD::SIGN_EXTEND;
10943 bool IsZeroExt = LeftOp.getOpcode() == ISD::ZERO_EXTEND;
10944
10945 if (!IsSignExt && !IsZeroExt)
10946 return SDValue();
10947
10948 EVT NarrowVT = LeftOp.getOperand(i: 0).getValueType();
10949 unsigned NarrowVTSize = NarrowVT.getScalarSizeInBits();
10950
10951 // return true if U may use the lower bits of its operands
10952 auto UserOfLowerBits = [NarrowVTSize](SDNode *U) {
10953 if (U->getOpcode() != ISD::SRL && U->getOpcode() != ISD::SRA) {
10954 return true;
10955 }
10956 ConstantSDNode *UShiftAmtSrc = isConstOrConstSplat(N: U->getOperand(Num: 1));
10957 if (!UShiftAmtSrc) {
10958 return true;
10959 }
10960 unsigned UShiftAmt = UShiftAmtSrc->getZExtValue();
10961 return UShiftAmt < NarrowVTSize;
10962 };
10963
10964 // If the lower part of the MUL is also used and MUL_LOHI is supported
10965 // do not introduce the MULH in favor of MUL_LOHI
10966 unsigned MulLoHiOp = IsSignExt ? ISD::SMUL_LOHI : ISD::UMUL_LOHI;
10967 if (!ShiftOperand.hasOneUse() &&
10968 TLI.isOperationLegalOrCustom(Op: MulLoHiOp, VT: NarrowVT) &&
10969 llvm::any_of(Range: ShiftOperand->users(), P: UserOfLowerBits)) {
10970 return SDValue();
10971 }
10972
10973 SDValue MulhRightOp;
10974 if (ConstantSDNode *Constant = isConstOrConstSplat(N: RightOp)) {
10975 unsigned ActiveBits = IsSignExt
10976 ? Constant->getAPIntValue().getSignificantBits()
10977 : Constant->getAPIntValue().getActiveBits();
10978 if (ActiveBits > NarrowVTSize)
10979 return SDValue();
10980 MulhRightOp = DAG.getConstant(
10981 Val: Constant->getAPIntValue().trunc(width: NarrowVT.getScalarSizeInBits()), DL,
10982 VT: NarrowVT);
10983 } else {
10984 if (LeftOp.getOpcode() != RightOp.getOpcode())
10985 return SDValue();
10986 // Check that the two extend nodes are the same type.
10987 if (NarrowVT != RightOp.getOperand(i: 0).getValueType())
10988 return SDValue();
10989 MulhRightOp = RightOp.getOperand(i: 0);
10990 }
10991
10992 EVT WideVT = LeftOp.getValueType();
10993 // Proceed with the transformation if the wide types match.
10994 assert((WideVT == RightOp.getValueType()) &&
10995 "Cannot have a multiply node with two different operand types.");
10996
10997 // Proceed with the transformation if the wide type is twice as large
10998 // as the narrow type.
10999 if (WideVT.getScalarSizeInBits() != 2 * NarrowVTSize)
11000 return SDValue();
11001
11002 // Check the shift amount with the narrow type size.
11003 // Proceed with the transformation if the shift amount is the width
11004 // of the narrow type.
11005 unsigned ShiftAmt = ShiftAmtSrc->getZExtValue();
11006 if (ShiftAmt != NarrowVTSize)
11007 return SDValue();
11008
11009 // If the operation feeding into the MUL is a sign extend (sext),
11010 // we use mulhs. Othewise, zero extends (zext) use mulhu.
11011 unsigned MulhOpcode = IsSignExt ? ISD::MULHS : ISD::MULHU;
11012
11013 // Combine to mulh if mulh is legal/custom for the narrow type on the target
11014 // or if it is a vector type then we could transform to an acceptable type and
11015 // rely on legalization to split/combine the result.
11016 EVT TransformVT = NarrowVT;
11017 if (NarrowVT.isVector()) {
11018 TransformVT = TLI.getLegalTypeToTransformTo(Context&: *DAG.getContext(), VT: NarrowVT);
11019 if (TransformVT.getScalarType() != NarrowVT.getScalarType())
11020 return SDValue();
11021 }
11022 if (!TLI.isOperationLegalOrCustom(Op: MulhOpcode, VT: TransformVT))
11023 return SDValue();
11024
11025 SDValue Result =
11026 DAG.getNode(Opcode: MulhOpcode, DL, VT: NarrowVT, N1: LeftOp.getOperand(i: 0), N2: MulhRightOp);
11027 bool IsSigned = N->getOpcode() == ISD::SRA;
11028 return DAG.getExtOrTrunc(IsSigned, Op: Result, DL, VT: WideVT);
11029}
11030
11031// fold (bswap (logic_op(bswap(x),y))) -> logic_op(x,bswap(y))
11032// This helper function accept SDNode with opcode ISD::BSWAP and ISD::BITREVERSE
11033static SDValue foldBitOrderCrossLogicOp(SDNode *N, SelectionDAG &DAG) {
11034 unsigned Opcode = N->getOpcode();
11035 if (Opcode != ISD::BSWAP && Opcode != ISD::BITREVERSE)
11036 return SDValue();
11037
11038 SDValue N0 = N->getOperand(Num: 0);
11039 EVT VT = N->getValueType(ResNo: 0);
11040 SDLoc DL(N);
11041 SDValue X, Y;
11042
11043 // If both operands are bswap/bitreverse, ignore the multiuse
11044 if (sd_match(N: N0, P: m_OneUse(P: m_BitwiseLogic(L: m_UnaryOp(Opc: Opcode, Op: m_Value(N&: X)),
11045 R: m_UnaryOp(Opc: Opcode, Op: m_Value(N&: Y))))))
11046 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, N1: X, N2: Y);
11047
11048 // Otherwise need to ensure logic_op and bswap/bitreverse(x) have one use.
11049 if (sd_match(N: N0, P: m_OneUse(P: m_BitwiseLogic(
11050 L: m_OneUse(P: m_UnaryOp(Opc: Opcode, Op: m_Value(N&: X))), R: m_Value(N&: Y))))) {
11051 SDValue NewBitReorder = DAG.getNode(Opcode, DL, VT, Operand: Y);
11052 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, N1: X, N2: NewBitReorder);
11053 }
11054
11055 return SDValue();
11056}
11057
11058SDValue DAGCombiner::visitSRA(SDNode *N) {
11059 SDValue N0 = N->getOperand(Num: 0);
11060 SDValue N1 = N->getOperand(Num: 1);
11061 if (SDValue V = DAG.simplifyShift(X: N0, Y: N1))
11062 return V;
11063
11064 SDLoc DL(N);
11065 EVT VT = N0.getValueType();
11066 unsigned OpSizeInBits = VT.getScalarSizeInBits();
11067
11068 // fold (sra c1, c2) -> (sra c1, c2)
11069 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::SRA, DL, VT, Ops: {N0, N1}))
11070 return C;
11071
11072 // Arithmetic shifting an all-sign-bit value is a no-op.
11073 // fold (sra 0, x) -> 0
11074 // fold (sra -1, x) -> -1
11075 if (DAG.ComputeNumSignBits(Op: N0) == OpSizeInBits)
11076 return N0;
11077
11078 // fold vector ops
11079 if (VT.isVector())
11080 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
11081 return FoldedVOp;
11082
11083 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
11084 return NewSel;
11085
11086 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
11087
11088 // fold (sra (sra x, c1), c2) -> (sra x, (add c1, c2))
11089 // clamp (add c1, c2) to max shift.
11090 if (N0.getOpcode() == ISD::SRA) {
11091 EVT ShiftVT = N1.getValueType();
11092 EVT ShiftSVT = ShiftVT.getScalarType();
11093 SmallVector<SDValue, 16> ShiftValues;
11094
11095 auto SumOfShifts = [&](ConstantSDNode *LHS, ConstantSDNode *RHS) {
11096 APInt c1 = LHS->getAPIntValue();
11097 APInt c2 = RHS->getAPIntValue();
11098 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
11099 APInt Sum = c1 + c2;
11100 unsigned ShiftSum =
11101 Sum.uge(RHS: OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue();
11102 ShiftValues.push_back(Elt: DAG.getConstant(Val: ShiftSum, DL, VT: ShiftSVT));
11103 return true;
11104 };
11105 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: SumOfShifts)) {
11106 SDValue ShiftValue;
11107 if (N1.getOpcode() == ISD::BUILD_VECTOR)
11108 ShiftValue = DAG.getBuildVector(VT: ShiftVT, DL, Ops: ShiftValues);
11109 else if (N1.getOpcode() == ISD::SPLAT_VECTOR) {
11110 assert(ShiftValues.size() == 1 &&
11111 "Expected matchBinaryPredicate to return one element for "
11112 "SPLAT_VECTORs");
11113 ShiftValue = DAG.getSplatVector(VT: ShiftVT, DL, Op: ShiftValues[0]);
11114 } else
11115 ShiftValue = ShiftValues[0];
11116 return DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: N0.getOperand(i: 0), N2: ShiftValue);
11117 }
11118 }
11119
11120 // fold (sra (xor (sra x, c1), -1), c2) -> (xor (sra x, c3), -1)
11121 // This allows merging two arithmetic shifts even when there's a NOT in
11122 // between.
11123 SDValue X;
11124 APInt C1;
11125 if (N1C && sd_match(N: N0, P: m_OneUse(P: m_Not(
11126 V: m_OneUse(P: m_Sra(L: m_Value(N&: X), R: m_ConstInt(V&: C1))))))) {
11127 APInt C2 = N1C->getAPIntValue();
11128 zeroExtendToMatch(LHS&: C1, RHS&: C2, Offset: 1 /* Overflow Bit */);
11129 APInt Sum = C1 + C2;
11130 unsigned ShiftSum = Sum.getLimitedValue(Limit: OpSizeInBits - 1);
11131 SDValue NewShift = DAG.getNode(
11132 Opcode: ISD::SRA, DL, VT, N1: X, N2: DAG.getShiftAmountConstant(Val: ShiftSum, VT, DL));
11133 return DAG.getNOT(DL, Val: NewShift, VT);
11134 }
11135
11136 // fold (sra (shl X, m), (sub result_size, n))
11137 // -> (sign_extend (trunc (shl X, (sub (sub result_size, n), m)))) for
11138 // result_size - n != m.
11139 // If truncate is free for the target sext(shl) is likely to result in better
11140 // code.
11141 if (N0.getOpcode() == ISD::SHL && N1C) {
11142 // Get the two constants of the shifts, CN0 = m, CN = n.
11143 const ConstantSDNode *N01C = isConstOrConstSplat(N: N0.getOperand(i: 1));
11144 if (N01C) {
11145 LLVMContext &Ctx = *DAG.getContext();
11146 // Determine what the truncate's result bitsize and type would be.
11147 EVT TruncVT = VT.changeElementType(
11148 Context&: Ctx, EltVT: EVT::getIntegerVT(Context&: Ctx, BitWidth: OpSizeInBits - N1C->getZExtValue()));
11149
11150 // Determine the residual right-shift amount.
11151 int ShiftAmt = N1C->getZExtValue() - N01C->getZExtValue();
11152
11153 // If the shift is not a no-op (in which case this should be just a sign
11154 // extend already), the truncated to type is legal, sign_extend is legal
11155 // on that type, and the truncate to that type is both legal and free,
11156 // perform the transform.
11157 if ((ShiftAmt > 0) &&
11158 TLI.isOperationLegalOrCustom(Op: ISD::SIGN_EXTEND, VT: TruncVT) &&
11159 TLI.isOperationLegalOrCustom(Op: ISD::TRUNCATE, VT) &&
11160 TLI.isTruncateFree(FromVT: VT, ToVT: TruncVT)) {
11161 SDValue Amt = DAG.getShiftAmountConstant(Val: ShiftAmt, VT, DL);
11162 SDValue Shift = DAG.getNode(Opcode: ISD::SRL, DL, VT,
11163 N1: N0.getOperand(i: 0), N2: Amt);
11164 SDValue Trunc = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: TruncVT,
11165 Operand: Shift);
11166 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL,
11167 VT: N->getValueType(ResNo: 0), Operand: Trunc);
11168 }
11169 }
11170 }
11171
11172 // We convert trunc/ext to opposing shifts in IR, but casts may be cheaper.
11173 // sra (add (shl X, N1C), AddC), N1C -->
11174 // sext (add (trunc X to (width - N1C)), AddC')
11175 // sra (sub AddC, (shl X, N1C)), N1C -->
11176 // sext (sub AddC1',(trunc X to (width - N1C)))
11177 if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::SUB) && N1C &&
11178 N0.hasOneUse()) {
11179 bool IsAdd = N0.getOpcode() == ISD::ADD;
11180 SDValue Shl = N0.getOperand(i: IsAdd ? 0 : 1);
11181 if (Shl.getOpcode() == ISD::SHL && Shl.getOperand(i: 1) == N1 &&
11182 Shl.hasOneUse()) {
11183 // TODO: AddC does not need to be a splat.
11184 if (ConstantSDNode *AddC =
11185 isConstOrConstSplat(N: N0.getOperand(i: IsAdd ? 1 : 0))) {
11186 // Determine what the truncate's type would be and ask the target if
11187 // that is a free operation.
11188 LLVMContext &Ctx = *DAG.getContext();
11189 unsigned ShiftAmt = N1C->getZExtValue();
11190 EVT TruncVT = VT.changeElementType(
11191 Context&: Ctx, EltVT: EVT::getIntegerVT(Context&: Ctx, BitWidth: OpSizeInBits - ShiftAmt));
11192
11193 // TODO: The simple type check probably belongs in the default hook
11194 // implementation and/or target-specific overrides (because
11195 // non-simple types likely require masking when legalized), but
11196 // that restriction may conflict with other transforms.
11197 if (TruncVT.isSimple() && isTypeLegal(VT: TruncVT) &&
11198 TLI.isTruncateFree(FromVT: VT, ToVT: TruncVT)) {
11199 SDValue Trunc = DAG.getZExtOrTrunc(Op: Shl.getOperand(i: 0), DL, VT: TruncVT);
11200 SDValue ShiftC =
11201 DAG.getConstant(Val: AddC->getAPIntValue().lshr(shiftAmt: ShiftAmt).trunc(
11202 width: TruncVT.getScalarSizeInBits()),
11203 DL, VT: TruncVT);
11204 SDValue Add;
11205 if (IsAdd)
11206 Add = DAG.getNode(Opcode: ISD::ADD, DL, VT: TruncVT, N1: Trunc, N2: ShiftC);
11207 else
11208 Add = DAG.getNode(Opcode: ISD::SUB, DL, VT: TruncVT, N1: ShiftC, N2: Trunc);
11209 return DAG.getSExtOrTrunc(Op: Add, DL, VT);
11210 }
11211 }
11212 }
11213 }
11214
11215 // fold (sra x, (trunc (and y, c))) -> (sra x, (and (trunc y), (trunc c))).
11216 if (N1.getOpcode() == ISD::TRUNCATE &&
11217 N1.getOperand(i: 0).getOpcode() == ISD::AND) {
11218 if (SDValue NewOp1 = distributeTruncateThroughAnd(N: N1.getNode()))
11219 return DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: N0, N2: NewOp1);
11220 }
11221
11222 // fold (sra (trunc (sra x, c1)), c2) -> (trunc (sra x, c1 + c2))
11223 // fold (sra (trunc (srl x, c1)), c2) -> (trunc (sra x, c1 + c2))
11224 // if c1 is equal to the number of bits the trunc removes
11225 // TODO - support non-uniform vector shift amounts.
11226 if (N0.getOpcode() == ISD::TRUNCATE &&
11227 (N0.getOperand(i: 0).getOpcode() == ISD::SRL ||
11228 N0.getOperand(i: 0).getOpcode() == ISD::SRA) &&
11229 N0.getOperand(i: 0).hasOneUse() &&
11230 N0.getOperand(i: 0).getOperand(i: 1).hasOneUse() && N1C) {
11231 SDValue N0Op0 = N0.getOperand(i: 0);
11232 if (ConstantSDNode *LargeShift = isConstOrConstSplat(N: N0Op0.getOperand(i: 1))) {
11233 EVT LargeVT = N0Op0.getValueType();
11234 unsigned TruncBits = LargeVT.getScalarSizeInBits() - OpSizeInBits;
11235 if (LargeShift->getAPIntValue() == TruncBits) {
11236 EVT LargeShiftVT = getShiftAmountTy(LHSTy: LargeVT);
11237 SDValue Amt = DAG.getZExtOrTrunc(Op: N1, DL, VT: LargeShiftVT);
11238 Amt = DAG.getNode(Opcode: ISD::ADD, DL, VT: LargeShiftVT, N1: Amt,
11239 N2: DAG.getConstant(Val: TruncBits, DL, VT: LargeShiftVT));
11240 SDValue SRA =
11241 DAG.getNode(Opcode: ISD::SRA, DL, VT: LargeVT, N1: N0Op0.getOperand(i: 0), N2: Amt);
11242 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: SRA);
11243 }
11244 }
11245 }
11246
11247 // Simplify, based on bits shifted out of the LHS.
11248 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
11249 return SDValue(N, 0);
11250
11251 // If the sign bit is known to be zero, switch this to a SRL.
11252 if (DAG.SignBitIsZero(Op: N0))
11253 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0, N2: N1);
11254
11255 if (N1C && !N1C->isOpaque())
11256 if (SDValue NewSRA = visitShiftByConstant(N))
11257 return NewSRA;
11258
11259 // Try to transform this shift into a multiply-high if
11260 // it matches the appropriate pattern detected in combineShiftToMULH.
11261 if (SDValue MULH = combineShiftToMULH(N, DL, DAG, TLI))
11262 return MULH;
11263
11264 // Attempt to convert a sra of a load into a narrower sign-extending load.
11265 if (SDValue NarrowLoad = reduceLoadWidth(N))
11266 return NarrowLoad;
11267
11268 if (SDValue AVG = foldShiftToAvg(N, DL))
11269 return AVG;
11270
11271 return SDValue();
11272}
11273
11274SDValue DAGCombiner::visitSRL(SDNode *N) {
11275 SDValue N0 = N->getOperand(Num: 0);
11276 SDValue N1 = N->getOperand(Num: 1);
11277 if (SDValue V = DAG.simplifyShift(X: N0, Y: N1))
11278 return V;
11279
11280 SDLoc DL(N);
11281 EVT VT = N0.getValueType();
11282 EVT ShiftVT = N1.getValueType();
11283 unsigned OpSizeInBits = VT.getScalarSizeInBits();
11284
11285 // fold (srl c1, c2) -> c1 >>u c2
11286 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::SRL, DL, VT, Ops: {N0, N1}))
11287 return C;
11288
11289 // fold vector ops
11290 if (VT.isVector())
11291 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
11292 return FoldedVOp;
11293
11294 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
11295 return NewSel;
11296
11297 // if (srl x, c) is known to be zero, return 0
11298 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
11299 if (N1C &&
11300 DAG.MaskedValueIsZero(Op: SDValue(N, 0), Mask: APInt::getAllOnes(numBits: OpSizeInBits)))
11301 return DAG.getConstant(Val: 0, DL, VT);
11302
11303 // fold (srl (srl x, c1), c2) -> 0 or (srl x, (add c1, c2))
11304 if (N0.getOpcode() == ISD::SRL) {
11305 auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
11306 ConstantSDNode *RHS) {
11307 APInt c1 = LHS->getAPIntValue();
11308 APInt c2 = RHS->getAPIntValue();
11309 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
11310 return (c1 + c2).uge(RHS: OpSizeInBits);
11311 };
11312 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchOutOfRange))
11313 return DAG.getConstant(Val: 0, DL, VT);
11314
11315 auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
11316 ConstantSDNode *RHS) {
11317 APInt c1 = LHS->getAPIntValue();
11318 APInt c2 = RHS->getAPIntValue();
11319 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
11320 return (c1 + c2).ult(RHS: OpSizeInBits);
11321 };
11322 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchInRange)) {
11323 SDValue Sum = DAG.getNode(Opcode: ISD::ADD, DL, VT: ShiftVT, N1, N2: N0.getOperand(i: 1));
11324 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0.getOperand(i: 0), N2: Sum);
11325 }
11326 }
11327
11328 if (N1C && N0.getOpcode() == ISD::TRUNCATE &&
11329 N0.getOperand(i: 0).getOpcode() == ISD::SRL) {
11330 SDValue InnerShift = N0.getOperand(i: 0);
11331 // TODO - support non-uniform vector shift amounts.
11332 if (auto *N001C = isConstOrConstSplat(N: InnerShift.getOperand(i: 1))) {
11333 uint64_t c1 = N001C->getZExtValue();
11334 uint64_t c2 = N1C->getZExtValue();
11335 EVT InnerShiftVT = InnerShift.getValueType();
11336 EVT ShiftAmtVT = InnerShift.getOperand(i: 1).getValueType();
11337 uint64_t InnerShiftSize = InnerShiftVT.getScalarSizeInBits();
11338 // srl (trunc (srl x, c1)), c2 --> 0 or (trunc (srl x, (add c1, c2)))
11339 // This is only valid if the OpSizeInBits + c1 = size of inner shift.
11340 if (c1 + OpSizeInBits == InnerShiftSize) {
11341 if (c1 + c2 >= InnerShiftSize)
11342 return DAG.getConstant(Val: 0, DL, VT);
11343 SDValue NewShiftAmt = DAG.getConstant(Val: c1 + c2, DL, VT: ShiftAmtVT);
11344 SDValue NewShift = DAG.getNode(Opcode: ISD::SRL, DL, VT: InnerShiftVT,
11345 N1: InnerShift.getOperand(i: 0), N2: NewShiftAmt);
11346 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: NewShift);
11347 }
11348 // In the more general case, we can clear the high bits after the shift:
11349 // srl (trunc (srl x, c1)), c2 --> trunc (and (srl x, (c1+c2)), Mask)
11350 if (N0.hasOneUse() && InnerShift.hasOneUse() &&
11351 c1 + c2 < InnerShiftSize) {
11352 SDValue NewShiftAmt = DAG.getConstant(Val: c1 + c2, DL, VT: ShiftAmtVT);
11353 SDValue NewShift = DAG.getNode(Opcode: ISD::SRL, DL, VT: InnerShiftVT,
11354 N1: InnerShift.getOperand(i: 0), N2: NewShiftAmt);
11355 SDValue Mask = DAG.getConstant(Val: APInt::getLowBitsSet(numBits: InnerShiftSize,
11356 loBitsSet: OpSizeInBits - c2),
11357 DL, VT: InnerShiftVT);
11358 SDValue And = DAG.getNode(Opcode: ISD::AND, DL, VT: InnerShiftVT, N1: NewShift, N2: Mask);
11359 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: And);
11360 }
11361 }
11362 }
11363
11364 if (N0.getOpcode() == ISD::SHL) {
11365 // fold (srl (shl nuw x, c), c) -> x
11366 if (N0.getOperand(i: 1) == N1 && N0->getFlags().hasNoUnsignedWrap())
11367 return N0.getOperand(i: 0);
11368
11369 // fold (srl (shl x, c1), c2) -> (and (shl x, (sub c1, c2), MASK) or
11370 // (and (srl x, (sub c2, c1), MASK)
11371 if ((N0.getOperand(i: 1) == N1 || N0->hasOneUse()) &&
11372 TLI.shouldFoldConstantShiftPairToMask(N)) {
11373 auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
11374 ConstantSDNode *RHS) {
11375 const APInt &LHSC = LHS->getAPIntValue();
11376 const APInt &RHSC = RHS->getAPIntValue();
11377 return LHSC.ult(RHS: OpSizeInBits) && RHSC.ult(RHS: OpSizeInBits) &&
11378 LHSC.getZExtValue() <= RHSC.getZExtValue();
11379 };
11380 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchShiftAmount,
11381 /*AllowUndefs*/ false,
11382 /*AllowTypeMismatch*/ true)) {
11383 SDValue N01 = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1), DL, VT: ShiftVT);
11384 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftVT, N1: N01, N2: N1);
11385 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
11386 Mask = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Mask, N2: N01);
11387 Mask = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Mask, N2: Diff);
11388 SDValue Shift = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0.getOperand(i: 0), N2: Diff);
11389 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Shift, N2: Mask);
11390 }
11391 if (ISD::matchBinaryPredicate(LHS: N0.getOperand(i: 1), RHS: N1, Match: MatchShiftAmount,
11392 /*AllowUndefs*/ false,
11393 /*AllowTypeMismatch*/ true)) {
11394 SDValue N01 = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1), DL, VT: ShiftVT);
11395 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftVT, N1, N2: N01);
11396 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
11397 Mask = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Mask, N2: N1);
11398 SDValue Shift = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0.getOperand(i: 0), N2: Diff);
11399 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Shift, N2: Mask);
11400 }
11401 }
11402 }
11403
11404 // fold (srl (anyextend x), c) -> (and (anyextend (srl x, c)), mask)
11405 // TODO - support non-uniform vector shift amounts.
11406 if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
11407 // Shifting in all undef bits?
11408 EVT SmallVT = N0.getOperand(i: 0).getValueType();
11409 unsigned BitSize = SmallVT.getScalarSizeInBits();
11410 if (N1C->getAPIntValue().uge(RHS: BitSize))
11411 return DAG.getUNDEF(VT);
11412
11413 if (!LegalTypes || TLI.isTypeDesirableForOp(ISD::SRL, VT: SmallVT)) {
11414 uint64_t ShiftAmt = N1C->getZExtValue();
11415 SDLoc DL0(N0);
11416 SDValue SmallShift =
11417 DAG.getNode(Opcode: ISD::SRL, DL: DL0, VT: SmallVT, N1: N0.getOperand(i: 0),
11418 N2: DAG.getShiftAmountConstant(Val: ShiftAmt, VT: SmallVT, DL: DL0));
11419 AddToWorklist(N: SmallShift.getNode());
11420 APInt Mask = APInt::getLowBitsSet(numBits: OpSizeInBits, loBitsSet: OpSizeInBits - ShiftAmt);
11421 return DAG.getNode(Opcode: ISD::AND, DL, VT,
11422 N1: DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT, Operand: SmallShift),
11423 N2: DAG.getConstant(Val: Mask, DL, VT));
11424 }
11425 }
11426
11427 // fold (srl (sra X, Y), 31) -> (srl X, 31). This srl only looks at the sign
11428 // bit, which is unmodified by sra.
11429 if (N1C && N1C->getAPIntValue() == (OpSizeInBits - 1)) {
11430 if (N0.getOpcode() == ISD::SRA)
11431 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0.getOperand(i: 0), N2: N1);
11432 }
11433
11434 // fold (srl (ctlz x), "5") -> x iff x has one bit set (the low bit), and x has a power
11435 // of two bitwidth. The "5" represents (log2 (bitwidth x)).
11436 if (N1C && N0.getOpcode() == ISD::CTLZ &&
11437 isPowerOf2_32(Value: OpSizeInBits) &&
11438 N1C->getAPIntValue() == Log2_32(Value: OpSizeInBits)) {
11439 KnownBits Known = DAG.computeKnownBits(Op: N0.getOperand(i: 0));
11440
11441 // If any of the input bits are KnownOne, then the input couldn't be all
11442 // zeros, thus the result of the srl will always be zero.
11443 if (Known.One.getBoolValue()) return DAG.getConstant(Val: 0, DL: SDLoc(N0), VT);
11444
11445 // If all of the bits input the to ctlz node are known to be zero, then
11446 // the result of the ctlz is "32" and the result of the shift is one.
11447 APInt UnknownBits = ~Known.Zero;
11448 if (UnknownBits == 0) return DAG.getConstant(Val: 1, DL: SDLoc(N0), VT);
11449
11450 // Otherwise, check to see if there is exactly one bit input to the ctlz.
11451 if (UnknownBits.isPowerOf2()) {
11452 // Okay, we know that only that the single bit specified by UnknownBits
11453 // could be set on input to the CTLZ node. If this bit is set, the SRL
11454 // will return 0, if it is clear, it returns 1. Change the CTLZ/SRL pair
11455 // to an SRL/XOR pair, which is likely to simplify more.
11456 unsigned ShAmt = UnknownBits.countr_zero();
11457 SDValue Op = N0.getOperand(i: 0);
11458
11459 if (ShAmt) {
11460 SDLoc DL(N0);
11461 Op = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Op,
11462 N2: DAG.getShiftAmountConstant(Val: ShAmt, VT, DL));
11463 AddToWorklist(N: Op.getNode());
11464 }
11465 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1: Op, N2: DAG.getConstant(Val: 1, DL, VT));
11466 }
11467 }
11468
11469 // fold (srl x, (trunc (and y, c))) -> (srl x, (and (trunc y), (trunc c))).
11470 if (N1.getOpcode() == ISD::TRUNCATE &&
11471 N1.getOperand(i: 0).getOpcode() == ISD::AND) {
11472 if (SDValue NewOp1 = distributeTruncateThroughAnd(N: N1.getNode()))
11473 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0, N2: NewOp1);
11474 }
11475
11476 // fold (srl (logic_op x, (shl (zext y), c1)), c1)
11477 // -> (logic_op (srl x, c1), (zext y))
11478 // c1 <= leadingzeros(zext(y))
11479 SDValue X, ZExtY;
11480 if (N1C && sd_match(N: N0, P: m_OneUse(P: m_BitwiseLogic(
11481 L: m_Value(N&: X),
11482 R: m_OneUse(P: m_Shl(L: m_AllOf(preds: m_Value(N&: ZExtY),
11483 preds: m_Opc(Opcode: ISD::ZERO_EXTEND)),
11484 R: m_Specific(N: N1))))))) {
11485 unsigned NumLeadingZeros = ZExtY.getScalarValueSizeInBits() -
11486 ZExtY.getOperand(i: 0).getScalarValueSizeInBits();
11487 if (N1C->getZExtValue() <= NumLeadingZeros)
11488 return DAG.getNode(Opcode: N0.getOpcode(), DL: SDLoc(N0), VT,
11489 N1: DAG.getNode(Opcode: ISD::SRL, DL: SDLoc(N0), VT, N1: X, N2: N1), N2: ZExtY);
11490 }
11491
11492 // fold operands of srl based on knowledge that the low bits are not
11493 // demanded.
11494 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
11495 return SDValue(N, 0);
11496
11497 if (N1C && !N1C->isOpaque())
11498 if (SDValue NewSRL = visitShiftByConstant(N))
11499 return NewSRL;
11500
11501 // Attempt to convert a srl of a load into a narrower zero-extending load.
11502 if (SDValue NarrowLoad = reduceLoadWidth(N))
11503 return NarrowLoad;
11504
11505 // Here is a common situation. We want to optimize:
11506 //
11507 // %a = ...
11508 // %b = and i32 %a, 2
11509 // %c = srl i32 %b, 1
11510 // brcond i32 %c ...
11511 //
11512 // into
11513 //
11514 // %a = ...
11515 // %b = and %a, 2
11516 // %c = setcc eq %b, 0
11517 // brcond %c ...
11518 //
11519 // However when after the source operand of SRL is optimized into AND, the SRL
11520 // itself may not be optimized further. Look for it and add the BRCOND into
11521 // the worklist.
11522 //
11523 // The also tends to happen for binary operations when SimplifyDemandedBits
11524 // is involved.
11525 //
11526 // FIXME: This is unecessary if we process the DAG in topological order,
11527 // which we plan to do. This workaround can be removed once the DAG is
11528 // processed in topological order.
11529 if (N->hasOneUse()) {
11530 SDNode *User = *N->user_begin();
11531
11532 // Look pass the truncate.
11533 if (User->getOpcode() == ISD::TRUNCATE && User->hasOneUse())
11534 User = *User->user_begin();
11535
11536 if (User->getOpcode() == ISD::BRCOND || User->getOpcode() == ISD::AND ||
11537 User->getOpcode() == ISD::OR || User->getOpcode() == ISD::XOR)
11538 AddToWorklist(N: User);
11539 }
11540
11541 // Try to transform this shift into a multiply-high if
11542 // it matches the appropriate pattern detected in combineShiftToMULH.
11543 if (SDValue MULH = combineShiftToMULH(N, DL, DAG, TLI))
11544 return MULH;
11545
11546 if (SDValue AVG = foldShiftToAvg(N, DL))
11547 return AVG;
11548
11549 SDValue Y;
11550 if (VT.getScalarSizeInBits() % 2 == 0 && N1C) {
11551 // Fold clmul(zext(x), zext(y)) >> (BW - 1 | BW) -> clmul(r|h)(x, y).
11552 unsigned HalfBW = VT.getScalarSizeInBits() / 2;
11553 if (sd_match(N: N0, P: m_Clmul(L: m_ZExt(Op: m_Value(N&: X)), R: m_ZExt(Op: m_Value(N&: Y)))) &&
11554 X.getScalarValueSizeInBits() == HalfBW &&
11555 Y.getScalarValueSizeInBits() == HalfBW) {
11556 if (N1C->getZExtValue() == HalfBW - 1 &&
11557 (!LegalOperations ||
11558 TLI.isOperationLegalOrCustom(Op: ISD::CLMULR, VT: X.getValueType())))
11559 return DAG.getNode(
11560 Opcode: ISD::ZERO_EXTEND, DL, VT,
11561 Operand: DAG.getNode(Opcode: ISD::CLMULR, DL, VT: X.getValueType(), N1: X, N2: Y));
11562 if (N1C->getZExtValue() == HalfBW &&
11563 (!LegalOperations ||
11564 TLI.isOperationLegalOrCustom(Op: ISD::CLMULH, VT: X.getValueType())))
11565 return DAG.getNode(
11566 Opcode: ISD::ZERO_EXTEND, DL, VT,
11567 Operand: DAG.getNode(Opcode: ISD::CLMULH, DL, VT: X.getValueType(), N1: X, N2: Y));
11568 }
11569 }
11570
11571 // Fold bitreverse(clmul(bitreverse(x), bitreverse(y))) >> 1 ->
11572 // clmulh(x, y).
11573 if (N1C && N1C->getZExtValue() == 1 &&
11574 sd_match(N: N0, P: m_BitReverse(Op: m_Clmul(L: m_BitReverse(Op: m_Value(N&: X)),
11575 R: m_BitReverse(Op: m_Value(N&: Y))))))
11576 return DAG.getNode(Opcode: ISD::CLMULH, DL, VT, N1: X, N2: Y);
11577
11578 return SDValue();
11579}
11580
11581SDValue DAGCombiner::visitFunnelShift(SDNode *N) {
11582 EVT VT = N->getValueType(ResNo: 0);
11583 SDValue N0 = N->getOperand(Num: 0);
11584 SDValue N1 = N->getOperand(Num: 1);
11585 SDValue N2 = N->getOperand(Num: 2);
11586 bool IsFSHL = N->getOpcode() == ISD::FSHL;
11587 unsigned BitWidth = VT.getScalarSizeInBits();
11588 SDLoc DL(N);
11589
11590 // fold (fshl/fshr C0, C1, C2) -> C3
11591 if (SDValue C =
11592 DAG.FoldConstantArithmetic(Opcode: N->getOpcode(), DL, VT, Ops: {N0, N1, N2}))
11593 return C;
11594
11595 // fold (fshl N0, N1, 0) -> N0
11596 // fold (fshr N0, N1, 0) -> N1
11597 if (isPowerOf2_32(Value: BitWidth))
11598 if (DAG.MaskedValueIsZero(
11599 Op: N2, Mask: APInt(N2.getScalarValueSizeInBits(), BitWidth - 1)))
11600 return IsFSHL ? N0 : N1;
11601
11602 auto IsUndefOrZero = [](SDValue V) {
11603 return V.isUndef() || isNullOrNullSplat(V, /*AllowUndefs*/ true);
11604 };
11605
11606 // TODO - support non-uniform vector shift amounts.
11607 if (ConstantSDNode *Cst = isConstOrConstSplat(N: N2)) {
11608 EVT ShAmtTy = N2.getValueType();
11609
11610 // fold (fsh* N0, N1, c) -> (fsh* N0, N1, c % BitWidth)
11611 if (Cst->getAPIntValue().uge(RHS: BitWidth)) {
11612 uint64_t RotAmt = Cst->getAPIntValue().urem(RHS: BitWidth);
11613 return DAG.getNode(Opcode: N->getOpcode(), DL, VT, N1: N0, N2: N1,
11614 N3: DAG.getConstant(Val: RotAmt, DL, VT: ShAmtTy));
11615 }
11616
11617 unsigned ShAmt = Cst->getZExtValue();
11618 if (ShAmt == 0)
11619 return IsFSHL ? N0 : N1;
11620
11621 // fold fshl(undef_or_zero, N1, C) -> lshr(N1, BW-C)
11622 // fold fshr(undef_or_zero, N1, C) -> lshr(N1, C)
11623 // fold fshl(N0, undef_or_zero, C) -> shl(N0, C)
11624 // fold fshr(N0, undef_or_zero, C) -> shl(N0, BW-C)
11625 if (IsUndefOrZero(N0))
11626 return DAG.getNode(
11627 Opcode: ISD::SRL, DL, VT, N1,
11628 N2: DAG.getConstant(Val: IsFSHL ? BitWidth - ShAmt : ShAmt, DL, VT: ShAmtTy));
11629 if (IsUndefOrZero(N1))
11630 return DAG.getNode(
11631 Opcode: ISD::SHL, DL, VT, N1: N0,
11632 N2: DAG.getConstant(Val: IsFSHL ? ShAmt : BitWidth - ShAmt, DL, VT: ShAmtTy));
11633
11634 // fold (fshl ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive.
11635 // fold (fshr ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive.
11636 // TODO - bigendian support once we have test coverage.
11637 // TODO - can we merge this with CombineConseutiveLoads/MatchLoadCombine?
11638 // TODO - permit LHS EXTLOAD if extensions are shifted out.
11639 if ((BitWidth % 8) == 0 && (ShAmt % 8) == 0 && !VT.isVector() &&
11640 !DAG.getDataLayout().isBigEndian()) {
11641 auto *LHS = dyn_cast<LoadSDNode>(Val&: N0);
11642 auto *RHS = dyn_cast<LoadSDNode>(Val&: N1);
11643 if (LHS && RHS && LHS->isSimple() && RHS->isSimple() &&
11644 LHS->getAddressSpace() == RHS->getAddressSpace() &&
11645 (LHS->hasNUsesOfValue(NUses: 1, Value: 0) || RHS->hasNUsesOfValue(NUses: 1, Value: 0)) &&
11646 ISD::isNON_EXTLoad(N: RHS) && ISD::isNON_EXTLoad(N: LHS)) {
11647 if (DAG.areNonVolatileConsecutiveLoads(LD: LHS, Base: RHS, Bytes: BitWidth / 8, Dist: 1)) {
11648 SDLoc DL(RHS);
11649 uint64_t PtrOff =
11650 IsFSHL ? (((BitWidth - ShAmt) % BitWidth) / 8) : (ShAmt / 8);
11651 Align NewAlign = commonAlignment(A: RHS->getAlign(), Offset: PtrOff);
11652 unsigned Fast = 0;
11653 if (TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT,
11654 AddrSpace: RHS->getAddressSpace(), Alignment: NewAlign,
11655 Flags: RHS->getMemOperand()->getFlags(), Fast: &Fast) &&
11656 Fast) {
11657 SDValue NewPtr = DAG.getMemBasePlusOffset(
11658 Base: RHS->getBasePtr(), Offset: TypeSize::getFixed(ExactSize: PtrOff), DL);
11659 AddToWorklist(N: NewPtr.getNode());
11660 SDValue Load = DAG.getLoad(
11661 VT, dl: DL, Chain: RHS->getChain(), Ptr: NewPtr,
11662 PtrInfo: RHS->getPointerInfo().getWithOffset(O: PtrOff), Alignment: NewAlign,
11663 MMOFlags: RHS->getMemOperand()->getFlags(), AAInfo: RHS->getAAInfo());
11664 DAG.makeEquivalentMemoryOrdering(OldLoad: LHS, NewMemOp: Load.getValue(R: 1));
11665 DAG.makeEquivalentMemoryOrdering(OldLoad: RHS, NewMemOp: Load.getValue(R: 1));
11666 return Load;
11667 }
11668 }
11669 }
11670 }
11671 }
11672
11673 // fold fshr(undef_or_zero, N1, N2) -> lshr(N1, N2)
11674 // fold fshl(N0, undef_or_zero, N2) -> shl(N0, N2)
11675 // iff We know the shift amount is in range.
11676 // TODO: when is it worth doing SUB(BW, N2) as well?
11677 if (isPowerOf2_32(Value: BitWidth)) {
11678 APInt ModuloBits(N2.getScalarValueSizeInBits(), BitWidth - 1);
11679 if (IsUndefOrZero(N0) && !IsFSHL && DAG.MaskedValueIsZero(Op: N2, Mask: ~ModuloBits))
11680 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1, N2);
11681 if (IsUndefOrZero(N1) && IsFSHL && DAG.MaskedValueIsZero(Op: N2, Mask: ~ModuloBits))
11682 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0, N2);
11683 }
11684
11685 // fold (fshl N0, N0, N2) -> (rotl N0, N2)
11686 // fold (fshr N0, N0, N2) -> (rotr N0, N2)
11687 // TODO: Investigate flipping this rotate if only one is legal.
11688 // If funnel shift is legal as well we might be better off avoiding
11689 // non-constant (BW - N2).
11690 unsigned RotOpc = IsFSHL ? ISD::ROTL : ISD::ROTR;
11691 if (N0 == N1 && hasOperation(Opcode: RotOpc, VT))
11692 return DAG.getNode(Opcode: RotOpc, DL, VT, N1: N0, N2);
11693
11694 // Simplify, based on bits shifted out of N0/N1.
11695 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
11696 return SDValue(N, 0);
11697
11698 return SDValue();
11699}
11700
11701SDValue DAGCombiner::visitSHLSAT(SDNode *N) {
11702 SDValue N0 = N->getOperand(Num: 0);
11703 SDValue N1 = N->getOperand(Num: 1);
11704 if (SDValue V = DAG.simplifyShift(X: N0, Y: N1))
11705 return V;
11706
11707 SDLoc DL(N);
11708 EVT VT = N0.getValueType();
11709
11710 // fold (*shlsat c1, c2) -> c1<<c2
11711 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: N->getOpcode(), DL, VT, Ops: {N0, N1}))
11712 return C;
11713
11714 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
11715
11716 if (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::SHL, VT)) {
11717 // fold (sshlsat x, c) -> (shl x, c)
11718 if (N->getOpcode() == ISD::SSHLSAT && N1C &&
11719 N1C->getAPIntValue().ult(RHS: DAG.ComputeNumSignBits(Op: N0)))
11720 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0, N2: N1);
11721
11722 // fold (ushlsat x, c) -> (shl x, c)
11723 if (N->getOpcode() == ISD::USHLSAT && N1C &&
11724 N1C->getAPIntValue().ule(
11725 RHS: DAG.computeKnownBits(Op: N0).countMinLeadingZeros()))
11726 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0, N2: N1);
11727 }
11728
11729 return SDValue();
11730}
11731
11732// Given a ABS node, detect the following patterns:
11733// (ABS (SUB (EXTEND a), (EXTEND b))).
11734// (TRUNC (ABS (SUB (EXTEND a), (EXTEND b)))).
11735// Generates UABD/SABD instruction.
11736SDValue DAGCombiner::foldABSToABD(SDNode *N, const SDLoc &DL) {
11737 EVT SrcVT = N->getValueType(ResNo: 0);
11738
11739 if (N->getOpcode() == ISD::TRUNCATE)
11740 N = N->getOperand(Num: 0).getNode();
11741
11742 EVT VT = N->getValueType(ResNo: 0);
11743 SDValue Op0, Op1;
11744
11745 if (!sd_match(N, P: m_Abs(Op: m_AnyOf(preds: m_Sub(L: m_Value(N&: Op0), R: m_Value(N&: Op1)),
11746 preds: m_Add(L: m_Value(N&: Op0), R: m_Value(N&: Op1))))))
11747 return SDValue();
11748
11749 SDValue AbsOp0 = N->getOperand(Num: 0);
11750 bool IsAdd = AbsOp0.getOpcode() == ISD::ADD;
11751 // Make sure (neg B) is positive.
11752 if (IsAdd) {
11753 // Elements of Op1 must be constant and != VT.minSignedValue() (or undef)
11754 std::function<bool(ConstantSDNode *)> IsNotMinSignedInt =
11755 [VT](ConstantSDNode *C) {
11756 if (C == nullptr)
11757 return true;
11758 return !C->getAPIntValue()
11759 .trunc(width: VT.getScalarSizeInBits())
11760 .isMinSignedValue();
11761 };
11762
11763 if (!ISD::matchUnaryPredicate(Op: Op1, Match: IsNotMinSignedInt, /*AllowUndefs=*/true,
11764 /*AllowTruncation=*/true))
11765 return SDValue();
11766 }
11767
11768 unsigned Opc0 = Op0.getOpcode();
11769
11770 // Check if the operands of the sub are (zero|sign)-extended, otherwise
11771 // fallback to ValueTracking.
11772 if (Opc0 != Op1.getOpcode() ||
11773 (Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND &&
11774 Opc0 != ISD::SIGN_EXTEND_INREG)) {
11775 // fold (abs (sub nsw x, y)) -> abds(x, y)
11776 // fold (abs (add nsw x, -y)) -> abds(x, y)
11777 bool AbsOpWillNSW =
11778 AbsOp0->getFlags().hasNoSignedWrap() ||
11779 (IsAdd ? DAG.willNotOverflowAdd(/*IsSigned=*/true, N0: Op0, N1: Op1)
11780 : DAG.willNotOverflowSub(/*IsSigned=*/true, N0: Op0, N1: Op1));
11781
11782 // Don't fold this for unsupported types as we lose the NSW handling.
11783 if (hasOperation(Opcode: ISD::ABDS, VT) && TLI.preferABDSToABSWithNSW(VT) &&
11784 AbsOpWillNSW) {
11785 if (IsAdd)
11786 Op1 = DAG.getNegative(Val: Op1, DL: SDLoc(Op1), VT);
11787 SDValue ABD = DAG.getNode(Opcode: ISD::ABDS, DL, VT, N1: Op0, N2: Op1);
11788 return DAG.getZExtOrTrunc(Op: ABD, DL, VT: SrcVT);
11789 }
11790 // fold (abs (sub x, y)) -> abdu(x, y)
11791 if (hasOperation(Opcode: ISD::ABDU, VT) && DAG.SignBitIsZero(Op: Op0) &&
11792 DAG.SignBitIsZero(Op: Op1)) {
11793 if (IsAdd)
11794 Op1 = DAG.getNegative(Val: Op1, DL: SDLoc(Op1), VT);
11795 SDValue ABD = DAG.getNode(Opcode: ISD::ABDU, DL, VT, N1: Op0, N2: Op1);
11796 return DAG.getZExtOrTrunc(Op: ABD, DL, VT: SrcVT);
11797 }
11798 return SDValue();
11799 }
11800
11801 // The IsAdd case explicitly checks for const/bv-of-const. This implies either
11802 // (Opc0 != Op1.getOpcode() || Opc0 is not in {zext/sext/sign_ext_inreg}. This
11803 // implies it was alrady handled by the above if statement.
11804 assert(!IsAdd && "Unexpected abs(add(x,y)) pattern");
11805
11806 EVT VT0, VT1;
11807 if (Opc0 == ISD::SIGN_EXTEND_INREG) {
11808 VT0 = cast<VTSDNode>(Val: Op0.getOperand(i: 1))->getVT();
11809 VT1 = cast<VTSDNode>(Val: Op1.getOperand(i: 1))->getVT();
11810 } else {
11811 VT0 = Op0.getOperand(i: 0).getValueType();
11812 VT1 = Op1.getOperand(i: 0).getValueType();
11813 }
11814 unsigned ABDOpcode = (Opc0 == ISD::ZERO_EXTEND) ? ISD::ABDU : ISD::ABDS;
11815
11816 // fold abs(sext(x) - sext(y)) -> zext(abds(x, y))
11817 // fold abs(zext(x) - zext(y)) -> zext(abdu(x, y))
11818 EVT MaxVT = VT0.bitsGT(VT: VT1) ? VT0 : VT1;
11819 if ((VT0 == MaxVT || Op0->hasOneUse()) &&
11820 (VT1 == MaxVT || Op1->hasOneUse()) &&
11821 (!LegalTypes || hasOperation(Opcode: ABDOpcode, VT: MaxVT))) {
11822 SDValue ABD = DAG.getNode(Opcode: ABDOpcode, DL, VT: MaxVT,
11823 N1: DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MaxVT, Operand: Op0),
11824 N2: DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MaxVT, Operand: Op1));
11825 ABD = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: ABD);
11826 return DAG.getZExtOrTrunc(Op: ABD, DL, VT: SrcVT);
11827 }
11828
11829 // fold abs(sext(x) - sext(y)) -> abds(sext(x), sext(y))
11830 // fold abs(zext(x) - zext(y)) -> abdu(zext(x), zext(y))
11831 if (!LegalOperations || hasOperation(Opcode: ABDOpcode, VT)) {
11832 SDValue ABD = DAG.getNode(Opcode: ABDOpcode, DL, VT, N1: Op0, N2: Op1);
11833 return DAG.getZExtOrTrunc(Op: ABD, DL, VT: SrcVT);
11834 }
11835
11836 return SDValue();
11837}
11838
11839SDValue DAGCombiner::visitABS(SDNode *N) {
11840 SDValue N0 = N->getOperand(Num: 0);
11841 EVT VT = N->getValueType(ResNo: 0);
11842 SDLoc DL(N);
11843
11844 // fold (abs c1) -> c2
11845 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::ABS, DL, VT, Ops: {N0}))
11846 return C;
11847 // fold (abs (abs x)) -> (abs x)
11848 if (N0.getOpcode() == ISD::ABS)
11849 return N0;
11850 // fold (abs x) -> x iff not-negative
11851 if (DAG.SignBitIsZero(Op: N0))
11852 return N0;
11853
11854 if (SDValue ABD = foldABSToABD(N, DL))
11855 return ABD;
11856
11857 // fold (abs (sign_extend_inreg x)) -> (zero_extend (abs (truncate x)))
11858 // iff zero_extend/truncate are free.
11859 if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG) {
11860 EVT ExtVT = cast<VTSDNode>(Val: N0.getOperand(i: 1))->getVT();
11861 if (TLI.isTruncateFree(FromVT: VT, ToVT: ExtVT) && TLI.isZExtFree(FromTy: ExtVT, ToTy: VT) &&
11862 TLI.isTypeDesirableForOp(ISD::ABS, VT: ExtVT) &&
11863 hasOperation(Opcode: ISD::ABS, VT: ExtVT)) {
11864 return DAG.getNode(
11865 Opcode: ISD::ZERO_EXTEND, DL, VT,
11866 Operand: DAG.getNode(Opcode: ISD::ABS, DL, VT: ExtVT,
11867 Operand: DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: ExtVT, Operand: N0.getOperand(i: 0))));
11868 }
11869 }
11870
11871 return SDValue();
11872}
11873
11874SDValue DAGCombiner::visitCLMUL(SDNode *N) {
11875 unsigned Opcode = N->getOpcode();
11876 SDValue N0 = N->getOperand(Num: 0);
11877 SDValue N1 = N->getOperand(Num: 1);
11878 EVT VT = N->getValueType(ResNo: 0);
11879 SDLoc DL(N);
11880
11881 // fold (clmul c1, c2)
11882 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, Ops: {N0, N1}))
11883 return C;
11884
11885 // canonicalize constant to RHS
11886 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
11887 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
11888 return DAG.getNode(Opcode, DL, VT, N1, N2: N0);
11889
11890 // fold (clmul x, 0) -> 0
11891 if (isNullConstant(V: N1) || ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
11892 return DAG.getConstant(Val: 0, DL, VT);
11893
11894 return SDValue();
11895}
11896
11897SDValue DAGCombiner::visitBSWAP(SDNode *N) {
11898 SDValue N0 = N->getOperand(Num: 0);
11899 EVT VT = N->getValueType(ResNo: 0);
11900 SDLoc DL(N);
11901
11902 // fold (bswap c1) -> c2
11903 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::BSWAP, DL, VT, Ops: {N0}))
11904 return C;
11905 // fold (bswap (bswap x)) -> x
11906 if (N0.getOpcode() == ISD::BSWAP)
11907 return N0.getOperand(i: 0);
11908
11909 // Canonicalize bswap(bitreverse(x)) -> bitreverse(bswap(x)). If bitreverse
11910 // isn't supported, it will be expanded to bswap followed by a manual reversal
11911 // of bits in each byte. By placing bswaps before bitreverse, we can remove
11912 // the two bswaps if the bitreverse gets expanded.
11913 if (N0.getOpcode() == ISD::BITREVERSE && N0.hasOneUse()) {
11914 SDValue BSwap = DAG.getNode(Opcode: ISD::BSWAP, DL, VT, Operand: N0.getOperand(i: 0));
11915 return DAG.getNode(Opcode: ISD::BITREVERSE, DL, VT, Operand: BSwap);
11916 }
11917
11918 // fold (bswap shl(x,c)) -> (zext(bswap(trunc(shl(x,sub(c,bw/2))))))
11919 // iff x >= bw/2 (i.e. lower half is known zero)
11920 unsigned BW = VT.getScalarSizeInBits();
11921 if (BW >= 32 && N0.getOpcode() == ISD::SHL && N0.hasOneUse()) {
11922 auto *ShAmt = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
11923 EVT HalfVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: BW / 2);
11924 if (ShAmt && ShAmt->getAPIntValue().ult(RHS: BW) &&
11925 ShAmt->getZExtValue() >= (BW / 2) &&
11926 (ShAmt->getZExtValue() % 16) == 0 && TLI.isTypeLegal(VT: HalfVT) &&
11927 TLI.isTruncateFree(FromVT: VT, ToVT: HalfVT) &&
11928 (!LegalOperations || hasOperation(Opcode: ISD::BSWAP, VT: HalfVT))) {
11929 SDValue Res = N0.getOperand(i: 0);
11930 if (uint64_t NewShAmt = (ShAmt->getZExtValue() - (BW / 2)))
11931 Res = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Res,
11932 N2: DAG.getShiftAmountConstant(Val: NewShAmt, VT, DL));
11933 Res = DAG.getZExtOrTrunc(Op: Res, DL, VT: HalfVT);
11934 Res = DAG.getNode(Opcode: ISD::BSWAP, DL, VT: HalfVT, Operand: Res);
11935 return DAG.getZExtOrTrunc(Op: Res, DL, VT);
11936 }
11937 }
11938
11939 // Try to canonicalize bswap-of-logical-shift-by-8-bit-multiple as
11940 // inverse-shift-of-bswap:
11941 // bswap (X u<< C) --> (bswap X) u>> C
11942 // bswap (X u>> C) --> (bswap X) u<< C
11943 if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) &&
11944 N0.hasOneUse()) {
11945 auto *ShAmt = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
11946 if (ShAmt && ShAmt->getAPIntValue().ult(RHS: BW) &&
11947 ShAmt->getZExtValue() % 8 == 0) {
11948 SDValue NewSwap = DAG.getNode(Opcode: ISD::BSWAP, DL, VT, Operand: N0.getOperand(i: 0));
11949 unsigned InverseShift = N0.getOpcode() == ISD::SHL ? ISD::SRL : ISD::SHL;
11950 return DAG.getNode(Opcode: InverseShift, DL, VT, N1: NewSwap, N2: N0.getOperand(i: 1));
11951 }
11952 }
11953
11954 if (SDValue V = foldBitOrderCrossLogicOp(N, DAG))
11955 return V;
11956
11957 return SDValue();
11958}
11959
11960SDValue DAGCombiner::visitBITREVERSE(SDNode *N) {
11961 SDValue N0 = N->getOperand(Num: 0);
11962 EVT VT = N->getValueType(ResNo: 0);
11963 SDLoc DL(N);
11964
11965 // fold (bitreverse c1) -> c2
11966 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::BITREVERSE, DL, VT, Ops: {N0}))
11967 return C;
11968
11969 // fold (bitreverse (bitreverse x)) -> x
11970 if (N0.getOpcode() == ISD::BITREVERSE)
11971 return N0.getOperand(i: 0);
11972
11973 SDValue X, Y;
11974
11975 // fold (bitreverse (lshr (bitreverse x), y)) -> (shl x, y)
11976 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::SHL, VT)) &&
11977 sd_match(N: N0, P: m_Srl(L: m_BitReverse(Op: m_Value(N&: X)), R: m_Value(N&: Y))))
11978 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: X, N2: Y);
11979
11980 // fold (bitreverse (shl (bitreverse x), y)) -> (lshr x, y)
11981 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::SRL, VT)) &&
11982 sd_match(N: N0, P: m_Shl(L: m_BitReverse(Op: m_Value(N&: X)), R: m_Value(N&: Y))))
11983 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: X, N2: Y);
11984
11985 // fold bitreverse(clmul(bitreverse(x), bitreverse(y))) -> clmulr(x, y)
11986 if (sd_match(N: N0, P: m_Clmul(L: m_BitReverse(Op: m_Value(N&: X)), R: m_BitReverse(Op: m_Value(N&: Y)))))
11987 return DAG.getNode(Opcode: ISD::CLMULR, DL, VT, N1: X, N2: Y);
11988
11989 return SDValue();
11990}
11991
11992// Fold (ctlz (xor x, (sra x, bitwidth-1))) -> (add (ctls x), 1).
11993// Fold (ctlz (or (shl (xor x, (sra x, bitwidth-1)), 1), 1) -> (ctls x)
11994SDValue DAGCombiner::foldCTLZToCTLS(SDValue Src, const SDLoc &DL) {
11995 EVT VT = Src.getValueType();
11996
11997 auto LK = TLI.getTypeConversion(Context&: *DAG.getContext(), VT);
11998 if ((LK.first != TargetLoweringBase::TypeLegal &&
11999 LK.first != TargetLoweringBase::TypePromoteInteger) ||
12000 !TLI.isOperationLegalOrCustom(Op: ISD::CTLS, VT: LK.second))
12001 return SDValue();
12002
12003 unsigned BitWidth = VT.getScalarSizeInBits();
12004
12005 bool NeedAdd = true;
12006
12007 SDValue X;
12008 if (sd_match(N: Src, P: m_OneUse(P: m_Or(L: m_OneUse(P: m_Shl(L: m_Value(N&: X), R: m_SpecificInt(V: 1))),
12009 R: m_SpecificInt(V: 1))))) {
12010 NeedAdd = false;
12011 Src = X;
12012 }
12013
12014 if (!sd_match(N: Src,
12015 P: m_OneUse(P: m_Xor(L: m_Value(N&: X),
12016 R: m_OneUse(P: m_Sra(L: m_Deferred(V&: X),
12017 R: m_SpecificInt(V: BitWidth - 1)))))))
12018 return SDValue();
12019
12020 SDValue Res = DAG.getNode(Opcode: ISD::CTLS, DL, VT, Operand: X);
12021 if (!NeedAdd)
12022 return Res;
12023
12024 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Res, N2: DAG.getConstant(Val: 1, DL, VT));
12025}
12026
12027SDValue DAGCombiner::visitCTLZ(SDNode *N) {
12028 SDValue N0 = N->getOperand(Num: 0);
12029 EVT VT = N->getValueType(ResNo: 0);
12030 SDLoc DL(N);
12031
12032 // fold (ctlz c1) -> c2
12033 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::CTLZ, DL, VT, Ops: {N0}))
12034 return C;
12035
12036 // If the value is known never to be zero, switch to the undef version.
12037 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::CTLZ_ZERO_UNDEF, VT))
12038 if (DAG.isKnownNeverZero(Op: N0))
12039 return DAG.getNode(Opcode: ISD::CTLZ_ZERO_UNDEF, DL, VT, Operand: N0);
12040
12041 if (SDValue V = foldCTLZToCTLS(Src: N0, DL))
12042 return V;
12043
12044 return SDValue();
12045}
12046
12047SDValue DAGCombiner::visitCTLZ_ZERO_UNDEF(SDNode *N) {
12048 SDValue N0 = N->getOperand(Num: 0);
12049 EVT VT = N->getValueType(ResNo: 0);
12050 SDLoc DL(N);
12051
12052 // fold (ctlz_zero_undef c1) -> c2
12053 if (SDValue C =
12054 DAG.FoldConstantArithmetic(Opcode: ISD::CTLZ_ZERO_UNDEF, DL, VT, Ops: {N0}))
12055 return C;
12056
12057 if (SDValue V = foldCTLZToCTLS(Src: N0, DL))
12058 return V;
12059
12060 return SDValue();
12061}
12062
12063SDValue DAGCombiner::visitCTTZ(SDNode *N) {
12064 SDValue N0 = N->getOperand(Num: 0);
12065 EVT VT = N->getValueType(ResNo: 0);
12066 SDLoc DL(N);
12067
12068 // fold (cttz c1) -> c2
12069 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::CTTZ, DL, VT, Ops: {N0}))
12070 return C;
12071
12072 // If the value is known never to be zero, switch to the undef version.
12073 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::CTTZ_ZERO_UNDEF, VT))
12074 if (DAG.isKnownNeverZero(Op: N0))
12075 return DAG.getNode(Opcode: ISD::CTTZ_ZERO_UNDEF, DL, VT, Operand: N0);
12076
12077 return SDValue();
12078}
12079
12080SDValue DAGCombiner::visitCTTZ_ZERO_UNDEF(SDNode *N) {
12081 SDValue N0 = N->getOperand(Num: 0);
12082 EVT VT = N->getValueType(ResNo: 0);
12083 SDLoc DL(N);
12084
12085 // fold (cttz_zero_undef c1) -> c2
12086 if (SDValue C =
12087 DAG.FoldConstantArithmetic(Opcode: ISD::CTTZ_ZERO_UNDEF, DL, VT, Ops: {N0}))
12088 return C;
12089 return SDValue();
12090}
12091
12092SDValue DAGCombiner::visitCTPOP(SDNode *N) {
12093 SDValue N0 = N->getOperand(Num: 0);
12094 EVT VT = N->getValueType(ResNo: 0);
12095 unsigned NumBits = VT.getScalarSizeInBits();
12096 SDLoc DL(N);
12097
12098 // fold (ctpop c1) -> c2
12099 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::CTPOP, DL, VT, Ops: {N0}))
12100 return C;
12101
12102 // If the source is being shifted, but doesn't affect any active bits,
12103 // then we can call CTPOP on the shift source directly.
12104 if (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SHL) {
12105 if (ConstantSDNode *AmtC = isConstOrConstSplat(N: N0.getOperand(i: 1))) {
12106 const APInt &Amt = AmtC->getAPIntValue();
12107 if (Amt.ult(RHS: NumBits)) {
12108 KnownBits KnownSrc = DAG.computeKnownBits(Op: N0.getOperand(i: 0));
12109 if ((N0.getOpcode() == ISD::SRL &&
12110 Amt.ule(RHS: KnownSrc.countMinTrailingZeros())) ||
12111 (N0.getOpcode() == ISD::SHL &&
12112 Amt.ule(RHS: KnownSrc.countMinLeadingZeros()))) {
12113 return DAG.getNode(Opcode: ISD::CTPOP, DL, VT, Operand: N0.getOperand(i: 0));
12114 }
12115 }
12116 }
12117 }
12118
12119 // If the upper bits are known to be zero, then see if its profitable to
12120 // only count the lower bits.
12121 if (VT.isScalarInteger() && NumBits > 8 && (NumBits & 1) == 0) {
12122 EVT HalfVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: NumBits / 2);
12123 if (hasOperation(Opcode: ISD::CTPOP, VT: HalfVT) &&
12124 TLI.isTypeDesirableForOp(ISD::CTPOP, VT: HalfVT) &&
12125 TLI.isTruncateFree(Val: N0, VT2: HalfVT) && TLI.isZExtFree(FromTy: HalfVT, ToTy: VT)) {
12126 APInt UpperBits = APInt::getHighBitsSet(numBits: NumBits, hiBitsSet: NumBits / 2);
12127 if (DAG.MaskedValueIsZero(Op: N0, Mask: UpperBits)) {
12128 SDValue PopCnt = DAG.getNode(Opcode: ISD::CTPOP, DL, VT: HalfVT,
12129 Operand: DAG.getZExtOrTrunc(Op: N0, DL, VT: HalfVT));
12130 return DAG.getZExtOrTrunc(Op: PopCnt, DL, VT);
12131 }
12132 }
12133 }
12134
12135 return SDValue();
12136}
12137
12138static bool isLegalToCombineMinNumMaxNum(SelectionDAG &DAG, SDValue LHS,
12139 SDValue RHS, const SDNodeFlags Flags,
12140 const TargetLowering &TLI) {
12141 EVT VT = LHS.getValueType();
12142 if (!VT.isFloatingPoint())
12143 return false;
12144
12145 return Flags.hasNoSignedZeros() &&
12146 TLI.isProfitableToCombineMinNumMaxNum(VT) &&
12147 (Flags.hasNoNaNs() ||
12148 (DAG.isKnownNeverNaN(Op: RHS) && DAG.isKnownNeverNaN(Op: LHS)));
12149}
12150
12151static SDValue combineMinNumMaxNumImpl(const SDLoc &DL, EVT VT, SDValue LHS,
12152 SDValue RHS, SDValue True, SDValue False,
12153 ISD::CondCode CC,
12154 const TargetLowering &TLI,
12155 SelectionDAG &DAG) {
12156 EVT TransformVT = TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT);
12157 switch (CC) {
12158 case ISD::SETOLT:
12159 case ISD::SETOLE:
12160 case ISD::SETLT:
12161 case ISD::SETLE:
12162 case ISD::SETULT:
12163 case ISD::SETULE: {
12164 // Since it's known never nan to get here already, either fminnum or
12165 // fminnum_ieee are OK. Try the ieee version first, since it's fminnum is
12166 // expanded in terms of it.
12167 unsigned IEEEOpcode = (LHS == True) ? ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE;
12168 if (TLI.isOperationLegalOrCustom(Op: IEEEOpcode, VT))
12169 return DAG.getNode(Opcode: IEEEOpcode, DL, VT, N1: LHS, N2: RHS);
12170
12171 unsigned Opcode = (LHS == True) ? ISD::FMINNUM : ISD::FMAXNUM;
12172 if (TLI.isOperationLegalOrCustom(Op: Opcode, VT: TransformVT))
12173 return DAG.getNode(Opcode, DL, VT, N1: LHS, N2: RHS);
12174 return SDValue();
12175 }
12176 case ISD::SETOGT:
12177 case ISD::SETOGE:
12178 case ISD::SETGT:
12179 case ISD::SETGE:
12180 case ISD::SETUGT:
12181 case ISD::SETUGE: {
12182 unsigned IEEEOpcode = (LHS == True) ? ISD::FMAXNUM_IEEE : ISD::FMINNUM_IEEE;
12183 if (TLI.isOperationLegalOrCustom(Op: IEEEOpcode, VT))
12184 return DAG.getNode(Opcode: IEEEOpcode, DL, VT, N1: LHS, N2: RHS);
12185
12186 unsigned Opcode = (LHS == True) ? ISD::FMAXNUM : ISD::FMINNUM;
12187 if (TLI.isOperationLegalOrCustom(Op: Opcode, VT: TransformVT))
12188 return DAG.getNode(Opcode, DL, VT, N1: LHS, N2: RHS);
12189 return SDValue();
12190 }
12191 default:
12192 return SDValue();
12193 }
12194}
12195
12196// Convert (sr[al] (add n[su]w x, y)) -> (avgfloor[su] x, y)
12197SDValue DAGCombiner::foldShiftToAvg(SDNode *N, const SDLoc &DL) {
12198 const unsigned Opcode = N->getOpcode();
12199 if (Opcode != ISD::SRA && Opcode != ISD::SRL)
12200 return SDValue();
12201
12202 EVT VT = N->getValueType(ResNo: 0);
12203 bool IsUnsigned = Opcode == ISD::SRL;
12204
12205 // Captured values.
12206 SDValue A, B;
12207
12208 // Match floor average as it is common to both floor/ceil avgs, ensure the add
12209 // doesn't wrap.
12210 SDNodeFlags Flags =
12211 IsUnsigned ? SDNodeFlags::NoUnsignedWrap : SDNodeFlags::NoSignedWrap;
12212 if (sd_match(N, P: m_BinOp(Opc: Opcode,
12213 L: m_c_BinOp(Opc: ISD::ADD, L: m_Value(N&: A), R: m_Value(N&: B), Flgs: Flags),
12214 R: m_One()))) {
12215 // Decide whether signed or unsigned.
12216 unsigned FloorISD = IsUnsigned ? ISD::AVGFLOORU : ISD::AVGFLOORS;
12217 if (hasOperation(Opcode: FloorISD, VT))
12218 return DAG.getNode(Opcode: FloorISD, DL, VT, Ops: {A, B});
12219 }
12220
12221 return SDValue();
12222}
12223
12224SDValue DAGCombiner::foldBitwiseOpWithNeg(SDNode *N, const SDLoc &DL, EVT VT) {
12225 unsigned Opc = N->getOpcode();
12226 SDValue X, Y, Z;
12227 if (sd_match(
12228 N, P: m_BitwiseLogic(L: m_Value(N&: X), R: m_Add(L: m_Not(V: m_Value(N&: Y)), R: m_Value(N&: Z)))))
12229 return DAG.getNode(Opcode: Opc, DL, VT, N1: X,
12230 N2: DAG.getNOT(DL, Val: DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Y, N2: Z), VT));
12231
12232 if (sd_match(N, P: m_BitwiseLogic(L: m_Value(N&: X), R: m_Sub(L: m_OneUse(P: m_Not(V: m_Value(N&: Y))),
12233 R: m_Value(N&: Z)))))
12234 return DAG.getNode(Opcode: Opc, DL, VT, N1: X,
12235 N2: DAG.getNOT(DL, Val: DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Y, N2: Z), VT));
12236
12237 return SDValue();
12238}
12239
12240/// Generate Min/Max node
12241SDValue DAGCombiner::combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
12242 SDValue RHS, SDValue True,
12243 SDValue False, ISD::CondCode CC) {
12244 if ((LHS == True && RHS == False) || (LHS == False && RHS == True))
12245 return combineMinNumMaxNumImpl(DL, VT, LHS, RHS, True, False, CC, TLI, DAG);
12246
12247 // If we can't directly match this, try to see if we can pull an fneg out of
12248 // the select.
12249 SDValue NegTrue = TLI.getCheaperOrNeutralNegatedExpression(
12250 Op: True, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize);
12251 if (!NegTrue)
12252 return SDValue();
12253
12254 HandleSDNode NegTrueHandle(NegTrue);
12255
12256 // Try to unfold an fneg from the select if we are comparing the negated
12257 // constant.
12258 //
12259 // select (setcc x, K) (fneg x), -K -> fneg(minnum(x, K))
12260 //
12261 // TODO: Handle fabs
12262 if (LHS == NegTrue) {
12263 // If we can't directly match this, try to see if we can pull an fneg out of
12264 // the select.
12265 SDValue NegRHS = TLI.getCheaperOrNeutralNegatedExpression(
12266 Op: RHS, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize);
12267 if (NegRHS) {
12268 HandleSDNode NegRHSHandle(NegRHS);
12269 if (NegRHS == False) {
12270 SDValue Combined = combineMinNumMaxNumImpl(DL, VT, LHS, RHS, True: NegTrue,
12271 False, CC, TLI, DAG);
12272 if (Combined)
12273 return DAG.getNode(Opcode: ISD::FNEG, DL, VT, Operand: Combined);
12274 }
12275 }
12276 }
12277
12278 return SDValue();
12279}
12280
12281/// If a (v)select has a condition value that is a sign-bit test, try to smear
12282/// the condition operand sign-bit across the value width and use it as a mask.
12283static SDValue foldSelectOfConstantsUsingSra(SDNode *N, const SDLoc &DL,
12284 SelectionDAG &DAG) {
12285 SDValue Cond = N->getOperand(Num: 0);
12286 SDValue C1 = N->getOperand(Num: 1);
12287 SDValue C2 = N->getOperand(Num: 2);
12288 if (!isConstantOrConstantVector(N: C1) || !isConstantOrConstantVector(N: C2))
12289 return SDValue();
12290
12291 EVT VT = N->getValueType(ResNo: 0);
12292 if (Cond.getOpcode() != ISD::SETCC || !Cond.hasOneUse() ||
12293 VT != Cond.getOperand(i: 0).getValueType())
12294 return SDValue();
12295
12296 // The inverted-condition + commuted-select variants of these patterns are
12297 // canonicalized to these forms in IR.
12298 SDValue X = Cond.getOperand(i: 0);
12299 SDValue CondC = Cond.getOperand(i: 1);
12300 ISD::CondCode CC = cast<CondCodeSDNode>(Val: Cond.getOperand(i: 2))->get();
12301 if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(V: CondC) &&
12302 isAllOnesOrAllOnesSplat(V: C2)) {
12303 // i32 X > -1 ? C1 : -1 --> (X >>s 31) | C1
12304 SDValue ShAmtC = DAG.getConstant(Val: X.getScalarValueSizeInBits() - 1, DL, VT);
12305 SDValue Sra = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: X, N2: ShAmtC);
12306 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: Sra, N2: C1);
12307 }
12308 if (CC == ISD::SETLT && isNullOrNullSplat(V: CondC) && isNullOrNullSplat(V: C2)) {
12309 // i8 X < 0 ? C1 : 0 --> (X >>s 7) & C1
12310 SDValue ShAmtC = DAG.getConstant(Val: X.getScalarValueSizeInBits() - 1, DL, VT);
12311 SDValue Sra = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: X, N2: ShAmtC);
12312 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Sra, N2: C1);
12313 }
12314 return SDValue();
12315}
12316
12317static bool shouldConvertSelectOfConstantsToMath(const SDValue &Cond, EVT VT,
12318 const TargetLowering &TLI) {
12319 if (!TLI.convertSelectOfConstantsToMath(VT))
12320 return false;
12321
12322 if (Cond.getOpcode() != ISD::SETCC || !Cond->hasOneUse())
12323 return true;
12324 if (!TLI.isOperationLegalOrCustom(Op: ISD::SELECT_CC, VT))
12325 return true;
12326
12327 ISD::CondCode CC = cast<CondCodeSDNode>(Val: Cond.getOperand(i: 2))->get();
12328 if (CC == ISD::SETLT && isNullOrNullSplat(V: Cond.getOperand(i: 1)))
12329 return true;
12330 if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(V: Cond.getOperand(i: 1)))
12331 return true;
12332
12333 return false;
12334}
12335
12336SDValue DAGCombiner::foldSelectOfConstants(SDNode *N) {
12337 SDValue Cond = N->getOperand(Num: 0);
12338 SDValue N1 = N->getOperand(Num: 1);
12339 SDValue N2 = N->getOperand(Num: 2);
12340 EVT VT = N->getValueType(ResNo: 0);
12341 EVT CondVT = Cond.getValueType();
12342 SDLoc DL(N);
12343
12344 if (!VT.isInteger())
12345 return SDValue();
12346
12347 auto *C1 = dyn_cast<ConstantSDNode>(Val&: N1);
12348 auto *C2 = dyn_cast<ConstantSDNode>(Val&: N2);
12349 if (!C1 || !C2)
12350 return SDValue();
12351
12352 if (CondVT != MVT::i1 || LegalOperations) {
12353 // fold (select Cond, 0, 1) -> (xor Cond, 1)
12354 // We can't do this reliably if integer based booleans have different contents
12355 // to floating point based booleans. This is because we can't tell whether we
12356 // have an integer-based boolean or a floating-point-based boolean unless we
12357 // can find the SETCC that produced it and inspect its operands. This is
12358 // fairly easy if C is the SETCC node, but it can potentially be
12359 // undiscoverable (or not reasonably discoverable). For example, it could be
12360 // in another basic block or it could require searching a complicated
12361 // expression.
12362 if (CondVT.isInteger() &&
12363 TLI.getBooleanContents(/*isVec*/false, /*isFloat*/true) ==
12364 TargetLowering::ZeroOrOneBooleanContent &&
12365 TLI.getBooleanContents(/*isVec*/false, /*isFloat*/false) ==
12366 TargetLowering::ZeroOrOneBooleanContent &&
12367 C1->isZero() && C2->isOne()) {
12368 SDValue NotCond =
12369 DAG.getNode(Opcode: ISD::XOR, DL, VT: CondVT, N1: Cond, N2: DAG.getConstant(Val: 1, DL, VT: CondVT));
12370 if (VT.bitsEq(VT: CondVT))
12371 return NotCond;
12372 return DAG.getZExtOrTrunc(Op: NotCond, DL, VT);
12373 }
12374
12375 return SDValue();
12376 }
12377
12378 // Only do this before legalization to avoid conflicting with target-specific
12379 // transforms in the other direction (create a select from a zext/sext). There
12380 // is also a target-independent combine here in DAGCombiner in the other
12381 // direction for (select Cond, -1, 0) when the condition is not i1.
12382 assert(CondVT == MVT::i1 && !LegalOperations);
12383
12384 // select Cond, 1, 0 --> zext (Cond)
12385 if (C1->isOne() && C2->isZero())
12386 return DAG.getZExtOrTrunc(Op: Cond, DL, VT);
12387
12388 // select Cond, -1, 0 --> sext (Cond)
12389 if (C1->isAllOnes() && C2->isZero())
12390 return DAG.getSExtOrTrunc(Op: Cond, DL, VT);
12391
12392 // select Cond, 0, 1 --> zext (!Cond)
12393 if (C1->isZero() && C2->isOne()) {
12394 SDValue NotCond = DAG.getNOT(DL, Val: Cond, VT: MVT::i1);
12395 NotCond = DAG.getZExtOrTrunc(Op: NotCond, DL, VT);
12396 return NotCond;
12397 }
12398
12399 // select Cond, 0, -1 --> sext (!Cond)
12400 if (C1->isZero() && C2->isAllOnes()) {
12401 SDValue NotCond = DAG.getNOT(DL, Val: Cond, VT: MVT::i1);
12402 NotCond = DAG.getSExtOrTrunc(Op: NotCond, DL, VT);
12403 return NotCond;
12404 }
12405
12406 // Use a target hook because some targets may prefer to transform in the
12407 // other direction.
12408 if (!shouldConvertSelectOfConstantsToMath(Cond, VT, TLI))
12409 return SDValue();
12410
12411 // For any constants that differ by 1, we can transform the select into
12412 // an extend and add.
12413 const APInt &C1Val = C1->getAPIntValue();
12414 const APInt &C2Val = C2->getAPIntValue();
12415
12416 // select Cond, C1, C1-1 --> add (zext Cond), C1-1
12417 if (C1Val - 1 == C2Val) {
12418 Cond = DAG.getZExtOrTrunc(Op: Cond, DL, VT);
12419 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Cond, N2);
12420 }
12421
12422 // select Cond, C1, C1+1 --> add (sext Cond), C1+1
12423 if (C1Val + 1 == C2Val) {
12424 Cond = DAG.getSExtOrTrunc(Op: Cond, DL, VT);
12425 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Cond, N2);
12426 }
12427
12428 // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2)
12429 if (C1Val.isPowerOf2() && C2Val.isZero()) {
12430 Cond = DAG.getZExtOrTrunc(Op: Cond, DL, VT);
12431 SDValue ShAmtC =
12432 DAG.getShiftAmountConstant(Val: C1Val.exactLogBase2(), VT, DL);
12433 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Cond, N2: ShAmtC);
12434 }
12435
12436 // select Cond, -1, C --> or (sext Cond), C
12437 if (C1->isAllOnes()) {
12438 Cond = DAG.getSExtOrTrunc(Op: Cond, DL, VT);
12439 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: Cond, N2);
12440 }
12441
12442 // select Cond, C, -1 --> or (sext (not Cond)), C
12443 if (C2->isAllOnes()) {
12444 SDValue NotCond = DAG.getNOT(DL, Val: Cond, VT: MVT::i1);
12445 NotCond = DAG.getSExtOrTrunc(Op: NotCond, DL, VT);
12446 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: NotCond, N2: N1);
12447 }
12448
12449 if (SDValue V = foldSelectOfConstantsUsingSra(N, DL, DAG))
12450 return V;
12451
12452 return SDValue();
12453}
12454
12455template <class MatchContextClass>
12456static SDValue foldBoolSelectToLogic(SDNode *N, const SDLoc &DL,
12457 SelectionDAG &DAG) {
12458 assert((N->getOpcode() == ISD::SELECT || N->getOpcode() == ISD::VSELECT ||
12459 N->getOpcode() == ISD::VP_SELECT) &&
12460 "Expected a (v)(vp.)select");
12461 SDValue Cond = N->getOperand(Num: 0);
12462 SDValue T = N->getOperand(Num: 1), F = N->getOperand(Num: 2);
12463 EVT VT = N->getValueType(ResNo: 0);
12464 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
12465 MatchContextClass matcher(DAG, TLI, N);
12466
12467 if (VT != Cond.getValueType() || VT.getScalarSizeInBits() != 1)
12468 return SDValue();
12469
12470 // select Cond, Cond, F --> or Cond, freeze(F)
12471 // select Cond, 1, F --> or Cond, freeze(F)
12472 if (Cond == T || isOneOrOneSplat(V: T, /* AllowUndefs */ true))
12473 return matcher.getNode(ISD::OR, DL, VT, Cond, DAG.getFreeze(V: F));
12474
12475 // select Cond, T, Cond --> and Cond, freeze(T)
12476 // select Cond, T, 0 --> and Cond, freeze(T)
12477 if (Cond == F || isNullOrNullSplat(V: F, /* AllowUndefs */ true))
12478 return matcher.getNode(ISD::AND, DL, VT, Cond, DAG.getFreeze(V: T));
12479
12480 // select Cond, T, 1 --> or (not Cond), freeze(T)
12481 if (isOneOrOneSplat(V: F, /* AllowUndefs */ true)) {
12482 SDValue NotCond =
12483 matcher.getNode(ISD::XOR, DL, VT, Cond, DAG.getAllOnesConstant(DL, VT));
12484 return matcher.getNode(ISD::OR, DL, VT, NotCond, DAG.getFreeze(V: T));
12485 }
12486
12487 // select Cond, 0, F --> and (not Cond), freeze(F)
12488 if (isNullOrNullSplat(V: T, /* AllowUndefs */ true)) {
12489 SDValue NotCond =
12490 matcher.getNode(ISD::XOR, DL, VT, Cond, DAG.getAllOnesConstant(DL, VT));
12491 return matcher.getNode(ISD::AND, DL, VT, NotCond, DAG.getFreeze(V: F));
12492 }
12493
12494 return SDValue();
12495}
12496
12497static SDValue foldVSelectToSignBitSplatMask(SDNode *N, SelectionDAG &DAG) {
12498 SDValue N0 = N->getOperand(Num: 0);
12499 SDValue N1 = N->getOperand(Num: 1);
12500 SDValue N2 = N->getOperand(Num: 2);
12501 EVT VT = N->getValueType(ResNo: 0);
12502 unsigned EltSizeInBits = VT.getScalarSizeInBits();
12503
12504 SDValue Cond0, Cond1;
12505 ISD::CondCode CC;
12506 if (!sd_match(N: N0, P: m_OneUse(P: m_SetCC(LHS: m_Value(N&: Cond0), RHS: m_Value(N&: Cond1),
12507 CC: m_CondCode(CC)))) ||
12508 VT != Cond0.getValueType())
12509 return SDValue();
12510
12511 // Match a signbit check of Cond0 as "Cond0 s<0". Swap select operands if the
12512 // compare is inverted from that pattern ("Cond0 s> -1").
12513 if (CC == ISD::SETLT && isNullOrNullSplat(V: Cond1))
12514 ; // This is the pattern we are looking for.
12515 else if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(V: Cond1))
12516 std::swap(a&: N1, b&: N2);
12517 else
12518 return SDValue();
12519
12520 // (Cond0 s< 0) ? N1 : 0 --> (Cond0 s>> BW-1) & freeze(N1)
12521 if (isNullOrNullSplat(V: N2)) {
12522 SDLoc DL(N);
12523 SDValue ShiftAmt = DAG.getShiftAmountConstant(Val: EltSizeInBits - 1, VT, DL);
12524 SDValue Sra = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: Cond0, N2: ShiftAmt);
12525 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Sra, N2: DAG.getFreeze(V: N1));
12526 }
12527
12528 // (Cond0 s< 0) ? -1 : N2 --> (Cond0 s>> BW-1) | freeze(N2)
12529 if (isAllOnesOrAllOnesSplat(V: N1)) {
12530 SDLoc DL(N);
12531 SDValue ShiftAmt = DAG.getShiftAmountConstant(Val: EltSizeInBits - 1, VT, DL);
12532 SDValue Sra = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: Cond0, N2: ShiftAmt);
12533 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: Sra, N2: DAG.getFreeze(V: N2));
12534 }
12535
12536 // If we have to invert the sign bit mask, only do that transform if the
12537 // target has a bitwise 'and not' instruction (the invert is free).
12538 // (Cond0 s< -0) ? 0 : N2 --> ~(Cond0 s>> BW-1) & freeze(N2)
12539 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
12540 if (isNullOrNullSplat(V: N1) && TLI.hasAndNot(X: N1)) {
12541 SDLoc DL(N);
12542 SDValue ShiftAmt = DAG.getShiftAmountConstant(Val: EltSizeInBits - 1, VT, DL);
12543 SDValue Sra = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: Cond0, N2: ShiftAmt);
12544 SDValue Not = DAG.getNOT(DL, Val: Sra, VT);
12545 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Not, N2: DAG.getFreeze(V: N2));
12546 }
12547
12548 // TODO: There's another pattern in this family, but it may require
12549 // implementing hasOrNot() to check for profitability:
12550 // (Cond0 s> -1) ? -1 : N2 --> ~(Cond0 s>> BW-1) | freeze(N2)
12551
12552 return SDValue();
12553}
12554
12555// Match SELECTs with absolute difference patterns.
12556// (select (setcc a, b, set?gt), (sub a, b), (sub b, a)) --> (abd? a, b)
12557// (select (setcc a, b, set?ge), (sub a, b), (sub b, a)) --> (abd? a, b)
12558// (select (setcc a, b, set?lt), (sub b, a), (sub a, b)) --> (abd? a, b)
12559// (select (setcc a, b, set?le), (sub b, a), (sub a, b)) --> (abd? a, b)
12560SDValue DAGCombiner::foldSelectToABD(SDValue LHS, SDValue RHS, SDValue True,
12561 SDValue False, ISD::CondCode CC,
12562 const SDLoc &DL) {
12563 bool IsSigned = isSignedIntSetCC(Code: CC);
12564 unsigned ABDOpc = IsSigned ? ISD::ABDS : ISD::ABDU;
12565 EVT VT = LHS.getValueType();
12566
12567 if (LegalOperations && !hasOperation(Opcode: ABDOpc, VT))
12568 return SDValue();
12569
12570 switch (CC) {
12571 case ISD::SETGT:
12572 case ISD::SETGE:
12573 case ISD::SETUGT:
12574 case ISD::SETUGE:
12575 if (sd_match(N: True, P: m_AnyOf(preds: m_Sub(L: m_Specific(N: LHS), R: m_Specific(N: RHS)),
12576 preds: m_Add(L: m_Specific(N: LHS), R: m_SpecificNeg(V: RHS)))) &&
12577 sd_match(N: False, P: m_AnyOf(preds: m_Sub(L: m_Specific(N: RHS), R: m_Specific(N: LHS)),
12578 preds: m_Add(L: m_Specific(N: RHS), R: m_SpecificNeg(V: LHS)))))
12579 return DAG.getNode(Opcode: ABDOpc, DL, VT, N1: LHS, N2: RHS);
12580 if (sd_match(N: True, P: m_AnyOf(preds: m_Sub(L: m_Specific(N: RHS), R: m_Specific(N: LHS)),
12581 preds: m_Add(L: m_Specific(N: RHS), R: m_SpecificNeg(V: LHS)))) &&
12582 sd_match(N: False, P: m_AnyOf(preds: m_Sub(L: m_Specific(N: LHS), R: m_Specific(N: RHS)),
12583 preds: m_Add(L: m_Specific(N: LHS), R: m_SpecificNeg(V: RHS)))) &&
12584 hasOperation(Opcode: ABDOpc, VT))
12585 return DAG.getNegative(Val: DAG.getNode(Opcode: ABDOpc, DL, VT, N1: LHS, N2: RHS), DL, VT);
12586 break;
12587 case ISD::SETLT:
12588 case ISD::SETLE:
12589 case ISD::SETULT:
12590 case ISD::SETULE:
12591 if (sd_match(N: True, P: m_AnyOf(preds: m_Sub(L: m_Specific(N: RHS), R: m_Specific(N: LHS)),
12592 preds: m_Add(L: m_Specific(N: RHS), R: m_SpecificNeg(V: LHS)))) &&
12593 sd_match(N: False, P: m_AnyOf(preds: m_Sub(L: m_Specific(N: LHS), R: m_Specific(N: RHS)),
12594 preds: m_Add(L: m_Specific(N: LHS), R: m_SpecificNeg(V: RHS)))))
12595 return DAG.getNode(Opcode: ABDOpc, DL, VT, N1: LHS, N2: RHS);
12596 if (sd_match(N: True, P: m_AnyOf(preds: m_Sub(L: m_Specific(N: LHS), R: m_Specific(N: RHS)),
12597 preds: m_Add(L: m_Specific(N: LHS), R: m_SpecificNeg(V: RHS)))) &&
12598 sd_match(N: False, P: m_AnyOf(preds: m_Sub(L: m_Specific(N: RHS), R: m_Specific(N: LHS)),
12599 preds: m_Add(L: m_Specific(N: RHS), R: m_SpecificNeg(V: LHS)))) &&
12600 hasOperation(Opcode: ABDOpc, VT))
12601 return DAG.getNegative(Val: DAG.getNode(Opcode: ABDOpc, DL, VT, N1: LHS, N2: RHS), DL, VT);
12602 break;
12603 default:
12604 break;
12605 }
12606
12607 return SDValue();
12608}
12609
12610// ([v]select (ugt x, C), (add x, ~C), x) -> (umin (add x, ~C), x)
12611// ([v]select (ult x, C), x, (add x, -C)) -> (umin x, (add x, -C))
12612SDValue DAGCombiner::foldSelectToUMin(SDValue LHS, SDValue RHS, SDValue True,
12613 SDValue False, ISD::CondCode CC,
12614 const SDLoc &DL) {
12615 APInt C;
12616 EVT VT = True.getValueType();
12617 if (sd_match(N: RHS, P: m_ConstInt(V&: C)) && hasUMin(VT)) {
12618 if (CC == ISD::SETUGT && LHS == False &&
12619 sd_match(N: True, P: m_Add(L: m_Specific(N: False), R: m_SpecificInt(V: ~C)))) {
12620 SDValue AddC = DAG.getConstant(Val: ~C, DL, VT);
12621 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: False, N2: AddC);
12622 return DAG.getNode(Opcode: ISD::UMIN, DL, VT, N1: Add, N2: False);
12623 }
12624 if (CC == ISD::SETULT && LHS == True &&
12625 sd_match(N: False, P: m_Add(L: m_Specific(N: True), R: m_SpecificInt(V: -C)))) {
12626 SDValue AddC = DAG.getConstant(Val: -C, DL, VT);
12627 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: True, N2: AddC);
12628 return DAG.getNode(Opcode: ISD::UMIN, DL, VT, N1: True, N2: Add);
12629 }
12630 }
12631 return SDValue();
12632}
12633
12634SDValue DAGCombiner::visitSELECT(SDNode *N) {
12635 SDValue N0 = N->getOperand(Num: 0);
12636 SDValue N1 = N->getOperand(Num: 1);
12637 SDValue N2 = N->getOperand(Num: 2);
12638 EVT VT = N->getValueType(ResNo: 0);
12639 EVT VT0 = N0.getValueType();
12640 SDLoc DL(N);
12641 SDNodeFlags Flags = N->getFlags();
12642
12643 if (SDValue V = DAG.simplifySelect(Cond: N0, TVal: N1, FVal: N2))
12644 return V;
12645
12646 if (SDValue V = foldBoolSelectToLogic<EmptyMatchContext>(N, DL, DAG))
12647 return V;
12648
12649 // select (not Cond), N1, N2 -> select Cond, N2, N1
12650 if (SDValue F = extractBooleanFlip(V: N0, DAG, TLI, Force: false))
12651 return DAG.getSelect(DL, VT, Cond: F, LHS: N2, RHS: N1, Flags);
12652
12653 if (SDValue V = foldSelectOfConstants(N))
12654 return V;
12655
12656 // If we can fold this based on the true/false value, do so.
12657 if (SimplifySelectOps(SELECT: N, LHS: N1, RHS: N2))
12658 return SDValue(N, 0); // Don't revisit N.
12659
12660 if (VT0 == MVT::i1) {
12661 // The code in this block deals with the following 2 equivalences:
12662 // select(C0|C1, x, y) <=> select(C0, x, select(C1, x, y))
12663 // select(C0&C1, x, y) <=> select(C0, select(C1, x, y), y)
12664 // The target can specify its preferred form with the
12665 // shouldNormalizeToSelectSequence() callback. However we always transform
12666 // to the right anyway if we find the inner select exists in the DAG anyway
12667 // and we always transform to the left side if we know that we can further
12668 // optimize the combination of the conditions.
12669 bool normalizeToSequence =
12670 TLI.shouldNormalizeToSelectSequence(Context&: *DAG.getContext(), VT);
12671 // select (and Cond0, Cond1), X, Y
12672 // -> select Cond0, (select Cond1, X, Y), Y
12673 if (N0->getOpcode() == ISD::AND && N0->hasOneUse()) {
12674 SDValue Cond0 = N0->getOperand(Num: 0);
12675 SDValue Cond1 = N0->getOperand(Num: 1);
12676 SDValue InnerSelect =
12677 DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: Cond1, N2: N1, N3: N2, Flags);
12678 if (normalizeToSequence || !InnerSelect.use_empty())
12679 return DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: Cond0,
12680 N2: InnerSelect, N3: N2, Flags);
12681 // Cleanup on failure.
12682 if (InnerSelect.use_empty())
12683 recursivelyDeleteUnusedNodes(N: InnerSelect.getNode());
12684 }
12685 // select (or Cond0, Cond1), X, Y -> select Cond0, X, (select Cond1, X, Y)
12686 if (N0->getOpcode() == ISD::OR && N0->hasOneUse()) {
12687 SDValue Cond0 = N0->getOperand(Num: 0);
12688 SDValue Cond1 = N0->getOperand(Num: 1);
12689 SDValue InnerSelect = DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(),
12690 N1: Cond1, N2: N1, N3: N2, Flags);
12691 if (normalizeToSequence || !InnerSelect.use_empty())
12692 return DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: Cond0, N2: N1,
12693 N3: InnerSelect, Flags);
12694 // Cleanup on failure.
12695 if (InnerSelect.use_empty())
12696 recursivelyDeleteUnusedNodes(N: InnerSelect.getNode());
12697 }
12698
12699 // select Cond0, (select Cond1, X, Y), Y -> select (and Cond0, Cond1), X, Y
12700 if (N1->getOpcode() == ISD::SELECT && N1->hasOneUse()) {
12701 SDValue N1_0 = N1->getOperand(Num: 0);
12702 SDValue N1_1 = N1->getOperand(Num: 1);
12703 SDValue N1_2 = N1->getOperand(Num: 2);
12704 if (N1_2 == N2 && N0.getValueType() == N1_0.getValueType()) {
12705 // Create the actual and node if we can generate good code for it.
12706 if (!normalizeToSequence) {
12707 SDValue And = DAG.getNode(Opcode: ISD::AND, DL, VT: N0.getValueType(), N1: N0, N2: N1_0);
12708 return DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: And, N2: N1_1,
12709 N3: N2, Flags);
12710 }
12711 // Otherwise see if we can optimize the "and" to a better pattern.
12712 if (SDValue Combined = visitANDLike(N0, N1: N1_0, N)) {
12713 return DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: Combined, N2: N1_1,
12714 N3: N2, Flags);
12715 }
12716 }
12717 }
12718 // select Cond0, X, (select Cond1, X, Y) -> select (or Cond0, Cond1), X, Y
12719 if (N2->getOpcode() == ISD::SELECT && N2->hasOneUse()) {
12720 SDValue N2_0 = N2->getOperand(Num: 0);
12721 SDValue N2_1 = N2->getOperand(Num: 1);
12722 SDValue N2_2 = N2->getOperand(Num: 2);
12723 if (N2_1 == N1 && N0.getValueType() == N2_0.getValueType()) {
12724 // Create the actual or node if we can generate good code for it.
12725 if (!normalizeToSequence) {
12726 SDValue Or = DAG.getNode(Opcode: ISD::OR, DL, VT: N0.getValueType(), N1: N0, N2: N2_0);
12727 return DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: Or, N2: N1,
12728 N3: N2_2, Flags);
12729 }
12730 // Otherwise see if we can optimize to a better pattern.
12731 if (SDValue Combined = visitORLike(N0, N1: N2_0, DL))
12732 return DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: Combined, N2: N1,
12733 N3: N2_2, Flags);
12734 }
12735 }
12736
12737 // select usubo(x, y).overflow, (sub y, x), (usubo x, y) -> abdu(x, y)
12738 if (N0.getOpcode() == ISD::USUBO && N0.getResNo() == 1 &&
12739 N2.getNode() == N0.getNode() && N2.getResNo() == 0 &&
12740 N1.getOpcode() == ISD::SUB && N2.getOperand(i: 0) == N1.getOperand(i: 1) &&
12741 N2.getOperand(i: 1) == N1.getOperand(i: 0) &&
12742 (!LegalOperations || TLI.isOperationLegal(Op: ISD::ABDU, VT)))
12743 return DAG.getNode(Opcode: ISD::ABDU, DL, VT, N1: N0.getOperand(i: 0), N2: N0.getOperand(i: 1));
12744
12745 // select usubo(x, y).overflow, (usubo x, y), (sub y, x) -> neg (abdu x, y)
12746 if (N0.getOpcode() == ISD::USUBO && N0.getResNo() == 1 &&
12747 N1.getNode() == N0.getNode() && N1.getResNo() == 0 &&
12748 N2.getOpcode() == ISD::SUB && N2.getOperand(i: 0) == N1.getOperand(i: 1) &&
12749 N2.getOperand(i: 1) == N1.getOperand(i: 0) &&
12750 (!LegalOperations || TLI.isOperationLegal(Op: ISD::ABDU, VT)))
12751 return DAG.getNegative(
12752 Val: DAG.getNode(Opcode: ISD::ABDU, DL, VT, N1: N0.getOperand(i: 0), N2: N0.getOperand(i: 1)),
12753 DL, VT);
12754 }
12755
12756 // Fold selects based on a setcc into other things, such as min/max/abs.
12757 if (N0.getOpcode() == ISD::SETCC) {
12758 SDValue Cond0 = N0.getOperand(i: 0), Cond1 = N0.getOperand(i: 1);
12759 ISD::CondCode CC = cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get();
12760
12761 // select (fcmp lt x, y), x, y -> fminnum x, y
12762 // select (fcmp gt x, y), x, y -> fmaxnum x, y
12763 //
12764 // This is OK if we don't care what happens if either operand is a NaN.
12765 if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, LHS: N1, RHS: N2, Flags, TLI))
12766 if (SDValue FMinMax =
12767 combineMinNumMaxNum(DL, VT, LHS: Cond0, RHS: Cond1, True: N1, False: N2, CC))
12768 return FMinMax;
12769
12770 // Use 'unsigned add with overflow' to optimize an unsigned saturating add.
12771 // This is conservatively limited to pre-legal-operations to give targets
12772 // a chance to reverse the transform if they want to do that. Also, it is
12773 // unlikely that the pattern would be formed late, so it's probably not
12774 // worth going through the other checks.
12775 if (!LegalOperations && TLI.isOperationLegalOrCustom(Op: ISD::UADDO, VT) &&
12776 CC == ISD::SETUGT && N0.hasOneUse() && isAllOnesConstant(V: N1) &&
12777 N2.getOpcode() == ISD::ADD && Cond0 == N2.getOperand(i: 0)) {
12778 auto *C = dyn_cast<ConstantSDNode>(Val: N2.getOperand(i: 1));
12779 auto *NotC = dyn_cast<ConstantSDNode>(Val&: Cond1);
12780 if (C && NotC && C->getAPIntValue() == ~NotC->getAPIntValue()) {
12781 // select (setcc Cond0, ~C, ugt), -1, (add Cond0, C) -->
12782 // uaddo Cond0, C; select uaddo.1, -1, uaddo.0
12783 //
12784 // The IR equivalent of this transform would have this form:
12785 // %a = add %x, C
12786 // %c = icmp ugt %x, ~C
12787 // %r = select %c, -1, %a
12788 // =>
12789 // %u = call {iN,i1} llvm.uadd.with.overflow(%x, C)
12790 // %u0 = extractvalue %u, 0
12791 // %u1 = extractvalue %u, 1
12792 // %r = select %u1, -1, %u0
12793 SDVTList VTs = DAG.getVTList(VT1: VT, VT2: VT0);
12794 SDValue UAO = DAG.getNode(Opcode: ISD::UADDO, DL, VTList: VTs, N1: Cond0, N2: N2.getOperand(i: 1));
12795 return DAG.getSelect(DL, VT, Cond: UAO.getValue(R: 1), LHS: N1, RHS: UAO.getValue(R: 0));
12796 }
12797 }
12798
12799 if (TLI.isOperationLegal(Op: ISD::SELECT_CC, VT) ||
12800 (!LegalOperations &&
12801 TLI.isOperationLegalOrCustom(Op: ISD::SELECT_CC, VT))) {
12802 // Any flags available in a select/setcc fold will be on the setcc as they
12803 // migrated from fcmp
12804 return DAG.getNode(Opcode: ISD::SELECT_CC, DL, VT, N1: Cond0, N2: Cond1, N3: N1, N4: N2,
12805 N5: N0.getOperand(i: 2), Flags: N0->getFlags());
12806 }
12807
12808 if (SDValue ABD = foldSelectToABD(LHS: Cond0, RHS: Cond1, True: N1, False: N2, CC, DL))
12809 return ABD;
12810
12811 if (SDValue NewSel = SimplifySelect(DL, N0, N1, N2))
12812 return NewSel;
12813
12814 // (select (ugt x, C), (add x, ~C), x) -> (umin (add x, ~C), x)
12815 // (select (ult x, C), x, (add x, -C)) -> (umin x, (add x, -C))
12816 if (SDValue UMin = foldSelectToUMin(LHS: Cond0, RHS: Cond1, True: N1, False: N2, CC, DL))
12817 return UMin;
12818 }
12819
12820 if (!VT.isVector())
12821 if (SDValue BinOp = foldSelectOfBinops(N))
12822 return BinOp;
12823
12824 if (SDValue R = combineSelectAsExtAnd(Cond: N0, T: N1, F: N2, DL, DAG))
12825 return R;
12826
12827 return SDValue();
12828}
12829
12830// This function assumes all the vselect's arguments are CONCAT_VECTOR
12831// nodes and that the condition is a BV of ConstantSDNodes (or undefs).
12832static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
12833 SDLoc DL(N);
12834 SDValue Cond = N->getOperand(Num: 0);
12835 SDValue LHS = N->getOperand(Num: 1);
12836 SDValue RHS = N->getOperand(Num: 2);
12837 EVT VT = N->getValueType(ResNo: 0);
12838 int NumElems = VT.getVectorNumElements();
12839 assert(LHS.getOpcode() == ISD::CONCAT_VECTORS &&
12840 RHS.getOpcode() == ISD::CONCAT_VECTORS &&
12841 Cond.getOpcode() == ISD::BUILD_VECTOR);
12842
12843 // CONCAT_VECTOR can take an arbitrary number of arguments. We only care about
12844 // binary ones here.
12845 if (LHS->getNumOperands() != 2 || RHS->getNumOperands() != 2)
12846 return SDValue();
12847
12848 // We're sure we have an even number of elements due to the
12849 // concat_vectors we have as arguments to vselect.
12850 // Skip BV elements until we find one that's not an UNDEF
12851 // After we find an UNDEF element, keep looping until we get to half the
12852 // length of the BV and see if all the non-undef nodes are the same.
12853 ConstantSDNode *BottomHalf = nullptr;
12854 for (int i = 0; i < NumElems / 2; ++i) {
12855 if (Cond->getOperand(Num: i)->isUndef())
12856 continue;
12857
12858 if (BottomHalf == nullptr)
12859 BottomHalf = cast<ConstantSDNode>(Val: Cond.getOperand(i));
12860 else if (Cond->getOperand(Num: i).getNode() != BottomHalf)
12861 return SDValue();
12862 }
12863
12864 // Do the same for the second half of the BuildVector
12865 ConstantSDNode *TopHalf = nullptr;
12866 for (int i = NumElems / 2; i < NumElems; ++i) {
12867 if (Cond->getOperand(Num: i)->isUndef())
12868 continue;
12869
12870 if (TopHalf == nullptr)
12871 TopHalf = cast<ConstantSDNode>(Val: Cond.getOperand(i));
12872 else if (Cond->getOperand(Num: i).getNode() != TopHalf)
12873 return SDValue();
12874 }
12875
12876 assert(TopHalf && BottomHalf &&
12877 "One half of the selector was all UNDEFs and the other was all the "
12878 "same value. This should have been addressed before this function.");
12879 return DAG.getNode(
12880 Opcode: ISD::CONCAT_VECTORS, DL, VT,
12881 N1: BottomHalf->isZero() ? RHS->getOperand(Num: 0) : LHS->getOperand(Num: 0),
12882 N2: TopHalf->isZero() ? RHS->getOperand(Num: 1) : LHS->getOperand(Num: 1));
12883}
12884
12885bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
12886 SelectionDAG &DAG, const SDLoc &DL) {
12887
12888 // Only perform the transformation when existing operands can be reused.
12889 if (IndexIsScaled)
12890 return false;
12891
12892 if (!isNullConstant(V: BasePtr) && !Index.hasOneUse())
12893 return false;
12894
12895 EVT VT = BasePtr.getValueType();
12896
12897 if (SDValue SplatVal = DAG.getSplatValue(V: Index);
12898 SplatVal && !isNullConstant(V: SplatVal) &&
12899 SplatVal.getValueType() == VT) {
12900 BasePtr = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: BasePtr, N2: SplatVal);
12901 Index = DAG.getSplat(VT: Index.getValueType(), DL, Op: DAG.getConstant(Val: 0, DL, VT));
12902 return true;
12903 }
12904
12905 if (Index.getOpcode() != ISD::ADD)
12906 return false;
12907
12908 if (SDValue SplatVal = DAG.getSplatValue(V: Index.getOperand(i: 0));
12909 SplatVal && SplatVal.getValueType() == VT) {
12910 BasePtr = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: BasePtr, N2: SplatVal);
12911 Index = Index.getOperand(i: 1);
12912 return true;
12913 }
12914 if (SDValue SplatVal = DAG.getSplatValue(V: Index.getOperand(i: 1));
12915 SplatVal && SplatVal.getValueType() == VT) {
12916 BasePtr = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: BasePtr, N2: SplatVal);
12917 Index = Index.getOperand(i: 0);
12918 return true;
12919 }
12920 return false;
12921}
12922
12923// Fold sext/zext of index into index type.
12924bool refineIndexType(SDValue &Index, ISD::MemIndexType &IndexType, EVT DataVT,
12925 SelectionDAG &DAG) {
12926 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
12927
12928 // It's always safe to look through zero extends.
12929 if (Index.getOpcode() == ISD::ZERO_EXTEND) {
12930 if (TLI.shouldRemoveExtendFromGSIndex(Extend: Index, DataVT)) {
12931 IndexType = ISD::UNSIGNED_SCALED;
12932 Index = Index.getOperand(i: 0);
12933 return true;
12934 }
12935 if (ISD::isIndexTypeSigned(IndexType)) {
12936 IndexType = ISD::UNSIGNED_SCALED;
12937 return true;
12938 }
12939 }
12940
12941 // It's only safe to look through sign extends when Index is signed.
12942 if (Index.getOpcode() == ISD::SIGN_EXTEND &&
12943 ISD::isIndexTypeSigned(IndexType) &&
12944 TLI.shouldRemoveExtendFromGSIndex(Extend: Index, DataVT)) {
12945 Index = Index.getOperand(i: 0);
12946 return true;
12947 }
12948
12949 return false;
12950}
12951
12952SDValue DAGCombiner::visitVPSCATTER(SDNode *N) {
12953 VPScatterSDNode *MSC = cast<VPScatterSDNode>(Val: N);
12954 SDValue Mask = MSC->getMask();
12955 SDValue Chain = MSC->getChain();
12956 SDValue Index = MSC->getIndex();
12957 SDValue Scale = MSC->getScale();
12958 SDValue StoreVal = MSC->getValue();
12959 SDValue BasePtr = MSC->getBasePtr();
12960 SDValue VL = MSC->getVectorLength();
12961 ISD::MemIndexType IndexType = MSC->getIndexType();
12962 SDLoc DL(N);
12963
12964 // Zap scatters with a zero mask.
12965 if (ISD::isConstantSplatVectorAllZeros(N: Mask.getNode()))
12966 return Chain;
12967
12968 if (refineUniformBase(BasePtr, Index, IndexIsScaled: MSC->isIndexScaled(), DAG, DL)) {
12969 SDValue Ops[] = {Chain, StoreVal, BasePtr, Index, Scale, Mask, VL};
12970 return DAG.getScatterVP(VTs: DAG.getVTList(VT: MVT::Other), VT: MSC->getMemoryVT(),
12971 dl: DL, Ops, MMO: MSC->getMemOperand(), IndexType);
12972 }
12973
12974 if (refineIndexType(Index, IndexType, DataVT: StoreVal.getValueType(), DAG)) {
12975 SDValue Ops[] = {Chain, StoreVal, BasePtr, Index, Scale, Mask, VL};
12976 return DAG.getScatterVP(VTs: DAG.getVTList(VT: MVT::Other), VT: MSC->getMemoryVT(),
12977 dl: DL, Ops, MMO: MSC->getMemOperand(), IndexType);
12978 }
12979
12980 return SDValue();
12981}
12982
12983SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
12984 MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(Val: N);
12985 SDValue Mask = MSC->getMask();
12986 SDValue Chain = MSC->getChain();
12987 SDValue Index = MSC->getIndex();
12988 SDValue Scale = MSC->getScale();
12989 SDValue StoreVal = MSC->getValue();
12990 SDValue BasePtr = MSC->getBasePtr();
12991 ISD::MemIndexType IndexType = MSC->getIndexType();
12992 SDLoc DL(N);
12993
12994 // Zap scatters with a zero mask.
12995 if (ISD::isConstantSplatVectorAllZeros(N: Mask.getNode()))
12996 return Chain;
12997
12998 if (refineUniformBase(BasePtr, Index, IndexIsScaled: MSC->isIndexScaled(), DAG, DL)) {
12999 SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
13000 return DAG.getMaskedScatter(VTs: DAG.getVTList(VT: MVT::Other), MemVT: MSC->getMemoryVT(),
13001 dl: DL, Ops, MMO: MSC->getMemOperand(), IndexType,
13002 IsTruncating: MSC->isTruncatingStore());
13003 }
13004
13005 if (refineIndexType(Index, IndexType, DataVT: StoreVal.getValueType(), DAG)) {
13006 SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
13007 return DAG.getMaskedScatter(VTs: DAG.getVTList(VT: MVT::Other), MemVT: MSC->getMemoryVT(),
13008 dl: DL, Ops, MMO: MSC->getMemOperand(), IndexType,
13009 IsTruncating: MSC->isTruncatingStore());
13010 }
13011
13012 return SDValue();
13013}
13014
13015SDValue DAGCombiner::visitMSTORE(SDNode *N) {
13016 MaskedStoreSDNode *MST = cast<MaskedStoreSDNode>(Val: N);
13017 SDValue Mask = MST->getMask();
13018 SDValue Chain = MST->getChain();
13019 SDValue Value = MST->getValue();
13020 SDValue Ptr = MST->getBasePtr();
13021
13022 // Zap masked stores with a zero mask.
13023 if (ISD::isConstantSplatVectorAllZeros(N: Mask.getNode()))
13024 return Chain;
13025
13026 // Remove a masked store if base pointers and masks are equal.
13027 if (MaskedStoreSDNode *MST1 = dyn_cast<MaskedStoreSDNode>(Val&: Chain)) {
13028 if (MST->isUnindexed() && MST->isSimple() && MST1->isUnindexed() &&
13029 MST1->isSimple() && MST1->getBasePtr() == Ptr &&
13030 !MST->getBasePtr().isUndef() &&
13031 ((Mask == MST1->getMask() && MST->getMemoryVT().getStoreSize() ==
13032 MST1->getMemoryVT().getStoreSize()) ||
13033 ISD::isConstantSplatVectorAllOnes(N: Mask.getNode())) &&
13034 TypeSize::isKnownLE(LHS: MST1->getMemoryVT().getStoreSize(),
13035 RHS: MST->getMemoryVT().getStoreSize())) {
13036 CombineTo(N: MST1, Res: MST1->getChain());
13037 if (N->getOpcode() != ISD::DELETED_NODE)
13038 AddToWorklist(N);
13039 return SDValue(N, 0);
13040 }
13041 }
13042
13043 // If this is a masked load with an all ones mask, we can use a unmasked load.
13044 // FIXME: Can we do this for indexed, compressing, or truncating stores?
13045 if (ISD::isConstantSplatVectorAllOnes(N: Mask.getNode()) && MST->isUnindexed() &&
13046 !MST->isCompressingStore() && !MST->isTruncatingStore())
13047 return DAG.getStore(Chain: MST->getChain(), dl: SDLoc(N), Val: MST->getValue(),
13048 Ptr: MST->getBasePtr(), PtrInfo: MST->getPointerInfo(),
13049 Alignment: MST->getBaseAlign(), MMOFlags: MST->getMemOperand()->getFlags(),
13050 AAInfo: MST->getAAInfo());
13051
13052 // Try transforming N to an indexed store.
13053 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
13054 return SDValue(N, 0);
13055
13056 if (MST->isTruncatingStore() && MST->isUnindexed() &&
13057 Value.getValueType().isInteger() &&
13058 (!isa<ConstantSDNode>(Val: Value) ||
13059 !cast<ConstantSDNode>(Val&: Value)->isOpaque())) {
13060 APInt TruncDemandedBits =
13061 APInt::getLowBitsSet(numBits: Value.getScalarValueSizeInBits(),
13062 loBitsSet: MST->getMemoryVT().getScalarSizeInBits());
13063
13064 // See if we can simplify the operation with
13065 // SimplifyDemandedBits, which only works if the value has a single use.
13066 if (SimplifyDemandedBits(Op: Value, DemandedBits: TruncDemandedBits)) {
13067 // Re-visit the store if anything changed and the store hasn't been merged
13068 // with another node (N is deleted) SimplifyDemandedBits will add Value's
13069 // node back to the worklist if necessary, but we also need to re-visit
13070 // the Store node itself.
13071 if (N->getOpcode() != ISD::DELETED_NODE)
13072 AddToWorklist(N);
13073 return SDValue(N, 0);
13074 }
13075 }
13076
13077 // If this is a TRUNC followed by a masked store, fold this into a masked
13078 // truncating store. We can do this even if this is already a masked
13079 // truncstore.
13080 // TODO: Try combine to masked compress store if possiable.
13081 if ((Value.getOpcode() == ISD::TRUNCATE) && Value->hasOneUse() &&
13082 MST->isUnindexed() && !MST->isCompressingStore() &&
13083 TLI.canCombineTruncStore(ValVT: Value.getOperand(i: 0).getValueType(),
13084 MemVT: MST->getMemoryVT(), LegalOnly: LegalOperations)) {
13085 auto Mask = TLI.promoteTargetBoolean(DAG, Bool: MST->getMask(),
13086 ValVT: Value.getOperand(i: 0).getValueType());
13087 return DAG.getMaskedStore(Chain, dl: SDLoc(N), Val: Value.getOperand(i: 0), Base: Ptr,
13088 Offset: MST->getOffset(), Mask, MemVT: MST->getMemoryVT(),
13089 MMO: MST->getMemOperand(), AM: MST->getAddressingMode(),
13090 /*IsTruncating=*/true);
13091 }
13092
13093 return SDValue();
13094}
13095
13096SDValue DAGCombiner::visitVP_STRIDED_STORE(SDNode *N) {
13097 auto *SST = cast<VPStridedStoreSDNode>(Val: N);
13098 EVT EltVT = SST->getValue().getValueType().getVectorElementType();
13099 // Combine strided stores with unit-stride to a regular VP store.
13100 if (auto *CStride = dyn_cast<ConstantSDNode>(Val: SST->getStride());
13101 CStride && CStride->getZExtValue() == EltVT.getStoreSize()) {
13102 return DAG.getStoreVP(Chain: SST->getChain(), dl: SDLoc(N), Val: SST->getValue(),
13103 Ptr: SST->getBasePtr(), Offset: SST->getOffset(), Mask: SST->getMask(),
13104 EVL: SST->getVectorLength(), MemVT: SST->getMemoryVT(),
13105 MMO: SST->getMemOperand(), AM: SST->getAddressingMode(),
13106 IsTruncating: SST->isTruncatingStore(), IsCompressing: SST->isCompressingStore());
13107 }
13108 return SDValue();
13109}
13110
13111SDValue DAGCombiner::visitVECTOR_COMPRESS(SDNode *N) {
13112 SDLoc DL(N);
13113 SDValue Vec = N->getOperand(Num: 0);
13114 SDValue Mask = N->getOperand(Num: 1);
13115 SDValue Passthru = N->getOperand(Num: 2);
13116 EVT VecVT = Vec.getValueType();
13117
13118 bool HasPassthru = !Passthru.isUndef();
13119
13120 APInt SplatVal;
13121 if (ISD::isConstantSplatVector(N: Mask.getNode(), SplatValue&: SplatVal))
13122 return TLI.isConstTrueVal(N: Mask) ? Vec : Passthru;
13123
13124 if (Vec.isUndef() || Mask.isUndef())
13125 return Passthru;
13126
13127 // No need for potentially expensive compress if the mask is constant.
13128 if (ISD::isBuildVectorOfConstantSDNodes(N: Mask.getNode())) {
13129 SmallVector<SDValue, 16> Ops;
13130 EVT ScalarVT = VecVT.getVectorElementType();
13131 unsigned NumSelected = 0;
13132 unsigned NumElmts = VecVT.getVectorNumElements();
13133 for (unsigned I = 0; I < NumElmts; ++I) {
13134 SDValue MaskI = Mask.getOperand(i: I);
13135 // We treat undef mask entries as "false".
13136 if (MaskI.isUndef())
13137 continue;
13138
13139 if (TLI.isConstTrueVal(N: MaskI)) {
13140 SDValue VecI = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: ScalarVT, N1: Vec,
13141 N2: DAG.getVectorIdxConstant(Val: I, DL));
13142 Ops.push_back(Elt: VecI);
13143 NumSelected++;
13144 }
13145 }
13146 for (unsigned Rest = NumSelected; Rest < NumElmts; ++Rest) {
13147 SDValue Val =
13148 HasPassthru
13149 ? DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: ScalarVT, N1: Passthru,
13150 N2: DAG.getVectorIdxConstant(Val: Rest, DL))
13151 : DAG.getUNDEF(VT: ScalarVT);
13152 Ops.push_back(Elt: Val);
13153 }
13154 return DAG.getBuildVector(VT: VecVT, DL, Ops);
13155 }
13156
13157 return SDValue();
13158}
13159
13160SDValue DAGCombiner::visitVPGATHER(SDNode *N) {
13161 VPGatherSDNode *MGT = cast<VPGatherSDNode>(Val: N);
13162 SDValue Mask = MGT->getMask();
13163 SDValue Chain = MGT->getChain();
13164 SDValue Index = MGT->getIndex();
13165 SDValue Scale = MGT->getScale();
13166 SDValue BasePtr = MGT->getBasePtr();
13167 SDValue VL = MGT->getVectorLength();
13168 ISD::MemIndexType IndexType = MGT->getIndexType();
13169 SDLoc DL(N);
13170
13171 if (refineUniformBase(BasePtr, Index, IndexIsScaled: MGT->isIndexScaled(), DAG, DL)) {
13172 SDValue Ops[] = {Chain, BasePtr, Index, Scale, Mask, VL};
13173 return DAG.getGatherVP(
13174 VTs: DAG.getVTList(VT1: N->getValueType(ResNo: 0), VT2: MVT::Other), VT: MGT->getMemoryVT(), dl: DL,
13175 Ops, MMO: MGT->getMemOperand(), IndexType);
13176 }
13177
13178 if (refineIndexType(Index, IndexType, DataVT: N->getValueType(ResNo: 0), DAG)) {
13179 SDValue Ops[] = {Chain, BasePtr, Index, Scale, Mask, VL};
13180 return DAG.getGatherVP(
13181 VTs: DAG.getVTList(VT1: N->getValueType(ResNo: 0), VT2: MVT::Other), VT: MGT->getMemoryVT(), dl: DL,
13182 Ops, MMO: MGT->getMemOperand(), IndexType);
13183 }
13184
13185 return SDValue();
13186}
13187
13188SDValue DAGCombiner::visitMGATHER(SDNode *N) {
13189 MaskedGatherSDNode *MGT = cast<MaskedGatherSDNode>(Val: N);
13190 SDValue Mask = MGT->getMask();
13191 SDValue Chain = MGT->getChain();
13192 SDValue Index = MGT->getIndex();
13193 SDValue Scale = MGT->getScale();
13194 SDValue PassThru = MGT->getPassThru();
13195 SDValue BasePtr = MGT->getBasePtr();
13196 ISD::MemIndexType IndexType = MGT->getIndexType();
13197 SDLoc DL(N);
13198
13199 // Zap gathers with a zero mask.
13200 if (ISD::isConstantSplatVectorAllZeros(N: Mask.getNode()))
13201 return CombineTo(N, Res0: PassThru, Res1: MGT->getChain());
13202
13203 if (refineUniformBase(BasePtr, Index, IndexIsScaled: MGT->isIndexScaled(), DAG, DL)) {
13204 SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
13205 return DAG.getMaskedGather(
13206 VTs: DAG.getVTList(VT1: N->getValueType(ResNo: 0), VT2: MVT::Other), MemVT: MGT->getMemoryVT(), dl: DL,
13207 Ops, MMO: MGT->getMemOperand(), IndexType, ExtTy: MGT->getExtensionType());
13208 }
13209
13210 if (refineIndexType(Index, IndexType, DataVT: N->getValueType(ResNo: 0), DAG)) {
13211 SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
13212 return DAG.getMaskedGather(
13213 VTs: DAG.getVTList(VT1: N->getValueType(ResNo: 0), VT2: MVT::Other), MemVT: MGT->getMemoryVT(), dl: DL,
13214 Ops, MMO: MGT->getMemOperand(), IndexType, ExtTy: MGT->getExtensionType());
13215 }
13216
13217 return SDValue();
13218}
13219
13220SDValue DAGCombiner::visitMLOAD(SDNode *N) {
13221 MaskedLoadSDNode *MLD = cast<MaskedLoadSDNode>(Val: N);
13222 SDValue Mask = MLD->getMask();
13223
13224 // Zap masked loads with a zero mask.
13225 if (ISD::isConstantSplatVectorAllZeros(N: Mask.getNode()))
13226 return CombineTo(N, Res0: MLD->getPassThru(), Res1: MLD->getChain());
13227
13228 // If this is a masked load with an all ones mask, we can use a unmasked load.
13229 // FIXME: Can we do this for indexed, expanding, or extending loads?
13230 if (ISD::isConstantSplatVectorAllOnes(N: Mask.getNode()) && MLD->isUnindexed() &&
13231 !MLD->isExpandingLoad() && MLD->getExtensionType() == ISD::NON_EXTLOAD) {
13232 SDValue NewLd = DAG.getLoad(
13233 VT: N->getValueType(ResNo: 0), dl: SDLoc(N), Chain: MLD->getChain(), Ptr: MLD->getBasePtr(),
13234 PtrInfo: MLD->getPointerInfo(), Alignment: MLD->getBaseAlign(),
13235 MMOFlags: MLD->getMemOperand()->getFlags(), AAInfo: MLD->getAAInfo(), Ranges: MLD->getRanges());
13236 return CombineTo(N, Res0: NewLd, Res1: NewLd.getValue(R: 1));
13237 }
13238
13239 // Try transforming N to an indexed load.
13240 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
13241 return SDValue(N, 0);
13242
13243 return SDValue();
13244}
13245
13246SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
13247 MaskedHistogramSDNode *HG = cast<MaskedHistogramSDNode>(Val: N);
13248 SDValue Chain = HG->getChain();
13249 SDValue Inc = HG->getInc();
13250 SDValue Mask = HG->getMask();
13251 SDValue BasePtr = HG->getBasePtr();
13252 SDValue Index = HG->getIndex();
13253 SDLoc DL(HG);
13254
13255 EVT MemVT = HG->getMemoryVT();
13256 EVT DataVT = Index.getValueType();
13257 MachineMemOperand *MMO = HG->getMemOperand();
13258 ISD::MemIndexType IndexType = HG->getIndexType();
13259
13260 if (ISD::isConstantSplatVectorAllZeros(N: Mask.getNode()))
13261 return Chain;
13262
13263 if (refineUniformBase(BasePtr, Index, IndexIsScaled: HG->isIndexScaled(), DAG, DL) ||
13264 refineIndexType(Index, IndexType, DataVT, DAG)) {
13265 SDValue Ops[] = {Chain, Inc, Mask, BasePtr, Index,
13266 HG->getScale(), HG->getIntID()};
13267 return DAG.getMaskedHistogram(VTs: DAG.getVTList(VT: MVT::Other), MemVT, dl: DL, Ops,
13268 MMO, IndexType);
13269 }
13270
13271 return SDValue();
13272}
13273
13274SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
13275 if (SDValue Res = foldPartialReduceMLAMulOp(N))
13276 return Res;
13277 if (SDValue Res = foldPartialReduceAdd(N))
13278 return Res;
13279 return SDValue();
13280}
13281
13282// partial_reduce_*mla(acc, mul(*ext(a), *ext(b)), splat(1))
13283// -> partial_reduce_*mla(acc, a, b)
13284//
13285// partial_reduce_*mla(acc, mul(*ext(x), splat(C)), splat(1))
13286// -> partial_reduce_*mla(acc, x, splat(C))
13287//
13288// partial_reduce_*mla(acc, sel(p, mul(*ext(a), *ext(b)), splat(0)), splat(1))
13289// -> partial_reduce_*mla(acc, sel(p, a, splat(0)), b)
13290//
13291// partial_reduce_*mla(acc, sel(p, mul(*ext(a), splat(C)), splat(0)), splat(1))
13292// -> partial_reduce_*mla(acc, sel(p, a, splat(0)), splat(C))
13293SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
13294 SDLoc DL(N);
13295 auto *Context = DAG.getContext();
13296 SDValue Acc = N->getOperand(Num: 0);
13297 SDValue Op1 = N->getOperand(Num: 1);
13298 SDValue Op2 = N->getOperand(Num: 2);
13299 unsigned Opc = Op1->getOpcode();
13300
13301 // Handle predication by moving the SELECT into the operand of the MUL.
13302 SDValue Pred;
13303 if (Opc == ISD::VSELECT && (isZeroOrZeroSplat(N: Op1->getOperand(Num: 2)) ||
13304 isZeroOrZeroSplatFP(N: Op1->getOperand(Num: 2)))) {
13305 Pred = Op1->getOperand(Num: 0);
13306 Op1 = Op1->getOperand(Num: 1);
13307 Opc = Op1->getOpcode();
13308 }
13309
13310 if (Opc != ISD::MUL && Opc != ISD::FMUL && Opc != ISD::SHL)
13311 return SDValue();
13312
13313 SDValue LHS = Op1->getOperand(Num: 0);
13314 SDValue RHS = Op1->getOperand(Num: 1);
13315
13316 // Try to treat (shl %a, %c) as (mul %a, (1 << %c)) for constant %c.
13317 if (Opc == ISD::SHL) {
13318 APInt C;
13319 if (!ISD::isConstantSplatVector(N: RHS.getNode(), SplatValue&: C))
13320 return SDValue();
13321
13322 RHS =
13323 DAG.getSplatVector(VT: RHS.getValueType(), DL,
13324 Op: DAG.getConstant(Val: APInt(C.getBitWidth(), 1).shl(ShiftAmt: C), DL,
13325 VT: RHS.getValueType().getScalarType()));
13326 Opc = ISD::MUL;
13327 }
13328
13329 if (!(Opc == ISD::MUL && llvm::isOneOrOneSplat(V: Op2)) &&
13330 !(Opc == ISD::FMUL && llvm::isOneOrOneSplatFP(V: Op2)))
13331 return SDValue();
13332
13333 auto IsIntOrFPExtOpcode = [](unsigned int Opcode) {
13334 return (ISD::isExtOpcode(Opcode) || Opcode == ISD::FP_EXTEND);
13335 };
13336
13337 unsigned LHSOpcode = LHS->getOpcode();
13338 if (!IsIntOrFPExtOpcode(LHSOpcode))
13339 return SDValue();
13340
13341 SDValue LHSExtOp = LHS->getOperand(Num: 0);
13342 EVT LHSExtOpVT = LHSExtOp.getValueType();
13343
13344 // When Pred is non-zero, set Op = select(Pred, Op, splat(0)) and freeze
13345 // OtherOp to keep the same semantics when moving the selects into the MUL
13346 // operands.
13347 auto ApplyPredicate = [&](SDValue &Op, SDValue &OtherOp) {
13348 if (Pred) {
13349 EVT OpVT = Op.getValueType();
13350 SDValue Zero = OpVT.isFloatingPoint() ? DAG.getConstantFP(Val: 0.0, DL, VT: OpVT)
13351 : DAG.getConstant(Val: 0, DL, VT: OpVT);
13352 Op = DAG.getSelect(DL, VT: OpVT, Cond: Pred, LHS: Op, RHS: Zero);
13353 OtherOp = DAG.getFreeze(V: OtherOp);
13354 }
13355 };
13356
13357 // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
13358 // -> partial_reduce_*mla(acc, x, C)
13359 APInt C;
13360 if (ISD::isConstantSplatVector(N: RHS.getNode(), SplatValue&: C)) {
13361 // TODO: Make use of partial_reduce_sumla here
13362 APInt CTrunc = C.trunc(width: LHSExtOpVT.getScalarSizeInBits());
13363 unsigned LHSBits = LHS.getValueType().getScalarSizeInBits();
13364 if ((LHSOpcode != ISD::ZERO_EXTEND || CTrunc.zext(width: LHSBits) != C) &&
13365 (LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(width: LHSBits) != C))
13366 return SDValue();
13367
13368 unsigned NewOpcode = LHSOpcode == ISD::SIGN_EXTEND
13369 ? ISD::PARTIAL_REDUCE_SMLA
13370 : ISD::PARTIAL_REDUCE_UMLA;
13371
13372 // Only perform these combines if the target supports folding
13373 // the extends into the operation.
13374 if (!TLI.isPartialReduceMLALegalOrCustom(
13375 Opc: NewOpcode, AccVT: TLI.getTypeToTransformTo(Context&: *Context, VT: N->getValueType(ResNo: 0)),
13376 InputVT: TLI.getTypeToTransformTo(Context&: *Context, VT: LHSExtOpVT)))
13377 return SDValue();
13378
13379 SDValue C = DAG.getConstant(Val: CTrunc, DL, VT: LHSExtOpVT);
13380 ApplyPredicate(C, LHSExtOp);
13381 return DAG.getNode(Opcode: NewOpcode, DL, VT: N->getValueType(ResNo: 0), N1: Acc, N2: LHSExtOp, N3: C);
13382 }
13383
13384 unsigned RHSOpcode = RHS->getOpcode();
13385 if (!IsIntOrFPExtOpcode(RHSOpcode))
13386 return SDValue();
13387
13388 SDValue RHSExtOp = RHS->getOperand(Num: 0);
13389 if (LHSExtOpVT != RHSExtOp.getValueType())
13390 return SDValue();
13391
13392 unsigned NewOpc;
13393 if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::SIGN_EXTEND)
13394 NewOpc = ISD::PARTIAL_REDUCE_SMLA;
13395 else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::ZERO_EXTEND)
13396 NewOpc = ISD::PARTIAL_REDUCE_UMLA;
13397 else if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::ZERO_EXTEND)
13398 NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
13399 else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::SIGN_EXTEND) {
13400 NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
13401 std::swap(a&: LHSExtOp, b&: RHSExtOp);
13402 } else if (LHSOpcode == ISD::FP_EXTEND && RHSOpcode == ISD::FP_EXTEND) {
13403 NewOpc = ISD::PARTIAL_REDUCE_FMLA;
13404 } else
13405 return SDValue();
13406 // For a 2-stage extend the signedness of both of the extends must match
13407 // If the mul has the same type, there is no outer extend, and thus we
13408 // can simply use the inner extends to pick the result node.
13409 // TODO: extend to handle nonneg zext as sext
13410 EVT AccElemVT = Acc.getValueType().getVectorElementType();
13411 if (Op1.getValueType().getVectorElementType() != AccElemVT &&
13412 NewOpc != N->getOpcode())
13413 return SDValue();
13414
13415 // Only perform these combines if the target supports folding
13416 // the extends into the operation.
13417 if (!TLI.isPartialReduceMLALegalOrCustom(
13418 Opc: NewOpc, AccVT: TLI.getTypeToTransformTo(Context&: *Context, VT: N->getValueType(ResNo: 0)),
13419 InputVT: TLI.getTypeToTransformTo(Context&: *Context, VT: LHSExtOpVT)))
13420 return SDValue();
13421
13422 ApplyPredicate(RHSExtOp, LHSExtOp);
13423 return DAG.getNode(Opcode: NewOpc, DL, VT: N->getValueType(ResNo: 0), N1: Acc, N2: LHSExtOp, N3: RHSExtOp);
13424}
13425
13426// partial.reduce.*mla(acc, *ext(op), splat(1))
13427// -> partial.reduce.*mla(acc, op, splat(trunc(1)))
13428// partial.reduce.sumla(acc, sext(op), splat(1))
13429// -> partial.reduce.smla(acc, op, splat(trunc(1)))
13430//
13431// partial.reduce.*mla(acc, sel(p, *ext(op), splat(0)), splat(1))
13432// -> partial.reduce.*mla(acc, sel(p, op, splat(0)), splat(trunc(1)))
13433SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
13434 SDLoc DL(N);
13435 SDValue Acc = N->getOperand(Num: 0);
13436 SDValue Op1 = N->getOperand(Num: 1);
13437 SDValue Op2 = N->getOperand(Num: 2);
13438
13439 if (!llvm::isOneOrOneSplat(V: Op2) && !llvm::isOneOrOneSplatFP(V: Op2))
13440 return SDValue();
13441
13442 SDValue Pred;
13443 unsigned Op1Opcode = Op1.getOpcode();
13444 if (Op1Opcode == ISD::VSELECT && (isZeroOrZeroSplat(N: Op1->getOperand(Num: 2)) ||
13445 isZeroOrZeroSplatFP(N: Op1->getOperand(Num: 2)))) {
13446 Pred = Op1->getOperand(Num: 0);
13447 Op1 = Op1->getOperand(Num: 1);
13448 Op1Opcode = Op1->getOpcode();
13449 }
13450
13451 if (!ISD::isExtOpcode(Opcode: Op1Opcode) && Op1Opcode != ISD::FP_EXTEND)
13452 return SDValue();
13453
13454 bool Op1IsSigned =
13455 Op1Opcode == ISD::SIGN_EXTEND || Op1Opcode == ISD::FP_EXTEND;
13456 bool NodeIsSigned = N->getOpcode() != ISD::PARTIAL_REDUCE_UMLA;
13457 EVT AccElemVT = Acc.getValueType().getVectorElementType();
13458 if (Op1IsSigned != NodeIsSigned &&
13459 Op1.getValueType().getVectorElementType() != AccElemVT)
13460 return SDValue();
13461
13462 unsigned NewOpcode = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA
13463 ? ISD::PARTIAL_REDUCE_FMLA
13464 : Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA
13465 : ISD::PARTIAL_REDUCE_UMLA;
13466
13467 SDValue UnextOp1 = Op1.getOperand(i: 0);
13468 EVT UnextOp1VT = UnextOp1.getValueType();
13469 auto *Context = DAG.getContext();
13470 if (!TLI.isPartialReduceMLALegalOrCustom(
13471 Opc: NewOpcode, AccVT: TLI.getTypeToTransformTo(Context&: *Context, VT: N->getValueType(ResNo: 0)),
13472 InputVT: TLI.getTypeToTransformTo(Context&: *Context, VT: UnextOp1VT)))
13473 return SDValue();
13474
13475 SDValue Constant = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA
13476 ? DAG.getConstantFP(Val: 1, DL, VT: UnextOp1VT)
13477 : DAG.getConstant(Val: 1, DL, VT: UnextOp1VT);
13478
13479 if (Pred) {
13480 SDValue Zero = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA
13481 ? DAG.getConstantFP(Val: 0, DL, VT: UnextOp1VT)
13482 : DAG.getConstant(Val: 0, DL, VT: UnextOp1VT);
13483 Constant = DAG.getSelect(DL, VT: UnextOp1VT, Cond: Pred, LHS: Constant, RHS: Zero);
13484 }
13485 return DAG.getNode(Opcode: NewOpcode, DL, VT: N->getValueType(ResNo: 0), N1: Acc, N2: UnextOp1,
13486 N3: Constant);
13487}
13488
13489SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
13490 auto *SLD = cast<VPStridedLoadSDNode>(Val: N);
13491 EVT EltVT = SLD->getValueType(ResNo: 0).getVectorElementType();
13492 // Combine strided loads with unit-stride to a regular VP load.
13493 if (auto *CStride = dyn_cast<ConstantSDNode>(Val: SLD->getStride());
13494 CStride && CStride->getZExtValue() == EltVT.getStoreSize()) {
13495 SDValue NewLd = DAG.getLoadVP(
13496 AM: SLD->getAddressingMode(), ExtType: SLD->getExtensionType(), VT: SLD->getValueType(ResNo: 0),
13497 dl: SDLoc(N), Chain: SLD->getChain(), Ptr: SLD->getBasePtr(), Offset: SLD->getOffset(),
13498 Mask: SLD->getMask(), EVL: SLD->getVectorLength(), MemVT: SLD->getMemoryVT(),
13499 MMO: SLD->getMemOperand(), IsExpanding: SLD->isExpandingLoad());
13500 return CombineTo(N, Res0: NewLd, Res1: NewLd.getValue(R: 1));
13501 }
13502 return SDValue();
13503}
13504
13505/// A vector select of 2 constant vectors can be simplified to math/logic to
13506/// avoid a variable select instruction and possibly avoid constant loads.
13507SDValue DAGCombiner::foldVSelectOfConstants(SDNode *N) {
13508 SDValue Cond = N->getOperand(Num: 0);
13509 SDValue N1 = N->getOperand(Num: 1);
13510 SDValue N2 = N->getOperand(Num: 2);
13511 EVT VT = N->getValueType(ResNo: 0);
13512 if (!Cond.hasOneUse() || Cond.getScalarValueSizeInBits() != 1 ||
13513 !shouldConvertSelectOfConstantsToMath(Cond, VT, TLI) ||
13514 !ISD::isBuildVectorOfConstantSDNodes(N: N1.getNode()) ||
13515 !ISD::isBuildVectorOfConstantSDNodes(N: N2.getNode()))
13516 return SDValue();
13517
13518 // Check if we can use the condition value to increment/decrement a single
13519 // constant value. This simplifies a select to an add and removes a constant
13520 // load/materialization from the general case.
13521 bool AllAddOne = true;
13522 bool AllSubOne = true;
13523 unsigned Elts = VT.getVectorNumElements();
13524 for (unsigned i = 0; i != Elts; ++i) {
13525 SDValue N1Elt = N1.getOperand(i);
13526 SDValue N2Elt = N2.getOperand(i);
13527 if (N1Elt.isUndef())
13528 continue;
13529 // N2 should not contain undef values since it will be reused in the fold.
13530 if (N2Elt.isUndef() || N1Elt.getValueType() != N2Elt.getValueType()) {
13531 AllAddOne = false;
13532 AllSubOne = false;
13533 break;
13534 }
13535
13536 const APInt &C1 = N1Elt->getAsAPIntVal();
13537 const APInt &C2 = N2Elt->getAsAPIntVal();
13538 if (C1 != C2 + 1)
13539 AllAddOne = false;
13540 if (C1 != C2 - 1)
13541 AllSubOne = false;
13542 }
13543
13544 // Further simplifications for the extra-special cases where the constants are
13545 // all 0 or all -1 should be implemented as folds of these patterns.
13546 SDLoc DL(N);
13547 if (AllAddOne || AllSubOne) {
13548 // vselect <N x i1> Cond, C+1, C --> add (zext Cond), C
13549 // vselect <N x i1> Cond, C-1, C --> add (sext Cond), C
13550 auto ExtendOpcode = AllAddOne ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND;
13551 SDValue ExtendedCond = DAG.getNode(Opcode: ExtendOpcode, DL, VT, Operand: Cond);
13552 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: ExtendedCond, N2);
13553 }
13554
13555 // select Cond, Pow2C, 0 --> (zext Cond) << log2(Pow2C)
13556 APInt Pow2C;
13557 if (ISD::isConstantSplatVector(N: N1.getNode(), SplatValue&: Pow2C) && Pow2C.isPowerOf2() &&
13558 isNullOrNullSplat(V: N2)) {
13559 SDValue ZextCond = DAG.getZExtOrTrunc(Op: Cond, DL, VT);
13560 SDValue ShAmtC = DAG.getConstant(Val: Pow2C.exactLogBase2(), DL, VT);
13561 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: ZextCond, N2: ShAmtC);
13562 }
13563
13564 if (SDValue V = foldSelectOfConstantsUsingSra(N, DL, DAG))
13565 return V;
13566
13567 // The general case for select-of-constants:
13568 // vselect <N x i1> Cond, C1, C2 --> xor (and (sext Cond), (C1^C2)), C2
13569 // ...but that only makes sense if a vselect is slower than 2 logic ops, so
13570 // leave that to a machine-specific pass.
13571 return SDValue();
13572}
13573
13574SDValue DAGCombiner::visitVP_SELECT(SDNode *N) {
13575 SDValue N0 = N->getOperand(Num: 0);
13576 SDValue N1 = N->getOperand(Num: 1);
13577 SDValue N2 = N->getOperand(Num: 2);
13578 SDLoc DL(N);
13579
13580 if (SDValue V = DAG.simplifySelect(Cond: N0, TVal: N1, FVal: N2))
13581 return V;
13582
13583 if (SDValue V = foldBoolSelectToLogic<VPMatchContext>(N, DL, DAG))
13584 return V;
13585
13586 return SDValue();
13587}
13588
13589static SDValue combineVSelectWithAllOnesOrZeros(SDValue Cond, SDValue TVal,
13590 SDValue FVal,
13591 const TargetLowering &TLI,
13592 SelectionDAG &DAG,
13593 const SDLoc &DL) {
13594 EVT VT = TVal.getValueType();
13595 if (!TLI.isTypeLegal(VT))
13596 return SDValue();
13597
13598 EVT CondVT = Cond.getValueType();
13599 assert(CondVT.isVector() && "Vector select expects a vector selector!");
13600
13601 bool IsTAllZero = ISD::isConstantSplatVectorAllZeros(N: TVal.getNode());
13602 bool IsTAllOne = ISD::isConstantSplatVectorAllOnes(N: TVal.getNode());
13603 bool IsFAllZero = ISD::isConstantSplatVectorAllZeros(N: FVal.getNode());
13604 bool IsFAllOne = ISD::isConstantSplatVectorAllOnes(N: FVal.getNode());
13605
13606 // no vselect(cond, 0/-1, X) or vselect(cond, X, 0/-1), return
13607 if (!IsTAllZero && !IsTAllOne && !IsFAllZero && !IsFAllOne)
13608 return SDValue();
13609
13610 // select Cond, 0, 0 → 0
13611 if (IsTAllZero && IsFAllZero) {
13612 return VT.isFloatingPoint() ? DAG.getConstantFP(Val: 0.0, DL, VT)
13613 : DAG.getConstant(Val: 0, DL, VT);
13614 }
13615
13616 // check select(setgt lhs, -1), 1, -1 --> or (sra lhs, bitwidth - 1), 1
13617 APInt TValAPInt;
13618 if (Cond.getOpcode() == ISD::SETCC &&
13619 Cond.getOperand(i: 2) == DAG.getCondCode(Cond: ISD::SETGT) &&
13620 Cond.getOperand(i: 0).getValueType() == VT && VT.isSimple() &&
13621 ISD::isConstantSplatVector(N: TVal.getNode(), SplatValue&: TValAPInt) &&
13622 TValAPInt.isOne() &&
13623 ISD::isConstantSplatVectorAllOnes(N: Cond.getOperand(i: 1).getNode()) &&
13624 ISD::isConstantSplatVectorAllOnes(N: FVal.getNode())) {
13625 return SDValue();
13626 }
13627
13628 // To use the condition operand as a bitwise mask, it must have elements that
13629 // are the same size as the select elements. i.e, the condition operand must
13630 // have already been promoted from the IR select condition type <N x i1>.
13631 // Don't check if the types themselves are equal because that excludes
13632 // vector floating-point selects.
13633 if (CondVT.getScalarSizeInBits() != VT.getScalarSizeInBits())
13634 return SDValue();
13635
13636 // Cond value must be 'sign splat' to be converted to a logical op.
13637 if (DAG.ComputeNumSignBits(Op: Cond) != CondVT.getScalarSizeInBits())
13638 return SDValue();
13639
13640 // Try inverting Cond and swapping T/F if it gives all-ones/all-zeros form
13641 if (!IsTAllOne && !IsFAllZero && Cond.hasOneUse() &&
13642 Cond.getOpcode() == ISD::SETCC &&
13643 TLI.getSetCCResultType(DL: DAG.getDataLayout(), Context&: *DAG.getContext(), VT) ==
13644 CondVT) {
13645 if (IsTAllZero || IsFAllOne) {
13646 SDValue CC = Cond.getOperand(i: 2);
13647 ISD::CondCode InverseCC = ISD::getSetCCInverse(
13648 Operation: cast<CondCodeSDNode>(Val&: CC)->get(), Type: Cond.getOperand(i: 0).getValueType());
13649 Cond = DAG.getSetCC(DL, VT: CondVT, LHS: Cond.getOperand(i: 0), RHS: Cond.getOperand(i: 1),
13650 Cond: InverseCC);
13651 std::swap(a&: TVal, b&: FVal);
13652 std::swap(a&: IsTAllOne, b&: IsFAllOne);
13653 std::swap(a&: IsTAllZero, b&: IsFAllZero);
13654 }
13655 }
13656
13657 assert(DAG.ComputeNumSignBits(Cond) == CondVT.getScalarSizeInBits() &&
13658 "Select condition no longer all-sign bits");
13659
13660 // select Cond, -1, 0 → bitcast Cond
13661 if (IsTAllOne && IsFAllZero)
13662 return DAG.getBitcast(VT, V: Cond);
13663
13664 // select Cond, -1, x → or Cond, x
13665 if (IsTAllOne) {
13666 SDValue X = DAG.getBitcast(VT: CondVT, V: DAG.getFreeze(V: FVal));
13667 SDValue Or = DAG.getNode(Opcode: ISD::OR, DL, VT: CondVT, N1: Cond, N2: X);
13668 return DAG.getBitcast(VT, V: Or);
13669 }
13670
13671 // select Cond, x, 0 → and Cond, x
13672 if (IsFAllZero) {
13673 SDValue X = DAG.getBitcast(VT: CondVT, V: DAG.getFreeze(V: TVal));
13674 SDValue And = DAG.getNode(Opcode: ISD::AND, DL, VT: CondVT, N1: Cond, N2: X);
13675 return DAG.getBitcast(VT, V: And);
13676 }
13677
13678 // select Cond, 0, x -> and not(Cond), x
13679 if (IsTAllZero &&
13680 (isBitwiseNot(V: peekThroughBitcasts(V: Cond)) || TLI.hasAndNot(X: Cond))) {
13681 SDValue X = DAG.getBitcast(VT: CondVT, V: DAG.getFreeze(V: FVal));
13682 SDValue And =
13683 DAG.getNode(Opcode: ISD::AND, DL, VT: CondVT, N1: DAG.getNOT(DL, Val: Cond, VT: CondVT), N2: X);
13684 return DAG.getBitcast(VT, V: And);
13685 }
13686
13687 return SDValue();
13688}
13689
13690SDValue DAGCombiner::visitVSELECT(SDNode *N) {
13691 SDValue N0 = N->getOperand(Num: 0);
13692 SDValue N1 = N->getOperand(Num: 1);
13693 SDValue N2 = N->getOperand(Num: 2);
13694 EVT VT = N->getValueType(ResNo: 0);
13695 SDLoc DL(N);
13696
13697 if (SDValue V = DAG.simplifySelect(Cond: N0, TVal: N1, FVal: N2))
13698 return V;
13699
13700 if (SDValue V = foldBoolSelectToLogic<EmptyMatchContext>(N, DL, DAG))
13701 return V;
13702
13703 // vselect (not Cond), N1, N2 -> vselect Cond, N2, N1
13704 if (!TLI.isTargetCanonicalSelect(N))
13705 if (SDValue F = extractBooleanFlip(V: N0, DAG, TLI, Force: false))
13706 return DAG.getSelect(DL, VT, Cond: F, LHS: N2, RHS: N1);
13707
13708 // select (sext m), (add X, C), X --> (add X, (and C, (sext m))))
13709 if (N1.getOpcode() == ISD::ADD && N1.getOperand(i: 0) == N2 && N1->hasOneUse() &&
13710 DAG.isConstantIntBuildVectorOrConstantInt(N: N1.getOperand(i: 1)) &&
13711 N0.getScalarValueSizeInBits() == N1.getScalarValueSizeInBits() &&
13712 TLI.getBooleanContents(Type: N0.getValueType()) ==
13713 TargetLowering::ZeroOrNegativeOneBooleanContent) {
13714 return DAG.getNode(
13715 Opcode: ISD::ADD, DL, VT: N1.getValueType(), N1: N2,
13716 N2: DAG.getNode(Opcode: ISD::AND, DL, VT: N0.getValueType(), N1: N1.getOperand(i: 1), N2: N0));
13717 }
13718
13719 // Canonicalize integer abs.
13720 // vselect (setg[te] X, 0), X, -X ->
13721 // vselect (setgt X, -1), X, -X ->
13722 // vselect (setl[te] X, 0), -X, X ->
13723 // Y = sra (X, size(X)-1); xor (add (X, Y), Y)
13724 if (N0.getOpcode() == ISD::SETCC) {
13725 SDValue LHS = N0.getOperand(i: 0), RHS = N0.getOperand(i: 1);
13726 ISD::CondCode CC = cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get();
13727 bool isAbs = false;
13728 bool RHSIsAllZeros = ISD::isBuildVectorAllZeros(N: RHS.getNode());
13729
13730 if (((RHSIsAllZeros && (CC == ISD::SETGT || CC == ISD::SETGE)) ||
13731 (ISD::isBuildVectorAllOnes(N: RHS.getNode()) && CC == ISD::SETGT)) &&
13732 N1 == LHS && N2.getOpcode() == ISD::SUB && N1 == N2.getOperand(i: 1))
13733 isAbs = ISD::isBuildVectorAllZeros(N: N2.getOperand(i: 0).getNode());
13734 else if ((RHSIsAllZeros && (CC == ISD::SETLT || CC == ISD::SETLE)) &&
13735 N2 == LHS && N1.getOpcode() == ISD::SUB && N2 == N1.getOperand(i: 1))
13736 isAbs = ISD::isBuildVectorAllZeros(N: N1.getOperand(i: 0).getNode());
13737
13738 if (isAbs) {
13739 if (TLI.isOperationLegalOrCustom(Op: ISD::ABS, VT))
13740 return DAG.getNode(Opcode: ISD::ABS, DL, VT, Operand: LHS);
13741
13742 SDValue Shift = DAG.getNode(
13743 Opcode: ISD::SRA, DL, VT, N1: LHS,
13744 N2: DAG.getShiftAmountConstant(Val: VT.getScalarSizeInBits() - 1, VT, DL));
13745 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: LHS, N2: Shift);
13746 AddToWorklist(N: Shift.getNode());
13747 AddToWorklist(N: Add.getNode());
13748 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1: Add, N2: Shift);
13749 }
13750
13751 // vselect x, y (fcmp lt x, y) -> fminnum x, y
13752 // vselect x, y (fcmp gt x, y) -> fmaxnum x, y
13753 //
13754 // This is OK if we don't care about what happens if either operand is a
13755 // NaN.
13756 //
13757 if (N0.hasOneUse() &&
13758 isLegalToCombineMinNumMaxNum(DAG, LHS, RHS, Flags: N->getFlags(), TLI)) {
13759 if (SDValue FMinMax = combineMinNumMaxNum(DL, VT, LHS, RHS, True: N1, False: N2, CC))
13760 return FMinMax;
13761 }
13762
13763 if (SDValue S = PerformMinMaxFpToSatCombine(N0: LHS, N1: RHS, N2: N1, N3: N2, CC, DAG))
13764 return S;
13765 if (SDValue S = PerformUMinFpToSatCombine(N0: LHS, N1: RHS, N2: N1, N3: N2, CC, DAG))
13766 return S;
13767
13768 // If this select has a condition (setcc) with narrower operands than the
13769 // select, try to widen the compare to match the select width.
13770 // TODO: This should be extended to handle any constant.
13771 // TODO: This could be extended to handle non-loading patterns, but that
13772 // requires thorough testing to avoid regressions.
13773 if (isNullOrNullSplat(V: RHS)) {
13774 EVT NarrowVT = LHS.getValueType();
13775 EVT WideVT = N1.getValueType().changeVectorElementTypeToInteger();
13776 EVT SetCCVT = getSetCCResultType(VT: LHS.getValueType());
13777 unsigned SetCCWidth = SetCCVT.getScalarSizeInBits();
13778 unsigned WideWidth = WideVT.getScalarSizeInBits();
13779 bool IsSigned = isSignedIntSetCC(Code: CC);
13780 auto LoadExtOpcode = IsSigned ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
13781 if (LHS.getOpcode() == ISD::LOAD && LHS.hasOneUse() &&
13782 SetCCWidth != 1 && SetCCWidth < WideWidth &&
13783 TLI.isLoadExtLegalOrCustom(ExtType: LoadExtOpcode, ValVT: WideVT, MemVT: NarrowVT) &&
13784 TLI.isOperationLegalOrCustom(Op: ISD::SETCC, VT: WideVT)) {
13785 // Both compare operands can be widened for free. The LHS can use an
13786 // extended load, and the RHS is a constant:
13787 // vselect (ext (setcc load(X), C)), N1, N2 -->
13788 // vselect (setcc extload(X), C'), N1, N2
13789 auto ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
13790 SDValue WideLHS = DAG.getNode(Opcode: ExtOpcode, DL, VT: WideVT, Operand: LHS);
13791 SDValue WideRHS = DAG.getNode(Opcode: ExtOpcode, DL, VT: WideVT, Operand: RHS);
13792 EVT WideSetCCVT = getSetCCResultType(VT: WideVT);
13793 SDValue WideSetCC = DAG.getSetCC(DL, VT: WideSetCCVT, LHS: WideLHS, RHS: WideRHS, Cond: CC);
13794 return DAG.getSelect(DL, VT: N1.getValueType(), Cond: WideSetCC, LHS: N1, RHS: N2);
13795 }
13796 }
13797
13798 if (SDValue ABD = foldSelectToABD(LHS, RHS, True: N1, False: N2, CC, DL))
13799 return ABD;
13800
13801 // Match VSELECTs into add with unsigned saturation.
13802 if (hasOperation(Opcode: ISD::UADDSAT, VT)) {
13803 // Check if one of the arms of the VSELECT is vector with all bits set.
13804 // If it's on the left side invert the predicate to simplify logic below.
13805 SDValue Other;
13806 ISD::CondCode SatCC = CC;
13807 if (ISD::isConstantSplatVectorAllOnes(N: N1.getNode())) {
13808 Other = N2;
13809 SatCC = ISD::getSetCCInverse(Operation: SatCC, Type: VT.getScalarType());
13810 } else if (ISD::isConstantSplatVectorAllOnes(N: N2.getNode())) {
13811 Other = N1;
13812 }
13813
13814 if (Other && Other.getOpcode() == ISD::ADD) {
13815 SDValue CondLHS = LHS, CondRHS = RHS;
13816 SDValue OpLHS = Other.getOperand(i: 0), OpRHS = Other.getOperand(i: 1);
13817
13818 // Canonicalize condition operands.
13819 if (SatCC == ISD::SETUGE) {
13820 std::swap(a&: CondLHS, b&: CondRHS);
13821 SatCC = ISD::SETULE;
13822 }
13823
13824 // We can test against either of the addition operands.
13825 // x <= x+y ? x+y : ~0 --> uaddsat x, y
13826 // x+y >= x ? x+y : ~0 --> uaddsat x, y
13827 if (SatCC == ISD::SETULE && Other == CondRHS &&
13828 (OpLHS == CondLHS || OpRHS == CondLHS))
13829 return DAG.getNode(Opcode: ISD::UADDSAT, DL, VT, N1: OpLHS, N2: OpRHS);
13830
13831 if (OpRHS.getOpcode() == CondRHS.getOpcode() &&
13832 (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
13833 OpRHS.getOpcode() == ISD::SPLAT_VECTOR) &&
13834 CondLHS == OpLHS) {
13835 // If the RHS is a constant we have to reverse the const
13836 // canonicalization.
13837 // x >= ~C ? x+C : ~0 --> uaddsat x, C
13838 auto MatchUADDSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
13839 return Cond->getAPIntValue() == ~Op->getAPIntValue();
13840 };
13841 if (SatCC == ISD::SETULE &&
13842 ISD::matchBinaryPredicate(LHS: OpRHS, RHS: CondRHS, Match: MatchUADDSAT))
13843 return DAG.getNode(Opcode: ISD::UADDSAT, DL, VT, N1: OpLHS, N2: OpRHS);
13844 }
13845 }
13846 }
13847
13848 // Match VSELECTs into sub with unsigned saturation.
13849 if (hasOperation(Opcode: ISD::USUBSAT, VT)) {
13850 // Check if one of the arms of the VSELECT is a zero vector. If it's on
13851 // the left side invert the predicate to simplify logic below.
13852 SDValue Other;
13853 ISD::CondCode SatCC = CC;
13854 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode())) {
13855 Other = N2;
13856 SatCC = ISD::getSetCCInverse(Operation: SatCC, Type: VT.getScalarType());
13857 } else if (ISD::isConstantSplatVectorAllZeros(N: N2.getNode())) {
13858 Other = N1;
13859 }
13860
13861 // zext(x) >= y ? trunc(zext(x) - y) : 0
13862 // --> usubsat(trunc(zext(x)),trunc(umin(y,SatLimit)))
13863 // zext(x) > y ? trunc(zext(x) - y) : 0
13864 // --> usubsat(trunc(zext(x)),trunc(umin(y,SatLimit)))
13865 if (Other && Other.getOpcode() == ISD::TRUNCATE &&
13866 Other.getOperand(i: 0).getOpcode() == ISD::SUB &&
13867 (SatCC == ISD::SETUGE || SatCC == ISD::SETUGT)) {
13868 SDValue OpLHS = Other.getOperand(i: 0).getOperand(i: 0);
13869 SDValue OpRHS = Other.getOperand(i: 0).getOperand(i: 1);
13870 if (LHS == OpLHS && RHS == OpRHS && LHS.getOpcode() == ISD::ZERO_EXTEND)
13871 if (SDValue R = getTruncatedUSUBSAT(DstVT: VT, SrcVT: LHS.getValueType(), LHS, RHS,
13872 DAG, DL))
13873 return R;
13874 }
13875
13876 if (Other && Other.getNumOperands() == 2) {
13877 SDValue CondRHS = RHS;
13878 SDValue OpLHS = Other.getOperand(i: 0), OpRHS = Other.getOperand(i: 1);
13879
13880 if (OpLHS == LHS) {
13881 // Look for a general sub with unsigned saturation first.
13882 // x >= y ? x-y : 0 --> usubsat x, y
13883 // x > y ? x-y : 0 --> usubsat x, y
13884 if ((SatCC == ISD::SETUGE || SatCC == ISD::SETUGT) &&
13885 Other.getOpcode() == ISD::SUB && OpRHS == CondRHS)
13886 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT, N1: OpLHS, N2: OpRHS);
13887
13888 if (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
13889 OpRHS.getOpcode() == ISD::SPLAT_VECTOR) {
13890 if (CondRHS.getOpcode() == ISD::BUILD_VECTOR ||
13891 CondRHS.getOpcode() == ISD::SPLAT_VECTOR) {
13892 // If the RHS is a constant we have to reverse the const
13893 // canonicalization.
13894 // x > C-1 ? x+-C : 0 --> usubsat x, C
13895 auto MatchUSUBSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
13896 return (!Op && !Cond) ||
13897 (Op && Cond &&
13898 Cond->getAPIntValue() == (-Op->getAPIntValue() - 1));
13899 };
13900 if (SatCC == ISD::SETUGT && Other.getOpcode() == ISD::ADD &&
13901 ISD::matchBinaryPredicate(LHS: OpRHS, RHS: CondRHS, Match: MatchUSUBSAT,
13902 /*AllowUndefs*/ true)) {
13903 OpRHS = DAG.getNegative(Val: OpRHS, DL, VT);
13904 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT, N1: OpLHS, N2: OpRHS);
13905 }
13906
13907 // Another special case: If C was a sign bit, the sub has been
13908 // canonicalized into a xor.
13909 // FIXME: Would it be better to use computeKnownBits to
13910 // determine whether it's safe to decanonicalize the xor?
13911 // x s< 0 ? x^C : 0 --> usubsat x, C
13912 APInt SplatValue;
13913 if (SatCC == ISD::SETLT && Other.getOpcode() == ISD::XOR &&
13914 ISD::isConstantSplatVector(N: OpRHS.getNode(), SplatValue) &&
13915 ISD::isConstantSplatVectorAllZeros(N: CondRHS.getNode()) &&
13916 SplatValue.isSignMask()) {
13917 // Note that we have to rebuild the RHS constant here to
13918 // ensure we don't rely on particular values of undef lanes.
13919 OpRHS = DAG.getConstant(Val: SplatValue, DL, VT);
13920 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT, N1: OpLHS, N2: OpRHS);
13921 }
13922 }
13923 }
13924 }
13925 }
13926 }
13927
13928 // (vselect (ugt x, C), (add x, ~C), x) -> (umin (add x, ~C), x)
13929 // (vselect (ult x, C), x, (add x, -C)) -> (umin x, (add x, -C))
13930 if (SDValue UMin = foldSelectToUMin(LHS, RHS, True: N1, False: N2, CC, DL))
13931 return UMin;
13932 }
13933
13934 if (SimplifySelectOps(SELECT: N, LHS: N1, RHS: N2))
13935 return SDValue(N, 0); // Don't revisit N.
13936
13937 // Fold (vselect all_ones, N1, N2) -> N1
13938 if (ISD::isConstantSplatVectorAllOnes(N: N0.getNode()))
13939 return N1;
13940 // Fold (vselect all_zeros, N1, N2) -> N2
13941 if (ISD::isConstantSplatVectorAllZeros(N: N0.getNode()))
13942 return N2;
13943
13944 // The ConvertSelectToConcatVector function is assuming both the above
13945 // checks for (vselect (build_vector all{ones,zeros) ...) have been made
13946 // and addressed.
13947 if (N1.getOpcode() == ISD::CONCAT_VECTORS &&
13948 N2.getOpcode() == ISD::CONCAT_VECTORS &&
13949 ISD::isBuildVectorOfConstantSDNodes(N: N0.getNode())) {
13950 if (SDValue CV = ConvertSelectToConcatVector(N, DAG))
13951 return CV;
13952 }
13953
13954 if (SDValue V = foldVSelectOfConstants(N))
13955 return V;
13956
13957 if (hasOperation(Opcode: ISD::SRA, VT))
13958 if (SDValue V = foldVSelectToSignBitSplatMask(N, DAG))
13959 return V;
13960
13961 if (SimplifyDemandedVectorElts(Op: SDValue(N, 0)))
13962 return SDValue(N, 0);
13963
13964 if (SDValue V = combineVSelectWithAllOnesOrZeros(Cond: N0, TVal: N1, FVal: N2, TLI, DAG, DL))
13965 return V;
13966
13967 return SDValue();
13968}
13969
13970SDValue DAGCombiner::visitSELECT_CC(SDNode *N) {
13971 SDValue N0 = N->getOperand(Num: 0);
13972 SDValue N1 = N->getOperand(Num: 1);
13973 SDValue N2 = N->getOperand(Num: 2);
13974 SDValue N3 = N->getOperand(Num: 3);
13975 SDValue N4 = N->getOperand(Num: 4);
13976 ISD::CondCode CC = cast<CondCodeSDNode>(Val&: N4)->get();
13977 SDLoc DL(N);
13978
13979 // fold select_cc lhs, rhs, x, x, cc -> x
13980 if (N2 == N3)
13981 return N2;
13982
13983 // select_cc bool, 0, x, y, seteq -> select bool, y, x
13984 if (CC == ISD::SETEQ && !LegalTypes && N0.getValueType() == MVT::i1 &&
13985 isNullConstant(V: N1))
13986 return DAG.getSelect(DL, VT: N2.getValueType(), Cond: N0, LHS: N3, RHS: N2);
13987
13988 // Determine if the condition we're dealing with is constant
13989 if (SDValue SCC = SimplifySetCC(VT: getSetCCResultType(VT: N0.getValueType()), N0, N1,
13990 Cond: CC, DL, foldBooleans: false)) {
13991 AddToWorklist(N: SCC.getNode());
13992
13993 // cond always true -> true val
13994 // cond always false -> false val
13995 if (auto *SCCC = dyn_cast<ConstantSDNode>(Val: SCC.getNode()))
13996 return SCCC->isZero() ? N3 : N2;
13997
13998 // When the condition is UNDEF, just return the first operand. This is
13999 // coherent the DAG creation, no setcc node is created in this case
14000 if (SCC->isUndef())
14001 return N2;
14002
14003 // Fold to a simpler select_cc
14004 if (SCC.getOpcode() == ISD::SETCC) {
14005 return DAG.getNode(Opcode: ISD::SELECT_CC, DL, VT: N2.getValueType(),
14006 N1: SCC.getOperand(i: 0), N2: SCC.getOperand(i: 1), N3: N2, N4: N3,
14007 N5: SCC.getOperand(i: 2), Flags: SCC->getFlags());
14008 }
14009 }
14010
14011 // If we can fold this based on the true/false value, do so.
14012 if (SimplifySelectOps(SELECT: N, LHS: N2, RHS: N3))
14013 return SDValue(N, 0); // Don't revisit N.
14014
14015 // fold select_cc into other things, such as min/max/abs
14016 return SimplifySelectCC(DL, N0, N1, N2, N3, CC);
14017}
14018
14019SDValue DAGCombiner::visitSETCC(SDNode *N) {
14020 // setcc is very commonly used as an argument to brcond. This pattern
14021 // also lend itself to numerous combines and, as a result, it is desired
14022 // we keep the argument to a brcond as a setcc as much as possible.
14023 bool PreferSetCC =
14024 N->hasOneUse() && N->user_begin()->getOpcode() == ISD::BRCOND;
14025
14026 ISD::CondCode Cond = cast<CondCodeSDNode>(Val: N->getOperand(Num: 2))->get();
14027 EVT VT = N->getValueType(ResNo: 0);
14028 SDValue N0 = N->getOperand(Num: 0), N1 = N->getOperand(Num: 1);
14029 SDLoc DL(N);
14030
14031 if (SDValue Combined = SimplifySetCC(VT, N0, N1, Cond, DL, foldBooleans: !PreferSetCC)) {
14032 // If we prefer to have a setcc, and we don't, we'll try our best to
14033 // recreate one using rebuildSetCC.
14034 if (PreferSetCC && Combined.getOpcode() != ISD::SETCC) {
14035 SDValue NewSetCC = rebuildSetCC(N: Combined);
14036
14037 // We don't have anything interesting to combine to.
14038 if (NewSetCC.getNode() == N)
14039 return SDValue();
14040
14041 if (NewSetCC)
14042 return NewSetCC;
14043 }
14044 return Combined;
14045 }
14046
14047 // Optimize
14048 // 1) (icmp eq/ne (and X, C0), (shift X, C1))
14049 // or
14050 // 2) (icmp eq/ne X, (rotate X, C1))
14051 // If C0 is a mask or shifted mask and the shift amt (C1) isolates the
14052 // remaining bits (i.e something like `(x64 & UINT32_MAX) == (x64 >> 32)`)
14053 // Then:
14054 // If C1 is a power of 2, then the rotate and shift+and versions are
14055 // equivilent, so we can interchange them depending on target preference.
14056 // Otherwise, if we have the shift+and version we can interchange srl/shl
14057 // which inturn affects the constant C0. We can use this to get better
14058 // constants again determined by target preference.
14059 if (Cond == ISD::SETNE || Cond == ISD::SETEQ) {
14060 auto IsAndWithShift = [](SDValue A, SDValue B) {
14061 return A.getOpcode() == ISD::AND &&
14062 (B.getOpcode() == ISD::SRL || B.getOpcode() == ISD::SHL) &&
14063 A.getOperand(i: 0) == B.getOperand(i: 0);
14064 };
14065 auto IsRotateWithOp = [](SDValue A, SDValue B) {
14066 return (B.getOpcode() == ISD::ROTL || B.getOpcode() == ISD::ROTR) &&
14067 B.getOperand(i: 0) == A;
14068 };
14069 SDValue AndOrOp = SDValue(), ShiftOrRotate = SDValue();
14070 bool IsRotate = false;
14071
14072 // Find either shift+and or rotate pattern.
14073 if (IsAndWithShift(N0, N1)) {
14074 AndOrOp = N0;
14075 ShiftOrRotate = N1;
14076 } else if (IsAndWithShift(N1, N0)) {
14077 AndOrOp = N1;
14078 ShiftOrRotate = N0;
14079 } else if (IsRotateWithOp(N0, N1)) {
14080 IsRotate = true;
14081 AndOrOp = N0;
14082 ShiftOrRotate = N1;
14083 } else if (IsRotateWithOp(N1, N0)) {
14084 IsRotate = true;
14085 AndOrOp = N1;
14086 ShiftOrRotate = N0;
14087 }
14088
14089 if (AndOrOp && ShiftOrRotate && ShiftOrRotate.hasOneUse() &&
14090 (IsRotate || AndOrOp.hasOneUse())) {
14091 EVT OpVT = N0.getValueType();
14092 // Get constant shift/rotate amount and possibly mask (if its shift+and
14093 // variant).
14094 auto GetAPIntValue = [](SDValue Op) -> std::optional<APInt> {
14095 ConstantSDNode *CNode = isConstOrConstSplat(N: Op, /*AllowUndefs*/ false,
14096 /*AllowTrunc*/ AllowTruncation: false);
14097 if (CNode == nullptr)
14098 return std::nullopt;
14099 return CNode->getAPIntValue();
14100 };
14101 std::optional<APInt> AndCMask =
14102 IsRotate ? std::nullopt : GetAPIntValue(AndOrOp.getOperand(i: 1));
14103 std::optional<APInt> ShiftCAmt =
14104 GetAPIntValue(ShiftOrRotate.getOperand(i: 1));
14105 unsigned NumBits = OpVT.getScalarSizeInBits();
14106
14107 // We found constants.
14108 if (ShiftCAmt && (IsRotate || AndCMask) && ShiftCAmt->ult(RHS: NumBits)) {
14109 unsigned ShiftOpc = ShiftOrRotate.getOpcode();
14110 // Check that the constants meet the constraints.
14111 bool CanTransform = IsRotate;
14112 if (!CanTransform) {
14113 // Check that mask and shift compliment eachother
14114 CanTransform = *ShiftCAmt == (~*AndCMask).popcount();
14115 // Check that we are comparing all bits
14116 CanTransform &= (*ShiftCAmt + AndCMask->popcount()) == NumBits;
14117 // Check that the and mask is correct for the shift
14118 CanTransform &=
14119 ShiftOpc == ISD::SHL ? (~*AndCMask).isMask() : AndCMask->isMask();
14120 }
14121
14122 // See if target prefers another shift/rotate opcode.
14123 unsigned NewShiftOpc = TLI.preferedOpcodeForCmpEqPiecesOfOperand(
14124 VT: OpVT, ShiftOpc, MayTransformRotate: ShiftCAmt->isPowerOf2(), ShiftOrRotateAmt: *ShiftCAmt, AndMask: AndCMask);
14125 // Transform is valid and we have a new preference.
14126 if (CanTransform && NewShiftOpc != ShiftOpc) {
14127 SDValue NewShiftOrRotate =
14128 DAG.getNode(Opcode: NewShiftOpc, DL, VT: OpVT, N1: ShiftOrRotate.getOperand(i: 0),
14129 N2: ShiftOrRotate.getOperand(i: 1));
14130 SDValue NewAndOrOp = SDValue();
14131
14132 if (NewShiftOpc == ISD::SHL || NewShiftOpc == ISD::SRL) {
14133 APInt NewMask =
14134 NewShiftOpc == ISD::SHL
14135 ? APInt::getHighBitsSet(numBits: NumBits,
14136 hiBitsSet: NumBits - ShiftCAmt->getZExtValue())
14137 : APInt::getLowBitsSet(numBits: NumBits,
14138 loBitsSet: NumBits - ShiftCAmt->getZExtValue());
14139 NewAndOrOp =
14140 DAG.getNode(Opcode: ISD::AND, DL, VT: OpVT, N1: ShiftOrRotate.getOperand(i: 0),
14141 N2: DAG.getConstant(Val: NewMask, DL, VT: OpVT));
14142 } else {
14143 NewAndOrOp = ShiftOrRotate.getOperand(i: 0);
14144 }
14145
14146 return DAG.getSetCC(DL, VT, LHS: NewAndOrOp, RHS: NewShiftOrRotate, Cond);
14147 }
14148 }
14149 }
14150 }
14151 return SDValue();
14152}
14153
14154SDValue DAGCombiner::visitSETCCCARRY(SDNode *N) {
14155 SDValue LHS = N->getOperand(Num: 0);
14156 SDValue RHS = N->getOperand(Num: 1);
14157 SDValue Carry = N->getOperand(Num: 2);
14158 SDValue Cond = N->getOperand(Num: 3);
14159
14160 // If Carry is false, fold to a regular SETCC.
14161 if (isNullConstant(V: Carry))
14162 return DAG.getNode(Opcode: ISD::SETCC, DL: SDLoc(N), VTList: N->getVTList(), N1: LHS, N2: RHS, N3: Cond);
14163
14164 return SDValue();
14165}
14166
14167/// Check if N satisfies:
14168/// N is used once.
14169/// N is a Load.
14170/// The load is compatible with ExtOpcode. It means
14171/// If load has explicit zero/sign extension, ExpOpcode must have the same
14172/// extension.
14173/// Otherwise returns true.
14174static bool isCompatibleLoad(SDValue N, unsigned ExtOpcode) {
14175 if (!N.hasOneUse())
14176 return false;
14177
14178 if (!isa<LoadSDNode>(Val: N))
14179 return false;
14180
14181 LoadSDNode *Load = cast<LoadSDNode>(Val&: N);
14182 ISD::LoadExtType LoadExt = Load->getExtensionType();
14183 if (LoadExt == ISD::NON_EXTLOAD || LoadExt == ISD::EXTLOAD)
14184 return true;
14185
14186 // Now LoadExt is either SEXTLOAD or ZEXTLOAD, ExtOpcode must have the same
14187 // extension.
14188 if ((LoadExt == ISD::SEXTLOAD && ExtOpcode != ISD::SIGN_EXTEND) ||
14189 (LoadExt == ISD::ZEXTLOAD && ExtOpcode != ISD::ZERO_EXTEND))
14190 return false;
14191
14192 return true;
14193}
14194
14195/// Fold
14196/// (sext (select c, load x, load y)) -> (select c, sextload x, sextload y)
14197/// (zext (select c, load x, load y)) -> (select c, zextload x, zextload y)
14198/// (aext (select c, load x, load y)) -> (select c, extload x, extload y)
14199/// This function is called by the DAGCombiner when visiting sext/zext/aext
14200/// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
14201static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI,
14202 SelectionDAG &DAG, const SDLoc &DL,
14203 CombineLevel Level) {
14204 unsigned Opcode = N->getOpcode();
14205 SDValue N0 = N->getOperand(Num: 0);
14206 EVT VT = N->getValueType(ResNo: 0);
14207 assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND ||
14208 Opcode == ISD::ANY_EXTEND) &&
14209 "Expected EXTEND dag node in input!");
14210
14211 SDValue Cond, Op1, Op2;
14212 if (!sd_match(N: N0, P: m_OneUse(P: m_SelectLike(Cond: m_Value(N&: Cond), T: m_Value(N&: Op1),
14213 F: m_Value(N&: Op2)))))
14214 return SDValue();
14215
14216 if (!isCompatibleLoad(N: Op1, ExtOpcode: Opcode) || !isCompatibleLoad(N: Op2, ExtOpcode: Opcode))
14217 return SDValue();
14218
14219 auto ExtLoadOpcode = ISD::EXTLOAD;
14220 if (Opcode == ISD::SIGN_EXTEND)
14221 ExtLoadOpcode = ISD::SEXTLOAD;
14222 else if (Opcode == ISD::ZERO_EXTEND)
14223 ExtLoadOpcode = ISD::ZEXTLOAD;
14224
14225 // Illegal VSELECT may ISel fail if happen after legalization (DAG
14226 // Combine2), so we should conservatively check the OperationAction.
14227 LoadSDNode *Load1 = cast<LoadSDNode>(Val&: Op1);
14228 LoadSDNode *Load2 = cast<LoadSDNode>(Val&: Op2);
14229 if (!TLI.isLoadExtLegal(ExtType: ExtLoadOpcode, ValVT: VT, MemVT: Load1->getMemoryVT()) ||
14230 !TLI.isLoadExtLegal(ExtType: ExtLoadOpcode, ValVT: VT, MemVT: Load2->getMemoryVT()) ||
14231 (N0->getOpcode() == ISD::VSELECT && Level >= AfterLegalizeTypes &&
14232 TLI.getOperationAction(Op: ISD::VSELECT, VT) != TargetLowering::Legal))
14233 return SDValue();
14234
14235 SDValue Ext1 = DAG.getNode(Opcode, DL, VT, Operand: Op1);
14236 SDValue Ext2 = DAG.getNode(Opcode, DL, VT, Operand: Op2);
14237 return DAG.getSelect(DL, VT, Cond, LHS: Ext1, RHS: Ext2);
14238}
14239
14240/// Try to fold a sext/zext/aext dag node into a ConstantSDNode or
14241/// a build_vector of constants.
14242/// This function is called by the DAGCombiner when visiting sext/zext/aext
14243/// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
14244/// Vector extends are not folded if operations are legal; this is to
14245/// avoid introducing illegal build_vector dag nodes.
14246static SDValue tryToFoldExtendOfConstant(SDNode *N, const SDLoc &DL,
14247 const TargetLowering &TLI,
14248 SelectionDAG &DAG, bool LegalTypes) {
14249 unsigned Opcode = N->getOpcode();
14250 SDValue N0 = N->getOperand(Num: 0);
14251 EVT VT = N->getValueType(ResNo: 0);
14252
14253 assert((ISD::isExtOpcode(Opcode) || ISD::isExtVecInRegOpcode(Opcode)) &&
14254 "Expected EXTEND dag node in input!");
14255
14256 // fold (sext c1) -> c1
14257 // fold (zext c1) -> c1
14258 // fold (aext c1) -> c1
14259 if (isa<ConstantSDNode>(Val: N0))
14260 return DAG.getNode(Opcode, DL, VT, Operand: N0);
14261
14262 // fold (sext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
14263 // fold (zext (select cond, c1, c2)) -> (select cond, zext c1, zext c2)
14264 // fold (aext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
14265 if (N0->getOpcode() == ISD::SELECT) {
14266 SDValue Op1 = N0->getOperand(Num: 1);
14267 SDValue Op2 = N0->getOperand(Num: 2);
14268 if (isa<ConstantSDNode>(Val: Op1) && isa<ConstantSDNode>(Val: Op2) &&
14269 (Opcode != ISD::ZERO_EXTEND || !TLI.isZExtFree(FromTy: N0.getValueType(), ToTy: VT))) {
14270 // For any_extend, choose sign extension of the constants to allow a
14271 // possible further transform to sign_extend_inreg.i.e.
14272 //
14273 // t1: i8 = select t0, Constant:i8<-1>, Constant:i8<0>
14274 // t2: i64 = any_extend t1
14275 // -->
14276 // t3: i64 = select t0, Constant:i64<-1>, Constant:i64<0>
14277 // -->
14278 // t4: i64 = sign_extend_inreg t3
14279 unsigned FoldOpc = Opcode;
14280 if (FoldOpc == ISD::ANY_EXTEND)
14281 FoldOpc = ISD::SIGN_EXTEND;
14282 return DAG.getSelect(DL, VT, Cond: N0->getOperand(Num: 0),
14283 LHS: DAG.getNode(Opcode: FoldOpc, DL, VT, Operand: Op1),
14284 RHS: DAG.getNode(Opcode: FoldOpc, DL, VT, Operand: Op2));
14285 }
14286 }
14287
14288 // fold (sext (build_vector AllConstants) -> (build_vector AllConstants)
14289 // fold (zext (build_vector AllConstants) -> (build_vector AllConstants)
14290 // fold (aext (build_vector AllConstants) -> (build_vector AllConstants)
14291 EVT SVT = VT.getScalarType();
14292 if (!(VT.isVector() && (!LegalTypes || TLI.isTypeLegal(VT: SVT)) &&
14293 ISD::isBuildVectorOfConstantSDNodes(N: N0.getNode())))
14294 return SDValue();
14295
14296 // We can fold this node into a build_vector.
14297 unsigned VTBits = SVT.getSizeInBits();
14298 unsigned EVTBits = N0->getValueType(ResNo: 0).getScalarSizeInBits();
14299 SmallVector<SDValue, 8> Elts;
14300 unsigned NumElts = VT.getVectorNumElements();
14301
14302 for (unsigned i = 0; i != NumElts; ++i) {
14303 SDValue Op = N0.getOperand(i);
14304 if (Op.isUndef()) {
14305 if (Opcode == ISD::ANY_EXTEND || Opcode == ISD::ANY_EXTEND_VECTOR_INREG)
14306 Elts.push_back(Elt: DAG.getUNDEF(VT: SVT));
14307 else
14308 Elts.push_back(Elt: DAG.getConstant(Val: 0, DL, VT: SVT));
14309 continue;
14310 }
14311
14312 SDLoc DL(Op);
14313 // Get the constant value and if needed trunc it to the size of the type.
14314 // Nodes like build_vector might have constants wider than the scalar type.
14315 APInt C = Op->getAsAPIntVal().zextOrTrunc(width: EVTBits);
14316 if (Opcode == ISD::SIGN_EXTEND || Opcode == ISD::SIGN_EXTEND_VECTOR_INREG)
14317 Elts.push_back(Elt: DAG.getConstant(Val: C.sext(width: VTBits), DL, VT: SVT));
14318 else
14319 Elts.push_back(Elt: DAG.getConstant(Val: C.zext(width: VTBits), DL, VT: SVT));
14320 }
14321
14322 return DAG.getBuildVector(VT, DL, Ops: Elts);
14323}
14324
14325// ExtendUsesToFormExtLoad - Trying to extend uses of a load to enable this:
14326// "fold ({s|z|a}ext (load x)) -> ({s|z|a}ext (truncate ({s|z|a}extload x)))"
14327// transformation. Returns true if extension are possible and the above
14328// mentioned transformation is profitable.
14329static bool ExtendUsesToFormExtLoad(EVT VT, SDNode *N, SDValue N0,
14330 unsigned ExtOpc,
14331 SmallVectorImpl<SDNode *> &ExtendNodes,
14332 const TargetLowering &TLI) {
14333 bool HasCopyToRegUses = false;
14334 bool isTruncFree = TLI.isTruncateFree(FromVT: VT, ToVT: N0.getValueType());
14335 for (SDUse &Use : N0->uses()) {
14336 SDNode *User = Use.getUser();
14337 if (User == N)
14338 continue;
14339 if (Use.getResNo() != N0.getResNo())
14340 continue;
14341 // FIXME: Only extend SETCC N, N and SETCC N, c for now.
14342 if (ExtOpc != ISD::ANY_EXTEND && User->getOpcode() == ISD::SETCC) {
14343 ISD::CondCode CC = cast<CondCodeSDNode>(Val: User->getOperand(Num: 2))->get();
14344 if (ExtOpc == ISD::ZERO_EXTEND && ISD::isSignedIntSetCC(Code: CC))
14345 // Sign bits will be lost after a zext.
14346 return false;
14347 bool Add = false;
14348 for (unsigned i = 0; i != 2; ++i) {
14349 SDValue UseOp = User->getOperand(Num: i);
14350 if (UseOp == N0)
14351 continue;
14352 if (!isa<ConstantSDNode>(Val: UseOp))
14353 return false;
14354 Add = true;
14355 }
14356 if (Add)
14357 ExtendNodes.push_back(Elt: User);
14358 continue;
14359 }
14360 // If truncates aren't free and there are users we can't
14361 // extend, it isn't worthwhile.
14362 if (!isTruncFree)
14363 return false;
14364 // Remember if this value is live-out.
14365 if (User->getOpcode() == ISD::CopyToReg)
14366 HasCopyToRegUses = true;
14367 }
14368
14369 if (HasCopyToRegUses) {
14370 bool BothLiveOut = false;
14371 for (SDUse &Use : N->uses()) {
14372 if (Use.getResNo() == 0 && Use.getUser()->getOpcode() == ISD::CopyToReg) {
14373 BothLiveOut = true;
14374 break;
14375 }
14376 }
14377 if (BothLiveOut)
14378 // Both unextended and extended values are live out. There had better be
14379 // a good reason for the transformation.
14380 return !ExtendNodes.empty();
14381 }
14382 return true;
14383}
14384
14385void DAGCombiner::ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
14386 SDValue OrigLoad, SDValue ExtLoad,
14387 ISD::NodeType ExtType) {
14388 // Extend SetCC uses if necessary.
14389 SDLoc DL(ExtLoad);
14390 for (SDNode *SetCC : SetCCs) {
14391 SmallVector<SDValue, 4> Ops;
14392
14393 for (unsigned j = 0; j != 2; ++j) {
14394 SDValue SOp = SetCC->getOperand(Num: j);
14395 if (SOp == OrigLoad)
14396 Ops.push_back(Elt: ExtLoad);
14397 else
14398 Ops.push_back(Elt: DAG.getNode(Opcode: ExtType, DL, VT: ExtLoad->getValueType(ResNo: 0), Operand: SOp));
14399 }
14400
14401 Ops.push_back(Elt: SetCC->getOperand(Num: 2));
14402 CombineTo(N: SetCC, Res: DAG.getNode(Opcode: ISD::SETCC, DL, VT: SetCC->getValueType(ResNo: 0), Ops));
14403 }
14404}
14405
14406// FIXME: Bring more similar combines here, common to sext/zext (maybe aext?).
14407SDValue DAGCombiner::CombineExtLoad(SDNode *N) {
14408 SDValue N0 = N->getOperand(Num: 0);
14409 EVT DstVT = N->getValueType(ResNo: 0);
14410 EVT SrcVT = N0.getValueType();
14411
14412 assert((N->getOpcode() == ISD::SIGN_EXTEND ||
14413 N->getOpcode() == ISD::ZERO_EXTEND) &&
14414 "Unexpected node type (not an extend)!");
14415
14416 // fold (sext (load x)) to multiple smaller sextloads; same for zext.
14417 // For example, on a target with legal v4i32, but illegal v8i32, turn:
14418 // (v8i32 (sext (v8i16 (load x))))
14419 // into:
14420 // (v8i32 (concat_vectors (v4i32 (sextload x)),
14421 // (v4i32 (sextload (x + 16)))))
14422 // Where uses of the original load, i.e.:
14423 // (v8i16 (load x))
14424 // are replaced with:
14425 // (v8i16 (truncate
14426 // (v8i32 (concat_vectors (v4i32 (sextload x)),
14427 // (v4i32 (sextload (x + 16)))))))
14428 //
14429 // This combine is only applicable to illegal, but splittable, vectors.
14430 // All legal types, and illegal non-vector types, are handled elsewhere.
14431 // This combine is controlled by TargetLowering::isVectorLoadExtDesirable.
14432 //
14433 if (N0->getOpcode() != ISD::LOAD)
14434 return SDValue();
14435
14436 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
14437
14438 if (!ISD::isNON_EXTLoad(N: LN0) || !ISD::isUNINDEXEDLoad(N: LN0) ||
14439 !N0.hasOneUse() || !LN0->isSimple() ||
14440 !DstVT.isVector() || !DstVT.isPow2VectorType() ||
14441 !TLI.isVectorLoadExtDesirable(ExtVal: SDValue(N, 0)))
14442 return SDValue();
14443
14444 SmallVector<SDNode *, 4> SetCCs;
14445 if (!ExtendUsesToFormExtLoad(VT: DstVT, N, N0, ExtOpc: N->getOpcode(), ExtendNodes&: SetCCs, TLI))
14446 return SDValue();
14447
14448 ISD::LoadExtType ExtType =
14449 N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
14450
14451 // Try to split the vector types to get down to legal types.
14452 EVT SplitSrcVT = SrcVT;
14453 EVT SplitDstVT = DstVT;
14454 while (!TLI.isLoadExtLegalOrCustom(ExtType, ValVT: SplitDstVT, MemVT: SplitSrcVT) &&
14455 SplitSrcVT.getVectorNumElements() > 1) {
14456 SplitDstVT = DAG.GetSplitDestVTs(VT: SplitDstVT).first;
14457 SplitSrcVT = DAG.GetSplitDestVTs(VT: SplitSrcVT).first;
14458 }
14459
14460 if (!TLI.isLoadExtLegalOrCustom(ExtType, ValVT: SplitDstVT, MemVT: SplitSrcVT))
14461 return SDValue();
14462
14463 assert(!DstVT.isScalableVector() && "Unexpected scalable vector type");
14464
14465 SDLoc DL(N);
14466 const unsigned NumSplits =
14467 DstVT.getVectorNumElements() / SplitDstVT.getVectorNumElements();
14468 const unsigned Stride = SplitSrcVT.getStoreSize();
14469 SmallVector<SDValue, 4> Loads;
14470 SmallVector<SDValue, 4> Chains;
14471
14472 SDValue BasePtr = LN0->getBasePtr();
14473 for (unsigned Idx = 0; Idx < NumSplits; Idx++) {
14474 const unsigned Offset = Idx * Stride;
14475
14476 SDValue SplitLoad =
14477 DAG.getExtLoad(ExtType, dl: SDLoc(LN0), VT: SplitDstVT, Chain: LN0->getChain(),
14478 Ptr: BasePtr, PtrInfo: LN0->getPointerInfo().getWithOffset(O: Offset),
14479 MemVT: SplitSrcVT, Alignment: LN0->getBaseAlign(),
14480 MMOFlags: LN0->getMemOperand()->getFlags(), AAInfo: LN0->getAAInfo());
14481
14482 BasePtr = DAG.getMemBasePlusOffset(Base: BasePtr, Offset: TypeSize::getFixed(ExactSize: Stride), DL);
14483
14484 Loads.push_back(Elt: SplitLoad.getValue(R: 0));
14485 Chains.push_back(Elt: SplitLoad.getValue(R: 1));
14486 }
14487
14488 SDValue NewChain = DAG.getNode(Opcode: ISD::TokenFactor, DL, VT: MVT::Other, Ops: Chains);
14489 SDValue NewValue = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT: DstVT, Ops: Loads);
14490
14491 // Simplify TF.
14492 AddToWorklist(N: NewChain.getNode());
14493
14494 CombineTo(N, Res: NewValue);
14495
14496 // Replace uses of the original load (before extension)
14497 // with a truncate of the concatenated sextloaded vectors.
14498 SDValue Trunc =
14499 DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(N0), VT: N0.getValueType(), Operand: NewValue);
14500 ExtendSetCCUses(SetCCs, OrigLoad: N0, ExtLoad: NewValue, ExtType: (ISD::NodeType)N->getOpcode());
14501 CombineTo(N: N0.getNode(), Res0: Trunc, Res1: NewChain);
14502 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14503}
14504
14505// fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
14506// (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
14507SDValue DAGCombiner::CombineZExtLogicopShiftLoad(SDNode *N) {
14508 assert(N->getOpcode() == ISD::ZERO_EXTEND);
14509 EVT VT = N->getValueType(ResNo: 0);
14510 EVT OrigVT = N->getOperand(Num: 0).getValueType();
14511 if (TLI.isZExtFree(FromTy: OrigVT, ToTy: VT))
14512 return SDValue();
14513
14514 // and/or/xor
14515 SDValue N0 = N->getOperand(Num: 0);
14516 if (!ISD::isBitwiseLogicOp(Opcode: N0.getOpcode()) ||
14517 N0.getOperand(i: 1).getOpcode() != ISD::Constant ||
14518 (LegalOperations && !TLI.isOperationLegal(Op: N0.getOpcode(), VT)))
14519 return SDValue();
14520
14521 // shl/shr
14522 SDValue N1 = N0->getOperand(Num: 0);
14523 if (!(N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) ||
14524 N1.getOperand(i: 1).getOpcode() != ISD::Constant ||
14525 (LegalOperations && !TLI.isOperationLegal(Op: N1.getOpcode(), VT)))
14526 return SDValue();
14527
14528 // load
14529 if (!isa<LoadSDNode>(Val: N1.getOperand(i: 0)))
14530 return SDValue();
14531 LoadSDNode *Load = cast<LoadSDNode>(Val: N1.getOperand(i: 0));
14532 EVT MemVT = Load->getMemoryVT();
14533 if (!TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: VT, MemVT) ||
14534 Load->getExtensionType() == ISD::SEXTLOAD || Load->isIndexed())
14535 return SDValue();
14536
14537
14538 // If the shift op is SHL, the logic op must be AND, otherwise the result
14539 // will be wrong.
14540 if (N1.getOpcode() == ISD::SHL && N0.getOpcode() != ISD::AND)
14541 return SDValue();
14542
14543 if (!N0.hasOneUse() || !N1.hasOneUse())
14544 return SDValue();
14545
14546 SmallVector<SDNode*, 4> SetCCs;
14547 if (!ExtendUsesToFormExtLoad(VT, N: N1.getNode(), N0: N1.getOperand(i: 0),
14548 ExtOpc: ISD::ZERO_EXTEND, ExtendNodes&: SetCCs, TLI))
14549 return SDValue();
14550
14551 // Actually do the transformation.
14552 SDValue ExtLoad = DAG.getExtLoad(ExtType: ISD::ZEXTLOAD, dl: SDLoc(Load), VT,
14553 Chain: Load->getChain(), Ptr: Load->getBasePtr(),
14554 MemVT: Load->getMemoryVT(), MMO: Load->getMemOperand());
14555
14556 SDLoc DL1(N1);
14557 SDValue Shift = DAG.getNode(Opcode: N1.getOpcode(), DL: DL1, VT, N1: ExtLoad,
14558 N2: N1.getOperand(i: 1));
14559
14560 APInt Mask = N0.getConstantOperandAPInt(i: 1).zext(width: VT.getSizeInBits());
14561 SDLoc DL0(N0);
14562 SDValue And = DAG.getNode(Opcode: N0.getOpcode(), DL: DL0, VT, N1: Shift,
14563 N2: DAG.getConstant(Val: Mask, DL: DL0, VT));
14564
14565 ExtendSetCCUses(SetCCs, OrigLoad: N1.getOperand(i: 0), ExtLoad, ExtType: ISD::ZERO_EXTEND);
14566 CombineTo(N, Res: And);
14567 if (SDValue(Load, 0).hasOneUse()) {
14568 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Load, 1), To: ExtLoad.getValue(R: 1));
14569 } else {
14570 SDValue Trunc = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(Load),
14571 VT: Load->getValueType(ResNo: 0), Operand: ExtLoad);
14572 CombineTo(N: Load, Res0: Trunc, Res1: ExtLoad.getValue(R: 1));
14573 }
14574
14575 // N0 is dead at this point.
14576 recursivelyDeleteUnusedNodes(N: N0.getNode());
14577
14578 return SDValue(N,0); // Return N so it doesn't get rechecked!
14579}
14580
14581/// If we're narrowing or widening the result of a vector select and the final
14582/// size is the same size as a setcc (compare) feeding the select, then try to
14583/// apply the cast operation to the select's operands because matching vector
14584/// sizes for a select condition and other operands should be more efficient.
14585SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
14586 unsigned CastOpcode = Cast->getOpcode();
14587 assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
14588 CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
14589 CastOpcode == ISD::FP_ROUND) &&
14590 "Unexpected opcode for vector select narrowing/widening");
14591
14592 // We only do this transform before legal ops because the pattern may be
14593 // obfuscated by target-specific operations after legalization. Do not create
14594 // an illegal select op, however, because that may be difficult to lower.
14595 EVT VT = Cast->getValueType(ResNo: 0);
14596 if (LegalOperations || !TLI.isOperationLegalOrCustom(Op: ISD::VSELECT, VT))
14597 return SDValue();
14598
14599 SDValue VSel = Cast->getOperand(Num: 0);
14600 if (VSel.getOpcode() != ISD::VSELECT || !VSel.hasOneUse() ||
14601 VSel.getOperand(i: 0).getOpcode() != ISD::SETCC)
14602 return SDValue();
14603
14604 // Does the setcc have the same vector size as the casted select?
14605 SDValue SetCC = VSel.getOperand(i: 0);
14606 EVT SetCCVT = getSetCCResultType(VT: SetCC.getOperand(i: 0).getValueType());
14607 if (SetCCVT.getSizeInBits() != VT.getSizeInBits())
14608 return SDValue();
14609
14610 // cast (vsel (setcc X), A, B) --> vsel (setcc X), (cast A), (cast B)
14611 SDValue A = VSel.getOperand(i: 1);
14612 SDValue B = VSel.getOperand(i: 2);
14613 SDValue CastA, CastB;
14614 SDLoc DL(Cast);
14615 if (CastOpcode == ISD::FP_ROUND) {
14616 // FP_ROUND (fptrunc) has an extra flag operand to pass along.
14617 CastA = DAG.getNode(Opcode: CastOpcode, DL, VT, N1: A, N2: Cast->getOperand(Num: 1));
14618 CastB = DAG.getNode(Opcode: CastOpcode, DL, VT, N1: B, N2: Cast->getOperand(Num: 1));
14619 } else {
14620 CastA = DAG.getNode(Opcode: CastOpcode, DL, VT, Operand: A);
14621 CastB = DAG.getNode(Opcode: CastOpcode, DL, VT, Operand: B);
14622 }
14623 return DAG.getNode(Opcode: ISD::VSELECT, DL, VT, N1: SetCC, N2: CastA, N3: CastB);
14624}
14625
14626// fold ([s|z]ext ([s|z]extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
14627// fold ([s|z]ext ( extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
14628static SDValue tryToFoldExtOfExtload(SelectionDAG &DAG, DAGCombiner &Combiner,
14629 const TargetLowering &TLI, EVT VT,
14630 bool LegalOperations, SDNode *N,
14631 SDValue N0, ISD::LoadExtType ExtLoadType) {
14632 SDNode *N0Node = N0.getNode();
14633 bool isAExtLoad = (ExtLoadType == ISD::SEXTLOAD) ? ISD::isSEXTLoad(N: N0Node)
14634 : ISD::isZEXTLoad(N: N0Node);
14635 if ((!isAExtLoad && !ISD::isEXTLoad(N: N0Node)) ||
14636 !ISD::isUNINDEXEDLoad(N: N0Node) || !N0.hasOneUse())
14637 return SDValue();
14638
14639 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
14640 EVT MemVT = LN0->getMemoryVT();
14641 if ((LegalOperations || !LN0->isSimple() ||
14642 VT.isVector()) &&
14643 !TLI.isLoadExtLegal(ExtType: ExtLoadType, ValVT: VT, MemVT))
14644 return SDValue();
14645
14646 SDValue ExtLoad =
14647 DAG.getExtLoad(ExtType: ExtLoadType, dl: SDLoc(LN0), VT, Chain: LN0->getChain(),
14648 Ptr: LN0->getBasePtr(), MemVT, MMO: LN0->getMemOperand());
14649 Combiner.CombineTo(N, Res: ExtLoad);
14650 DAG.ReplaceAllUsesOfValueWith(From: SDValue(LN0, 1), To: ExtLoad.getValue(R: 1));
14651 if (LN0->use_empty())
14652 Combiner.recursivelyDeleteUnusedNodes(N: LN0);
14653 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14654}
14655
14656// fold ([s|z]ext (load x)) -> ([s|z]ext (truncate ([s|z]extload x)))
14657// Only generate vector extloads when 1) they're legal, and 2) they are
14658// deemed desirable by the target. NonNegZExt can be set to true if a zero
14659// extend has the nonneg flag to allow use of sextload if profitable.
14660static SDValue tryToFoldExtOfLoad(SelectionDAG &DAG, DAGCombiner &Combiner,
14661 const TargetLowering &TLI, EVT VT,
14662 bool LegalOperations, SDNode *N, SDValue N0,
14663 ISD::LoadExtType ExtLoadType,
14664 ISD::NodeType ExtOpc,
14665 bool NonNegZExt = false) {
14666 bool Frozen = N0.getOpcode() == ISD::FREEZE;
14667 SDValue Freeze = Frozen ? N0 : SDValue();
14668 auto *Load = dyn_cast<LoadSDNode>(Val: Frozen ? N0.getOperand(i: 0) : N0);
14669 // TODO: Support multiple uses of the load when frozen.
14670 if (!Load || !ISD::isNON_EXTLoad(N: Load) || !ISD::isUNINDEXEDLoad(N: Load) ||
14671 (Frozen && !Load->hasNUsesOfValue(NUses: 1, Value: 0)))
14672 return {};
14673
14674 // If this is zext nneg, see if it would make sense to treat it as a sext.
14675 if (NonNegZExt) {
14676 assert(ExtLoadType == ISD::ZEXTLOAD && ExtOpc == ISD::ZERO_EXTEND &&
14677 "Unexpected load type or opcode");
14678 for (SDNode *User : Load->users()) {
14679 if (User->getOpcode() == ISD::SETCC) {
14680 ISD::CondCode CC = cast<CondCodeSDNode>(Val: User->getOperand(Num: 2))->get();
14681 if (ISD::isSignedIntSetCC(Code: CC)) {
14682 ExtLoadType = ISD::SEXTLOAD;
14683 ExtOpc = ISD::SIGN_EXTEND;
14684 break;
14685 }
14686 }
14687 }
14688 }
14689
14690 // TODO: isFixedLengthVector() should be removed and any negative effects on
14691 // code generation being the result of that target's implementation of
14692 // isVectorLoadExtDesirable().
14693 if ((LegalOperations || VT.isFixedLengthVector() || !Load->isSimple()) &&
14694 !TLI.isLoadExtLegal(ExtType: ExtLoadType, ValVT: VT, MemVT: Load->getValueType(ResNo: 0)))
14695 return {};
14696
14697 bool DoXform = true;
14698 SmallVector<SDNode *, 4> SetCCs;
14699 if (!N0->hasOneUse())
14700 DoXform = ExtendUsesToFormExtLoad(VT, N, N0: Frozen ? Freeze : SDValue(Load, 0),
14701 ExtOpc, ExtendNodes&: SetCCs, TLI);
14702 if (VT.isVector())
14703 DoXform &= TLI.isVectorLoadExtDesirable(ExtVal: SDValue(N, 0));
14704 if (!DoXform)
14705 return {};
14706
14707 SDLoc DL(Load);
14708 // If the load value is used only by N, replace it via CombineTo N.
14709 bool NoReplaceTrunc = N0.hasOneUse();
14710 SDValue ExtLoad =
14711 DAG.getExtLoad(ExtType: ExtLoadType, dl: DL, VT, Chain: Load->getChain(), Ptr: Load->getBasePtr(),
14712 MemVT: Load->getValueType(ResNo: 0), MMO: Load->getMemOperand());
14713 SDValue Res = ExtLoad;
14714 if (Frozen) {
14715 Res = DAG.getFreeze(V: ExtLoad);
14716 Res = DAG.getNode(Opcode: ExtLoadType == ISD::SEXTLOAD ? ISD::AssertSext
14717 : ISD::AssertZext,
14718 DL, VT: Res.getValueType(), N1: Res,
14719 N2: DAG.getValueType(Load->getValueType(ResNo: 0).getScalarType()));
14720 }
14721 Combiner.ExtendSetCCUses(SetCCs, OrigLoad: N0, ExtLoad: Res, ExtType: ExtOpc);
14722 Combiner.CombineTo(N, Res);
14723 if (NoReplaceTrunc) {
14724 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Load, 1), To: ExtLoad.getValue(R: 1));
14725 Combiner.recursivelyDeleteUnusedNodes(N: N0.getNode());
14726 } else {
14727 SDValue Trunc = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: Load->getValueType(ResNo: 0), Operand: Res);
14728 if (Frozen) {
14729 Combiner.CombineTo(N: Freeze.getNode(), Res: Trunc);
14730 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Load, 1), To: ExtLoad.getValue(R: 1));
14731 } else {
14732 Combiner.CombineTo(N: Load, Res0: Trunc, Res1: ExtLoad.getValue(R: 1));
14733 }
14734 }
14735 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14736}
14737
14738static SDValue
14739tryToFoldExtOfMaskedLoad(SelectionDAG &DAG, const TargetLowering &TLI, EVT VT,
14740 bool LegalOperations, SDNode *N, SDValue N0,
14741 ISD::LoadExtType ExtLoadType, ISD::NodeType ExtOpc) {
14742 if (!N0.hasOneUse())
14743 return SDValue();
14744
14745 MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(Val&: N0);
14746 if (!Ld || Ld->getExtensionType() != ISD::NON_EXTLOAD)
14747 return SDValue();
14748
14749 if ((LegalOperations || !cast<MaskedLoadSDNode>(Val&: N0)->isSimple()) &&
14750 !TLI.isLoadExtLegalOrCustom(ExtType: ExtLoadType, ValVT: VT, MemVT: Ld->getValueType(ResNo: 0)))
14751 return SDValue();
14752
14753 if (!TLI.isVectorLoadExtDesirable(ExtVal: SDValue(N, 0)))
14754 return SDValue();
14755
14756 SDLoc dl(Ld);
14757 SDValue PassThru = DAG.getNode(Opcode: ExtOpc, DL: dl, VT, Operand: Ld->getPassThru());
14758 SDValue NewLoad = DAG.getMaskedLoad(
14759 VT, dl, Chain: Ld->getChain(), Base: Ld->getBasePtr(), Offset: Ld->getOffset(), Mask: Ld->getMask(),
14760 Src0: PassThru, MemVT: Ld->getMemoryVT(), MMO: Ld->getMemOperand(), AM: Ld->getAddressingMode(),
14761 ExtLoadType, IsExpanding: Ld->isExpandingLoad());
14762 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Ld, 1), To: SDValue(NewLoad.getNode(), 1));
14763 return NewLoad;
14764}
14765
14766// fold ([s|z]ext (atomic_load)) -> ([s|z]ext (truncate ([s|z]ext atomic_load)))
14767static SDValue tryToFoldExtOfAtomicLoad(SelectionDAG &DAG,
14768 const TargetLowering &TLI, EVT VT,
14769 SDValue N0,
14770 ISD::LoadExtType ExtLoadType) {
14771 auto *ALoad = dyn_cast<AtomicSDNode>(Val&: N0);
14772 if (!ALoad || ALoad->getOpcode() != ISD::ATOMIC_LOAD)
14773 return {};
14774 EVT MemoryVT = ALoad->getMemoryVT();
14775 if (!TLI.isAtomicLoadExtLegal(ExtType: ExtLoadType, ValVT: VT, MemVT: MemoryVT))
14776 return {};
14777 // Can't fold into ALoad if it is already extending differently.
14778 ISD::LoadExtType ALoadExtTy = ALoad->getExtensionType();
14779 if ((ALoadExtTy == ISD::ZEXTLOAD && ExtLoadType == ISD::SEXTLOAD) ||
14780 (ALoadExtTy == ISD::SEXTLOAD && ExtLoadType == ISD::ZEXTLOAD))
14781 return {};
14782
14783 EVT OrigVT = ALoad->getValueType(ResNo: 0);
14784 assert(OrigVT.getSizeInBits() < VT.getSizeInBits() && "VT should be wider.");
14785 auto *NewALoad = cast<AtomicSDNode>(Val: DAG.getAtomicLoad(
14786 ExtType: ExtLoadType, dl: SDLoc(ALoad), MemVT: MemoryVT, VT, Chain: ALoad->getChain(),
14787 Ptr: ALoad->getBasePtr(), MMO: ALoad->getMemOperand()));
14788 DAG.ReplaceAllUsesOfValueWith(
14789 From: SDValue(ALoad, 0),
14790 To: DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(ALoad), VT: OrigVT, Operand: SDValue(NewALoad, 0)));
14791 // Update the chain uses.
14792 DAG.ReplaceAllUsesOfValueWith(From: SDValue(ALoad, 1), To: SDValue(NewALoad, 1));
14793 return SDValue(NewALoad, 0);
14794}
14795
14796static SDValue foldExtendedSignBitTest(SDNode *N, SelectionDAG &DAG,
14797 bool LegalOperations) {
14798 assert((N->getOpcode() == ISD::SIGN_EXTEND ||
14799 N->getOpcode() == ISD::ZERO_EXTEND) && "Expected sext or zext");
14800
14801 SDValue SetCC = N->getOperand(Num: 0);
14802 if (LegalOperations || SetCC.getOpcode() != ISD::SETCC ||
14803 !SetCC.hasOneUse() || SetCC.getValueType() != MVT::i1)
14804 return SDValue();
14805
14806 SDValue X = SetCC.getOperand(i: 0);
14807 SDValue Ones = SetCC.getOperand(i: 1);
14808 ISD::CondCode CC = cast<CondCodeSDNode>(Val: SetCC.getOperand(i: 2))->get();
14809 EVT VT = N->getValueType(ResNo: 0);
14810 EVT XVT = X.getValueType();
14811 // setge X, C is canonicalized to setgt, so we do not need to match that
14812 // pattern. The setlt sibling is folded in SimplifySelectCC() because it does
14813 // not require the 'not' op.
14814 if (CC == ISD::SETGT && isAllOnesConstant(V: Ones) && VT == XVT) {
14815 // Invert and smear/shift the sign bit:
14816 // sext i1 (setgt iN X, -1) --> sra (not X), (N - 1)
14817 // zext i1 (setgt iN X, -1) --> srl (not X), (N - 1)
14818 SDLoc DL(N);
14819 unsigned ShCt = VT.getSizeInBits() - 1;
14820 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
14821 if (!TLI.shouldAvoidTransformToShift(VT, Amount: ShCt)) {
14822 SDValue NotX = DAG.getNOT(DL, Val: X, VT);
14823 SDValue ShiftAmount = DAG.getConstant(Val: ShCt, DL, VT);
14824 auto ShiftOpcode =
14825 N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SRA : ISD::SRL;
14826 return DAG.getNode(Opcode: ShiftOpcode, DL, VT, N1: NotX, N2: ShiftAmount);
14827 }
14828 }
14829 return SDValue();
14830}
14831
14832SDValue DAGCombiner::foldSextSetcc(SDNode *N) {
14833 SDValue N0 = N->getOperand(Num: 0);
14834 if (N0.getOpcode() != ISD::SETCC)
14835 return SDValue();
14836
14837 SDValue N00 = N0.getOperand(i: 0);
14838 SDValue N01 = N0.getOperand(i: 1);
14839 ISD::CondCode CC = cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get();
14840 EVT VT = N->getValueType(ResNo: 0);
14841 EVT N00VT = N00.getValueType();
14842 SDLoc DL(N);
14843
14844 // Propagate fast-math-flags.
14845 SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
14846
14847 // On some architectures (such as SSE/NEON/etc) the SETCC result type is
14848 // the same size as the compared operands. Try to optimize sext(setcc())
14849 // if this is the case.
14850 if (VT.isVector() && !LegalOperations &&
14851 TLI.getBooleanContents(Type: N00VT) ==
14852 TargetLowering::ZeroOrNegativeOneBooleanContent) {
14853 EVT SVT = getSetCCResultType(VT: N00VT);
14854
14855 // If we already have the desired type, don't change it.
14856 if (SVT != N0.getValueType()) {
14857 // We know that the # elements of the results is the same as the
14858 // # elements of the compare (and the # elements of the compare result
14859 // for that matter). Check to see that they are the same size. If so,
14860 // we know that the element size of the sext'd result matches the
14861 // element size of the compare operands.
14862 if (VT.getSizeInBits() == SVT.getSizeInBits())
14863 return DAG.getSetCC(DL, VT, LHS: N00, RHS: N01, Cond: CC);
14864
14865 // If the desired elements are smaller or larger than the source
14866 // elements, we can use a matching integer vector type and then
14867 // truncate/sign extend.
14868 EVT MatchingVecType = N00VT.changeVectorElementTypeToInteger();
14869 if (SVT == MatchingVecType) {
14870 SDValue VsetCC = DAG.getSetCC(DL, VT: MatchingVecType, LHS: N00, RHS: N01, Cond: CC);
14871 return DAG.getSExtOrTrunc(Op: VsetCC, DL, VT);
14872 }
14873 }
14874
14875 // Try to eliminate the sext of a setcc by zexting the compare operands.
14876 if (N0.hasOneUse() && TLI.isOperationLegalOrCustom(Op: ISD::SETCC, VT) &&
14877 !TLI.isOperationLegalOrCustom(Op: ISD::SETCC, VT: SVT)) {
14878 bool IsSignedCmp = ISD::isSignedIntSetCC(Code: CC);
14879 unsigned LoadOpcode = IsSignedCmp ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
14880 unsigned ExtOpcode = IsSignedCmp ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
14881
14882 // We have an unsupported narrow vector compare op that would be legal
14883 // if extended to the destination type. See if the compare operands
14884 // can be freely extended to the destination type.
14885 auto IsFreeToExtend = [&](SDValue V) {
14886 if (isConstantOrConstantVector(N: V, /*NoOpaques*/ true))
14887 return true;
14888 // Match a simple, non-extended load that can be converted to a
14889 // legal {z/s}ext-load.
14890 // TODO: Allow widening of an existing {z/s}ext-load?
14891 if (!(ISD::isNON_EXTLoad(N: V.getNode()) &&
14892 ISD::isUNINDEXEDLoad(N: V.getNode()) &&
14893 cast<LoadSDNode>(Val&: V)->isSimple() &&
14894 TLI.isLoadExtLegal(ExtType: LoadOpcode, ValVT: VT, MemVT: V.getValueType())))
14895 return false;
14896
14897 // Non-chain users of this value must either be the setcc in this
14898 // sequence or extends that can be folded into the new {z/s}ext-load.
14899 for (SDUse &Use : V->uses()) {
14900 // Skip uses of the chain and the setcc.
14901 SDNode *User = Use.getUser();
14902 if (Use.getResNo() != 0 || User == N0.getNode())
14903 continue;
14904 // Extra users must have exactly the same cast we are about to create.
14905 // TODO: This restriction could be eased if ExtendUsesToFormExtLoad()
14906 // is enhanced similarly.
14907 if (User->getOpcode() != ExtOpcode || User->getValueType(ResNo: 0) != VT)
14908 return false;
14909 }
14910 return true;
14911 };
14912
14913 if (IsFreeToExtend(N00) && IsFreeToExtend(N01)) {
14914 SDValue Ext0 = DAG.getNode(Opcode: ExtOpcode, DL, VT, Operand: N00);
14915 SDValue Ext1 = DAG.getNode(Opcode: ExtOpcode, DL, VT, Operand: N01);
14916 return DAG.getSetCC(DL, VT, LHS: Ext0, RHS: Ext1, Cond: CC);
14917 }
14918 }
14919 }
14920
14921 // sext(setcc x, y, cc) -> (select (setcc x, y, cc), T, 0)
14922 // Here, T can be 1 or -1, depending on the type of the setcc and
14923 // getBooleanContents().
14924 unsigned SetCCWidth = N0.getScalarValueSizeInBits();
14925
14926 // To determine the "true" side of the select, we need to know the high bit
14927 // of the value returned by the setcc if it evaluates to true.
14928 // If the type of the setcc is i1, then the true case of the select is just
14929 // sext(i1 1), that is, -1.
14930 // If the type of the setcc is larger (say, i8) then the value of the high
14931 // bit depends on getBooleanContents(), so ask TLI for a real "true" value
14932 // of the appropriate width.
14933 SDValue ExtTrueVal = (SetCCWidth == 1)
14934 ? DAG.getAllOnesConstant(DL, VT)
14935 : DAG.getBoolConstant(V: true, DL, VT, OpVT: N00VT);
14936 SDValue Zero = DAG.getConstant(Val: 0, DL, VT);
14937 if (SDValue SCC = SimplifySelectCC(DL, N0: N00, N1: N01, N2: ExtTrueVal, N3: Zero, CC, NotExtCompare: true))
14938 return SCC;
14939
14940 if (!VT.isVector() && !shouldConvertSelectOfConstantsToMath(Cond: N0, VT, TLI)) {
14941 EVT SetCCVT = getSetCCResultType(VT: N00VT);
14942 // Don't do this transform for i1 because there's a select transform
14943 // that would reverse it.
14944 // TODO: We should not do this transform at all without a target hook
14945 // because a sext is likely cheaper than a select?
14946 if (SetCCVT.getScalarSizeInBits() != 1 &&
14947 (!LegalOperations || TLI.isOperationLegal(Op: ISD::SETCC, VT: N00VT))) {
14948 SDValue SetCC = DAG.getSetCC(DL, VT: SetCCVT, LHS: N00, RHS: N01, Cond: CC);
14949 return DAG.getSelect(DL, VT, Cond: SetCC, LHS: ExtTrueVal, RHS: Zero);
14950 }
14951 }
14952
14953 return SDValue();
14954}
14955
14956SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
14957 SDValue N0 = N->getOperand(Num: 0);
14958 EVT VT = N->getValueType(ResNo: 0);
14959 SDLoc DL(N);
14960
14961 if (VT.isVector())
14962 if (SDValue FoldedVOp = SimplifyVCastOp(N, DL))
14963 return FoldedVOp;
14964
14965 // sext(undef) = 0 because the top bit will all be the same.
14966 if (N0.isUndef())
14967 return DAG.getConstant(Val: 0, DL, VT);
14968
14969 if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
14970 return Res;
14971
14972 // fold (sext (sext x)) -> (sext x)
14973 // fold (sext (aext x)) -> (sext x)
14974 if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND)
14975 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: N0.getOperand(i: 0));
14976
14977 // fold (sext (aext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x)
14978 // fold (sext (sext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x)
14979 if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
14980 N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG)
14981 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_VECTOR_INREG, DL: SDLoc(N), VT,
14982 Operand: N0.getOperand(i: 0));
14983
14984 if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG) {
14985 SDValue N00 = N0.getOperand(i: 0);
14986 EVT ExtVT = cast<VTSDNode>(Val: N0->getOperand(Num: 1))->getVT();
14987 if (N00.getOpcode() == ISD::TRUNCATE || TLI.isTruncateFree(Val: N00, VT2: ExtVT)) {
14988 // fold (sext (sext_inreg x)) -> (sext (trunc x))
14989 if ((!LegalTypes || TLI.isTypeLegal(VT: ExtVT))) {
14990 SDValue T = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: ExtVT, Operand: N00);
14991 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: T);
14992 }
14993
14994 // If the trunc wasn't legal, try to fold to (sext_inreg (anyext x))
14995 if (!LegalTypes || TLI.isTypeLegal(VT)) {
14996 SDValue ExtSrc = DAG.getAnyExtOrTrunc(Op: N00, DL, VT);
14997 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL, VT, N1: ExtSrc,
14998 N2: N0->getOperand(Num: 1));
14999 }
15000 }
15001 }
15002
15003 if (N0.getOpcode() == ISD::TRUNCATE) {
15004 // fold (sext (truncate (load x))) -> (sext (smaller load x))
15005 // fold (sext (truncate (srl (load x), c))) -> (sext (smaller load (x+c/n)))
15006 if (SDValue NarrowLoad = reduceLoadWidth(N: N0.getNode())) {
15007 SDNode *oye = N0.getOperand(i: 0).getNode();
15008 if (NarrowLoad.getNode() != N0.getNode()) {
15009 CombineTo(N: N0.getNode(), Res: NarrowLoad);
15010 // CombineTo deleted the truncate, if needed, but not what's under it.
15011 AddToWorklist(N: oye);
15012 }
15013 return SDValue(N, 0); // Return N so it doesn't get rechecked!
15014 }
15015
15016 // See if the value being truncated is already sign extended. If so, just
15017 // eliminate the trunc/sext pair.
15018 SDValue Op = N0.getOperand(i: 0);
15019 unsigned OpBits = Op.getScalarValueSizeInBits();
15020 unsigned MidBits = N0.getScalarValueSizeInBits();
15021 unsigned DestBits = VT.getScalarSizeInBits();
15022
15023 if (N0->getFlags().hasNoSignedWrap() ||
15024 DAG.ComputeNumSignBits(Op) > OpBits - MidBits) {
15025 if (OpBits == DestBits) {
15026 // Op is i32, Mid is i8, and Dest is i32. If Op has more than 24 sign
15027 // bits, it is already ready.
15028 return Op;
15029 }
15030
15031 if (OpBits < DestBits) {
15032 // Op is i32, Mid is i8, and Dest is i64. If Op has more than 24 sign
15033 // bits, just sext from i32.
15034 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: Op);
15035 }
15036
15037 // Op is i64, Mid is i8, and Dest is i32. If Op has more than 56 sign
15038 // bits, just truncate to i32.
15039 SDNodeFlags Flags;
15040 Flags.setNoSignedWrap(true);
15041 Flags.setNoUnsignedWrap(N0->getFlags().hasNoUnsignedWrap());
15042 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: Op, Flags);
15043 }
15044
15045 // fold (sext (truncate x)) -> (sextinreg x).
15046 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::SIGN_EXTEND_INREG,
15047 VT: N0.getValueType())) {
15048 if (OpBits < DestBits)
15049 Op = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL: SDLoc(N0), VT, Operand: Op);
15050 else if (OpBits > DestBits)
15051 Op = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(N0), VT, Operand: Op);
15052 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL, VT, N1: Op,
15053 N2: DAG.getValueType(N0.getValueType()));
15054 }
15055 }
15056
15057 // Try to simplify (sext (load x)).
15058 if (SDValue foldedExt =
15059 tryToFoldExtOfLoad(DAG, Combiner&: *this, TLI, VT, LegalOperations, N, N0,
15060 ExtLoadType: ISD::SEXTLOAD, ExtOpc: ISD::SIGN_EXTEND))
15061 return foldedExt;
15062
15063 if (SDValue foldedExt =
15064 tryToFoldExtOfMaskedLoad(DAG, TLI, VT, LegalOperations, N, N0,
15065 ExtLoadType: ISD::SEXTLOAD, ExtOpc: ISD::SIGN_EXTEND))
15066 return foldedExt;
15067
15068 // fold (sext (load x)) to multiple smaller sextloads.
15069 // Only on illegal but splittable vectors.
15070 if (SDValue ExtLoad = CombineExtLoad(N))
15071 return ExtLoad;
15072
15073 // Try to simplify (sext (sextload x)).
15074 if (SDValue foldedExt = tryToFoldExtOfExtload(
15075 DAG, Combiner&: *this, TLI, VT, LegalOperations, N, N0, ExtLoadType: ISD::SEXTLOAD))
15076 return foldedExt;
15077
15078 // Try to simplify (sext (atomic_load x)).
15079 if (SDValue foldedExt =
15080 tryToFoldExtOfAtomicLoad(DAG, TLI, VT, N0, ExtLoadType: ISD::SEXTLOAD))
15081 return foldedExt;
15082
15083 // fold (sext (and/or/xor (load x), cst)) ->
15084 // (and/or/xor (sextload x), (sext cst))
15085 if (ISD::isBitwiseLogicOp(Opcode: N0.getOpcode()) &&
15086 isa<LoadSDNode>(Val: N0.getOperand(i: 0)) &&
15087 N0.getOperand(i: 1).getOpcode() == ISD::Constant &&
15088 (!LegalOperations && TLI.isOperationLegal(Op: N0.getOpcode(), VT))) {
15089 LoadSDNode *LN00 = cast<LoadSDNode>(Val: N0.getOperand(i: 0));
15090 EVT MemVT = LN00->getMemoryVT();
15091 if (TLI.isLoadExtLegal(ExtType: ISD::SEXTLOAD, ValVT: VT, MemVT) &&
15092 LN00->getExtensionType() != ISD::ZEXTLOAD && LN00->isUnindexed()) {
15093 SmallVector<SDNode*, 4> SetCCs;
15094 bool DoXform = ExtendUsesToFormExtLoad(VT, N: N0.getNode(), N0: N0.getOperand(i: 0),
15095 ExtOpc: ISD::SIGN_EXTEND, ExtendNodes&: SetCCs, TLI);
15096 if (DoXform) {
15097 SDValue ExtLoad = DAG.getExtLoad(ExtType: ISD::SEXTLOAD, dl: SDLoc(LN00), VT,
15098 Chain: LN00->getChain(), Ptr: LN00->getBasePtr(),
15099 MemVT: LN00->getMemoryVT(),
15100 MMO: LN00->getMemOperand());
15101 APInt Mask = N0.getConstantOperandAPInt(i: 1).sext(width: VT.getSizeInBits());
15102 SDValue And = DAG.getNode(Opcode: N0.getOpcode(), DL, VT,
15103 N1: ExtLoad, N2: DAG.getConstant(Val: Mask, DL, VT));
15104 ExtendSetCCUses(SetCCs, OrigLoad: N0.getOperand(i: 0), ExtLoad, ExtType: ISD::SIGN_EXTEND);
15105 bool NoReplaceTruncAnd = !N0.hasOneUse();
15106 bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
15107 CombineTo(N, Res: And);
15108 // If N0 has multiple uses, change other uses as well.
15109 if (NoReplaceTruncAnd) {
15110 SDValue TruncAnd =
15111 DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: N0.getValueType(), Operand: And);
15112 CombineTo(N: N0.getNode(), Res: TruncAnd);
15113 }
15114 if (NoReplaceTrunc) {
15115 DAG.ReplaceAllUsesOfValueWith(From: SDValue(LN00, 1), To: ExtLoad.getValue(R: 1));
15116 } else {
15117 SDValue Trunc = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(LN00),
15118 VT: LN00->getValueType(ResNo: 0), Operand: ExtLoad);
15119 CombineTo(N: LN00, Res0: Trunc, Res1: ExtLoad.getValue(R: 1));
15120 }
15121 return SDValue(N,0); // Return N so it doesn't get rechecked!
15122 }
15123 }
15124 }
15125
15126 if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
15127 return V;
15128
15129 if (SDValue V = foldSextSetcc(N))
15130 return V;
15131
15132 // fold (sext x) -> (zext x) if the sign bit is known zero.
15133 if (!TLI.isSExtCheaperThanZExt(FromTy: N0.getValueType(), ToTy: VT) &&
15134 (!LegalOperations || TLI.isOperationLegal(Op: ISD::ZERO_EXTEND, VT)) &&
15135 DAG.SignBitIsZero(Op: N0))
15136 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: N0, Flags: SDNodeFlags::NonNeg);
15137
15138 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(Cast: N))
15139 return NewVSel;
15140
15141 // Eliminate this sign extend by doing a negation in the destination type:
15142 // sext i32 (0 - (zext i8 X to i32)) to i64 --> 0 - (zext i8 X to i64)
15143 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
15144 isNullOrNullSplat(V: N0.getOperand(i: 0)) &&
15145 N0.getOperand(i: 1).getOpcode() == ISD::ZERO_EXTEND &&
15146 TLI.isOperationLegalOrCustom(Op: ISD::SUB, VT)) {
15147 SDValue Zext = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1).getOperand(i: 0), DL, VT);
15148 return DAG.getNegative(Val: Zext, DL, VT);
15149 }
15150 // Eliminate this sign extend by doing a decrement in the destination type:
15151 // sext i32 ((zext i8 X to i32) + (-1)) to i64 --> (zext i8 X to i64) + (-1)
15152 if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
15153 isAllOnesOrAllOnesSplat(V: N0.getOperand(i: 1)) &&
15154 N0.getOperand(i: 0).getOpcode() == ISD::ZERO_EXTEND &&
15155 TLI.isOperationLegalOrCustom(Op: ISD::ADD, VT)) {
15156 SDValue Zext = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 0).getOperand(i: 0), DL, VT);
15157 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Zext, N2: DAG.getAllOnesConstant(DL, VT));
15158 }
15159
15160 // fold sext (not i1 X) -> add (zext i1 X), -1
15161 // TODO: This could be extended to handle bool vectors.
15162 if (N0.getValueType() == MVT::i1 && isBitwiseNot(V: N0) && N0.hasOneUse() &&
15163 (!LegalOperations || (TLI.isOperationLegal(Op: ISD::ZERO_EXTEND, VT) &&
15164 TLI.isOperationLegal(Op: ISD::ADD, VT)))) {
15165 // If we can eliminate the 'not', the sext form should be better
15166 if (SDValue NewXor = visitXOR(N: N0.getNode())) {
15167 // Returning N0 is a form of in-visit replacement that may have
15168 // invalidated N0.
15169 if (NewXor.getNode() == N0.getNode()) {
15170 // Return SDValue here as the xor should have already been replaced in
15171 // this sext.
15172 return SDValue();
15173 }
15174
15175 // Return a new sext with the new xor.
15176 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: NewXor);
15177 }
15178
15179 SDValue Zext = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: N0.getOperand(i: 0));
15180 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Zext, N2: DAG.getAllOnesConstant(DL, VT));
15181 }
15182
15183 if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, DL, Level))
15184 return Res;
15185
15186 return SDValue();
15187}
15188
15189/// Given an extending node with a pop-count operand, if the target does not
15190/// support a pop-count in the narrow source type but does support it in the
15191/// destination type, widen the pop-count to the destination type.
15192static SDValue widenCtPop(SDNode *Extend, SelectionDAG &DAG, const SDLoc &DL) {
15193 assert((Extend->getOpcode() == ISD::ZERO_EXTEND ||
15194 Extend->getOpcode() == ISD::ANY_EXTEND) &&
15195 "Expected extend op");
15196
15197 SDValue CtPop = Extend->getOperand(Num: 0);
15198 if (CtPop.getOpcode() != ISD::CTPOP || !CtPop.hasOneUse())
15199 return SDValue();
15200
15201 EVT VT = Extend->getValueType(ResNo: 0);
15202 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
15203 if (TLI.isOperationLegalOrCustom(Op: ISD::CTPOP, VT: CtPop.getValueType()) ||
15204 !TLI.isOperationLegalOrCustom(Op: ISD::CTPOP, VT))
15205 return SDValue();
15206
15207 // zext (ctpop X) --> ctpop (zext X)
15208 SDValue NewZext = DAG.getZExtOrTrunc(Op: CtPop.getOperand(i: 0), DL, VT);
15209 return DAG.getNode(Opcode: ISD::CTPOP, DL, VT, Operand: NewZext);
15210}
15211
15212// If we have (zext (abs X)) where X is a type that will be promoted by type
15213// legalization, convert to (abs (sext X)). But don't extend past a legal type.
15214static SDValue widenAbs(SDNode *Extend, SelectionDAG &DAG) {
15215 assert(Extend->getOpcode() == ISD::ZERO_EXTEND && "Expected zero extend.");
15216
15217 EVT VT = Extend->getValueType(ResNo: 0);
15218 if (VT.isVector())
15219 return SDValue();
15220
15221 SDValue Abs = Extend->getOperand(Num: 0);
15222 if (Abs.getOpcode() != ISD::ABS || !Abs.hasOneUse())
15223 return SDValue();
15224
15225 EVT AbsVT = Abs.getValueType();
15226 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
15227 if (TLI.getTypeAction(Context&: *DAG.getContext(), VT: AbsVT) !=
15228 TargetLowering::TypePromoteInteger)
15229 return SDValue();
15230
15231 EVT LegalVT = TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT: AbsVT);
15232
15233 SDValue SExt =
15234 DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL: SDLoc(Abs), VT: LegalVT, Operand: Abs.getOperand(i: 0));
15235 SDValue NewAbs = DAG.getNode(Opcode: ISD::ABS, DL: SDLoc(Abs), VT: LegalVT, Operand: SExt);
15236 return DAG.getZExtOrTrunc(Op: NewAbs, DL: SDLoc(Extend), VT);
15237}
15238
15239SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
15240 SDValue N0 = N->getOperand(Num: 0);
15241 EVT VT = N->getValueType(ResNo: 0);
15242 SDLoc DL(N);
15243
15244 if (VT.isVector())
15245 if (SDValue FoldedVOp = SimplifyVCastOp(N, DL))
15246 return FoldedVOp;
15247
15248 // zext(undef) = 0
15249 if (N0.isUndef())
15250 return DAG.getConstant(Val: 0, DL, VT);
15251
15252 if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
15253 return Res;
15254
15255 // fold (zext (zext x)) -> (zext x)
15256 // fold (zext (aext x)) -> (zext x)
15257 if (N0.getOpcode() == ISD::ZERO_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) {
15258 SDNodeFlags Flags;
15259 if (N0.getOpcode() == ISD::ZERO_EXTEND)
15260 Flags.setNonNeg(N0->getFlags().hasNonNeg());
15261 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: N0.getOperand(i: 0), Flags);
15262 }
15263
15264 // fold (zext (aext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x)
15265 // fold (zext (zext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x)
15266 if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
15267 N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG)
15268 return DAG.getNode(Opcode: ISD::ZERO_EXTEND_VECTOR_INREG, DL, VT, Operand: N0.getOperand(i: 0));
15269
15270 // fold (zext (truncate x)) -> (zext x) or
15271 // (zext (truncate x)) -> (truncate x)
15272 // This is valid when the truncated bits of x are already zero.
15273 SDValue Op;
15274 KnownBits Known;
15275 if (isTruncateOf(DAG, N: N0, Op, Known)) {
15276 APInt TruncatedBits =
15277 (Op.getScalarValueSizeInBits() == N0.getScalarValueSizeInBits()) ?
15278 APInt(Op.getScalarValueSizeInBits(), 0) :
15279 APInt::getBitsSet(numBits: Op.getScalarValueSizeInBits(),
15280 loBit: N0.getScalarValueSizeInBits(),
15281 hiBit: std::min(a: Op.getScalarValueSizeInBits(),
15282 b: VT.getScalarSizeInBits()));
15283 if (TruncatedBits.isSubsetOf(RHS: Known.Zero)) {
15284 SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, DL, VT);
15285 DAG.salvageDebugInfo(N&: *N0.getNode());
15286
15287 return ZExtOrTrunc;
15288 }
15289 }
15290
15291 // fold (zext (truncate x)) -> (and x, mask)
15292 if (N0.getOpcode() == ISD::TRUNCATE) {
15293 // fold (zext (truncate (load x))) -> (zext (smaller load x))
15294 // fold (zext (truncate (srl (load x), c))) -> (zext (smaller load (x+c/n)))
15295 if (SDValue NarrowLoad = reduceLoadWidth(N: N0.getNode())) {
15296 SDNode *oye = N0.getOperand(i: 0).getNode();
15297 if (NarrowLoad.getNode() != N0.getNode()) {
15298 CombineTo(N: N0.getNode(), Res: NarrowLoad);
15299 // CombineTo deleted the truncate, if needed, but not what's under it.
15300 AddToWorklist(N: oye);
15301 }
15302 return SDValue(N, 0); // Return N so it doesn't get rechecked!
15303 }
15304
15305 EVT SrcVT = N0.getOperand(i: 0).getValueType();
15306 EVT MinVT = N0.getValueType();
15307
15308 if (N->getFlags().hasNonNeg()) {
15309 SDValue Op = N0.getOperand(i: 0);
15310 unsigned OpBits = SrcVT.getScalarSizeInBits();
15311 unsigned MidBits = MinVT.getScalarSizeInBits();
15312 unsigned DestBits = VT.getScalarSizeInBits();
15313
15314 if (N0->getFlags().hasNoSignedWrap() ||
15315 DAG.ComputeNumSignBits(Op) > OpBits - MidBits) {
15316 if (OpBits == DestBits) {
15317 // Op is i32, Mid is i8, and Dest is i32. If Op has more than 24 sign
15318 // bits, it is already ready.
15319 return Op;
15320 }
15321
15322 if (OpBits < DestBits) {
15323 // Op is i32, Mid is i8, and Dest is i64. If Op has more than 24 sign
15324 // bits, just sext from i32.
15325 // FIXME: This can probably be ZERO_EXTEND nneg?
15326 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: Op);
15327 }
15328
15329 // Op is i64, Mid is i8, and Dest is i32. If Op has more than 56 sign
15330 // bits, just truncate to i32.
15331 SDNodeFlags Flags;
15332 Flags.setNoSignedWrap(true);
15333 Flags.setNoUnsignedWrap(true);
15334 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: Op, Flags);
15335 }
15336 }
15337
15338 // Try to mask before the extension to avoid having to generate a larger mask,
15339 // possibly over several sub-vectors.
15340 if (SrcVT.bitsLT(VT) && VT.isVector()) {
15341 if (!LegalOperations || (TLI.isOperationLegal(Op: ISD::AND, VT: SrcVT) &&
15342 TLI.isOperationLegal(Op: ISD::ZERO_EXTEND, VT))) {
15343 SDValue Op = N0.getOperand(i: 0);
15344 Op = DAG.getZeroExtendInReg(Op, DL, VT: MinVT);
15345 AddToWorklist(N: Op.getNode());
15346 SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, DL, VT);
15347 // Transfer the debug info; the new node is equivalent to N0.
15348 DAG.transferDbgValues(From: N0, To: ZExtOrTrunc);
15349 return ZExtOrTrunc;
15350 }
15351 }
15352
15353 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::AND, VT)) {
15354 SDValue Op = DAG.getAnyExtOrTrunc(Op: N0.getOperand(i: 0), DL, VT);
15355 AddToWorklist(N: Op.getNode());
15356 SDValue And = DAG.getZeroExtendInReg(Op, DL, VT: MinVT);
15357 // We may safely transfer the debug info describing the truncate node over
15358 // to the equivalent and operation.
15359 DAG.transferDbgValues(From: N0, To: And);
15360 return And;
15361 }
15362 }
15363
15364 // Fold (zext (and (trunc x), cst)) -> (and x, cst),
15365 // if either of the casts is not free.
15366 if (N0.getOpcode() == ISD::AND &&
15367 N0.getOperand(i: 0).getOpcode() == ISD::TRUNCATE &&
15368 N0.getOperand(i: 1).getOpcode() == ISD::Constant &&
15369 (!TLI.isTruncateFree(Val: N0.getOperand(i: 0).getOperand(i: 0), VT2: N0.getValueType()) ||
15370 !TLI.isZExtFree(FromTy: N0.getValueType(), ToTy: VT))) {
15371 SDValue X = N0.getOperand(i: 0).getOperand(i: 0);
15372 X = DAG.getAnyExtOrTrunc(Op: X, DL: SDLoc(X), VT);
15373 APInt Mask = N0.getConstantOperandAPInt(i: 1).zext(width: VT.getSizeInBits());
15374 return DAG.getNode(Opcode: ISD::AND, DL, VT,
15375 N1: X, N2: DAG.getConstant(Val: Mask, DL, VT));
15376 }
15377
15378 // Try to simplify (zext (load x)).
15379 if (SDValue foldedExt = tryToFoldExtOfLoad(
15380 DAG, Combiner&: *this, TLI, VT, LegalOperations, N, N0, ExtLoadType: ISD::ZEXTLOAD,
15381 ExtOpc: ISD::ZERO_EXTEND, NonNegZExt: N->getFlags().hasNonNeg()))
15382 return foldedExt;
15383
15384 if (SDValue foldedExt =
15385 tryToFoldExtOfMaskedLoad(DAG, TLI, VT, LegalOperations, N, N0,
15386 ExtLoadType: ISD::ZEXTLOAD, ExtOpc: ISD::ZERO_EXTEND))
15387 return foldedExt;
15388
15389 // fold (zext (load x)) to multiple smaller zextloads.
15390 // Only on illegal but splittable vectors.
15391 if (SDValue ExtLoad = CombineExtLoad(N))
15392 return ExtLoad;
15393
15394 // Try to simplify (zext (atomic_load x)).
15395 if (SDValue foldedExt =
15396 tryToFoldExtOfAtomicLoad(DAG, TLI, VT, N0, ExtLoadType: ISD::ZEXTLOAD))
15397 return foldedExt;
15398
15399 // fold (zext (and/or/xor (load x), cst)) ->
15400 // (and/or/xor (zextload x), (zext cst))
15401 // Unless (and (load x) cst) will match as a zextload already and has
15402 // additional users, or the zext is already free.
15403 if (ISD::isBitwiseLogicOp(Opcode: N0.getOpcode()) && !TLI.isZExtFree(Val: N0, VT2: VT) &&
15404 isa<LoadSDNode>(Val: N0.getOperand(i: 0)) &&
15405 N0.getOperand(i: 1).getOpcode() == ISD::Constant &&
15406 (!LegalOperations && TLI.isOperationLegal(Op: N0.getOpcode(), VT))) {
15407 LoadSDNode *LN00 = cast<LoadSDNode>(Val: N0.getOperand(i: 0));
15408 EVT MemVT = LN00->getMemoryVT();
15409 if (TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: VT, MemVT) &&
15410 LN00->getExtensionType() != ISD::SEXTLOAD && LN00->isUnindexed()) {
15411 bool DoXform = true;
15412 SmallVector<SDNode*, 4> SetCCs;
15413 if (!N0.hasOneUse()) {
15414 if (N0.getOpcode() == ISD::AND) {
15415 auto *AndC = cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
15416 EVT LoadResultTy = AndC->getValueType(ResNo: 0);
15417 EVT ExtVT;
15418 if (isAndLoadExtLoad(AndC, LoadN: LN00, LoadResultTy, ExtVT))
15419 DoXform = false;
15420 }
15421 }
15422 if (DoXform)
15423 DoXform = ExtendUsesToFormExtLoad(VT, N: N0.getNode(), N0: N0.getOperand(i: 0),
15424 ExtOpc: ISD::ZERO_EXTEND, ExtendNodes&: SetCCs, TLI);
15425 if (DoXform) {
15426 SDValue ExtLoad = DAG.getExtLoad(ExtType: ISD::ZEXTLOAD, dl: SDLoc(LN00), VT,
15427 Chain: LN00->getChain(), Ptr: LN00->getBasePtr(),
15428 MemVT: LN00->getMemoryVT(),
15429 MMO: LN00->getMemOperand());
15430 APInt Mask = N0.getConstantOperandAPInt(i: 1).zext(width: VT.getSizeInBits());
15431 SDValue And = DAG.getNode(Opcode: N0.getOpcode(), DL, VT,
15432 N1: ExtLoad, N2: DAG.getConstant(Val: Mask, DL, VT));
15433 ExtendSetCCUses(SetCCs, OrigLoad: N0.getOperand(i: 0), ExtLoad, ExtType: ISD::ZERO_EXTEND);
15434 bool NoReplaceTruncAnd = !N0.hasOneUse();
15435 bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
15436 CombineTo(N, Res: And);
15437 // If N0 has multiple uses, change other uses as well.
15438 if (NoReplaceTruncAnd) {
15439 SDValue TruncAnd =
15440 DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: N0.getValueType(), Operand: And);
15441 CombineTo(N: N0.getNode(), Res: TruncAnd);
15442 }
15443 if (NoReplaceTrunc) {
15444 DAG.ReplaceAllUsesOfValueWith(From: SDValue(LN00, 1), To: ExtLoad.getValue(R: 1));
15445 } else {
15446 SDValue Trunc = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(LN00),
15447 VT: LN00->getValueType(ResNo: 0), Operand: ExtLoad);
15448 CombineTo(N: LN00, Res0: Trunc, Res1: ExtLoad.getValue(R: 1));
15449 }
15450 return SDValue(N,0); // Return N so it doesn't get rechecked!
15451 }
15452 }
15453 }
15454
15455 // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
15456 // (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
15457 if (SDValue ZExtLoad = CombineZExtLogicopShiftLoad(N))
15458 return ZExtLoad;
15459
15460 // Try to simplify (zext (zextload x)).
15461 if (SDValue foldedExt = tryToFoldExtOfExtload(
15462 DAG, Combiner&: *this, TLI, VT, LegalOperations, N, N0, ExtLoadType: ISD::ZEXTLOAD))
15463 return foldedExt;
15464
15465 if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
15466 return V;
15467
15468 if (N0.getOpcode() == ISD::SETCC) {
15469 // Propagate fast-math-flags.
15470 SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
15471
15472 // Only do this before legalize for now.
15473 if (!LegalOperations && VT.isVector() &&
15474 N0.getValueType().getVectorElementType() == MVT::i1) {
15475 EVT N00VT = N0.getOperand(i: 0).getValueType();
15476 if (getSetCCResultType(VT: N00VT) == N0.getValueType())
15477 return SDValue();
15478
15479 // We know that the # elements of the results is the same as the #
15480 // elements of the compare (and the # elements of the compare result for
15481 // that matter). Check to see that they are the same size. If so, we know
15482 // that the element size of the sext'd result matches the element size of
15483 // the compare operands.
15484 if (VT.getSizeInBits() == N00VT.getSizeInBits()) {
15485 // zext(setcc) -> zext_in_reg(vsetcc) for vectors.
15486 SDValue VSetCC = DAG.getNode(Opcode: ISD::SETCC, DL, VT, N1: N0.getOperand(i: 0),
15487 N2: N0.getOperand(i: 1), N3: N0.getOperand(i: 2));
15488 return DAG.getZeroExtendInReg(Op: VSetCC, DL, VT: N0.getValueType());
15489 }
15490
15491 // If the desired elements are smaller or larger than the source
15492 // elements we can use a matching integer vector type and then
15493 // truncate/any extend followed by zext_in_reg.
15494 EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
15495 SDValue VsetCC =
15496 DAG.getNode(Opcode: ISD::SETCC, DL, VT: MatchingVectorType, N1: N0.getOperand(i: 0),
15497 N2: N0.getOperand(i: 1), N3: N0.getOperand(i: 2));
15498 return DAG.getZeroExtendInReg(Op: DAG.getAnyExtOrTrunc(Op: VsetCC, DL, VT), DL,
15499 VT: N0.getValueType());
15500 }
15501
15502 // zext(setcc x,y,cc) -> zext(select x, y, true, false, cc)
15503 EVT N0VT = N0.getValueType();
15504 EVT N00VT = N0.getOperand(i: 0).getValueType();
15505 if (SDValue SCC = SimplifySelectCC(
15506 DL, N0: N0.getOperand(i: 0), N1: N0.getOperand(i: 1),
15507 N2: DAG.getBoolConstant(V: true, DL, VT: N0VT, OpVT: N00VT),
15508 N3: DAG.getBoolConstant(V: false, DL, VT: N0VT, OpVT: N00VT),
15509 CC: cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get(), NotExtCompare: true))
15510 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: SCC);
15511 }
15512
15513 // (zext (shl (zext x), cst)) -> (shl (zext x), cst)
15514 if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) &&
15515 !TLI.isZExtFree(Val: N0, VT2: VT)) {
15516 SDValue ShVal = N0.getOperand(i: 0);
15517 SDValue ShAmt = N0.getOperand(i: 1);
15518 if (auto *ShAmtC = dyn_cast<ConstantSDNode>(Val&: ShAmt)) {
15519 if (ShVal.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse()) {
15520 if (N0.getOpcode() == ISD::SHL) {
15521 // If the original shl may be shifting out bits, do not perform this
15522 // transformation.
15523 unsigned KnownZeroBits = ShVal.getValueSizeInBits() -
15524 ShVal.getOperand(i: 0).getValueSizeInBits();
15525 if (ShAmtC->getAPIntValue().ugt(RHS: KnownZeroBits)) {
15526 // If the shift is too large, then see if we can deduce that the
15527 // shift is safe anyway.
15528
15529 // Check if the bits being shifted out are known to be zero.
15530 KnownBits KnownShVal = DAG.computeKnownBits(Op: ShVal);
15531 if (ShAmtC->getAPIntValue().ugt(RHS: KnownShVal.countMinLeadingZeros()))
15532 return SDValue();
15533 }
15534 }
15535
15536 // Ensure that the shift amount is wide enough for the shifted value.
15537 if (Log2_32_Ceil(Value: VT.getSizeInBits()) > ShAmt.getValueSizeInBits())
15538 ShAmt = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: MVT::i32, Operand: ShAmt);
15539
15540 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT,
15541 N1: DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: ShVal), N2: ShAmt);
15542 }
15543 }
15544 }
15545
15546 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(Cast: N))
15547 return NewVSel;
15548
15549 if (SDValue NewCtPop = widenCtPop(Extend: N, DAG, DL))
15550 return NewCtPop;
15551
15552 if (SDValue V = widenAbs(Extend: N, DAG))
15553 return V;
15554
15555 if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, DL, Level))
15556 return Res;
15557
15558 // CSE zext nneg with sext if the zext is not free.
15559 if (N->getFlags().hasNonNeg() && !TLI.isZExtFree(FromTy: N0.getValueType(), ToTy: VT)) {
15560 SDNode *CSENode = DAG.getNodeIfExists(Opcode: ISD::SIGN_EXTEND, VTList: N->getVTList(), Ops: N0);
15561 if (CSENode)
15562 return SDValue(CSENode, 0);
15563 }
15564
15565 return SDValue();
15566}
15567
15568SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
15569 SDValue N0 = N->getOperand(Num: 0);
15570 EVT VT = N->getValueType(ResNo: 0);
15571 SDLoc DL(N);
15572
15573 // aext(undef) = undef
15574 if (N0.isUndef())
15575 return DAG.getUNDEF(VT);
15576
15577 if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
15578 return Res;
15579
15580 // fold (aext (aext x)) -> (aext x)
15581 // fold (aext (zext x)) -> (zext x)
15582 // fold (aext (sext x)) -> (sext x)
15583 if (N0.getOpcode() == ISD::ANY_EXTEND || N0.getOpcode() == ISD::ZERO_EXTEND ||
15584 N0.getOpcode() == ISD::SIGN_EXTEND) {
15585 SDNodeFlags Flags;
15586 if (N0.getOpcode() == ISD::ZERO_EXTEND)
15587 Flags.setNonNeg(N0->getFlags().hasNonNeg());
15588 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, Operand: N0.getOperand(i: 0), Flags);
15589 }
15590
15591 // fold (aext (aext_extend_vector_inreg x)) -> (aext_extend_vector_inreg x)
15592 // fold (aext (zext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x)
15593 // fold (aext (sext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x)
15594 if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
15595 N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG ||
15596 N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG)
15597 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, Operand: N0.getOperand(i: 0));
15598
15599 // fold (aext (truncate (load x))) -> (aext (smaller load x))
15600 // fold (aext (truncate (srl (load x), c))) -> (aext (small load (x+c/n)))
15601 if (N0.getOpcode() == ISD::TRUNCATE) {
15602 if (SDValue NarrowLoad = reduceLoadWidth(N: N0.getNode())) {
15603 SDNode *oye = N0.getOperand(i: 0).getNode();
15604 if (NarrowLoad.getNode() != N0.getNode()) {
15605 CombineTo(N: N0.getNode(), Res: NarrowLoad);
15606 // CombineTo deleted the truncate, if needed, but not what's under it.
15607 AddToWorklist(N: oye);
15608 }
15609 return SDValue(N, 0); // Return N so it doesn't get rechecked!
15610 }
15611 }
15612
15613 // fold (aext (truncate x))
15614 if (N0.getOpcode() == ISD::TRUNCATE)
15615 return DAG.getAnyExtOrTrunc(Op: N0.getOperand(i: 0), DL, VT);
15616
15617 // Fold (aext (and (trunc x), cst)) -> (and x, cst)
15618 // if the trunc is not free.
15619 if (N0.getOpcode() == ISD::AND &&
15620 N0.getOperand(i: 0).getOpcode() == ISD::TRUNCATE &&
15621 N0.getOperand(i: 1).getOpcode() == ISD::Constant &&
15622 !TLI.isTruncateFree(Val: N0.getOperand(i: 0).getOperand(i: 0), VT2: N0.getValueType())) {
15623 SDValue X = DAG.getAnyExtOrTrunc(Op: N0.getOperand(i: 0).getOperand(i: 0), DL, VT);
15624 SDValue Y = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT, Operand: N0.getOperand(i: 1));
15625 assert(isa<ConstantSDNode>(Y) && "Expected constant to be folded!");
15626 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: X, N2: Y);
15627 }
15628
15629 // fold (aext (load x)) -> (aext (truncate (extload x)))
15630 // None of the supported targets knows how to perform load and any_ext
15631 // on vectors in one instruction, so attempt to fold to zext instead.
15632 if (VT.isVector()) {
15633 // Try to simplify (zext (load x)).
15634 if (SDValue foldedExt =
15635 tryToFoldExtOfLoad(DAG, Combiner&: *this, TLI, VT, LegalOperations, N, N0,
15636 ExtLoadType: ISD::ZEXTLOAD, ExtOpc: ISD::ZERO_EXTEND))
15637 return foldedExt;
15638 } else if (ISD::isNON_EXTLoad(N: N0.getNode()) &&
15639 ISD::isUNINDEXEDLoad(N: N0.getNode()) &&
15640 TLI.isLoadExtLegalOrCustom(ExtType: ISD::EXTLOAD, ValVT: VT, MemVT: N0.getValueType())) {
15641 bool DoXform = true;
15642 SmallVector<SDNode *, 4> SetCCs;
15643 if (!N0.hasOneUse())
15644 DoXform =
15645 ExtendUsesToFormExtLoad(VT, N, N0, ExtOpc: ISD::ANY_EXTEND, ExtendNodes&: SetCCs, TLI);
15646 if (DoXform) {
15647 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
15648 SDValue ExtLoad = DAG.getExtLoad(ExtType: ISD::EXTLOAD, dl: DL, VT, Chain: LN0->getChain(),
15649 Ptr: LN0->getBasePtr(), MemVT: N0.getValueType(),
15650 MMO: LN0->getMemOperand());
15651 ExtendSetCCUses(SetCCs, OrigLoad: N0, ExtLoad, ExtType: ISD::ANY_EXTEND);
15652 // If the load value is used only by N, replace it via CombineTo N.
15653 bool NoReplaceTrunc = N0.hasOneUse();
15654 CombineTo(N, Res: ExtLoad);
15655 if (NoReplaceTrunc) {
15656 DAG.ReplaceAllUsesOfValueWith(From: SDValue(LN0, 1), To: ExtLoad.getValue(R: 1));
15657 recursivelyDeleteUnusedNodes(N: LN0);
15658 } else {
15659 SDValue Trunc =
15660 DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(N0), VT: N0.getValueType(), Operand: ExtLoad);
15661 CombineTo(N: LN0, Res0: Trunc, Res1: ExtLoad.getValue(R: 1));
15662 }
15663 return SDValue(N, 0); // Return N so it doesn't get rechecked!
15664 }
15665 }
15666
15667 // fold (aext (zextload x)) -> (aext (truncate (zextload x)))
15668 // fold (aext (sextload x)) -> (aext (truncate (sextload x)))
15669 // fold (aext ( extload x)) -> (aext (truncate (extload x)))
15670 if (N0.getOpcode() == ISD::LOAD && !ISD::isNON_EXTLoad(N: N0.getNode()) &&
15671 ISD::isUNINDEXEDLoad(N: N0.getNode()) && N0.hasOneUse()) {
15672 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
15673 ISD::LoadExtType ExtType = LN0->getExtensionType();
15674 EVT MemVT = LN0->getMemoryVT();
15675 if (!LegalOperations || TLI.isLoadExtLegal(ExtType, ValVT: VT, MemVT)) {
15676 SDValue ExtLoad =
15677 DAG.getExtLoad(ExtType, dl: DL, VT, Chain: LN0->getChain(), Ptr: LN0->getBasePtr(),
15678 MemVT, MMO: LN0->getMemOperand());
15679 CombineTo(N, Res: ExtLoad);
15680 DAG.ReplaceAllUsesOfValueWith(From: SDValue(LN0, 1), To: ExtLoad.getValue(R: 1));
15681 recursivelyDeleteUnusedNodes(N: LN0);
15682 return SDValue(N, 0); // Return N so it doesn't get rechecked!
15683 }
15684 }
15685
15686 if (N0.getOpcode() == ISD::SETCC) {
15687 // Propagate fast-math-flags.
15688 SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
15689
15690 // For vectors:
15691 // aext(setcc) -> vsetcc
15692 // aext(setcc) -> truncate(vsetcc)
15693 // aext(setcc) -> aext(vsetcc)
15694 // Only do this before legalize for now.
15695 if (VT.isVector() && !LegalOperations) {
15696 EVT N00VT = N0.getOperand(i: 0).getValueType();
15697 if (getSetCCResultType(VT: N00VT) == N0.getValueType())
15698 return SDValue();
15699
15700 // We know that the # elements of the results is the same as the
15701 // # elements of the compare (and the # elements of the compare result
15702 // for that matter). Check to see that they are the same size. If so,
15703 // we know that the element size of the sext'd result matches the
15704 // element size of the compare operands.
15705 if (VT.getSizeInBits() == N00VT.getSizeInBits())
15706 return DAG.getSetCC(DL, VT, LHS: N0.getOperand(i: 0), RHS: N0.getOperand(i: 1),
15707 Cond: cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get());
15708
15709 // If the desired elements are smaller or larger than the source
15710 // elements we can use a matching integer vector type and then
15711 // truncate/any extend
15712 EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
15713 SDValue VsetCC = DAG.getSetCC(
15714 DL, VT: MatchingVectorType, LHS: N0.getOperand(i: 0), RHS: N0.getOperand(i: 1),
15715 Cond: cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get());
15716 return DAG.getAnyExtOrTrunc(Op: VsetCC, DL, VT);
15717 }
15718
15719 // aext(setcc x,y,cc) -> select_cc x, y, 1, 0, cc
15720 if (SDValue SCC = SimplifySelectCC(
15721 DL, N0: N0.getOperand(i: 0), N1: N0.getOperand(i: 1), N2: DAG.getConstant(Val: 1, DL, VT),
15722 N3: DAG.getConstant(Val: 0, DL, VT),
15723 CC: cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get(), NotExtCompare: true))
15724 return SCC;
15725 }
15726
15727 if (SDValue NewCtPop = widenCtPop(Extend: N, DAG, DL))
15728 return NewCtPop;
15729
15730 if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, DL, Level))
15731 return Res;
15732
15733 return SDValue();
15734}
15735
15736SDValue DAGCombiner::visitAssertExt(SDNode *N) {
15737 unsigned Opcode = N->getOpcode();
15738 SDValue N0 = N->getOperand(Num: 0);
15739 SDValue N1 = N->getOperand(Num: 1);
15740 EVT AssertVT = cast<VTSDNode>(Val&: N1)->getVT();
15741
15742 // fold (assert?ext (assert?ext x, vt), vt) -> (assert?ext x, vt)
15743 if (N0.getOpcode() == Opcode &&
15744 AssertVT == cast<VTSDNode>(Val: N0.getOperand(i: 1))->getVT())
15745 return N0;
15746
15747 if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
15748 N0.getOperand(i: 0).getOpcode() == Opcode) {
15749 // We have an assert, truncate, assert sandwich. Make one stronger assert
15750 // by asserting on the smallest asserted type to the larger source type.
15751 // This eliminates the later assert:
15752 // assert (trunc (assert X, i8) to iN), i1 --> trunc (assert X, i1) to iN
15753 // assert (trunc (assert X, i1) to iN), i8 --> trunc (assert X, i1) to iN
15754 SDLoc DL(N);
15755 SDValue BigA = N0.getOperand(i: 0);
15756 EVT BigA_AssertVT = cast<VTSDNode>(Val: BigA.getOperand(i: 1))->getVT();
15757 EVT MinAssertVT = AssertVT.bitsLT(VT: BigA_AssertVT) ? AssertVT : BigA_AssertVT;
15758 SDValue MinAssertVTVal = DAG.getValueType(MinAssertVT);
15759 SDValue NewAssert = DAG.getNode(Opcode, DL, VT: BigA.getValueType(),
15760 N1: BigA.getOperand(i: 0), N2: MinAssertVTVal);
15761 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: N->getValueType(ResNo: 0), Operand: NewAssert);
15762 }
15763
15764 // If we have (AssertZext (truncate (AssertSext X, iX)), iY) and Y is smaller
15765 // than X. Just move the AssertZext in front of the truncate and drop the
15766 // AssertSExt.
15767 if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
15768 N0.getOperand(i: 0).getOpcode() == ISD::AssertSext &&
15769 Opcode == ISD::AssertZext) {
15770 SDValue BigA = N0.getOperand(i: 0);
15771 EVT BigA_AssertVT = cast<VTSDNode>(Val: BigA.getOperand(i: 1))->getVT();
15772 if (AssertVT.bitsLT(VT: BigA_AssertVT)) {
15773 SDLoc DL(N);
15774 SDValue NewAssert = DAG.getNode(Opcode, DL, VT: BigA.getValueType(),
15775 N1: BigA.getOperand(i: 0), N2: N1);
15776 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: N->getValueType(ResNo: 0), Operand: NewAssert);
15777 }
15778 }
15779
15780 if (Opcode == ISD::AssertZext && N0.getOpcode() == ISD::AND &&
15781 isa<ConstantSDNode>(Val: N0.getOperand(i: 1))) {
15782 const APInt &Mask = N0.getConstantOperandAPInt(i: 1);
15783
15784 // If we have (AssertZext (and (AssertSext X, iX), M), iY) and Y is smaller
15785 // than X, and the And doesn't change the lower iX bits, we can move the
15786 // AssertZext in front of the And and drop the AssertSext.
15787 if (N0.getOperand(i: 0).getOpcode() == ISD::AssertSext && N0.hasOneUse()) {
15788 SDValue BigA = N0.getOperand(i: 0);
15789 EVT BigA_AssertVT = cast<VTSDNode>(Val: BigA.getOperand(i: 1))->getVT();
15790 if (AssertVT.bitsLT(VT: BigA_AssertVT) &&
15791 Mask.countr_one() >= BigA_AssertVT.getScalarSizeInBits()) {
15792 SDLoc DL(N);
15793 SDValue NewAssert =
15794 DAG.getNode(Opcode, DL, VT: N->getValueType(ResNo: 0), N1: BigA.getOperand(i: 0), N2: N1);
15795 return DAG.getNode(Opcode: ISD::AND, DL, VT: N->getValueType(ResNo: 0), N1: NewAssert,
15796 N2: N0.getOperand(i: 1));
15797 }
15798 }
15799
15800 // Remove AssertZext entirely if the mask guarantees the assertion cannot
15801 // fail.
15802 // TODO: Use KB countMinLeadingZeros to handle non-constant masks?
15803 if (Mask.isIntN(N: AssertVT.getScalarSizeInBits()))
15804 return N0;
15805 }
15806
15807 return SDValue();
15808}
15809
15810SDValue DAGCombiner::visitAssertAlign(SDNode *N) {
15811 SDLoc DL(N);
15812
15813 Align AL = cast<AssertAlignSDNode>(Val: N)->getAlign();
15814 SDValue N0 = N->getOperand(Num: 0);
15815
15816 // Fold (assertalign (assertalign x, AL0), AL1) ->
15817 // (assertalign x, max(AL0, AL1))
15818 if (auto *AAN = dyn_cast<AssertAlignSDNode>(Val&: N0))
15819 return DAG.getAssertAlign(DL, V: N0.getOperand(i: 0),
15820 A: std::max(a: AL, b: AAN->getAlign()));
15821
15822 // In rare cases, there are trivial arithmetic ops in source operands. Sink
15823 // this assert down to source operands so that those arithmetic ops could be
15824 // exposed to the DAG combining.
15825 switch (N0.getOpcode()) {
15826 default:
15827 break;
15828 case ISD::ADD:
15829 case ISD::PTRADD:
15830 case ISD::SUB: {
15831 unsigned AlignShift = Log2(A: AL);
15832 SDValue LHS = N0.getOperand(i: 0);
15833 SDValue RHS = N0.getOperand(i: 1);
15834 unsigned LHSAlignShift = DAG.computeKnownBits(Op: LHS).countMinTrailingZeros();
15835 unsigned RHSAlignShift = DAG.computeKnownBits(Op: RHS).countMinTrailingZeros();
15836 if (LHSAlignShift >= AlignShift || RHSAlignShift >= AlignShift) {
15837 if (LHSAlignShift < AlignShift)
15838 LHS = DAG.getAssertAlign(DL, V: LHS, A: AL);
15839 if (RHSAlignShift < AlignShift)
15840 RHS = DAG.getAssertAlign(DL, V: RHS, A: AL);
15841 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT: N0.getValueType(), N1: LHS, N2: RHS);
15842 }
15843 break;
15844 }
15845 }
15846
15847 return SDValue();
15848}
15849
15850/// If the result of a load is shifted/masked/truncated to an effectively
15851/// narrower type, try to transform the load to a narrower type and/or
15852/// use an extending load.
15853SDValue DAGCombiner::reduceLoadWidth(SDNode *N) {
15854 unsigned Opc = N->getOpcode();
15855
15856 ISD::LoadExtType ExtType = ISD::NON_EXTLOAD;
15857 SDValue N0 = N->getOperand(Num: 0);
15858 EVT VT = N->getValueType(ResNo: 0);
15859 EVT ExtVT = VT;
15860
15861 // This transformation isn't valid for vector loads.
15862 if (VT.isVector())
15863 return SDValue();
15864
15865 // The ShAmt variable is used to indicate that we've consumed a right
15866 // shift. I.e. we want to narrow the width of the load by skipping to load the
15867 // ShAmt least significant bits.
15868 unsigned ShAmt = 0;
15869 // A special case is when the least significant bits from the load are masked
15870 // away, but using an AND rather than a right shift. HasShiftedOffset is used
15871 // to indicate that the narrowed load should be left-shifted ShAmt bits to get
15872 // the result.
15873 unsigned ShiftedOffset = 0;
15874 // Special case: SIGN_EXTEND_INREG is basically truncating to ExtVT then
15875 // extended to VT.
15876 if (Opc == ISD::SIGN_EXTEND_INREG) {
15877 ExtType = ISD::SEXTLOAD;
15878 ExtVT = cast<VTSDNode>(Val: N->getOperand(Num: 1))->getVT();
15879 } else if (Opc == ISD::SRL || Opc == ISD::SRA) {
15880 // Another special-case: SRL/SRA is basically zero/sign-extending a narrower
15881 // value, or it may be shifting a higher subword, half or byte into the
15882 // lowest bits.
15883
15884 // Only handle shift with constant shift amount, and the shiftee must be a
15885 // load.
15886 auto *LN = dyn_cast<LoadSDNode>(Val&: N0);
15887 auto *N1C = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
15888 if (!N1C || !LN)
15889 return SDValue();
15890 // If the shift amount is larger than the memory type then we're not
15891 // accessing any of the loaded bytes.
15892 ShAmt = N1C->getZExtValue();
15893 uint64_t MemoryWidth = LN->getMemoryVT().getScalarSizeInBits();
15894 if (MemoryWidth <= ShAmt)
15895 return SDValue();
15896 // Attempt to fold away the SRL by using ZEXTLOAD and SRA by using SEXTLOAD.
15897 ExtType = Opc == ISD::SRL ? ISD::ZEXTLOAD : ISD::SEXTLOAD;
15898 ExtVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: MemoryWidth - ShAmt);
15899 // If original load is a SEXTLOAD then we can't simply replace it by a
15900 // ZEXTLOAD (we could potentially replace it by a more narrow SEXTLOAD
15901 // followed by a ZEXT, but that is not handled at the moment). Similarly if
15902 // the original load is a ZEXTLOAD and we want to use a SEXTLOAD.
15903 if ((LN->getExtensionType() == ISD::SEXTLOAD ||
15904 LN->getExtensionType() == ISD::ZEXTLOAD) &&
15905 LN->getExtensionType() != ExtType)
15906 return SDValue();
15907 } else if (Opc == ISD::AND) {
15908 // An AND with a constant mask is the same as a truncate + zero-extend.
15909 auto AndC = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
15910 if (!AndC)
15911 return SDValue();
15912
15913 const APInt &Mask = AndC->getAPIntValue();
15914 unsigned ActiveBits = 0;
15915 if (Mask.isMask()) {
15916 ActiveBits = Mask.countr_one();
15917 } else if (Mask.isShiftedMask(MaskIdx&: ShAmt, MaskLen&: ActiveBits)) {
15918 ShiftedOffset = ShAmt;
15919 } else {
15920 return SDValue();
15921 }
15922
15923 ExtType = ISD::ZEXTLOAD;
15924 ExtVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: ActiveBits);
15925 }
15926
15927 // In case Opc==SRL we've already prepared ExtVT/ExtType/ShAmt based on doing
15928 // a right shift. Here we redo some of those checks, to possibly adjust the
15929 // ExtVT even further based on "a masking AND". We could also end up here for
15930 // other reasons (e.g. based on Opc==TRUNCATE) and that is why some checks
15931 // need to be done here as well.
15932 if (Opc == ISD::SRL || N0.getOpcode() == ISD::SRL) {
15933 SDValue SRL = Opc == ISD::SRL ? SDValue(N, 0) : N0;
15934 // Bail out when the SRL has more than one use. This is done for historical
15935 // (undocumented) reasons. Maybe intent was to guard the AND-masking below
15936 // check below? And maybe it could be non-profitable to do the transform in
15937 // case the SRL has multiple uses and we get here with Opc!=ISD::SRL?
15938 // FIXME: Can't we just skip this check for the Opc==ISD::SRL case.
15939 if (!SRL.hasOneUse())
15940 return SDValue();
15941
15942 // Only handle shift with constant shift amount, and the shiftee must be a
15943 // load.
15944 auto *LN = dyn_cast<LoadSDNode>(Val: SRL.getOperand(i: 0));
15945 auto *SRL1C = dyn_cast<ConstantSDNode>(Val: SRL.getOperand(i: 1));
15946 if (!SRL1C || !LN)
15947 return SDValue();
15948
15949 // If the shift amount is larger than the input type then we're not
15950 // accessing any of the loaded bytes. If the load was a zextload/extload
15951 // then the result of the shift+trunc is zero/undef (handled elsewhere).
15952 ShAmt = SRL1C->getZExtValue();
15953 uint64_t MemoryWidth = LN->getMemoryVT().getSizeInBits();
15954 if (ShAmt >= MemoryWidth)
15955 return SDValue();
15956
15957 // Because a SRL must be assumed to *need* to zero-extend the high bits
15958 // (as opposed to anyext the high bits), we can't combine the zextload
15959 // lowering of SRL and an sextload.
15960 if (LN->getExtensionType() == ISD::SEXTLOAD)
15961 return SDValue();
15962
15963 // Avoid reading outside the memory accessed by the original load (could
15964 // happened if we only adjust the load base pointer by ShAmt). Instead we
15965 // try to narrow the load even further. The typical scenario here is:
15966 // (i64 (truncate (i96 (srl (load x), 64)))) ->
15967 // (i64 (truncate (i96 (zextload (load i32 + offset) from i32))))
15968 if (ExtVT.getScalarSizeInBits() > MemoryWidth - ShAmt) {
15969 // Don't replace sextload by zextload.
15970 if (ExtType == ISD::SEXTLOAD)
15971 return SDValue();
15972 // Narrow the load.
15973 ExtType = ISD::ZEXTLOAD;
15974 ExtVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: MemoryWidth - ShAmt);
15975 }
15976
15977 // If the SRL is only used by a masking AND, we may be able to adjust
15978 // the ExtVT to make the AND redundant.
15979 SDNode *Mask = *(SRL->user_begin());
15980 if (SRL.hasOneUse() && Mask->getOpcode() == ISD::AND &&
15981 isa<ConstantSDNode>(Val: Mask->getOperand(Num: 1))) {
15982 unsigned Offset, ActiveBits;
15983 const APInt& ShiftMask = Mask->getConstantOperandAPInt(Num: 1);
15984 if (ShiftMask.isMask()) {
15985 EVT MaskedVT =
15986 EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: ShiftMask.countr_one());
15987 // If the mask is smaller, recompute the type.
15988 if ((ExtVT.getScalarSizeInBits() > MaskedVT.getScalarSizeInBits()) &&
15989 TLI.isLoadExtLegal(ExtType, ValVT: SRL.getValueType(), MemVT: MaskedVT))
15990 ExtVT = MaskedVT;
15991 } else if (ExtType == ISD::ZEXTLOAD &&
15992 ShiftMask.isShiftedMask(MaskIdx&: Offset, MaskLen&: ActiveBits) &&
15993 (Offset + ShAmt) < VT.getScalarSizeInBits()) {
15994 EVT MaskedVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: ActiveBits);
15995 // If the mask is shifted we can use a narrower load and a shl to insert
15996 // the trailing zeros.
15997 if (((Offset + ActiveBits) <= ExtVT.getScalarSizeInBits()) &&
15998 TLI.isLoadExtLegal(ExtType, ValVT: SRL.getValueType(), MemVT: MaskedVT)) {
15999 ExtVT = MaskedVT;
16000 ShAmt = Offset + ShAmt;
16001 ShiftedOffset = Offset;
16002 }
16003 }
16004 }
16005
16006 N0 = SRL.getOperand(i: 0);
16007 }
16008
16009 // If the load is shifted left (and the result isn't shifted back right), we
16010 // can fold a truncate through the shift. The typical scenario is that N
16011 // points at a TRUNCATE here so the attempted fold is:
16012 // (truncate (shl (load x), c))) -> (shl (narrow load x), c)
16013 // ShLeftAmt will indicate how much a narrowed load should be shifted left.
16014 unsigned ShLeftAmt = 0;
16015 if (ShAmt == 0 && N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
16016 ExtVT == VT && TLI.isNarrowingProfitable(N, SrcVT: N0.getValueType(), DestVT: VT)) {
16017 if (ConstantSDNode *N01 = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1))) {
16018 ShLeftAmt = N01->getZExtValue();
16019 N0 = N0.getOperand(i: 0);
16020 }
16021 }
16022
16023 // If we haven't found a load, we can't narrow it.
16024 if (!isa<LoadSDNode>(Val: N0))
16025 return SDValue();
16026
16027 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
16028 // Reducing the width of a volatile load is illegal. For atomics, we may be
16029 // able to reduce the width provided we never widen again. (see D66309)
16030 if (!LN0->isSimple() ||
16031 !isLegalNarrowLdSt(LDST: LN0, ExtType, MemVT&: ExtVT, ShAmt))
16032 return SDValue();
16033
16034 auto AdjustBigEndianShift = [&](unsigned ShAmt) {
16035 unsigned LVTStoreBits =
16036 LN0->getMemoryVT().getStoreSizeInBits().getFixedValue();
16037 unsigned EVTStoreBits = ExtVT.getStoreSizeInBits().getFixedValue();
16038 return LVTStoreBits - EVTStoreBits - ShAmt;
16039 };
16040
16041 // We need to adjust the pointer to the load by ShAmt bits in order to load
16042 // the correct bytes.
16043 unsigned PtrAdjustmentInBits =
16044 DAG.getDataLayout().isBigEndian() ? AdjustBigEndianShift(ShAmt) : ShAmt;
16045
16046 uint64_t PtrOff = PtrAdjustmentInBits / 8;
16047 SDLoc DL(LN0);
16048 // The original load itself didn't wrap, so an offset within it doesn't.
16049 SDValue NewPtr =
16050 DAG.getMemBasePlusOffset(Base: LN0->getBasePtr(), Offset: TypeSize::getFixed(ExactSize: PtrOff),
16051 DL, Flags: SDNodeFlags::NoUnsignedWrap);
16052 AddToWorklist(N: NewPtr.getNode());
16053
16054 SDValue Load;
16055 if (ExtType == ISD::NON_EXTLOAD) {
16056 const MDNode *OldRanges = LN0->getRanges();
16057 const MDNode *NewRanges = nullptr;
16058 // If LSBs are loaded and the truncated ConstantRange for the OldRanges
16059 // metadata is not the full-set for the new width then create a NewRanges
16060 // metadata for the truncated load
16061 if (ShAmt == 0 && OldRanges) {
16062 ConstantRange CR = getConstantRangeFromMetadata(RangeMD: *OldRanges);
16063 unsigned BitSize = VT.getScalarSizeInBits();
16064
16065 // It is possible for an 8-bit extending load with 8-bit range
16066 // metadata to be narrowed to an 8-bit load. This guard is necessary to
16067 // ensure that truncation is strictly smaller.
16068 if (CR.getBitWidth() > BitSize) {
16069 ConstantRange TruncatedCR = CR.truncate(BitWidth: BitSize);
16070 if (!TruncatedCR.isFullSet()) {
16071 Metadata *Bounds[2] = {
16072 ConstantAsMetadata::get(
16073 C: ConstantInt::get(Context&: *DAG.getContext(), V: TruncatedCR.getLower())),
16074 ConstantAsMetadata::get(
16075 C: ConstantInt::get(Context&: *DAG.getContext(), V: TruncatedCR.getUpper()))};
16076 NewRanges = MDNode::get(Context&: *DAG.getContext(), MDs: Bounds);
16077 }
16078 } else if (CR.getBitWidth() == BitSize)
16079 NewRanges = OldRanges;
16080 }
16081 Load = DAG.getLoad(VT, dl: DL, Chain: LN0->getChain(), Ptr: NewPtr,
16082 PtrInfo: LN0->getPointerInfo().getWithOffset(O: PtrOff),
16083 Alignment: LN0->getBaseAlign(), MMOFlags: LN0->getMemOperand()->getFlags(),
16084 AAInfo: LN0->getAAInfo(), Ranges: NewRanges);
16085 } else
16086 Load = DAG.getExtLoad(ExtType, dl: DL, VT, Chain: LN0->getChain(), Ptr: NewPtr,
16087 PtrInfo: LN0->getPointerInfo().getWithOffset(O: PtrOff), MemVT: ExtVT,
16088 Alignment: LN0->getBaseAlign(), MMOFlags: LN0->getMemOperand()->getFlags(),
16089 AAInfo: LN0->getAAInfo());
16090
16091 // Replace the old load's chain with the new load's chain.
16092 WorklistRemover DeadNodes(*this);
16093 DAG.ReplaceAllUsesOfValueWith(From: N0.getValue(R: 1), To: Load.getValue(R: 1));
16094
16095 // Shift the result left, if we've swallowed a left shift.
16096 SDValue Result = Load;
16097 if (ShLeftAmt != 0) {
16098 // If the shift amount is as large as the result size (but, presumably,
16099 // no larger than the source) then the useful bits of the result are
16100 // zero; we can't simply return the shortened shift, because the result
16101 // of that operation is undefined.
16102 if (ShLeftAmt >= VT.getScalarSizeInBits())
16103 Result = DAG.getConstant(Val: 0, DL, VT);
16104 else
16105 Result = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Result,
16106 N2: DAG.getShiftAmountConstant(Val: ShLeftAmt, VT, DL));
16107 }
16108
16109 if (ShiftedOffset != 0) {
16110 // We're using a shifted mask, so the load now has an offset. This means
16111 // that data has been loaded into the lower bytes than it would have been
16112 // before, so we need to shl the loaded data into the correct position in the
16113 // register.
16114 SDValue ShiftC = DAG.getConstant(Val: ShiftedOffset, DL, VT);
16115 Result = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Result, N2: ShiftC);
16116 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Result);
16117 }
16118
16119 // Return the new loaded value.
16120 return Result;
16121}
16122
16123SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
16124 SDValue N0 = N->getOperand(Num: 0);
16125 SDValue N1 = N->getOperand(Num: 1);
16126 EVT VT = N->getValueType(ResNo: 0);
16127 EVT ExtVT = cast<VTSDNode>(Val&: N1)->getVT();
16128 unsigned VTBits = VT.getScalarSizeInBits();
16129 unsigned ExtVTBits = ExtVT.getScalarSizeInBits();
16130 SDLoc DL(N);
16131
16132 // sext_vector_inreg(undef) = 0 because the top bit will all be the same.
16133 if (N0.isUndef())
16134 return DAG.getConstant(Val: 0, DL, VT);
16135
16136 // fold (sext_in_reg c1) -> c1
16137 if (SDValue C =
16138 DAG.FoldConstantArithmetic(Opcode: ISD::SIGN_EXTEND_INREG, DL, VT, Ops: {N0, N1}))
16139 return C;
16140
16141 // If the input is already sign extended, just drop the extension.
16142 if (ExtVTBits >= DAG.ComputeMaxSignificantBits(Op: N0))
16143 return N0;
16144
16145 // fold (sext_in_reg (sext_in_reg x, VT2), VT1) -> (sext_in_reg x, minVT) pt2
16146 if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
16147 ExtVT.bitsLT(VT: cast<VTSDNode>(Val: N0.getOperand(i: 1))->getVT()))
16148 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL, VT, N1: N0.getOperand(i: 0), N2: N1);
16149
16150 // fold (sext_in_reg (sext x)) -> (sext x)
16151 // fold (sext_in_reg (aext x)) -> (sext x)
16152 // if x is small enough or if we know that x has more than 1 sign bit and the
16153 // sign_extend_inreg is extending from one of them.
16154 if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) {
16155 SDValue N00 = N0.getOperand(i: 0);
16156 unsigned N00Bits = N00.getScalarValueSizeInBits();
16157 if ((N00Bits <= ExtVTBits ||
16158 DAG.ComputeMaxSignificantBits(Op: N00) <= ExtVTBits) &&
16159 (!LegalOperations || TLI.isOperationLegal(Op: ISD::SIGN_EXTEND, VT)))
16160 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: N00);
16161 }
16162
16163 // fold (sext_in_reg (*_extend_vector_inreg x)) -> (sext_vector_inreg x)
16164 // if x is small enough or if we know that x has more than 1 sign bit and the
16165 // sign_extend_inreg is extending from one of them.
16166 if (ISD::isExtVecInRegOpcode(Opcode: N0.getOpcode())) {
16167 SDValue N00 = N0.getOperand(i: 0);
16168 unsigned N00Bits = N00.getScalarValueSizeInBits();
16169 bool IsZext = N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG;
16170 if ((N00Bits == ExtVTBits ||
16171 (!IsZext && (N00Bits < ExtVTBits ||
16172 DAG.ComputeMaxSignificantBits(Op: N00) <= ExtVTBits))) &&
16173 (!LegalOperations ||
16174 TLI.isOperationLegal(Op: ISD::SIGN_EXTEND_VECTOR_INREG, VT)))
16175 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_VECTOR_INREG, DL, VT, Operand: N00);
16176 }
16177
16178 // fold (sext_in_reg (zext x)) -> (sext x)
16179 // iff we are extending the source sign bit.
16180 if (N0.getOpcode() == ISD::ZERO_EXTEND) {
16181 SDValue N00 = N0.getOperand(i: 0);
16182 if (N00.getScalarValueSizeInBits() == ExtVTBits &&
16183 (!LegalOperations || TLI.isOperationLegal(Op: ISD::SIGN_EXTEND, VT)))
16184 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: N00);
16185 }
16186
16187 // fold (sext_in_reg x) -> (zext_in_reg x) if the sign bit is known zero.
16188 if (DAG.MaskedValueIsZero(Op: N0, Mask: APInt::getOneBitSet(numBits: VTBits, BitNo: ExtVTBits - 1)))
16189 return DAG.getZeroExtendInReg(Op: N0, DL, VT: ExtVT);
16190
16191 // fold operands of sext_in_reg based on knowledge that the top bits are not
16192 // demanded.
16193 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
16194 return SDValue(N, 0);
16195
16196 // fold (sext_in_reg (load x)) -> (smaller sextload x)
16197 // fold (sext_in_reg (srl (load x), c)) -> (smaller sextload (x+c/evtbits))
16198 if (SDValue NarrowLoad = reduceLoadWidth(N))
16199 return NarrowLoad;
16200
16201 // fold (sext_in_reg (srl X, 24), i8) -> (sra X, 24)
16202 // fold (sext_in_reg (srl X, 23), i8) -> (sra X, 23) iff possible.
16203 // We already fold "(sext_in_reg (srl X, 25), i8) -> srl X, 25" above.
16204 if (N0.getOpcode() == ISD::SRL) {
16205 if (auto *ShAmt = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1)))
16206 if (ShAmt->getAPIntValue().ule(RHS: VTBits - ExtVTBits)) {
16207 // We can turn this into an SRA iff the input to the SRL is already sign
16208 // extended enough.
16209 unsigned InSignBits = DAG.ComputeNumSignBits(Op: N0.getOperand(i: 0));
16210 if (((VTBits - ExtVTBits) - ShAmt->getZExtValue()) < InSignBits)
16211 return DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: N0.getOperand(i: 0),
16212 N2: N0.getOperand(i: 1));
16213 }
16214 }
16215
16216 // fold (sext_inreg (extload x)) -> (sextload x)
16217 // If sextload is not supported by target, we can only do the combine when
16218 // load has one use. Doing otherwise can block folding the extload with other
16219 // extends that the target does support.
16220 if (ISD::isEXTLoad(N: N0.getNode()) && ISD::isUNINDEXEDLoad(N: N0.getNode()) &&
16221 ExtVT == cast<LoadSDNode>(Val&: N0)->getMemoryVT() &&
16222 ((!LegalOperations && cast<LoadSDNode>(Val&: N0)->isSimple() &&
16223 N0.hasOneUse()) ||
16224 TLI.isLoadExtLegal(ExtType: ISD::SEXTLOAD, ValVT: VT, MemVT: ExtVT))) {
16225 auto *LN0 = cast<LoadSDNode>(Val&: N0);
16226 SDValue ExtLoad =
16227 DAG.getExtLoad(ExtType: ISD::SEXTLOAD, dl: DL, VT, Chain: LN0->getChain(),
16228 Ptr: LN0->getBasePtr(), MemVT: ExtVT, MMO: LN0->getMemOperand());
16229 CombineTo(N, Res: ExtLoad);
16230 CombineTo(N: N0.getNode(), Res0: ExtLoad, Res1: ExtLoad.getValue(R: 1));
16231 AddToWorklist(N: ExtLoad.getNode());
16232 return SDValue(N, 0); // Return N so it doesn't get rechecked!
16233 }
16234
16235 // fold (sext_inreg (zextload x)) -> (sextload x) iff load has one use
16236 if (ISD::isZEXTLoad(N: N0.getNode()) && ISD::isUNINDEXEDLoad(N: N0.getNode()) &&
16237 N0.hasOneUse() && ExtVT == cast<LoadSDNode>(Val&: N0)->getMemoryVT() &&
16238 ((!LegalOperations && cast<LoadSDNode>(Val&: N0)->isSimple()) &&
16239 TLI.isLoadExtLegal(ExtType: ISD::SEXTLOAD, ValVT: VT, MemVT: ExtVT))) {
16240 auto *LN0 = cast<LoadSDNode>(Val&: N0);
16241 SDValue ExtLoad =
16242 DAG.getExtLoad(ExtType: ISD::SEXTLOAD, dl: DL, VT, Chain: LN0->getChain(),
16243 Ptr: LN0->getBasePtr(), MemVT: ExtVT, MMO: LN0->getMemOperand());
16244 CombineTo(N, Res: ExtLoad);
16245 CombineTo(N: N0.getNode(), Res0: ExtLoad, Res1: ExtLoad.getValue(R: 1));
16246 return SDValue(N, 0); // Return N so it doesn't get rechecked!
16247 }
16248
16249 // fold (sext_inreg (masked_load x)) -> (sext_masked_load x)
16250 // ignore it if the masked load is already sign extended
16251 bool Frozen = N0.getOpcode() == ISD::FREEZE && N0.hasOneUse();
16252 if (auto *Ld = dyn_cast<MaskedLoadSDNode>(Val: Frozen ? N0.getOperand(i: 0) : N0)) {
16253 if (ExtVT == Ld->getMemoryVT() && Ld->hasNUsesOfValue(NUses: 1, Value: 0) &&
16254 Ld->getExtensionType() != ISD::LoadExtType::NON_EXTLOAD &&
16255 TLI.isLoadExtLegal(ExtType: ISD::SEXTLOAD, ValVT: VT, MemVT: ExtVT)) {
16256 SDValue ExtMaskedLoad = DAG.getMaskedLoad(
16257 VT, dl: DL, Chain: Ld->getChain(), Base: Ld->getBasePtr(), Offset: Ld->getOffset(),
16258 Mask: Ld->getMask(), Src0: Ld->getPassThru(), MemVT: ExtVT, MMO: Ld->getMemOperand(),
16259 AM: Ld->getAddressingMode(), ISD::SEXTLOAD, IsExpanding: Ld->isExpandingLoad());
16260 CombineTo(N, Res: Frozen ? N0 : ExtMaskedLoad);
16261 CombineTo(N: Ld, Res0: ExtMaskedLoad, Res1: ExtMaskedLoad.getValue(R: 1));
16262 return SDValue(N, 0); // Return N so it doesn't get rechecked!
16263 }
16264 }
16265
16266 // fold (sext_inreg (masked_gather x)) -> (sext_masked_gather x)
16267 if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(Val&: N0)) {
16268 if (SDValue(GN0, 0).hasOneUse() && ExtVT == GN0->getMemoryVT() &&
16269 TLI.isVectorLoadExtDesirable(ExtVal: SDValue(N, 0))) {
16270 SDValue Ops[] = {GN0->getChain(), GN0->getPassThru(), GN0->getMask(),
16271 GN0->getBasePtr(), GN0->getIndex(), GN0->getScale()};
16272
16273 SDValue ExtLoad = DAG.getMaskedGather(
16274 VTs: DAG.getVTList(VT1: VT, VT2: MVT::Other), MemVT: ExtVT, dl: DL, Ops, MMO: GN0->getMemOperand(),
16275 IndexType: GN0->getIndexType(), ExtTy: ISD::SEXTLOAD);
16276
16277 CombineTo(N, Res: ExtLoad);
16278 CombineTo(N: N0.getNode(), Res0: ExtLoad, Res1: ExtLoad.getValue(R: 1));
16279 AddToWorklist(N: ExtLoad.getNode());
16280 return SDValue(N, 0); // Return N so it doesn't get rechecked!
16281 }
16282 }
16283
16284 // Form (sext_inreg (bswap >> 16)) or (sext_inreg (rotl (bswap) 16))
16285 if (ExtVTBits <= 16 && N0.getOpcode() == ISD::OR) {
16286 if (SDValue BSwap = MatchBSwapHWordLow(N: N0.getNode(), N0: N0.getOperand(i: 0),
16287 N1: N0.getOperand(i: 1), DemandHighBits: false))
16288 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL, VT, N1: BSwap, N2: N1);
16289 }
16290
16291 // Fold (iM_signext_inreg
16292 // (extract_subvector (zext|anyext|sext iN_v to _) _)
16293 // from iN)
16294 // -> (extract_subvector (signext iN_v to iM))
16295 if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR && N0.hasOneUse() &&
16296 ISD::isExtOpcode(Opcode: N0.getOperand(i: 0).getOpcode())) {
16297 SDValue InnerExt = N0.getOperand(i: 0);
16298 EVT InnerExtVT = InnerExt->getValueType(ResNo: 0);
16299 SDValue Extendee = InnerExt->getOperand(Num: 0);
16300
16301 if (ExtVTBits == Extendee.getValueType().getScalarSizeInBits() &&
16302 (!LegalOperations ||
16303 TLI.isOperationLegal(Op: ISD::SIGN_EXTEND, VT: InnerExtVT))) {
16304 SDValue SignExtExtendee =
16305 DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT: InnerExtVT, Operand: Extendee);
16306 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT, N1: SignExtExtendee,
16307 N2: N0.getOperand(i: 1));
16308 }
16309 }
16310
16311 return SDValue();
16312}
16313
16314static SDValue foldExtendVectorInregToExtendOfSubvector(
16315 SDNode *N, const SDLoc &DL, const TargetLowering &TLI, SelectionDAG &DAG,
16316 bool LegalOperations) {
16317 unsigned InregOpcode = N->getOpcode();
16318 unsigned Opcode = DAG.getOpcode_EXTEND(Opcode: InregOpcode);
16319
16320 SDValue Src = N->getOperand(Num: 0);
16321 EVT VT = N->getValueType(ResNo: 0);
16322 EVT SrcVT = VT.changeVectorElementType(
16323 Context&: *DAG.getContext(), EltVT: Src.getValueType().getVectorElementType());
16324
16325 assert(ISD::isExtVecInRegOpcode(InregOpcode) &&
16326 "Expected EXTEND_VECTOR_INREG dag node in input!");
16327
16328 // Profitability check: our operand must be an one-use CONCAT_VECTORS.
16329 // FIXME: one-use check may be overly restrictive
16330 if (!Src.hasOneUse() || Src.getOpcode() != ISD::CONCAT_VECTORS)
16331 return SDValue();
16332
16333 // Profitability check: we must be extending exactly one of it's operands.
16334 // FIXME: this is probably overly restrictive.
16335 Src = Src.getOperand(i: 0);
16336 if (Src.getValueType() != SrcVT)
16337 return SDValue();
16338
16339 if (LegalOperations && !TLI.isOperationLegal(Op: Opcode, VT))
16340 return SDValue();
16341
16342 return DAG.getNode(Opcode, DL, VT, Operand: Src);
16343}
16344
16345SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
16346 SDValue N0 = N->getOperand(Num: 0);
16347 EVT VT = N->getValueType(ResNo: 0);
16348 SDLoc DL(N);
16349
16350 if (N0.isUndef()) {
16351 // aext_vector_inreg(undef) = undef because the top bits are undefined.
16352 // {s/z}ext_vector_inreg(undef) = 0 because the top bits must be the same.
16353 return N->getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG
16354 ? DAG.getUNDEF(VT)
16355 : DAG.getConstant(Val: 0, DL, VT);
16356 }
16357
16358 if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
16359 return Res;
16360
16361 if (SimplifyDemandedVectorElts(Op: SDValue(N, 0)))
16362 return SDValue(N, 0);
16363
16364 if (SDValue R = foldExtendVectorInregToExtendOfSubvector(N, DL, TLI, DAG,
16365 LegalOperations))
16366 return R;
16367
16368 return SDValue();
16369}
16370
16371SDValue DAGCombiner::visitTRUNCATE_USAT_U(SDNode *N) {
16372 EVT VT = N->getValueType(ResNo: 0);
16373 SDValue N0 = N->getOperand(Num: 0);
16374
16375 SDValue FPVal;
16376 if (sd_match(N: N0, P: m_FPToUI(Op: m_Value(N&: FPVal))) &&
16377 DAG.getTargetLoweringInfo().shouldConvertFpToSat(
16378 Op: ISD::FP_TO_UINT_SAT, FPVT: FPVal.getValueType(), VT))
16379 return DAG.getNode(Opcode: ISD::FP_TO_UINT_SAT, DL: SDLoc(N0), VT, N1: FPVal,
16380 N2: DAG.getValueType(VT.getScalarType()));
16381
16382 return SDValue();
16383}
16384
16385/// Detect patterns of truncation with unsigned saturation:
16386///
16387/// (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type).
16388/// Return the source value x to be truncated or SDValue() if the pattern was
16389/// not matched.
16390///
16391static SDValue detectUSatUPattern(SDValue In, EVT VT) {
16392 unsigned NumDstBits = VT.getScalarSizeInBits();
16393 unsigned NumSrcBits = In.getScalarValueSizeInBits();
16394 // Saturation with truncation. We truncate from InVT to VT.
16395 assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
16396
16397 SDValue Min;
16398 APInt UnsignedMax = APInt::getMaxValue(numBits: NumDstBits).zext(width: NumSrcBits);
16399 if (sd_match(N: In, P: m_UMin(L: m_Value(N&: Min), R: m_SpecificInt(V: UnsignedMax))))
16400 return Min;
16401
16402 return SDValue();
16403}
16404
16405/// Detect patterns of truncation with signed saturation:
16406/// (truncate (smin (smax (x, signed_min_of_dest_type),
16407/// signed_max_of_dest_type)) to dest_type)
16408/// or:
16409/// (truncate (smax (smin (x, signed_max_of_dest_type),
16410/// signed_min_of_dest_type)) to dest_type).
16411///
16412/// Return the source value to be truncated or SDValue() if the pattern was not
16413/// matched.
16414static SDValue detectSSatSPattern(SDValue In, EVT VT) {
16415 unsigned NumDstBits = VT.getScalarSizeInBits();
16416 unsigned NumSrcBits = In.getScalarValueSizeInBits();
16417 // Saturation with truncation. We truncate from InVT to VT.
16418 assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
16419
16420 SDValue Val;
16421 APInt SignedMax = APInt::getSignedMaxValue(numBits: NumDstBits).sext(width: NumSrcBits);
16422 APInt SignedMin = APInt::getSignedMinValue(numBits: NumDstBits).sext(width: NumSrcBits);
16423
16424 if (sd_match(N: In, P: m_SMin(L: m_SMax(L: m_Value(N&: Val), R: m_SpecificInt(V: SignedMin)),
16425 R: m_SpecificInt(V: SignedMax))))
16426 return Val;
16427
16428 if (sd_match(N: In, P: m_SMax(L: m_SMin(L: m_Value(N&: Val), R: m_SpecificInt(V: SignedMax)),
16429 R: m_SpecificInt(V: SignedMin))))
16430 return Val;
16431
16432 return SDValue();
16433}
16434
16435/// Detect patterns of truncation with unsigned saturation:
16436static SDValue detectSSatUPattern(SDValue In, EVT VT, SelectionDAG &DAG,
16437 const SDLoc &DL) {
16438 unsigned NumDstBits = VT.getScalarSizeInBits();
16439 unsigned NumSrcBits = In.getScalarValueSizeInBits();
16440 // Saturation with truncation. We truncate from InVT to VT.
16441 assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
16442
16443 SDValue Val;
16444 APInt UnsignedMax = APInt::getMaxValue(numBits: NumDstBits).zext(width: NumSrcBits);
16445 // Min == 0, Max is unsigned max of destination type.
16446 if (sd_match(N: In, P: m_SMax(L: m_SMin(L: m_Value(N&: Val), R: m_SpecificInt(V: UnsignedMax)),
16447 R: m_Zero())))
16448 return Val;
16449
16450 if (sd_match(N: In, P: m_SMin(L: m_SMax(L: m_Value(N&: Val), R: m_Zero()),
16451 R: m_SpecificInt(V: UnsignedMax))))
16452 return Val;
16453
16454 if (sd_match(N: In, P: m_UMin(L: m_SMax(L: m_Value(N&: Val), R: m_Zero()),
16455 R: m_SpecificInt(V: UnsignedMax))))
16456 return Val;
16457
16458 return SDValue();
16459}
16460
16461static SDValue foldToSaturated(SDNode *N, EVT &VT, SDValue &Src, EVT &SrcVT,
16462 SDLoc &DL, const TargetLowering &TLI,
16463 SelectionDAG &DAG) {
16464 auto AllowedTruncateSat = [&](unsigned Opc, EVT SrcVT, EVT VT) -> bool {
16465 return (TLI.isOperationLegalOrCustom(Op: Opc, VT: SrcVT) &&
16466 TLI.isTypeDesirableForOp(Opc, VT));
16467 };
16468
16469 if (Src.getOpcode() == ISD::SMIN || Src.getOpcode() == ISD::SMAX) {
16470 if (AllowedTruncateSat(ISD::TRUNCATE_SSAT_S, SrcVT, VT))
16471 if (SDValue SSatVal = detectSSatSPattern(In: Src, VT))
16472 return DAG.getNode(Opcode: ISD::TRUNCATE_SSAT_S, DL, VT, Operand: SSatVal);
16473 if (AllowedTruncateSat(ISD::TRUNCATE_SSAT_U, SrcVT, VT))
16474 if (SDValue SSatVal = detectSSatUPattern(In: Src, VT, DAG, DL))
16475 return DAG.getNode(Opcode: ISD::TRUNCATE_SSAT_U, DL, VT, Operand: SSatVal);
16476 } else if (Src.getOpcode() == ISD::UMIN) {
16477 if (AllowedTruncateSat(ISD::TRUNCATE_SSAT_U, SrcVT, VT))
16478 if (SDValue SSatVal = detectSSatUPattern(In: Src, VT, DAG, DL))
16479 return DAG.getNode(Opcode: ISD::TRUNCATE_SSAT_U, DL, VT, Operand: SSatVal);
16480 if (AllowedTruncateSat(ISD::TRUNCATE_USAT_U, SrcVT, VT))
16481 if (SDValue USatVal = detectUSatUPattern(In: Src, VT))
16482 return DAG.getNode(Opcode: ISD::TRUNCATE_USAT_U, DL, VT, Operand: USatVal);
16483 }
16484
16485 return SDValue();
16486}
16487
16488SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
16489 SDValue N0 = N->getOperand(Num: 0);
16490 EVT VT = N->getValueType(ResNo: 0);
16491 EVT SrcVT = N0.getValueType();
16492 bool isLE = DAG.getDataLayout().isLittleEndian();
16493 SDLoc DL(N);
16494
16495 // trunc(undef) = undef
16496 if (N0.isUndef())
16497 return DAG.getUNDEF(VT);
16498
16499 // fold (truncate (truncate x)) -> (truncate x)
16500 if (N0.getOpcode() == ISD::TRUNCATE)
16501 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N0.getOperand(i: 0));
16502
16503 // fold saturated truncate
16504 if (SDValue SaturatedTR = foldToSaturated(N, VT, Src&: N0, SrcVT, DL, TLI, DAG))
16505 return SaturatedTR;
16506
16507 // fold (truncate c1) -> c1
16508 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::TRUNCATE, DL, VT, Ops: {N0}))
16509 return C;
16510
16511 // fold (truncate (ext x)) -> (ext x) or (truncate x) or x
16512 if (N0.getOpcode() == ISD::ZERO_EXTEND ||
16513 N0.getOpcode() == ISD::SIGN_EXTEND ||
16514 N0.getOpcode() == ISD::ANY_EXTEND) {
16515 // if the source is smaller than the dest, we still need an extend.
16516 if (N0.getOperand(i: 0).getValueType().bitsLT(VT)) {
16517 SDNodeFlags Flags;
16518 if (N0.getOpcode() == ISD::ZERO_EXTEND)
16519 Flags.setNonNeg(N0->getFlags().hasNonNeg());
16520 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, Operand: N0.getOperand(i: 0), Flags);
16521 }
16522 // if the source is larger than the dest, than we just need the truncate.
16523 if (N0.getOperand(i: 0).getValueType().bitsGT(VT))
16524 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N0.getOperand(i: 0));
16525 // if the source and dest are the same type, we can drop both the extend
16526 // and the truncate.
16527 return N0.getOperand(i: 0);
16528 }
16529
16530 // Try to narrow a truncate-of-sext_in_reg to the destination type:
16531 // trunc (sign_ext_inreg X, iM) to iN --> sign_ext_inreg (trunc X to iN), iM
16532 if (!LegalTypes && N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
16533 N0.hasOneUse()) {
16534 SDValue X = N0.getOperand(i: 0);
16535 SDValue ExtVal = N0.getOperand(i: 1);
16536 EVT ExtVT = cast<VTSDNode>(Val&: ExtVal)->getVT();
16537 if (ExtVT.bitsLT(VT) && TLI.preferSextInRegOfTruncate(TruncVT: VT, VT: SrcVT, ExtVT)) {
16538 SDValue TrX = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: X);
16539 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL, VT, N1: TrX, N2: ExtVal);
16540 }
16541 }
16542
16543 // If this is anyext(trunc), don't fold it, allow ourselves to be folded.
16544 if (N->hasOneUse() && (N->user_begin()->getOpcode() == ISD::ANY_EXTEND))
16545 return SDValue();
16546
16547 // Fold extract-and-trunc into a narrow extract. For example:
16548 // i64 x = EXTRACT_VECTOR_ELT(v2i64 val, i32 1)
16549 // i32 y = TRUNCATE(i64 x)
16550 // -- becomes --
16551 // v16i8 b = BITCAST (v2i64 val)
16552 // i8 x = EXTRACT_VECTOR_ELT(v16i8 b, i32 8)
16553 //
16554 // Note: We only run this optimization after type legalization (which often
16555 // creates this pattern) and before operation legalization after which
16556 // we need to be more careful about the vector instructions that we generate.
16557 if (LegalTypes && !LegalOperations && VT.isScalarInteger() && VT != MVT::i1 &&
16558 N0->hasOneUse()) {
16559 EVT TrTy = N->getValueType(ResNo: 0);
16560 SDValue Src = N0;
16561
16562 // Check for cases where we shift down an upper element before truncation.
16563 int EltOffset = 0;
16564 if (Src.getOpcode() == ISD::SRL && Src.getOperand(i: 0)->hasOneUse()) {
16565 if (auto ShAmt = DAG.getValidShiftAmount(V: Src)) {
16566 if ((*ShAmt % TrTy.getSizeInBits()) == 0) {
16567 Src = Src.getOperand(i: 0);
16568 EltOffset = *ShAmt / TrTy.getSizeInBits();
16569 }
16570 }
16571 }
16572
16573 if (Src.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
16574 EVT VecTy = Src.getOperand(i: 0).getValueType();
16575 EVT ExTy = Src.getValueType();
16576
16577 auto EltCnt = VecTy.getVectorElementCount();
16578 unsigned SizeRatio = ExTy.getSizeInBits() / TrTy.getSizeInBits();
16579 auto NewEltCnt = EltCnt * SizeRatio;
16580
16581 EVT NVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: TrTy, EC: NewEltCnt);
16582 assert(NVT.getSizeInBits() == VecTy.getSizeInBits() && "Invalid Size");
16583
16584 SDValue EltNo = Src->getOperand(Num: 1);
16585 if (isa<ConstantSDNode>(Val: EltNo) && isTypeLegal(VT: NVT)) {
16586 int Elt = EltNo->getAsZExtVal();
16587 int Index = isLE ? (Elt * SizeRatio + EltOffset)
16588 : (Elt * SizeRatio + (SizeRatio - 1) - EltOffset);
16589 return DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: TrTy,
16590 N1: DAG.getBitcast(VT: NVT, V: Src.getOperand(i: 0)),
16591 N2: DAG.getVectorIdxConstant(Val: Index, DL));
16592 }
16593 }
16594 }
16595
16596 // trunc (select c, a, b) -> select c, (trunc a), (trunc b)
16597 if (N0.getOpcode() == ISD::SELECT && N0.hasOneUse() &&
16598 TLI.isTruncateFree(FromVT: SrcVT, ToVT: VT)) {
16599 if (!LegalOperations ||
16600 (TLI.isOperationLegal(Op: ISD::SELECT, VT: SrcVT) &&
16601 TLI.isNarrowingProfitable(N: N0.getNode(), SrcVT, DestVT: VT))) {
16602 SDLoc SL(N0);
16603 SDValue Cond = N0.getOperand(i: 0);
16604 SDValue TruncOp0 = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SL, VT, Operand: N0.getOperand(i: 1));
16605 SDValue TruncOp1 = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SL, VT, Operand: N0.getOperand(i: 2));
16606 return DAG.getNode(Opcode: ISD::SELECT, DL, VT, N1: Cond, N2: TruncOp0, N3: TruncOp1);
16607 }
16608 }
16609
16610 // trunc (shl x, K) -> shl (trunc x), K => K < VT.getScalarSizeInBits()
16611 if (N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
16612 (!LegalOperations || TLI.isOperationLegal(Op: ISD::SHL, VT)) &&
16613 TLI.isTypeDesirableForOp(ISD::SHL, VT)) {
16614 SDValue Amt = N0.getOperand(i: 1);
16615 KnownBits Known = DAG.computeKnownBits(Op: Amt);
16616 unsigned Size = VT.getScalarSizeInBits();
16617 if (Known.countMaxActiveBits() <= Log2_32(Value: Size)) {
16618 EVT AmtVT = TLI.getShiftAmountTy(LHSTy: VT, DL: DAG.getDataLayout());
16619 SDValue Trunc = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N0.getOperand(i: 0));
16620 if (AmtVT != Amt.getValueType()) {
16621 Amt = DAG.getZExtOrTrunc(Op: Amt, DL, VT: AmtVT);
16622 AddToWorklist(N: Amt.getNode());
16623 }
16624 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Trunc, N2: Amt);
16625 }
16626 }
16627
16628 if (SDValue V = foldSubToUSubSat(DstVT: VT, N: N0.getNode(), DL))
16629 return V;
16630
16631 if (SDValue ABD = foldABSToABD(N, DL))
16632 return ABD;
16633
16634 // Attempt to pre-truncate BUILD_VECTOR sources.
16635 if (N0.getOpcode() == ISD::BUILD_VECTOR && !LegalOperations &&
16636 N0.hasOneUse() &&
16637 TLI.isTruncateFree(FromVT: SrcVT.getScalarType(), ToVT: VT.getScalarType()) &&
16638 // Avoid creating illegal types if running after type legalizer.
16639 (!LegalTypes || TLI.isTypeLegal(VT: VT.getScalarType()))) {
16640 EVT SVT = VT.getScalarType();
16641 SmallVector<SDValue, 8> TruncOps;
16642 for (const SDValue &Op : N0->op_values()) {
16643 SDValue TruncOp = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: SVT, Operand: Op);
16644 TruncOps.push_back(Elt: TruncOp);
16645 }
16646 return DAG.getBuildVector(VT, DL, Ops: TruncOps);
16647 }
16648
16649 // trunc (splat_vector x) -> splat_vector (trunc x)
16650 if (N0.getOpcode() == ISD::SPLAT_VECTOR &&
16651 (!LegalTypes || TLI.isTypeLegal(VT: VT.getScalarType())) &&
16652 (!LegalOperations || TLI.isOperationLegal(Op: ISD::SPLAT_VECTOR, VT))) {
16653 EVT SVT = VT.getScalarType();
16654 return DAG.getSplatVector(
16655 VT, DL, Op: DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: SVT, Operand: N0->getOperand(Num: 0)));
16656 }
16657
16658 // Fold a series of buildvector, bitcast, and truncate if possible.
16659 // For example fold
16660 // (2xi32 trunc (bitcast ((4xi32)buildvector x, x, y, y) 2xi64)) to
16661 // (2xi32 (buildvector x, y)).
16662 if (Level == AfterLegalizeVectorOps && VT.isVector() &&
16663 N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
16664 N0.getOperand(i: 0).getOpcode() == ISD::BUILD_VECTOR &&
16665 N0.getOperand(i: 0).hasOneUse()) {
16666 SDValue BuildVect = N0.getOperand(i: 0);
16667 EVT BuildVectEltTy = BuildVect.getValueType().getVectorElementType();
16668 EVT TruncVecEltTy = VT.getVectorElementType();
16669
16670 // Check that the element types match.
16671 if (BuildVectEltTy == TruncVecEltTy) {
16672 // Now we only need to compute the offset of the truncated elements.
16673 unsigned BuildVecNumElts = BuildVect.getNumOperands();
16674 unsigned TruncVecNumElts = VT.getVectorNumElements();
16675 unsigned TruncEltOffset = BuildVecNumElts / TruncVecNumElts;
16676 unsigned FirstElt = isLE ? 0 : (TruncEltOffset - 1);
16677
16678 assert((BuildVecNumElts % TruncVecNumElts) == 0 &&
16679 "Invalid number of elements");
16680
16681 SmallVector<SDValue, 8> Opnds;
16682 for (unsigned i = FirstElt, e = BuildVecNumElts; i < e;
16683 i += TruncEltOffset)
16684 Opnds.push_back(Elt: BuildVect.getOperand(i));
16685
16686 return DAG.getBuildVector(VT, DL, Ops: Opnds);
16687 }
16688 }
16689
16690 // fold (truncate (load x)) -> (smaller load x)
16691 // fold (truncate (srl (load x), c)) -> (smaller load (x+c/evtbits))
16692 if (!LegalTypes || TLI.isTypeDesirableForOp(N0.getOpcode(), VT)) {
16693 if (SDValue Reduced = reduceLoadWidth(N))
16694 return Reduced;
16695
16696 // Handle the case where the truncated result is at least as wide as the
16697 // loaded type.
16698 if (N0.hasOneUse() && ISD::isUNINDEXEDLoad(N: N0.getNode())) {
16699 auto *LN0 = cast<LoadSDNode>(Val&: N0);
16700 if (LN0->isSimple() && LN0->getMemoryVT().bitsLE(VT)) {
16701 SDValue NewLoad = DAG.getExtLoad(
16702 ExtType: LN0->getExtensionType(), dl: SDLoc(LN0), VT, Chain: LN0->getChain(),
16703 Ptr: LN0->getBasePtr(), MemVT: LN0->getMemoryVT(), MMO: LN0->getMemOperand());
16704 DAG.ReplaceAllUsesOfValueWith(From: N0.getValue(R: 1), To: NewLoad.getValue(R: 1));
16705 return NewLoad;
16706 }
16707 }
16708 }
16709
16710 // fold (trunc (concat ... x ...)) -> (concat ..., (trunc x), ...)),
16711 // where ... are all 'undef'.
16712 if (N0.getOpcode() == ISD::CONCAT_VECTORS && !LegalTypes) {
16713 SmallVector<EVT, 8> VTs;
16714 SDValue V;
16715 unsigned Idx = 0;
16716 unsigned NumDefs = 0;
16717
16718 for (unsigned i = 0, e = N0.getNumOperands(); i != e; ++i) {
16719 SDValue X = N0.getOperand(i);
16720 if (!X.isUndef()) {
16721 V = X;
16722 Idx = i;
16723 NumDefs++;
16724 }
16725 // Stop if more than one members are non-undef.
16726 if (NumDefs > 1)
16727 break;
16728
16729 VTs.push_back(Elt: EVT::getVectorVT(Context&: *DAG.getContext(),
16730 VT: VT.getVectorElementType(),
16731 EC: X.getValueType().getVectorElementCount()));
16732 }
16733
16734 if (NumDefs == 0)
16735 return DAG.getUNDEF(VT);
16736
16737 if (NumDefs == 1) {
16738 assert(V.getNode() && "The single defined operand is empty!");
16739 SmallVector<SDValue, 8> Opnds;
16740 for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
16741 if (i != Idx) {
16742 Opnds.push_back(Elt: DAG.getUNDEF(VT: VTs[i]));
16743 continue;
16744 }
16745 SDValue NV = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(V), VT: VTs[i], Operand: V);
16746 AddToWorklist(N: NV.getNode());
16747 Opnds.push_back(Elt: NV);
16748 }
16749 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT, Ops: Opnds);
16750 }
16751 }
16752
16753 // Fold truncate of a bitcast of a vector to an extract of the low vector
16754 // element.
16755 //
16756 // e.g. trunc (i64 (bitcast v2i32:x)) -> extract_vector_elt v2i32:x, idx
16757 if (N0.getOpcode() == ISD::BITCAST && !VT.isVector()) {
16758 SDValue VecSrc = N0.getOperand(i: 0);
16759 EVT VecSrcVT = VecSrc.getValueType();
16760 if (VecSrcVT.isVector() && VecSrcVT.getScalarType() == VT &&
16761 (!LegalOperations ||
16762 TLI.isOperationLegal(Op: ISD::EXTRACT_VECTOR_ELT, VT: VecSrcVT))) {
16763 unsigned Idx = isLE ? 0 : VecSrcVT.getVectorNumElements() - 1;
16764 return DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT, N1: VecSrc,
16765 N2: DAG.getVectorIdxConstant(Val: Idx, DL));
16766 }
16767 }
16768
16769 // Simplify the operands using demanded-bits information.
16770 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
16771 return SDValue(N, 0);
16772
16773 // fold (truncate (extract_subvector(ext x))) ->
16774 // (extract_subvector x)
16775 // TODO: This can be generalized to cover cases where the truncate and extract
16776 // do not fully cancel each other out.
16777 if (!LegalTypes && N0.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
16778 SDValue N00 = N0.getOperand(i: 0);
16779 if (N00.getOpcode() == ISD::SIGN_EXTEND ||
16780 N00.getOpcode() == ISD::ZERO_EXTEND ||
16781 N00.getOpcode() == ISD::ANY_EXTEND) {
16782 if (N00.getOperand(i: 0)->getValueType(ResNo: 0).getVectorElementType() ==
16783 VT.getVectorElementType())
16784 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL: SDLoc(N0->getOperand(Num: 0)), VT,
16785 N1: N00.getOperand(i: 0), N2: N0.getOperand(i: 1));
16786 }
16787 }
16788
16789 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(Cast: N))
16790 return NewVSel;
16791
16792 // Narrow a suitable binary operation with a non-opaque constant operand by
16793 // moving it ahead of the truncate. This is limited to pre-legalization
16794 // because targets may prefer a wider type during later combines and invert
16795 // this transform.
16796 switch (N0.getOpcode()) {
16797 case ISD::ADD:
16798 case ISD::SUB:
16799 case ISD::MUL:
16800 case ISD::AND:
16801 case ISD::OR:
16802 case ISD::XOR:
16803 if (!LegalOperations && N0.hasOneUse() &&
16804 (N0.getOperand(i: 0) == N0.getOperand(i: 1) ||
16805 isConstantOrConstantVector(N: N0.getOperand(i: 0), NoOpaques: true) ||
16806 isConstantOrConstantVector(N: N0.getOperand(i: 1), NoOpaques: true))) {
16807 // TODO: We already restricted this to pre-legalization, but for vectors
16808 // we are extra cautious to not create an unsupported operation.
16809 // Target-specific changes are likely needed to avoid regressions here.
16810 if (VT.isScalarInteger() || TLI.isOperationLegal(Op: N0.getOpcode(), VT)) {
16811 SDValue NarrowL = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N0.getOperand(i: 0));
16812 SDValue NarrowR = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N0.getOperand(i: 1));
16813 SDNodeFlags Flags;
16814 // Propagate nuw for sub.
16815 if (N0->getOpcode() == ISD::SUB && N0->getFlags().hasNoUnsignedWrap() &&
16816 DAG.MaskedValueIsZero(
16817 Op: N0->getOperand(Num: 0),
16818 Mask: APInt::getBitsSetFrom(numBits: SrcVT.getScalarSizeInBits(),
16819 loBit: VT.getScalarSizeInBits())))
16820 Flags.setNoUnsignedWrap(true);
16821 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, N1: NarrowL, N2: NarrowR, Flags);
16822 }
16823 }
16824 break;
16825 case ISD::ADDE:
16826 case ISD::UADDO_CARRY:
16827 // (trunc adde(X, Y, Carry)) -> (adde trunc(X), trunc(Y), Carry)
16828 // (trunc uaddo_carry(X, Y, Carry)) ->
16829 // (uaddo_carry trunc(X), trunc(Y), Carry)
16830 // When the adde's carry is not used.
16831 // We only do for uaddo_carry before legalize operation
16832 if (((!LegalOperations && N0.getOpcode() == ISD::UADDO_CARRY) ||
16833 TLI.isOperationLegal(Op: N0.getOpcode(), VT)) &&
16834 N0.hasOneUse() && !N0->hasAnyUseOfValue(Value: 1)) {
16835 SDValue X = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N0.getOperand(i: 0));
16836 SDValue Y = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N0.getOperand(i: 1));
16837 SDVTList VTs = DAG.getVTList(VT1: VT, VT2: N0->getValueType(ResNo: 1));
16838 return DAG.getNode(Opcode: N0.getOpcode(), DL, VTList: VTs, N1: X, N2: Y, N3: N0.getOperand(i: 2));
16839 }
16840 break;
16841 case ISD::USUBSAT:
16842 // Truncate the USUBSAT only if LHS is a known zero-extension, its not
16843 // enough to know that the upper bits are zero we must ensure that we don't
16844 // introduce an extra truncate.
16845 if (!LegalOperations && N0.hasOneUse() &&
16846 N0.getOperand(i: 0).getOpcode() == ISD::ZERO_EXTEND &&
16847 N0.getOperand(i: 0).getOperand(i: 0).getScalarValueSizeInBits() <=
16848 VT.getScalarSizeInBits() &&
16849 hasOperation(Opcode: N0.getOpcode(), VT)) {
16850 return getTruncatedUSUBSAT(DstVT: VT, SrcVT, LHS: N0.getOperand(i: 0), RHS: N0.getOperand(i: 1),
16851 DAG, DL);
16852 }
16853 break;
16854 case ISD::AVGCEILS:
16855 case ISD::AVGCEILU:
16856 // trunc (avgceilu (sext (x), sext (y))) -> avgceils(x, y)
16857 // trunc (avgceils (zext (x), zext (y))) -> avgceilu(x, y)
16858 if (N0.hasOneUse()) {
16859 SDValue Op0 = N0.getOperand(i: 0);
16860 SDValue Op1 = N0.getOperand(i: 1);
16861 if (N0.getOpcode() == ISD::AVGCEILU) {
16862 if (TLI.isOperationLegalOrCustom(Op: ISD::AVGCEILS, VT) &&
16863 Op0.getOpcode() == ISD::SIGN_EXTEND &&
16864 Op1.getOpcode() == ISD::SIGN_EXTEND &&
16865 Op0.getOperand(i: 0).getValueType() == VT &&
16866 Op1.getOperand(i: 0).getValueType() == VT)
16867 return DAG.getNode(Opcode: ISD::AVGCEILS, DL, VT, N1: Op0.getOperand(i: 0),
16868 N2: Op1.getOperand(i: 0));
16869 } else {
16870 if (TLI.isOperationLegalOrCustom(Op: ISD::AVGCEILU, VT) &&
16871 Op0.getOpcode() == ISD::ZERO_EXTEND &&
16872 Op1.getOpcode() == ISD::ZERO_EXTEND &&
16873 Op0.getOperand(i: 0).getValueType() == VT &&
16874 Op1.getOperand(i: 0).getValueType() == VT)
16875 return DAG.getNode(Opcode: ISD::AVGCEILU, DL, VT, N1: Op0.getOperand(i: 0),
16876 N2: Op1.getOperand(i: 0));
16877 }
16878 }
16879 [[fallthrough]];
16880 case ISD::AVGFLOORS:
16881 case ISD::AVGFLOORU:
16882 case ISD::ABDS:
16883 case ISD::ABDU:
16884 // (trunc (avg a, b)) -> (avg (trunc a), (trunc b))
16885 // (trunc (abdu/abds a, b)) -> (abdu/abds (trunc a), (trunc b))
16886 if (!LegalOperations && N0.hasOneUse() &&
16887 TLI.isOperationLegal(Op: N0.getOpcode(), VT)) {
16888 EVT TruncVT = VT;
16889 unsigned SrcBits = SrcVT.getScalarSizeInBits();
16890 unsigned TruncBits = TruncVT.getScalarSizeInBits();
16891
16892 SDValue A = N0.getOperand(i: 0);
16893 SDValue B = N0.getOperand(i: 1);
16894 bool CanFold = false;
16895
16896 if (N0.getOpcode() == ISD::AVGFLOORU || N0.getOpcode() == ISD::AVGCEILU ||
16897 N0.getOpcode() == ISD::ABDU) {
16898 APInt UpperBits = APInt::getBitsSetFrom(numBits: SrcBits, loBit: TruncBits);
16899 CanFold = DAG.MaskedValueIsZero(Op: B, Mask: UpperBits) &&
16900 DAG.MaskedValueIsZero(Op: A, Mask: UpperBits);
16901 } else {
16902 unsigned NeededBits = SrcBits - TruncBits;
16903 CanFold = DAG.ComputeNumSignBits(Op: B) > NeededBits &&
16904 DAG.ComputeNumSignBits(Op: A) > NeededBits;
16905 }
16906
16907 if (CanFold) {
16908 SDValue NewA = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: TruncVT, Operand: A);
16909 SDValue NewB = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: TruncVT, Operand: B);
16910 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT: TruncVT, N1: NewA, N2: NewB);
16911 }
16912 }
16913 break;
16914 }
16915
16916 return SDValue();
16917}
16918
16919static SDNode *getBuildPairElt(SDNode *N, unsigned i) {
16920 SDValue Elt = N->getOperand(Num: i);
16921 if (Elt.getOpcode() != ISD::MERGE_VALUES)
16922 return Elt.getNode();
16923 return Elt.getOperand(i: Elt.getResNo()).getNode();
16924}
16925
16926/// build_pair (load, load) -> load
16927/// if load locations are consecutive.
16928SDValue DAGCombiner::CombineConsecutiveLoads(SDNode *N, EVT VT) {
16929 assert(N->getOpcode() == ISD::BUILD_PAIR);
16930
16931 auto *LD1 = dyn_cast<LoadSDNode>(Val: getBuildPairElt(N, i: 0));
16932 auto *LD2 = dyn_cast<LoadSDNode>(Val: getBuildPairElt(N, i: 1));
16933
16934 // A BUILD_PAIR is always having the least significant part in elt 0 and the
16935 // most significant part in elt 1. So when combining into one large load, we
16936 // need to consider the endianness.
16937 if (DAG.getDataLayout().isBigEndian())
16938 std::swap(a&: LD1, b&: LD2);
16939
16940 if (!LD1 || !LD2 || !ISD::isNON_EXTLoad(N: LD1) || !ISD::isNON_EXTLoad(N: LD2) ||
16941 !LD1->hasOneUse() || !LD2->hasOneUse() ||
16942 LD1->getAddressSpace() != LD2->getAddressSpace())
16943 return SDValue();
16944
16945 unsigned LD1Fast = 0;
16946 EVT LD1VT = LD1->getValueType(ResNo: 0);
16947 unsigned LD1Bytes = LD1VT.getStoreSize();
16948 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::LOAD, VT)) &&
16949 DAG.areNonVolatileConsecutiveLoads(LD: LD2, Base: LD1, Bytes: LD1Bytes, Dist: 1) &&
16950 TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT,
16951 MMO: *LD1->getMemOperand(), Fast: &LD1Fast) && LD1Fast)
16952 return DAG.getLoad(VT, dl: SDLoc(N), Chain: LD1->getChain(), Ptr: LD1->getBasePtr(),
16953 PtrInfo: LD1->getPointerInfo(), Alignment: LD1->getAlign());
16954
16955 return SDValue();
16956}
16957
16958static unsigned getPPCf128HiElementSelector(const SelectionDAG &DAG) {
16959 // On little-endian machines, bitcasting from ppcf128 to i128 does swap the Hi
16960 // and Lo parts; on big-endian machines it doesn't.
16961 return DAG.getDataLayout().isBigEndian() ? 1 : 0;
16962}
16963
16964SDValue DAGCombiner::foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
16965 const TargetLowering &TLI) {
16966 // If this is not a bitcast to an FP type or if the target doesn't have
16967 // IEEE754-compliant FP logic, we're done.
16968 EVT VT = N->getValueType(ResNo: 0);
16969 SDValue N0 = N->getOperand(Num: 0);
16970 EVT SourceVT = N0.getValueType();
16971
16972 if (!VT.isFloatingPoint())
16973 return SDValue();
16974
16975 // TODO: Handle cases where the integer constant is a different scalar
16976 // bitwidth to the FP.
16977 if (VT.getScalarSizeInBits() != SourceVT.getScalarSizeInBits())
16978 return SDValue();
16979
16980 unsigned FPOpcode;
16981 APInt SignMask;
16982 switch (N0.getOpcode()) {
16983 case ISD::AND:
16984 FPOpcode = ISD::FABS;
16985 SignMask = ~APInt::getSignMask(BitWidth: SourceVT.getScalarSizeInBits());
16986 break;
16987 case ISD::XOR:
16988 FPOpcode = ISD::FNEG;
16989 SignMask = APInt::getSignMask(BitWidth: SourceVT.getScalarSizeInBits());
16990 break;
16991 case ISD::OR:
16992 FPOpcode = ISD::FABS;
16993 SignMask = APInt::getSignMask(BitWidth: SourceVT.getScalarSizeInBits());
16994 break;
16995 default:
16996 return SDValue();
16997 }
16998
16999 if (LegalOperations && !TLI.isOperationLegal(Op: FPOpcode, VT))
17000 return SDValue();
17001
17002 // This needs to be the inverse of logic in foldSignChangeInBitcast.
17003 // FIXME: I don't think looking for bitcast intrinsically makes sense, but
17004 // removing this would require more changes.
17005 auto IsBitCastOrFree = [&TLI, FPOpcode](SDValue Op, EVT VT) {
17006 if (sd_match(N: Op, P: m_BitCast(Op: m_SpecificVT(RefVT: VT))))
17007 return true;
17008
17009 return FPOpcode == ISD::FABS ? TLI.isFAbsFree(VT) : TLI.isFNegFree(VT);
17010 };
17011
17012 // Fold (bitcast int (and (bitcast fp X to int), 0x7fff...) to fp) -> fabs X
17013 // Fold (bitcast int (xor (bitcast fp X to int), 0x8000...) to fp) -> fneg X
17014 // Fold (bitcast int (or (bitcast fp X to int), 0x8000...) to fp) ->
17015 // fneg (fabs X)
17016 SDValue LogicOp0 = N0.getOperand(i: 0);
17017 ConstantSDNode *LogicOp1 = isConstOrConstSplat(N: N0.getOperand(i: 1), AllowUndefs: true);
17018 if (LogicOp1 && LogicOp1->getAPIntValue() == SignMask &&
17019 IsBitCastOrFree(LogicOp0, VT)) {
17020 SDValue CastOp0 = DAG.getNode(Opcode: ISD::BITCAST, DL: SDLoc(N), VT, Operand: LogicOp0);
17021 SDValue FPOp = DAG.getNode(Opcode: FPOpcode, DL: SDLoc(N), VT, Operand: CastOp0);
17022 NumFPLogicOpsConv++;
17023 if (N0.getOpcode() == ISD::OR)
17024 return DAG.getNode(Opcode: ISD::FNEG, DL: SDLoc(N), VT, Operand: FPOp);
17025 return FPOp;
17026 }
17027
17028 return SDValue();
17029}
17030
17031SDValue DAGCombiner::visitBITCAST(SDNode *N) {
17032 SDValue N0 = N->getOperand(Num: 0);
17033 EVT VT = N->getValueType(ResNo: 0);
17034
17035 if (N0.isUndef())
17036 return DAG.getUNDEF(VT);
17037
17038 // If the input is a BUILD_VECTOR with all constant elements, fold this now.
17039 // Only do this before legalize types, unless both types are integer and the
17040 // scalar type is legal. Only do this before legalize ops, since the target
17041 // maybe depending on the bitcast.
17042 // First check to see if this is all constant.
17043 // TODO: Support FP bitcasts after legalize types.
17044 if (VT.isVector() &&
17045 (!LegalTypes ||
17046 (!LegalOperations && VT.isInteger() && N0.getValueType().isInteger() &&
17047 TLI.isTypeLegal(VT: VT.getVectorElementType()))) &&
17048 N0.getOpcode() == ISD::BUILD_VECTOR && N0->hasOneUse() &&
17049 cast<BuildVectorSDNode>(Val&: N0)->isConstant())
17050 return DAG.FoldConstantBuildVector(BV: cast<BuildVectorSDNode>(Val&: N0), DL: SDLoc(N),
17051 DstEltVT: VT.getVectorElementType());
17052
17053 // If the input is a constant, let getNode fold it.
17054 if (isIntOrFPConstant(V: N0)) {
17055 // If we can't allow illegal operations, we need to check that this is just
17056 // a fp -> int or int -> conversion and that the resulting operation will
17057 // be legal.
17058 if (!LegalOperations ||
17059 (isa<ConstantSDNode>(Val: N0) && VT.isFloatingPoint() && !VT.isVector() &&
17060 TLI.isOperationLegal(Op: ISD::ConstantFP, VT)) ||
17061 (isa<ConstantFPSDNode>(Val: N0) && VT.isInteger() && !VT.isVector() &&
17062 TLI.isOperationLegal(Op: ISD::Constant, VT))) {
17063 SDValue C = DAG.getBitcast(VT, V: N0);
17064 if (C.getNode() != N)
17065 return C;
17066 }
17067 }
17068
17069 // (conv (conv x, t1), t2) -> (conv x, t2)
17070 if (N0.getOpcode() == ISD::BITCAST)
17071 return DAG.getBitcast(VT, V: N0.getOperand(i: 0));
17072
17073 // fold (conv (logicop (conv x), (c))) -> (logicop x, (conv c))
17074 // iff the current bitwise logicop type isn't legal
17075 if (ISD::isBitwiseLogicOp(Opcode: N0.getOpcode()) && VT.isInteger() &&
17076 !TLI.isTypeLegal(VT: N0.getOperand(i: 0).getValueType())) {
17077 auto IsFreeBitcast = [VT](SDValue V) {
17078 return (V.getOpcode() == ISD::BITCAST &&
17079 V.getOperand(i: 0).getValueType() == VT) ||
17080 (ISD::isBuildVectorOfConstantSDNodes(N: V.getNode()) &&
17081 V->hasOneUse());
17082 };
17083 if (IsFreeBitcast(N0.getOperand(i: 0)) && IsFreeBitcast(N0.getOperand(i: 1)))
17084 return DAG.getNode(Opcode: N0.getOpcode(), DL: SDLoc(N), VT,
17085 N1: DAG.getBitcast(VT, V: N0.getOperand(i: 0)),
17086 N2: DAG.getBitcast(VT, V: N0.getOperand(i: 1)));
17087 }
17088
17089 // fold (conv (load x)) -> (load (conv*)x)
17090 // fold (conv (freeze (load x))) -> (freeze (load (conv*)x))
17091 // If the resultant load doesn't need a higher alignment than the original!
17092 auto CastLoad = [this, &VT](SDValue N0, const SDLoc &DL) {
17093 if (!ISD::isNormalLoad(N: N0.getNode()) || !N0.hasOneUse())
17094 return SDValue();
17095
17096 // Do not remove the cast if the types differ in endian layout.
17097 if (TLI.hasBigEndianPartOrdering(VT: N0.getValueType(), DL: DAG.getDataLayout()) !=
17098 TLI.hasBigEndianPartOrdering(VT, DL: DAG.getDataLayout()))
17099 return SDValue();
17100
17101 // If the load is volatile, we only want to change the load type if the
17102 // resulting load is legal. Otherwise we might increase the number of
17103 // memory accesses. We don't care if the original type was legal or not
17104 // as we assume software couldn't rely on the number of accesses of an
17105 // illegal type.
17106 auto *LN0 = cast<LoadSDNode>(Val&: N0);
17107 if ((LegalOperations || !LN0->isSimple()) &&
17108 !TLI.isOperationLegal(Op: ISD::LOAD, VT))
17109 return SDValue();
17110
17111 if (!TLI.isLoadBitCastBeneficial(LoadVT: N0.getValueType(), BitcastVT: VT, DAG,
17112 MMO: *LN0->getMemOperand()))
17113 return SDValue();
17114
17115 // If the range metadata type does not match the new memory
17116 // operation type, remove the range metadata.
17117 if (const MDNode *MD = LN0->getRanges()) {
17118 ConstantInt *Lower = mdconst::extract<ConstantInt>(MD: MD->getOperand(I: 0));
17119 if (Lower->getBitWidth() != VT.getScalarSizeInBits() || !VT.isInteger()) {
17120 LN0->getMemOperand()->clearRanges();
17121 }
17122 }
17123 SDValue Load = DAG.getLoad(VT, dl: DL, Chain: LN0->getChain(), Ptr: LN0->getBasePtr(),
17124 MMO: LN0->getMemOperand());
17125 DAG.ReplaceAllUsesOfValueWith(From: N0.getValue(R: 1), To: Load.getValue(R: 1));
17126 return Load;
17127 };
17128
17129 if (SDValue NewLd = CastLoad(N0, SDLoc(N)))
17130 return NewLd;
17131
17132 if (N0.getOpcode() == ISD::FREEZE && N0.hasOneUse())
17133 if (SDValue NewLd = CastLoad(N0.getOperand(i: 0), SDLoc(N)))
17134 return DAG.getFreeze(V: NewLd);
17135
17136 if (SDValue V = foldBitcastedFPLogic(N, DAG, TLI))
17137 return V;
17138
17139 // fold (bitconvert (fneg x)) -> (xor (bitconvert x), signbit)
17140 // fold (bitconvert (fabs x)) -> (and (bitconvert x), (not signbit))
17141 //
17142 // For ppc_fp128:
17143 // fold (bitcast (fneg x)) ->
17144 // flipbit = signbit
17145 // (xor (bitcast x) (build_pair flipbit, flipbit))
17146 //
17147 // fold (bitcast (fabs x)) ->
17148 // flipbit = (and (extract_element (bitcast x), 0), signbit)
17149 // (xor (bitcast x) (build_pair flipbit, flipbit))
17150 // This often reduces constant pool loads.
17151 if (((N0.getOpcode() == ISD::FNEG && !TLI.isFNegFree(VT: N0.getValueType())) ||
17152 (N0.getOpcode() == ISD::FABS && !TLI.isFAbsFree(VT: N0.getValueType()))) &&
17153 N0->hasOneUse() && VT.isInteger() && !VT.isVector() &&
17154 !N0.getValueType().isVector()) {
17155 SDValue NewConv = DAG.getBitcast(VT, V: N0.getOperand(i: 0));
17156 AddToWorklist(N: NewConv.getNode());
17157
17158 SDLoc DL(N);
17159 if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
17160 assert(VT.getSizeInBits() == 128);
17161 SDValue SignBit = DAG.getConstant(
17162 Val: APInt::getSignMask(BitWidth: VT.getSizeInBits() / 2), DL: SDLoc(N0), VT: MVT::i64);
17163 SDValue FlipBit;
17164 if (N0.getOpcode() == ISD::FNEG) {
17165 FlipBit = SignBit;
17166 AddToWorklist(N: FlipBit.getNode());
17167 } else {
17168 assert(N0.getOpcode() == ISD::FABS);
17169 SDValue Hi =
17170 DAG.getNode(Opcode: ISD::EXTRACT_ELEMENT, DL: SDLoc(NewConv), VT: MVT::i64, N1: NewConv,
17171 N2: DAG.getIntPtrConstant(Val: getPPCf128HiElementSelector(DAG),
17172 DL: SDLoc(NewConv)));
17173 AddToWorklist(N: Hi.getNode());
17174 FlipBit = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(N0), VT: MVT::i64, N1: Hi, N2: SignBit);
17175 AddToWorklist(N: FlipBit.getNode());
17176 }
17177 SDValue FlipBits =
17178 DAG.getNode(Opcode: ISD::BUILD_PAIR, DL: SDLoc(N0), VT, N1: FlipBit, N2: FlipBit);
17179 AddToWorklist(N: FlipBits.getNode());
17180 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1: NewConv, N2: FlipBits);
17181 }
17182 APInt SignBit = APInt::getSignMask(BitWidth: VT.getSizeInBits());
17183 if (N0.getOpcode() == ISD::FNEG)
17184 return DAG.getNode(Opcode: ISD::XOR, DL, VT,
17185 N1: NewConv, N2: DAG.getConstant(Val: SignBit, DL, VT));
17186 assert(N0.getOpcode() == ISD::FABS);
17187 return DAG.getNode(Opcode: ISD::AND, DL, VT,
17188 N1: NewConv, N2: DAG.getConstant(Val: ~SignBit, DL, VT));
17189 }
17190
17191 // fold (bitconvert (fcopysign cst, x)) ->
17192 // (or (and (bitconvert x), sign), (and cst, (not sign)))
17193 // Note that we don't handle (copysign x, cst) because this can always be
17194 // folded to an fneg or fabs.
17195 //
17196 // For ppc_fp128:
17197 // fold (bitcast (fcopysign cst, x)) ->
17198 // flipbit = (and (extract_element
17199 // (xor (bitcast cst), (bitcast x)), 0),
17200 // signbit)
17201 // (xor (bitcast cst) (build_pair flipbit, flipbit))
17202 if (N0.getOpcode() == ISD::FCOPYSIGN && N0->hasOneUse() &&
17203 isa<ConstantFPSDNode>(Val: N0.getOperand(i: 0)) && VT.isInteger() &&
17204 !VT.isVector()) {
17205 unsigned OrigXWidth = N0.getOperand(i: 1).getValueSizeInBits();
17206 EVT IntXVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: OrigXWidth);
17207 if (isTypeLegal(VT: IntXVT)) {
17208 SDValue X = DAG.getBitcast(VT: IntXVT, V: N0.getOperand(i: 1));
17209 AddToWorklist(N: X.getNode());
17210
17211 // If X has a different width than the result/lhs, sext it or truncate it.
17212 unsigned VTWidth = VT.getSizeInBits();
17213 if (OrigXWidth < VTWidth) {
17214 X = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL: SDLoc(N), VT, Operand: X);
17215 AddToWorklist(N: X.getNode());
17216 } else if (OrigXWidth > VTWidth) {
17217 // To get the sign bit in the right place, we have to shift it right
17218 // before truncating.
17219 SDLoc DL(X);
17220 X = DAG.getNode(Opcode: ISD::SRL, DL,
17221 VT: X.getValueType(), N1: X,
17222 N2: DAG.getConstant(Val: OrigXWidth-VTWidth, DL,
17223 VT: X.getValueType()));
17224 AddToWorklist(N: X.getNode());
17225 X = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(X), VT, Operand: X);
17226 AddToWorklist(N: X.getNode());
17227 }
17228
17229 if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
17230 APInt SignBit = APInt::getSignMask(BitWidth: VT.getSizeInBits() / 2);
17231 SDValue Cst = DAG.getBitcast(VT, V: N0.getOperand(i: 0));
17232 AddToWorklist(N: Cst.getNode());
17233 SDValue X = DAG.getBitcast(VT, V: N0.getOperand(i: 1));
17234 AddToWorklist(N: X.getNode());
17235 SDValue XorResult = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N0), VT, N1: Cst, N2: X);
17236 AddToWorklist(N: XorResult.getNode());
17237 SDValue XorResult64 = DAG.getNode(
17238 Opcode: ISD::EXTRACT_ELEMENT, DL: SDLoc(XorResult), VT: MVT::i64, N1: XorResult,
17239 N2: DAG.getIntPtrConstant(Val: getPPCf128HiElementSelector(DAG),
17240 DL: SDLoc(XorResult)));
17241 AddToWorklist(N: XorResult64.getNode());
17242 SDValue FlipBit =
17243 DAG.getNode(Opcode: ISD::AND, DL: SDLoc(XorResult64), VT: MVT::i64, N1: XorResult64,
17244 N2: DAG.getConstant(Val: SignBit, DL: SDLoc(XorResult64), VT: MVT::i64));
17245 AddToWorklist(N: FlipBit.getNode());
17246 SDValue FlipBits =
17247 DAG.getNode(Opcode: ISD::BUILD_PAIR, DL: SDLoc(N0), VT, N1: FlipBit, N2: FlipBit);
17248 AddToWorklist(N: FlipBits.getNode());
17249 return DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N), VT, N1: Cst, N2: FlipBits);
17250 }
17251 APInt SignBit = APInt::getSignMask(BitWidth: VT.getSizeInBits());
17252 X = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(X), VT,
17253 N1: X, N2: DAG.getConstant(Val: SignBit, DL: SDLoc(X), VT));
17254 AddToWorklist(N: X.getNode());
17255
17256 SDValue Cst = DAG.getBitcast(VT, V: N0.getOperand(i: 0));
17257 Cst = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(Cst), VT,
17258 N1: Cst, N2: DAG.getConstant(Val: ~SignBit, DL: SDLoc(Cst), VT));
17259 AddToWorklist(N: Cst.getNode());
17260
17261 return DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N), VT, N1: X, N2: Cst);
17262 }
17263 }
17264
17265 // bitconvert(build_pair(ld, ld)) -> ld iff load locations are consecutive.
17266 if (N0.getOpcode() == ISD::BUILD_PAIR)
17267 if (SDValue CombineLD = CombineConsecutiveLoads(N: N0.getNode(), VT))
17268 return CombineLD;
17269
17270 // int_vt (bitcast (vec_vt (scalar_to_vector elt_vt:x)))
17271 // => int_vt (any_extend elt_vt:x)
17272 if (N0.getOpcode() == ISD::SCALAR_TO_VECTOR && VT.isScalarInteger()) {
17273 SDValue SrcScalar = N0.getOperand(i: 0);
17274 if (SrcScalar.getValueType().isScalarInteger())
17275 return DAG.getNode(Opcode: ISD::ANY_EXTEND, DL: SDLoc(N), VT, Operand: SrcScalar);
17276 }
17277
17278 // Remove double bitcasts from shuffles - this is often a legacy of
17279 // XformToShuffleWithZero being used to combine bitmaskings (of
17280 // float vectors bitcast to integer vectors) into shuffles.
17281 // bitcast(shuffle(bitcast(s0),bitcast(s1))) -> shuffle(s0,s1)
17282 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT) && VT.isVector() &&
17283 N0->getOpcode() == ISD::VECTOR_SHUFFLE && N0.hasOneUse() &&
17284 VT.getVectorNumElements() >= N0.getValueType().getVectorNumElements() &&
17285 !(VT.getVectorNumElements() % N0.getValueType().getVectorNumElements())) {
17286 ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Val&: N0);
17287
17288 // If operands are a bitcast, peek through if it casts the original VT.
17289 // If operands are a constant, just bitcast back to original VT.
17290 auto PeekThroughBitcast = [&](SDValue Op) {
17291 if (Op.getOpcode() == ISD::BITCAST &&
17292 Op.getOperand(i: 0).getValueType() == VT)
17293 return SDValue(Op.getOperand(i: 0));
17294 if (Op.isUndef() || isAnyConstantBuildVector(V: Op))
17295 return DAG.getBitcast(VT, V: Op);
17296 return SDValue();
17297 };
17298
17299 // FIXME: If either input vector is bitcast, try to convert the shuffle to
17300 // the result type of this bitcast. This would eliminate at least one
17301 // bitcast. See the transform in InstCombine.
17302 SDValue SV0 = PeekThroughBitcast(N0->getOperand(Num: 0));
17303 SDValue SV1 = PeekThroughBitcast(N0->getOperand(Num: 1));
17304 if (!(SV0 && SV1))
17305 return SDValue();
17306
17307 int MaskScale =
17308 VT.getVectorNumElements() / N0.getValueType().getVectorNumElements();
17309 SmallVector<int, 8> NewMask;
17310 for (int M : SVN->getMask())
17311 for (int i = 0; i != MaskScale; ++i)
17312 NewMask.push_back(Elt: M < 0 ? -1 : M * MaskScale + i);
17313
17314 SDValue LegalShuffle =
17315 TLI.buildLegalVectorShuffle(VT, DL: SDLoc(N), N0: SV0, N1: SV1, Mask: NewMask, DAG);
17316 if (LegalShuffle)
17317 return LegalShuffle;
17318 }
17319
17320 return SDValue();
17321}
17322
17323SDValue DAGCombiner::visitBUILD_PAIR(SDNode *N) {
17324 EVT VT = N->getValueType(ResNo: 0);
17325 return CombineConsecutiveLoads(N, VT);
17326}
17327
17328SDValue DAGCombiner::visitFREEZE(SDNode *N) {
17329 SDValue N0 = N->getOperand(Num: 0);
17330
17331 if (DAG.isGuaranteedNotToBeUndefOrPoison(Op: N0, /*PoisonOnly*/ false))
17332 return N0;
17333
17334 // If we have frozen and unfrozen users of N0, update so everything uses N.
17335 if (!N0.isUndef() && !N0.hasOneUse()) {
17336 SDValue FrozenN0(N, 0);
17337 // Unfreeze all uses of N to avoid double deleting N from the CSE map.
17338 DAG.ReplaceAllUsesOfValueWith(From: FrozenN0, To: N0);
17339 DAG.ReplaceAllUsesOfValueWith(From: N0, To: FrozenN0);
17340 // ReplaceAllUsesOfValueWith will have also updated the use in N, thus
17341 // creating a cycle in a DAG. Let's undo that by mutating the freeze.
17342 assert(N->getOperand(0) == FrozenN0 && "Expected cycle in DAG");
17343 DAG.UpdateNodeOperands(N, Op: N0);
17344 return FrozenN0;
17345 }
17346
17347 // We currently avoid folding freeze over SRA/SRL, due to the problems seen
17348 // with (freeze (assert ext)) blocking simplifications of SRA/SRL. See for
17349 // example https://reviews.llvm.org/D136529#4120959.
17350 if (N0.getOpcode() == ISD::SRA || N0.getOpcode() == ISD::SRL)
17351 return SDValue();
17352
17353 // Fold freeze(op(x, ...)) -> op(freeze(x), ...).
17354 // Try to push freeze through instructions that propagate but don't produce
17355 // poison as far as possible. If an operand of freeze follows three
17356 // conditions 1) one-use, 2) does not produce poison, and 3) has all but one
17357 // guaranteed-non-poison operands (or is a BUILD_VECTOR or similar) then push
17358 // the freeze through to the operands that are not guaranteed non-poison.
17359 // NOTE: we will strip poison-generating flags, so ignore them here.
17360 if (DAG.canCreateUndefOrPoison(Op: N0, /*PoisonOnly*/ false,
17361 /*ConsiderFlags*/ false) ||
17362 N0->getNumValues() != 1 || !N0->hasOneUse())
17363 return SDValue();
17364
17365 // TOOD: we should always allow multiple operands, however this increases the
17366 // likelihood of infinite loops due to the ReplaceAllUsesOfValueWith call
17367 // below causing later nodes that share frozen operands to fold again and no
17368 // longer being able to confirm other operands are not poison due to recursion
17369 // depth limits on isGuaranteedNotToBeUndefOrPoison.
17370 bool AllowMultipleMaybePoisonOperands =
17371 N0.getOpcode() == ISD::SELECT_CC || N0.getOpcode() == ISD::SETCC ||
17372 N0.getOpcode() == ISD::BUILD_VECTOR ||
17373 N0.getOpcode() == ISD::INSERT_SUBVECTOR ||
17374 N0.getOpcode() == ISD::BUILD_PAIR ||
17375 N0.getOpcode() == ISD::VECTOR_SHUFFLE ||
17376 N0.getOpcode() == ISD::CONCAT_VECTORS || N0.getOpcode() == ISD::FMUL;
17377
17378 // Avoid turning a BUILD_VECTOR that can be recognized as "all zeros", "all
17379 // ones" or "constant" into something that depends on FrozenUndef. We can
17380 // instead pick undef values to keep those properties, while at the same time
17381 // folding away the freeze.
17382 // If we implement a more general solution for folding away freeze(undef) in
17383 // the future, then this special handling can be removed.
17384 if (N0.getOpcode() == ISD::BUILD_VECTOR) {
17385 SDLoc DL(N0);
17386 EVT VT = N0.getValueType();
17387 if (llvm::ISD::isBuildVectorAllOnes(N: N0.getNode()) && VT.isInteger())
17388 return DAG.getAllOnesConstant(DL, VT);
17389 if (llvm::ISD::isBuildVectorOfConstantSDNodes(N: N0.getNode())) {
17390 SmallVector<SDValue, 8> NewVecC;
17391 for (const SDValue &Op : N0->op_values())
17392 NewVecC.push_back(
17393 Elt: Op.isUndef() ? DAG.getConstant(Val: 0, DL, VT: Op.getValueType()) : Op);
17394 return DAG.getBuildVector(VT, DL, Ops: NewVecC);
17395 }
17396 }
17397
17398 SmallSet<SDValue, 8> MaybePoisonOperands;
17399 SmallVector<unsigned, 8> MaybePoisonOperandNumbers;
17400 for (auto [OpNo, Op] : enumerate(First: N0->ops())) {
17401 if (DAG.isGuaranteedNotToBeUndefOrPoison(Op, /*PoisonOnly=*/false))
17402 continue;
17403 bool HadMaybePoisonOperands = !MaybePoisonOperands.empty();
17404 bool IsNewMaybePoisonOperand = MaybePoisonOperands.insert(V: Op).second;
17405 if (IsNewMaybePoisonOperand)
17406 MaybePoisonOperandNumbers.push_back(Elt: OpNo);
17407 if (!HadMaybePoisonOperands)
17408 continue;
17409 if (IsNewMaybePoisonOperand && !AllowMultipleMaybePoisonOperands) {
17410 // Multiple maybe-poison ops when not allowed - bail out.
17411 return SDValue();
17412 }
17413 }
17414 // NOTE: the whole op may be not guaranteed to not be undef or poison because
17415 // it could create undef or poison due to it's poison-generating flags.
17416 // So not finding any maybe-poison operands is fine.
17417
17418 for (unsigned OpNo : MaybePoisonOperandNumbers) {
17419 // N0 can mutate during iteration, so make sure to refetch the maybe poison
17420 // operands via the operand numbers. The typical scenario is that we have
17421 // something like this
17422 // t262: i32 = freeze t181
17423 // t150: i32 = ctlz_zero_undef t262
17424 // t184: i32 = ctlz_zero_undef t181
17425 // t268: i32 = select_cc t181, Constant:i32<0>, t184, t186, setne:ch
17426 // When freezing the t181 operand we get t262 back, and then the
17427 // ReplaceAllUsesOfValueWith call will not only replace t181 by t262, but
17428 // also recursively replace t184 by t150.
17429 SDValue MaybePoisonOperand = N->getOperand(Num: 0).getOperand(i: OpNo);
17430 // Don't replace every single UNDEF everywhere with frozen UNDEF, though.
17431 if (MaybePoisonOperand.isUndef())
17432 continue;
17433 // First, freeze each offending operand.
17434 SDValue FrozenMaybePoisonOperand = DAG.getFreeze(V: MaybePoisonOperand);
17435 // Then, change all other uses of unfrozen operand to use frozen operand.
17436 DAG.ReplaceAllUsesOfValueWith(From: MaybePoisonOperand, To: FrozenMaybePoisonOperand);
17437 if (FrozenMaybePoisonOperand.getOpcode() == ISD::FREEZE &&
17438 FrozenMaybePoisonOperand.getOperand(i: 0) == FrozenMaybePoisonOperand) {
17439 // But, that also updated the use in the freeze we just created, thus
17440 // creating a cycle in a DAG. Let's undo that by mutating the freeze.
17441 DAG.UpdateNodeOperands(N: FrozenMaybePoisonOperand.getNode(),
17442 Op: MaybePoisonOperand);
17443 }
17444
17445 // This node has been merged with another.
17446 if (N->getOpcode() == ISD::DELETED_NODE)
17447 return SDValue(N, 0);
17448 }
17449
17450 assert(N->getOpcode() != ISD::DELETED_NODE && "Node was deleted!");
17451
17452 // The whole node may have been updated, so the value we were holding
17453 // may no longer be valid. Re-fetch the operand we're `freeze`ing.
17454 N0 = N->getOperand(Num: 0);
17455
17456 // Finally, recreate the node, it's operands were updated to use
17457 // frozen operands, so we just need to use it's "original" operands.
17458 SmallVector<SDValue> Ops(N0->ops());
17459 // TODO: ISD::UNDEF and ISD::POISON should get separate handling, but best
17460 // leave for a future patch.
17461 for (SDValue &Op : Ops) {
17462 if (Op.isUndef())
17463 Op = DAG.getFreeze(V: Op);
17464 }
17465
17466 SDLoc DL(N0);
17467
17468 // Special case handling for ShuffleVectorSDNode nodes.
17469 if (auto *SVN = dyn_cast<ShuffleVectorSDNode>(Val&: N0))
17470 return DAG.getVectorShuffle(VT: N0.getValueType(), dl: DL, N1: Ops[0], N2: Ops[1],
17471 Mask: SVN->getMask());
17472
17473 // NOTE: this strips poison generating flags.
17474 // Folding freeze(op(x, ...)) -> op(freeze(x), ...) does not require nnan,
17475 // ninf, nsz, or fast.
17476 // However, contract, reassoc, afn, and arcp should be preserved,
17477 // as these fast-math flags do not introduce poison values.
17478 SDNodeFlags SrcFlags = N0->getFlags();
17479 SDNodeFlags SafeFlags;
17480 SafeFlags.setAllowContract(SrcFlags.hasAllowContract());
17481 SafeFlags.setAllowReassociation(SrcFlags.hasAllowReassociation());
17482 SafeFlags.setApproximateFuncs(SrcFlags.hasApproximateFuncs());
17483 SafeFlags.setAllowReciprocal(SrcFlags.hasAllowReciprocal());
17484 return DAG.getNode(Opcode: N0.getOpcode(), DL, VTList: N0->getVTList(), Ops, Flags: SafeFlags);
17485}
17486
17487// Returns true if floating point contraction is allowed on the FMUL-SDValue
17488// `N`
17489static bool isContractableFMUL(const TargetOptions &Options, SDValue N) {
17490 assert(N.getOpcode() == ISD::FMUL);
17491
17492 return Options.AllowFPOpFusion == FPOpFusion::Fast ||
17493 N->getFlags().hasAllowContract();
17494}
17495
17496/// Try to perform FMA combining on a given FADD node.
17497template <class MatchContextClass>
17498SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
17499 SDValue N0 = N->getOperand(Num: 0);
17500 SDValue N1 = N->getOperand(Num: 1);
17501 EVT VT = N->getValueType(ResNo: 0);
17502 SDLoc SL(N);
17503 MatchContextClass matcher(DAG, TLI, N);
17504 const TargetOptions &Options = DAG.getTarget().Options;
17505
17506 bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
17507
17508 // Floating-point multiply-add with intermediate rounding.
17509 // FIXME: Make isFMADLegal have specific behavior when using VPMatchContext.
17510 // FIXME: Add VP_FMAD opcode.
17511 bool HasFMAD = !UseVP && (LegalOperations && TLI.isFMADLegal(DAG, N));
17512
17513 // Floating-point multiply-add without intermediate rounding.
17514 bool HasFMA =
17515 (!LegalOperations || matcher.isOperationLegalOrCustom(ISD::FMA, VT)) &&
17516 TLI.isFMAFasterThanFMulAndFAdd(MF: DAG.getMachineFunction(), VT);
17517
17518 // No valid opcode, do not combine.
17519 if (!HasFMAD && !HasFMA)
17520 return SDValue();
17521
17522 bool AllowFusionGlobally =
17523 Options.AllowFPOpFusion == FPOpFusion::Fast || HasFMAD;
17524 // If the addition is not contractable, do not combine.
17525 if (!AllowFusionGlobally && !N->getFlags().hasAllowContract())
17526 return SDValue();
17527
17528 // Folding fadd (fmul x, y), (fmul x, y) -> fma x, y, (fmul x, y) is never
17529 // beneficial. It does not reduce latency. It increases register pressure. It
17530 // replaces an fadd with an fma which is a more complex instruction, so is
17531 // likely to have a larger encoding, use more functional units, etc.
17532 if (N0 == N1)
17533 return SDValue();
17534
17535 if (TLI.generateFMAsInMachineCombiner(VT, OptLevel))
17536 return SDValue();
17537
17538 // Always prefer FMAD to FMA for precision.
17539 unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
17540 bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
17541
17542 auto isFusedOp = [&](SDValue N) {
17543 return matcher.match(N, ISD::FMA) || matcher.match(N, ISD::FMAD);
17544 };
17545
17546 // Is the node an FMUL and contractable either due to global flags or
17547 // SDNodeFlags.
17548 auto isContractableFMUL = [AllowFusionGlobally, &matcher](SDValue N) {
17549 if (!matcher.match(N, ISD::FMUL))
17550 return false;
17551 return AllowFusionGlobally || N->getFlags().hasAllowContract();
17552 };
17553 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
17554 // prefer to fold the multiply with fewer uses.
17555 if (Aggressive && isContractableFMUL(N0) && isContractableFMUL(N1)) {
17556 if (N0->use_size() > N1->use_size())
17557 std::swap(a&: N0, b&: N1);
17558 }
17559
17560 // fold (fadd (fmul x, y), z) -> (fma x, y, z)
17561 if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) {
17562 return matcher.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(i: 0),
17563 N0.getOperand(i: 1), N1);
17564 }
17565
17566 // fold (fadd x, (fmul y, z)) -> (fma y, z, x)
17567 // Note: Commutes FADD operands.
17568 if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) {
17569 return matcher.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(i: 0),
17570 N1.getOperand(i: 1), N0);
17571 }
17572
17573 // fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E)
17574 // fadd E, (fma A, B, (fmul C, D)) --> fma A, B, (fma C, D, E)
17575 // This also works with nested fma instructions:
17576 // fadd (fma A, B, (fma (C, D, (fmul (E, F))))), G -->
17577 // fma A, B, (fma C, D, fma (E, F, G))
17578 // fadd (G, (fma A, B, (fma (C, D, (fmul (E, F)))))) -->
17579 // fma A, B, (fma C, D, fma (E, F, G)).
17580 // This requires reassociation because it changes the order of operations.
17581 bool CanReassociate = N->getFlags().hasAllowReassociation();
17582 if (CanReassociate) {
17583 SDValue FMA, E;
17584 if (isFusedOp(N0) && N0.hasOneUse()) {
17585 FMA = N0;
17586 E = N1;
17587 } else if (isFusedOp(N1) && N1.hasOneUse()) {
17588 FMA = N1;
17589 E = N0;
17590 }
17591
17592 SDValue TmpFMA = FMA;
17593 while (E && isFusedOp(TmpFMA) && TmpFMA.hasOneUse()) {
17594 SDValue FMul = TmpFMA->getOperand(Num: 2);
17595 if (matcher.match(FMul, ISD::FMUL) && FMul.hasOneUse()) {
17596 SDValue C = FMul.getOperand(i: 0);
17597 SDValue D = FMul.getOperand(i: 1);
17598 SDValue CDE = matcher.getNode(PreferredFusedOpcode, SL, VT, C, D, E);
17599 DAG.ReplaceAllUsesOfValueWith(From: FMul, To: CDE);
17600 // Replacing the inner FMul could cause the outer FMA to be simplified
17601 // away.
17602 return FMA.getOpcode() == ISD::DELETED_NODE ? SDValue(N, 0) : FMA;
17603 }
17604
17605 TmpFMA = TmpFMA->getOperand(Num: 2);
17606 }
17607 }
17608
17609 // Look through FP_EXTEND nodes to do more combining.
17610
17611 // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
17612 if (matcher.match(N0, ISD::FP_EXTEND)) {
17613 SDValue N00 = N0.getOperand(i: 0);
17614 if (isContractableFMUL(N00) &&
17615 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
17616 SrcVT: N00.getValueType())) {
17617 return matcher.getNode(
17618 PreferredFusedOpcode, SL, VT,
17619 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(i: 0)),
17620 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(i: 1)), N1);
17621 }
17622 }
17623
17624 // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x)
17625 // Note: Commutes FADD operands.
17626 if (matcher.match(N1, ISD::FP_EXTEND)) {
17627 SDValue N10 = N1.getOperand(i: 0);
17628 if (isContractableFMUL(N10) &&
17629 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
17630 SrcVT: N10.getValueType())) {
17631 return matcher.getNode(
17632 PreferredFusedOpcode, SL, VT,
17633 matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(i: 0)),
17634 matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(i: 1)), N0);
17635 }
17636 }
17637
17638 // More folding opportunities when target permits.
17639 if (Aggressive) {
17640 // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
17641 // -> (fma x, y, (fma (fpext u), (fpext v), z))
17642 auto FoldFAddFMAFPExtFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
17643 SDValue Z) {
17644 return matcher.getNode(
17645 PreferredFusedOpcode, SL, VT, X, Y,
17646 matcher.getNode(PreferredFusedOpcode, SL, VT,
17647 matcher.getNode(ISD::FP_EXTEND, SL, VT, U),
17648 matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z));
17649 };
17650 if (isFusedOp(N0)) {
17651 SDValue N02 = N0.getOperand(i: 2);
17652 if (matcher.match(N02, ISD::FP_EXTEND)) {
17653 SDValue N020 = N02.getOperand(i: 0);
17654 if (isContractableFMUL(N020) &&
17655 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
17656 SrcVT: N020.getValueType())) {
17657 return FoldFAddFMAFPExtFMul(N0.getOperand(i: 0), N0.getOperand(i: 1),
17658 N020.getOperand(i: 0), N020.getOperand(i: 1),
17659 N1);
17660 }
17661 }
17662 }
17663
17664 // fold (fadd (fpext (fma x, y, (fmul u, v))), z)
17665 // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
17666 // FIXME: This turns two single-precision and one double-precision
17667 // operation into two double-precision operations, which might not be
17668 // interesting for all targets, especially GPUs.
17669 auto FoldFAddFPExtFMAFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
17670 SDValue Z) {
17671 return matcher.getNode(
17672 PreferredFusedOpcode, SL, VT,
17673 matcher.getNode(ISD::FP_EXTEND, SL, VT, X),
17674 matcher.getNode(ISD::FP_EXTEND, SL, VT, Y),
17675 matcher.getNode(PreferredFusedOpcode, SL, VT,
17676 matcher.getNode(ISD::FP_EXTEND, SL, VT, U),
17677 matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z));
17678 };
17679 if (N0.getOpcode() == ISD::FP_EXTEND) {
17680 SDValue N00 = N0.getOperand(i: 0);
17681 if (isFusedOp(N00)) {
17682 SDValue N002 = N00.getOperand(i: 2);
17683 if (isContractableFMUL(N002) &&
17684 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
17685 SrcVT: N00.getValueType())) {
17686 return FoldFAddFPExtFMAFMul(N00.getOperand(i: 0), N00.getOperand(i: 1),
17687 N002.getOperand(i: 0), N002.getOperand(i: 1),
17688 N1);
17689 }
17690 }
17691 }
17692
17693 // fold (fadd x, (fma y, z, (fpext (fmul u, v)))
17694 // -> (fma y, z, (fma (fpext u), (fpext v), x))
17695 if (isFusedOp(N1)) {
17696 SDValue N12 = N1.getOperand(i: 2);
17697 if (N12.getOpcode() == ISD::FP_EXTEND) {
17698 SDValue N120 = N12.getOperand(i: 0);
17699 if (isContractableFMUL(N120) &&
17700 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
17701 SrcVT: N120.getValueType())) {
17702 return FoldFAddFMAFPExtFMul(N1.getOperand(i: 0), N1.getOperand(i: 1),
17703 N120.getOperand(i: 0), N120.getOperand(i: 1),
17704 N0);
17705 }
17706 }
17707 }
17708
17709 // fold (fadd x, (fpext (fma y, z, (fmul u, v)))
17710 // -> (fma (fpext y), (fpext z), (fma (fpext u), (fpext v), x))
17711 // FIXME: This turns two single-precision and one double-precision
17712 // operation into two double-precision operations, which might not be
17713 // interesting for all targets, especially GPUs.
17714 if (N1.getOpcode() == ISD::FP_EXTEND) {
17715 SDValue N10 = N1.getOperand(i: 0);
17716 if (isFusedOp(N10)) {
17717 SDValue N102 = N10.getOperand(i: 2);
17718 if (isContractableFMUL(N102) &&
17719 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
17720 SrcVT: N10.getValueType())) {
17721 return FoldFAddFPExtFMAFMul(N10.getOperand(i: 0), N10.getOperand(i: 1),
17722 N102.getOperand(i: 0), N102.getOperand(i: 1),
17723 N0);
17724 }
17725 }
17726 }
17727 }
17728
17729 return SDValue();
17730}
17731
17732/// Try to perform FMA combining on a given FSUB node.
17733template <class MatchContextClass>
17734SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
17735 SDValue N0 = N->getOperand(Num: 0);
17736 SDValue N1 = N->getOperand(Num: 1);
17737 EVT VT = N->getValueType(ResNo: 0);
17738 SDLoc SL(N);
17739 MatchContextClass matcher(DAG, TLI, N);
17740 const TargetOptions &Options = DAG.getTarget().Options;
17741
17742 bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
17743
17744 // Floating-point multiply-add with intermediate rounding.
17745 // FIXME: Make isFMADLegal have specific behavior when using VPMatchContext.
17746 // FIXME: Add VP_FMAD opcode.
17747 bool HasFMAD = !UseVP && (LegalOperations && TLI.isFMADLegal(DAG, N));
17748
17749 // Floating-point multiply-add without intermediate rounding.
17750 bool HasFMA =
17751 (!LegalOperations || matcher.isOperationLegalOrCustom(ISD::FMA, VT)) &&
17752 TLI.isFMAFasterThanFMulAndFAdd(MF: DAG.getMachineFunction(), VT);
17753
17754 // No valid opcode, do not combine.
17755 if (!HasFMAD && !HasFMA)
17756 return SDValue();
17757
17758 const SDNodeFlags Flags = N->getFlags();
17759 bool AllowFusionGlobally =
17760 (Options.AllowFPOpFusion == FPOpFusion::Fast || HasFMAD);
17761
17762 // If the subtraction is not contractable, do not combine.
17763 if (!AllowFusionGlobally && !N->getFlags().hasAllowContract())
17764 return SDValue();
17765
17766 if (TLI.generateFMAsInMachineCombiner(VT, OptLevel))
17767 return SDValue();
17768
17769 // Always prefer FMAD to FMA for precision.
17770 unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
17771 bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
17772 bool NoSignedZero = Flags.hasNoSignedZeros();
17773
17774 // Is the node an FMUL and contractable either due to global flags or
17775 // SDNodeFlags.
17776 auto isContractableFMUL = [AllowFusionGlobally, &matcher](SDValue N) {
17777 if (!matcher.match(N, ISD::FMUL))
17778 return false;
17779 return AllowFusionGlobally || N->getFlags().hasAllowContract();
17780 };
17781
17782 // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
17783 auto tryToFoldXYSubZ = [&](SDValue XY, SDValue Z) {
17784 if (isContractableFMUL(XY) && (Aggressive || XY->hasOneUse())) {
17785 return matcher.getNode(PreferredFusedOpcode, SL, VT, XY.getOperand(i: 0),
17786 XY.getOperand(i: 1),
17787 matcher.getNode(ISD::FNEG, SL, VT, Z));
17788 }
17789 return SDValue();
17790 };
17791
17792 // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
17793 // Note: Commutes FSUB operands.
17794 auto tryToFoldXSubYZ = [&](SDValue X, SDValue YZ) {
17795 if (isContractableFMUL(YZ) && (Aggressive || YZ->hasOneUse())) {
17796 return matcher.getNode(
17797 PreferredFusedOpcode, SL, VT,
17798 matcher.getNode(ISD::FNEG, SL, VT, YZ.getOperand(i: 0)),
17799 YZ.getOperand(i: 1), X);
17800 }
17801 return SDValue();
17802 };
17803
17804 // If we have two choices trying to fold (fsub (fmul u, v), (fmul x, y)),
17805 // prefer to fold the multiply with fewer uses.
17806 if (isContractableFMUL(N0) && isContractableFMUL(N1) &&
17807 (N0->use_size() > N1->use_size())) {
17808 // fold (fsub (fmul a, b), (fmul c, d)) -> (fma (fneg c), d, (fmul a, b))
17809 if (SDValue V = tryToFoldXSubYZ(N0, N1))
17810 return V;
17811 // fold (fsub (fmul a, b), (fmul c, d)) -> (fma a, b, (fneg (fmul c, d)))
17812 if (SDValue V = tryToFoldXYSubZ(N0, N1))
17813 return V;
17814 } else {
17815 // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
17816 if (SDValue V = tryToFoldXYSubZ(N0, N1))
17817 return V;
17818 // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
17819 if (SDValue V = tryToFoldXSubYZ(N0, N1))
17820 return V;
17821 }
17822
17823 // fold (fsub (fneg (fmul, x, y)), z) -> (fma (fneg x), y, (fneg z))
17824 if (matcher.match(N0, ISD::FNEG) && isContractableFMUL(N0.getOperand(i: 0)) &&
17825 (Aggressive || (N0->hasOneUse() && N0.getOperand(i: 0).hasOneUse()))) {
17826 SDValue N00 = N0.getOperand(i: 0).getOperand(i: 0);
17827 SDValue N01 = N0.getOperand(i: 0).getOperand(i: 1);
17828 return matcher.getNode(PreferredFusedOpcode, SL, VT,
17829 matcher.getNode(ISD::FNEG, SL, VT, N00), N01,
17830 matcher.getNode(ISD::FNEG, SL, VT, N1));
17831 }
17832
17833 // Look through FP_EXTEND nodes to do more combining.
17834
17835 // fold (fsub (fpext (fmul x, y)), z)
17836 // -> (fma (fpext x), (fpext y), (fneg z))
17837 if (matcher.match(N0, ISD::FP_EXTEND)) {
17838 SDValue N00 = N0.getOperand(i: 0);
17839 if (isContractableFMUL(N00) &&
17840 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
17841 SrcVT: N00.getValueType())) {
17842 return matcher.getNode(
17843 PreferredFusedOpcode, SL, VT,
17844 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(i: 0)),
17845 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(i: 1)),
17846 matcher.getNode(ISD::FNEG, SL, VT, N1));
17847 }
17848 }
17849
17850 // fold (fsub x, (fpext (fmul y, z)))
17851 // -> (fma (fneg (fpext y)), (fpext z), x)
17852 // Note: Commutes FSUB operands.
17853 if (matcher.match(N1, ISD::FP_EXTEND)) {
17854 SDValue N10 = N1.getOperand(i: 0);
17855 if (isContractableFMUL(N10) &&
17856 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
17857 SrcVT: N10.getValueType())) {
17858 return matcher.getNode(
17859 PreferredFusedOpcode, SL, VT,
17860 matcher.getNode(
17861 ISD::FNEG, SL, VT,
17862 matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(i: 0))),
17863 matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(i: 1)), N0);
17864 }
17865 }
17866
17867 // fold (fsub (fpext (fneg (fmul, x, y))), z)
17868 // -> (fneg (fma (fpext x), (fpext y), z))
17869 // Note: This could be removed with appropriate canonicalization of the
17870 // input expression into (fneg (fadd (fpext (fmul, x, y)), z)). However, the
17871 // command line flag -fp-contract=fast and fast-math flag contract prevent
17872 // from implementing the canonicalization in visitFSUB.
17873 if (matcher.match(N0, ISD::FP_EXTEND)) {
17874 SDValue N00 = N0.getOperand(i: 0);
17875 if (matcher.match(N00, ISD::FNEG)) {
17876 SDValue N000 = N00.getOperand(i: 0);
17877 if (isContractableFMUL(N000) &&
17878 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
17879 SrcVT: N00.getValueType())) {
17880 return matcher.getNode(
17881 ISD::FNEG, SL, VT,
17882 matcher.getNode(
17883 PreferredFusedOpcode, SL, VT,
17884 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(i: 0)),
17885 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(i: 1)),
17886 N1));
17887 }
17888 }
17889 }
17890
17891 // fold (fsub (fneg (fpext (fmul, x, y))), z)
17892 // -> (fneg (fma (fpext x)), (fpext y), z)
17893 // Note: This could be removed with appropriate canonicalization of the
17894 // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
17895 // command line flag -fp-contract=fast and fast-math flag contract prevent
17896 // from implementing the canonicalization in visitFSUB.
17897 if (matcher.match(N0, ISD::FNEG)) {
17898 SDValue N00 = N0.getOperand(i: 0);
17899 if (matcher.match(N00, ISD::FP_EXTEND)) {
17900 SDValue N000 = N00.getOperand(i: 0);
17901 if (isContractableFMUL(N000) &&
17902 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
17903 SrcVT: N000.getValueType())) {
17904 return matcher.getNode(
17905 ISD::FNEG, SL, VT,
17906 matcher.getNode(
17907 PreferredFusedOpcode, SL, VT,
17908 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(i: 0)),
17909 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(i: 1)),
17910 N1));
17911 }
17912 }
17913 }
17914
17915 auto isContractableAndReassociableFMUL = [&isContractableFMUL](SDValue N) {
17916 return isContractableFMUL(N) && N->getFlags().hasAllowReassociation();
17917 };
17918
17919 auto isFusedOp = [&](SDValue N) {
17920 return matcher.match(N, ISD::FMA) || matcher.match(N, ISD::FMAD);
17921 };
17922
17923 // More folding opportunities when target permits.
17924 if (Aggressive && N->getFlags().hasAllowReassociation()) {
17925 bool CanFuse = N->getFlags().hasAllowContract();
17926 // fold (fsub (fma x, y, (fmul u, v)), z)
17927 // -> (fma x, y (fma u, v, (fneg z)))
17928 if (CanFuse && isFusedOp(N0) &&
17929 isContractableAndReassociableFMUL(N0.getOperand(i: 2)) &&
17930 N0->hasOneUse() && N0.getOperand(i: 2)->hasOneUse()) {
17931 return matcher.getNode(
17932 PreferredFusedOpcode, SL, VT, N0.getOperand(i: 0), N0.getOperand(i: 1),
17933 matcher.getNode(PreferredFusedOpcode, SL, VT,
17934 N0.getOperand(i: 2).getOperand(i: 0),
17935 N0.getOperand(i: 2).getOperand(i: 1),
17936 matcher.getNode(ISD::FNEG, SL, VT, N1)));
17937 }
17938
17939 // fold (fsub x, (fma y, z, (fmul u, v)))
17940 // -> (fma (fneg y), z, (fma (fneg u), v, x))
17941 if (CanFuse && isFusedOp(N1) &&
17942 isContractableAndReassociableFMUL(N1.getOperand(i: 2)) &&
17943 N1->hasOneUse() && NoSignedZero) {
17944 SDValue N20 = N1.getOperand(i: 2).getOperand(i: 0);
17945 SDValue N21 = N1.getOperand(i: 2).getOperand(i: 1);
17946 return matcher.getNode(
17947 PreferredFusedOpcode, SL, VT,
17948 matcher.getNode(ISD::FNEG, SL, VT, N1.getOperand(i: 0)),
17949 N1.getOperand(i: 1),
17950 matcher.getNode(PreferredFusedOpcode, SL, VT,
17951 matcher.getNode(ISD::FNEG, SL, VT, N20), N21, N0));
17952 }
17953
17954 // fold (fsub (fma x, y, (fpext (fmul u, v))), z)
17955 // -> (fma x, y (fma (fpext u), (fpext v), (fneg z)))
17956 if (isFusedOp(N0) && N0->hasOneUse()) {
17957 SDValue N02 = N0.getOperand(i: 2);
17958 if (matcher.match(N02, ISD::FP_EXTEND)) {
17959 SDValue N020 = N02.getOperand(i: 0);
17960 if (isContractableAndReassociableFMUL(N020) &&
17961 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
17962 SrcVT: N020.getValueType())) {
17963 return matcher.getNode(
17964 PreferredFusedOpcode, SL, VT, N0.getOperand(i: 0), N0.getOperand(i: 1),
17965 matcher.getNode(
17966 PreferredFusedOpcode, SL, VT,
17967 matcher.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(i: 0)),
17968 matcher.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(i: 1)),
17969 matcher.getNode(ISD::FNEG, SL, VT, N1)));
17970 }
17971 }
17972 }
17973
17974 // fold (fsub (fpext (fma x, y, (fmul u, v))), z)
17975 // -> (fma (fpext x), (fpext y),
17976 // (fma (fpext u), (fpext v), (fneg z)))
17977 // FIXME: This turns two single-precision and one double-precision
17978 // operation into two double-precision operations, which might not be
17979 // interesting for all targets, especially GPUs.
17980 if (matcher.match(N0, ISD::FP_EXTEND)) {
17981 SDValue N00 = N0.getOperand(i: 0);
17982 if (isFusedOp(N00)) {
17983 SDValue N002 = N00.getOperand(i: 2);
17984 if (isContractableAndReassociableFMUL(N002) &&
17985 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
17986 SrcVT: N00.getValueType())) {
17987 return matcher.getNode(
17988 PreferredFusedOpcode, SL, VT,
17989 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(i: 0)),
17990 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(i: 1)),
17991 matcher.getNode(
17992 PreferredFusedOpcode, SL, VT,
17993 matcher.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(i: 0)),
17994 matcher.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(i: 1)),
17995 matcher.getNode(ISD::FNEG, SL, VT, N1)));
17996 }
17997 }
17998 }
17999
18000 // fold (fsub x, (fma y, z, (fpext (fmul u, v))))
18001 // -> (fma (fneg y), z, (fma (fneg (fpext u)), (fpext v), x))
18002 if (isFusedOp(N1) && matcher.match(N1.getOperand(i: 2), ISD::FP_EXTEND) &&
18003 N1->hasOneUse()) {
18004 SDValue N120 = N1.getOperand(i: 2).getOperand(i: 0);
18005 if (isContractableAndReassociableFMUL(N120) &&
18006 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
18007 SrcVT: N120.getValueType())) {
18008 SDValue N1200 = N120.getOperand(i: 0);
18009 SDValue N1201 = N120.getOperand(i: 1);
18010 return matcher.getNode(
18011 PreferredFusedOpcode, SL, VT,
18012 matcher.getNode(ISD::FNEG, SL, VT, N1.getOperand(i: 0)),
18013 N1.getOperand(i: 1),
18014 matcher.getNode(
18015 PreferredFusedOpcode, SL, VT,
18016 matcher.getNode(ISD::FNEG, SL, VT,
18017 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1200)),
18018 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1201), N0));
18019 }
18020 }
18021
18022 // fold (fsub x, (fpext (fma y, z, (fmul u, v))))
18023 // -> (fma (fneg (fpext y)), (fpext z),
18024 // (fma (fneg (fpext u)), (fpext v), x))
18025 // FIXME: This turns two single-precision and one double-precision
18026 // operation into two double-precision operations, which might not be
18027 // interesting for all targets, especially GPUs.
18028 if (matcher.match(N1, ISD::FP_EXTEND) && isFusedOp(N1.getOperand(i: 0))) {
18029 SDValue CvtSrc = N1.getOperand(i: 0);
18030 SDValue N100 = CvtSrc.getOperand(i: 0);
18031 SDValue N101 = CvtSrc.getOperand(i: 1);
18032 SDValue N102 = CvtSrc.getOperand(i: 2);
18033 if (isContractableAndReassociableFMUL(N102) &&
18034 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
18035 SrcVT: CvtSrc.getValueType())) {
18036 SDValue N1020 = N102.getOperand(i: 0);
18037 SDValue N1021 = N102.getOperand(i: 1);
18038 return matcher.getNode(
18039 PreferredFusedOpcode, SL, VT,
18040 matcher.getNode(ISD::FNEG, SL, VT,
18041 matcher.getNode(ISD::FP_EXTEND, SL, VT, N100)),
18042 matcher.getNode(ISD::FP_EXTEND, SL, VT, N101),
18043 matcher.getNode(
18044 PreferredFusedOpcode, SL, VT,
18045 matcher.getNode(ISD::FNEG, SL, VT,
18046 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1020)),
18047 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1021), N0));
18048 }
18049 }
18050 }
18051
18052 return SDValue();
18053}
18054
18055/// Try to perform FMA combining on a given FMUL node based on the distributive
18056/// law x * (y + 1) = x * y + x and variants thereof (commuted versions,
18057/// subtraction instead of addition).
18058SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) {
18059 SDValue N0 = N->getOperand(Num: 0);
18060 SDValue N1 = N->getOperand(Num: 1);
18061 EVT VT = N->getValueType(ResNo: 0);
18062 SDLoc SL(N);
18063
18064 assert(N->getOpcode() == ISD::FMUL && "Expected FMUL Operation");
18065
18066 const TargetOptions &Options = DAG.getTarget().Options;
18067
18068 // The transforms below are incorrect when x == 0 and y == inf, because the
18069 // intermediate multiplication produces a nan.
18070 SDValue FAdd = N0.getOpcode() == ISD::FADD ? N0 : N1;
18071 if (!FAdd->getFlags().hasNoInfs())
18072 return SDValue();
18073
18074 // Floating-point multiply-add without intermediate rounding.
18075 bool HasFMA =
18076 isContractableFMUL(Options, N: SDValue(N, 0)) &&
18077 (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::FMA, VT)) &&
18078 TLI.isFMAFasterThanFMulAndFAdd(MF: DAG.getMachineFunction(), VT);
18079
18080 // Floating-point multiply-add with intermediate rounding. This can result
18081 // in a less precise result due to the changed rounding order.
18082 bool HasFMAD = LegalOperations && TLI.isFMADLegal(DAG, N);
18083
18084 // No valid opcode, do not combine.
18085 if (!HasFMAD && !HasFMA)
18086 return SDValue();
18087
18088 // Always prefer FMAD to FMA for precision.
18089 unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
18090 bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
18091
18092 // fold (fmul (fadd x0, +1.0), y) -> (fma x0, y, y)
18093 // fold (fmul (fadd x0, -1.0), y) -> (fma x0, y, (fneg y))
18094 auto FuseFADD = [&](SDValue X, SDValue Y) {
18095 if (X.getOpcode() == ISD::FADD && (Aggressive || X->hasOneUse())) {
18096 if (auto *C = isConstOrConstSplatFP(N: X.getOperand(i: 1), AllowUndefs: true)) {
18097 if (C->isExactlyValue(V: +1.0))
18098 return DAG.getNode(Opcode: PreferredFusedOpcode, DL: SL, VT, N1: X.getOperand(i: 0), N2: Y,
18099 N3: Y);
18100 if (C->isExactlyValue(V: -1.0))
18101 return DAG.getNode(Opcode: PreferredFusedOpcode, DL: SL, VT, N1: X.getOperand(i: 0), N2: Y,
18102 N3: DAG.getNode(Opcode: ISD::FNEG, DL: SL, VT, Operand: Y));
18103 }
18104 }
18105 return SDValue();
18106 };
18107
18108 if (SDValue FMA = FuseFADD(N0, N1))
18109 return FMA;
18110 if (SDValue FMA = FuseFADD(N1, N0))
18111 return FMA;
18112
18113 // fold (fmul (fsub +1.0, x1), y) -> (fma (fneg x1), y, y)
18114 // fold (fmul (fsub -1.0, x1), y) -> (fma (fneg x1), y, (fneg y))
18115 // fold (fmul (fsub x0, +1.0), y) -> (fma x0, y, (fneg y))
18116 // fold (fmul (fsub x0, -1.0), y) -> (fma x0, y, y)
18117 auto FuseFSUB = [&](SDValue X, SDValue Y) {
18118 if (X.getOpcode() == ISD::FSUB && (Aggressive || X->hasOneUse())) {
18119 if (auto *C0 = isConstOrConstSplatFP(N: X.getOperand(i: 0), AllowUndefs: true)) {
18120 if (C0->isExactlyValue(V: +1.0))
18121 return DAG.getNode(Opcode: PreferredFusedOpcode, DL: SL, VT,
18122 N1: DAG.getNode(Opcode: ISD::FNEG, DL: SL, VT, Operand: X.getOperand(i: 1)), N2: Y,
18123 N3: Y);
18124 if (C0->isExactlyValue(V: -1.0))
18125 return DAG.getNode(Opcode: PreferredFusedOpcode, DL: SL, VT,
18126 N1: DAG.getNode(Opcode: ISD::FNEG, DL: SL, VT, Operand: X.getOperand(i: 1)), N2: Y,
18127 N3: DAG.getNode(Opcode: ISD::FNEG, DL: SL, VT, Operand: Y));
18128 }
18129 if (auto *C1 = isConstOrConstSplatFP(N: X.getOperand(i: 1), AllowUndefs: true)) {
18130 if (C1->isExactlyValue(V: +1.0))
18131 return DAG.getNode(Opcode: PreferredFusedOpcode, DL: SL, VT, N1: X.getOperand(i: 0), N2: Y,
18132 N3: DAG.getNode(Opcode: ISD::FNEG, DL: SL, VT, Operand: Y));
18133 if (C1->isExactlyValue(V: -1.0))
18134 return DAG.getNode(Opcode: PreferredFusedOpcode, DL: SL, VT, N1: X.getOperand(i: 0), N2: Y,
18135 N3: Y);
18136 }
18137 }
18138 return SDValue();
18139 };
18140
18141 if (SDValue FMA = FuseFSUB(N0, N1))
18142 return FMA;
18143 if (SDValue FMA = FuseFSUB(N1, N0))
18144 return FMA;
18145
18146 return SDValue();
18147}
18148
18149SDValue DAGCombiner::visitVP_FADD(SDNode *N) {
18150 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
18151
18152 // FADD -> FMA combines:
18153 if (SDValue Fused = visitFADDForFMACombine<VPMatchContext>(N)) {
18154 if (Fused.getOpcode() != ISD::DELETED_NODE)
18155 AddToWorklist(N: Fused.getNode());
18156 return Fused;
18157 }
18158 return SDValue();
18159}
18160
18161SDValue DAGCombiner::visitFADD(SDNode *N) {
18162 SDValue N0 = N->getOperand(Num: 0);
18163 SDValue N1 = N->getOperand(Num: 1);
18164 bool N0CFP = DAG.isConstantFPBuildVectorOrConstantFP(N: N0);
18165 bool N1CFP = DAG.isConstantFPBuildVectorOrConstantFP(N: N1);
18166 EVT VT = N->getValueType(ResNo: 0);
18167 SDLoc DL(N);
18168 SDNodeFlags Flags = N->getFlags();
18169 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
18170
18171 if (SDValue R = DAG.simplifyFPBinop(Opcode: N->getOpcode(), X: N0, Y: N1, Flags))
18172 return R;
18173
18174 // fold (fadd c1, c2) -> c1 + c2
18175 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FADD, DL, VT, Ops: {N0, N1}))
18176 return C;
18177
18178 // canonicalize constant to RHS
18179 if (N0CFP && !N1CFP)
18180 return DAG.getNode(Opcode: ISD::FADD, DL, VT, N1, N2: N0);
18181
18182 // fold vector ops
18183 if (VT.isVector())
18184 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
18185 return FoldedVOp;
18186
18187 // N0 + -0.0 --> N0 (also allowed with +0.0 and fast-math)
18188 ConstantFPSDNode *N1C = isConstOrConstSplatFP(N: N1, AllowUndefs: true);
18189 if (N1C && N1C->isZero())
18190 if (N1C->isNegative() || DAG.canIgnoreSignBitOfZero(Op: SDValue(N, 0)))
18191 return N0;
18192
18193 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
18194 return NewSel;
18195
18196 // fold (fadd A, (fneg B)) -> (fsub A, B)
18197 if (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::FSUB, VT))
18198 if (SDValue NegN1 = TLI.getCheaperNegatedExpression(
18199 Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize))
18200 return DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1: N0, N2: NegN1);
18201
18202 // fold (fadd (fneg A), B) -> (fsub B, A)
18203 if (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::FSUB, VT))
18204 if (SDValue NegN0 = TLI.getCheaperNegatedExpression(
18205 Op: N0, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize))
18206 return DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1, N2: NegN0);
18207
18208 auto isFMulNegTwo = [](SDValue FMul) {
18209 if (!FMul.hasOneUse() || FMul.getOpcode() != ISD::FMUL)
18210 return false;
18211 auto *C = isConstOrConstSplatFP(N: FMul.getOperand(i: 1), AllowUndefs: true);
18212 return C && C->isExactlyValue(V: -2.0);
18213 };
18214
18215 // fadd (fmul B, -2.0), A --> fsub A, (fadd B, B)
18216 if (isFMulNegTwo(N0)) {
18217 SDValue B = N0.getOperand(i: 0);
18218 SDValue Add = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: B, N2: B);
18219 return DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1, N2: Add);
18220 }
18221 // fadd A, (fmul B, -2.0) --> fsub A, (fadd B, B)
18222 if (isFMulNegTwo(N1)) {
18223 SDValue B = N1.getOperand(i: 0);
18224 SDValue Add = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: B, N2: B);
18225 return DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1: N0, N2: Add);
18226 }
18227
18228 // No FP constant should be created after legalization as Instruction
18229 // Selection pass has a hard time dealing with FP constants.
18230 bool AllowNewConst = (Level < AfterLegalizeDAG);
18231
18232 // If nnan is enabled, fold lots of things.
18233 if (Flags.hasNoNaNs() && AllowNewConst) {
18234 // If allowed, fold (fadd (fneg x), x) -> 0.0
18235 if (N0.getOpcode() == ISD::FNEG && N0.getOperand(i: 0) == N1)
18236 return DAG.getConstantFP(Val: 0.0, DL, VT);
18237
18238 // If allowed, fold (fadd x, (fneg x)) -> 0.0
18239 if (N1.getOpcode() == ISD::FNEG && N1.getOperand(i: 0) == N0)
18240 return DAG.getConstantFP(Val: 0.0, DL, VT);
18241 }
18242
18243 // If reassoc and nsz, fold lots of things.
18244 // TODO: break out portions of the transformations below for which Unsafe is
18245 // considered and which do not require both nsz and reassoc
18246 if (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros() &&
18247 AllowNewConst) {
18248 // fadd (fadd x, c1), c2 -> fadd x, c1 + c2
18249 if (N1CFP && N0.getOpcode() == ISD::FADD &&
18250 DAG.isConstantFPBuildVectorOrConstantFP(N: N0.getOperand(i: 1))) {
18251 SDValue NewC = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N0.getOperand(i: 1), N2: N1);
18252 return DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N0.getOperand(i: 0), N2: NewC);
18253 }
18254
18255 // We can fold chains of FADD's of the same value into multiplications.
18256 // This transform is not safe in general because we are reducing the number
18257 // of rounding steps.
18258 if (TLI.isOperationLegalOrCustom(Op: ISD::FMUL, VT) && !N0CFP && !N1CFP) {
18259 if (N0.getOpcode() == ISD::FMUL) {
18260 bool CFP00 = DAG.isConstantFPBuildVectorOrConstantFP(N: N0.getOperand(i: 0));
18261 bool CFP01 = DAG.isConstantFPBuildVectorOrConstantFP(N: N0.getOperand(i: 1));
18262
18263 // (fadd (fmul x, c), x) -> (fmul x, c+1)
18264 if (CFP01 && !CFP00 && N0.getOperand(i: 0) == N1) {
18265 SDValue NewCFP = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N0.getOperand(i: 1),
18266 N2: DAG.getConstantFP(Val: 1.0, DL, VT));
18267 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1, N2: NewCFP);
18268 }
18269
18270 // (fadd (fmul x, c), (fadd x, x)) -> (fmul x, c+2)
18271 if (CFP01 && !CFP00 && N1.getOpcode() == ISD::FADD &&
18272 N1.getOperand(i: 0) == N1.getOperand(i: 1) &&
18273 N0.getOperand(i: 0) == N1.getOperand(i: 0)) {
18274 SDValue NewCFP = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N0.getOperand(i: 1),
18275 N2: DAG.getConstantFP(Val: 2.0, DL, VT));
18276 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0.getOperand(i: 0), N2: NewCFP);
18277 }
18278 }
18279
18280 if (N1.getOpcode() == ISD::FMUL) {
18281 bool CFP10 = DAG.isConstantFPBuildVectorOrConstantFP(N: N1.getOperand(i: 0));
18282 bool CFP11 = DAG.isConstantFPBuildVectorOrConstantFP(N: N1.getOperand(i: 1));
18283
18284 // (fadd x, (fmul x, c)) -> (fmul x, c+1)
18285 if (CFP11 && !CFP10 && N1.getOperand(i: 0) == N0) {
18286 SDValue NewCFP = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N1.getOperand(i: 1),
18287 N2: DAG.getConstantFP(Val: 1.0, DL, VT));
18288 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0, N2: NewCFP);
18289 }
18290
18291 // (fadd (fadd x, x), (fmul x, c)) -> (fmul x, c+2)
18292 if (CFP11 && !CFP10 && N0.getOpcode() == ISD::FADD &&
18293 N0.getOperand(i: 0) == N0.getOperand(i: 1) &&
18294 N1.getOperand(i: 0) == N0.getOperand(i: 0)) {
18295 SDValue NewCFP = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N1.getOperand(i: 1),
18296 N2: DAG.getConstantFP(Val: 2.0, DL, VT));
18297 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N1.getOperand(i: 0), N2: NewCFP);
18298 }
18299 }
18300
18301 if (N0.getOpcode() == ISD::FADD) {
18302 bool CFP00 = DAG.isConstantFPBuildVectorOrConstantFP(N: N0.getOperand(i: 0));
18303 // (fadd (fadd x, x), x) -> (fmul x, 3.0)
18304 if (!CFP00 && N0.getOperand(i: 0) == N0.getOperand(i: 1) &&
18305 (N0.getOperand(i: 0) == N1)) {
18306 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1,
18307 N2: DAG.getConstantFP(Val: 3.0, DL, VT));
18308 }
18309 }
18310
18311 if (N1.getOpcode() == ISD::FADD) {
18312 bool CFP10 = DAG.isConstantFPBuildVectorOrConstantFP(N: N1.getOperand(i: 0));
18313 // (fadd x, (fadd x, x)) -> (fmul x, 3.0)
18314 if (!CFP10 && N1.getOperand(i: 0) == N1.getOperand(i: 1) &&
18315 N1.getOperand(i: 0) == N0) {
18316 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0,
18317 N2: DAG.getConstantFP(Val: 3.0, DL, VT));
18318 }
18319 }
18320
18321 // (fadd (fadd x, x), (fadd x, x)) -> (fmul x, 4.0)
18322 if (N0.getOpcode() == ISD::FADD && N1.getOpcode() == ISD::FADD &&
18323 N0.getOperand(i: 0) == N0.getOperand(i: 1) &&
18324 N1.getOperand(i: 0) == N1.getOperand(i: 1) &&
18325 N0.getOperand(i: 0) == N1.getOperand(i: 0)) {
18326 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0.getOperand(i: 0),
18327 N2: DAG.getConstantFP(Val: 4.0, DL, VT));
18328 }
18329 }
18330 } // reassoc && nsz && AllowNewConst
18331
18332 if (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros()) {
18333 // Fold fadd(vecreduce(x), vecreduce(y)) -> vecreduce(fadd(x, y))
18334 if (SDValue SD = reassociateReduction(RedOpc: ISD::VECREDUCE_FADD, Opc: ISD::FADD, DL,
18335 VT, N0, N1, Flags))
18336 return SD;
18337 }
18338
18339 // FADD -> FMA combines:
18340 if (SDValue Fused = visitFADDForFMACombine<EmptyMatchContext>(N)) {
18341 if (Fused.getOpcode() != ISD::DELETED_NODE)
18342 AddToWorklist(N: Fused.getNode());
18343 return Fused;
18344 }
18345 return SDValue();
18346}
18347
18348SDValue DAGCombiner::visitSTRICT_FADD(SDNode *N) {
18349 SDValue Chain = N->getOperand(Num: 0);
18350 SDValue N0 = N->getOperand(Num: 1);
18351 SDValue N1 = N->getOperand(Num: 2);
18352 EVT VT = N->getValueType(ResNo: 0);
18353 EVT ChainVT = N->getValueType(ResNo: 1);
18354 SDLoc DL(N);
18355 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
18356
18357 // fold (strict_fadd A, (fneg B)) -> (strict_fsub A, B)
18358 if (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::STRICT_FSUB, VT))
18359 if (SDValue NegN1 = TLI.getCheaperNegatedExpression(
18360 Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize)) {
18361 return DAG.getNode(Opcode: ISD::STRICT_FSUB, DL, VTList: DAG.getVTList(VT1: VT, VT2: ChainVT),
18362 Ops: {Chain, N0, NegN1});
18363 }
18364
18365 // fold (strict_fadd (fneg A), B) -> (strict_fsub B, A)
18366 if (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::STRICT_FSUB, VT))
18367 if (SDValue NegN0 = TLI.getCheaperNegatedExpression(
18368 Op: N0, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize)) {
18369 return DAG.getNode(Opcode: ISD::STRICT_FSUB, DL, VTList: DAG.getVTList(VT1: VT, VT2: ChainVT),
18370 Ops: {Chain, N1, NegN0});
18371 }
18372 return SDValue();
18373}
18374
18375SDValue DAGCombiner::visitFSUB(SDNode *N) {
18376 SDValue N0 = N->getOperand(Num: 0);
18377 SDValue N1 = N->getOperand(Num: 1);
18378 ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N: N0, AllowUndefs: true);
18379 ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N: N1, AllowUndefs: true);
18380 EVT VT = N->getValueType(ResNo: 0);
18381 SDLoc DL(N);
18382 const SDNodeFlags Flags = N->getFlags();
18383 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
18384
18385 if (SDValue R = DAG.simplifyFPBinop(Opcode: N->getOpcode(), X: N0, Y: N1, Flags))
18386 return R;
18387
18388 // fold (fsub c1, c2) -> c1-c2
18389 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FSUB, DL, VT, Ops: {N0, N1}))
18390 return C;
18391
18392 // fold vector ops
18393 if (VT.isVector())
18394 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
18395 return FoldedVOp;
18396
18397 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
18398 return NewSel;
18399
18400 // (fsub A, 0) -> A
18401 if (N1CFP && N1CFP->isZero()) {
18402 if (!N1CFP->isNegative() || DAG.canIgnoreSignBitOfZero(Op: SDValue(N, 0))) {
18403 return N0;
18404 }
18405 }
18406
18407 if (N0 == N1) {
18408 // (fsub x, x) -> 0.0
18409 if (Flags.hasNoNaNs())
18410 return DAG.getConstantFP(Val: 0.0f, DL, VT);
18411 }
18412
18413 // (fsub -0.0, N1) -> -N1
18414 if (N0CFP && N0CFP->isZero()) {
18415 if (N0CFP->isNegative() || DAG.canIgnoreSignBitOfZero(Op: SDValue(N, 0))) {
18416 // We cannot replace an FSUB(+-0.0,X) with FNEG(X) when denormals are
18417 // flushed to zero, unless all users treat denorms as zero (DAZ).
18418 // FIXME: This transform will change the sign of a NaN and the behavior
18419 // of a signaling NaN. It is only valid when a NoNaN flag is present.
18420 DenormalMode DenormMode = DAG.getDenormalMode(VT);
18421 if (DenormMode == DenormalMode::getIEEE()) {
18422 if (SDValue NegN1 =
18423 TLI.getNegatedExpression(Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize))
18424 return NegN1;
18425 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::FNEG, VT))
18426 return DAG.getNode(Opcode: ISD::FNEG, DL, VT, Operand: N1);
18427 }
18428 }
18429 }
18430
18431 if (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros() &&
18432 N1.getOpcode() == ISD::FADD) {
18433 // X - (X + Y) -> -Y
18434 if (N0 == N1->getOperand(Num: 0))
18435 return DAG.getNode(Opcode: ISD::FNEG, DL, VT, Operand: N1->getOperand(Num: 1));
18436 // X - (Y + X) -> -Y
18437 if (N0 == N1->getOperand(Num: 1))
18438 return DAG.getNode(Opcode: ISD::FNEG, DL, VT, Operand: N1->getOperand(Num: 0));
18439 }
18440
18441 // fold (fsub A, (fneg B)) -> (fadd A, B)
18442 if (SDValue NegN1 =
18443 TLI.getNegatedExpression(Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize))
18444 return DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N0, N2: NegN1);
18445
18446 // FSUB -> FMA combines:
18447 if (SDValue Fused = visitFSUBForFMACombine<EmptyMatchContext>(N)) {
18448 AddToWorklist(N: Fused.getNode());
18449 return Fused;
18450 }
18451
18452 return SDValue();
18453}
18454
18455// Transform IEEE Floats:
18456// (fmul C, (uitofp Pow2))
18457// -> (bitcast_to_FP (add (bitcast_to_INT C), Log2(Pow2) << mantissa))
18458// (fdiv C, (uitofp Pow2))
18459// -> (bitcast_to_FP (sub (bitcast_to_INT C), Log2(Pow2) << mantissa))
18460//
18461// The rationale is fmul/fdiv by a power of 2 is just change the exponent, so
18462// there is no need for more than an add/sub.
18463//
18464// This is valid under the following circumstances:
18465// 1) We are dealing with IEEE floats
18466// 2) C is normal
18467// 3) The fmul/fdiv add/sub will not go outside of min/max exponent bounds.
18468// TODO: Much of this could also be used for generating `ldexp` on targets the
18469// prefer it.
18470SDValue DAGCombiner::combineFMulOrFDivWithIntPow2(SDNode *N) {
18471 EVT VT = N->getValueType(ResNo: 0);
18472 if (!APFloat::isIEEELikeFP(VT.getFltSemantics()))
18473 return SDValue();
18474
18475 SDValue ConstOp, Pow2Op;
18476
18477 std::optional<int> Mantissa;
18478 auto GetConstAndPow2Ops = [&](unsigned ConstOpIdx) {
18479 if (ConstOpIdx == 1 && N->getOpcode() == ISD::FDIV)
18480 return false;
18481
18482 ConstOp = peekThroughBitcasts(V: N->getOperand(Num: ConstOpIdx));
18483 Pow2Op = N->getOperand(Num: 1 - ConstOpIdx);
18484 if (Pow2Op.getOpcode() != ISD::UINT_TO_FP &&
18485 (Pow2Op.getOpcode() != ISD::SINT_TO_FP ||
18486 !DAG.computeKnownBits(Op: Pow2Op).isNonNegative()))
18487 return false;
18488
18489 Pow2Op = Pow2Op.getOperand(i: 0);
18490
18491 // `Log2(Pow2Op) < Pow2Op.getScalarSizeInBits()`.
18492 // TODO: We could use knownbits to make this bound more precise.
18493 int MaxExpChange = Pow2Op.getValueType().getScalarSizeInBits();
18494
18495 auto IsFPConstValid = [N, MaxExpChange, &Mantissa](ConstantFPSDNode *CFP) {
18496 if (CFP == nullptr)
18497 return false;
18498
18499 const APFloat &APF = CFP->getValueAPF();
18500
18501 // Make sure we have normal constant.
18502 if (!APF.isNormal())
18503 return false;
18504
18505 // Make sure the floats exponent is within the bounds that this transform
18506 // produces bitwise equals value.
18507 int CurExp = ilogb(Arg: APF);
18508 // FMul by pow2 will only increase exponent.
18509 int MinExp =
18510 N->getOpcode() == ISD::FMUL ? CurExp : (CurExp - MaxExpChange);
18511 // FDiv by pow2 will only decrease exponent.
18512 int MaxExp =
18513 N->getOpcode() == ISD::FDIV ? CurExp : (CurExp + MaxExpChange);
18514 if (MinExp <= APFloat::semanticsMinExponent(APF.getSemantics()) ||
18515 MaxExp >= APFloat::semanticsMaxExponent(APF.getSemantics()))
18516 return false;
18517
18518 // Finally make sure we actually know the mantissa for the float type.
18519 int ThisMantissa = APFloat::semanticsPrecision(APF.getSemantics()) - 1;
18520 if (!Mantissa)
18521 Mantissa = ThisMantissa;
18522
18523 return *Mantissa == ThisMantissa && ThisMantissa > 0;
18524 };
18525
18526 // TODO: We may be able to include undefs.
18527 return ISD::matchUnaryFpPredicate(Op: ConstOp, Match: IsFPConstValid);
18528 };
18529
18530 if (!GetConstAndPow2Ops(0) && !GetConstAndPow2Ops(1))
18531 return SDValue();
18532
18533 if (!TLI.optimizeFMulOrFDivAsShiftAddBitcast(N, FPConst: ConstOp, IntPow2: Pow2Op))
18534 return SDValue();
18535
18536 // Get log2 after all other checks have taken place. This is because
18537 // BuildLogBase2 may create a new node.
18538 SDLoc DL(N);
18539 // Get Log2 type with same bitwidth as the float type (VT).
18540 EVT NewIntVT = VT.changeElementType(
18541 Context&: *DAG.getContext(),
18542 EltVT: EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: VT.getScalarSizeInBits()));
18543
18544 SDValue Log2 = BuildLogBase2(V: Pow2Op, DL, KnownNeverZero: DAG.isKnownNeverZero(Op: Pow2Op),
18545 /*InexpensiveOnly*/ true, OutVT: NewIntVT);
18546 if (!Log2)
18547 return SDValue();
18548
18549 // Perform actual transform.
18550 SDValue MantissaShiftCnt =
18551 DAG.getShiftAmountConstant(Val: *Mantissa, VT: NewIntVT, DL);
18552 // TODO: Sometimes Log2 is of form `(X + C)`. `(X + C) << C1` should fold to
18553 // `(X << C1) + (C << C1)`, but that isn't always the case because of the
18554 // cast. We could implement that by handle here to handle the casts.
18555 SDValue Shift = DAG.getNode(Opcode: ISD::SHL, DL, VT: NewIntVT, N1: Log2, N2: MantissaShiftCnt);
18556 SDValue ResAsInt =
18557 DAG.getNode(Opcode: N->getOpcode() == ISD::FMUL ? ISD::ADD : ISD::SUB, DL,
18558 VT: NewIntVT, N1: DAG.getBitcast(VT: NewIntVT, V: ConstOp), N2: Shift);
18559 SDValue ResAsFP = DAG.getBitcast(VT, V: ResAsInt);
18560 return ResAsFP;
18561}
18562
18563SDValue DAGCombiner::visitFMUL(SDNode *N) {
18564 SDValue N0 = N->getOperand(Num: 0);
18565 SDValue N1 = N->getOperand(Num: 1);
18566 ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N: N1, AllowUndefs: true);
18567 EVT VT = N->getValueType(ResNo: 0);
18568 SDLoc DL(N);
18569 const SDNodeFlags Flags = N->getFlags();
18570 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
18571
18572 if (SDValue R = DAG.simplifyFPBinop(Opcode: N->getOpcode(), X: N0, Y: N1, Flags))
18573 return R;
18574
18575 // fold (fmul c1, c2) -> c1*c2
18576 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FMUL, DL, VT, Ops: {N0, N1}))
18577 return C;
18578
18579 // canonicalize constant to RHS
18580 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0) &&
18581 !DAG.isConstantFPBuildVectorOrConstantFP(N: N1))
18582 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1, N2: N0);
18583
18584 // fold vector ops
18585 if (VT.isVector())
18586 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
18587 return FoldedVOp;
18588
18589 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
18590 return NewSel;
18591
18592 if (Flags.hasAllowReassociation()) {
18593 // fmul (fmul X, C1), C2 -> fmul X, C1 * C2
18594 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N1) &&
18595 N0.getOpcode() == ISD::FMUL) {
18596 SDValue N00 = N0.getOperand(i: 0);
18597 SDValue N01 = N0.getOperand(i: 1);
18598 // Avoid an infinite loop by making sure that N00 is not a constant
18599 // (the inner multiply has not been constant folded yet).
18600 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N01) &&
18601 !DAG.isConstantFPBuildVectorOrConstantFP(N: N00)) {
18602 SDValue MulConsts = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N01, N2: N1);
18603 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N00, N2: MulConsts);
18604 }
18605 }
18606
18607 // Match a special-case: we convert X * 2.0 into fadd.
18608 // fmul (fadd X, X), C -> fmul X, 2.0 * C
18609 if (N0.getOpcode() == ISD::FADD && N0.hasOneUse() &&
18610 N0.getOperand(i: 0) == N0.getOperand(i: 1)) {
18611 const SDValue Two = DAG.getConstantFP(Val: 2.0, DL, VT);
18612 SDValue MulConsts = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Two, N2: N1);
18613 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0.getOperand(i: 0), N2: MulConsts);
18614 }
18615
18616 // Fold fmul(vecreduce(x), vecreduce(y)) -> vecreduce(fmul(x, y))
18617 if (SDValue SD = reassociateReduction(RedOpc: ISD::VECREDUCE_FMUL, Opc: ISD::FMUL, DL,
18618 VT, N0, N1, Flags))
18619 return SD;
18620 }
18621
18622 // fold (fmul X, 2.0) -> (fadd X, X)
18623 if (N1CFP && N1CFP->isExactlyValue(V: +2.0))
18624 return DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N0, N2: N0);
18625
18626 // fold (fmul X, -1.0) -> (fsub -0.0, X)
18627 if (N1CFP && N1CFP->isExactlyValue(V: -1.0)) {
18628 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::FSUB, VT)) {
18629 return DAG.getNode(Opcode: ISD::FSUB, DL, VT,
18630 N1: DAG.getConstantFP(Val: -0.0, DL, VT), N2: N0, Flags);
18631 }
18632 }
18633
18634 // -N0 * -N1 --> N0 * N1
18635 TargetLowering::NegatibleCost CostN0 =
18636 TargetLowering::NegatibleCost::Expensive;
18637 TargetLowering::NegatibleCost CostN1 =
18638 TargetLowering::NegatibleCost::Expensive;
18639 SDValue NegN0 =
18640 TLI.getNegatedExpression(Op: N0, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize, Cost&: CostN0);
18641 if (NegN0) {
18642 HandleSDNode NegN0Handle(NegN0);
18643 SDValue NegN1 =
18644 TLI.getNegatedExpression(Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize, Cost&: CostN1);
18645 if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
18646 CostN1 == TargetLowering::NegatibleCost::Cheaper))
18647 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: NegN0, N2: NegN1);
18648 }
18649
18650 // fold (fmul X, (select (fcmp X > 0.0), -1.0, 1.0)) -> (fneg (fabs X))
18651 // fold (fmul X, (select (fcmp X > 0.0), 1.0, -1.0)) -> (fabs X)
18652 if (Flags.hasNoNaNs() && Flags.hasNoSignedZeros() &&
18653 (N0.getOpcode() == ISD::SELECT || N1.getOpcode() == ISD::SELECT) &&
18654 TLI.isOperationLegal(Op: ISD::FABS, VT)) {
18655 SDValue Select = N0, X = N1;
18656 if (Select.getOpcode() != ISD::SELECT)
18657 std::swap(a&: Select, b&: X);
18658
18659 SDValue Cond = Select.getOperand(i: 0);
18660 auto TrueOpnd = dyn_cast<ConstantFPSDNode>(Val: Select.getOperand(i: 1));
18661 auto FalseOpnd = dyn_cast<ConstantFPSDNode>(Val: Select.getOperand(i: 2));
18662
18663 if (TrueOpnd && FalseOpnd &&
18664 Cond.getOpcode() == ISD::SETCC && Cond.getOperand(i: 0) == X &&
18665 isa<ConstantFPSDNode>(Val: Cond.getOperand(i: 1)) &&
18666 cast<ConstantFPSDNode>(Val: Cond.getOperand(i: 1))->isExactlyValue(V: 0.0)) {
18667 ISD::CondCode CC = cast<CondCodeSDNode>(Val: Cond.getOperand(i: 2))->get();
18668 switch (CC) {
18669 default: break;
18670 case ISD::SETOLT:
18671 case ISD::SETULT:
18672 case ISD::SETOLE:
18673 case ISD::SETULE:
18674 case ISD::SETLT:
18675 case ISD::SETLE:
18676 std::swap(a&: TrueOpnd, b&: FalseOpnd);
18677 [[fallthrough]];
18678 case ISD::SETOGT:
18679 case ISD::SETUGT:
18680 case ISD::SETOGE:
18681 case ISD::SETUGE:
18682 case ISD::SETGT:
18683 case ISD::SETGE:
18684 if (TrueOpnd->isExactlyValue(V: -1.0) && FalseOpnd->isExactlyValue(V: 1.0) &&
18685 TLI.isOperationLegal(Op: ISD::FNEG, VT))
18686 return DAG.getNode(Opcode: ISD::FNEG, DL, VT,
18687 Operand: DAG.getNode(Opcode: ISD::FABS, DL, VT, Operand: X));
18688 if (TrueOpnd->isExactlyValue(V: 1.0) && FalseOpnd->isExactlyValue(V: -1.0))
18689 return DAG.getNode(Opcode: ISD::FABS, DL, VT, Operand: X);
18690
18691 break;
18692 }
18693 }
18694 }
18695
18696 // FMUL -> FMA combines:
18697 if (SDValue Fused = visitFMULForFMADistributiveCombine(N)) {
18698 AddToWorklist(N: Fused.getNode());
18699 return Fused;
18700 }
18701
18702 // Don't do `combineFMulOrFDivWithIntPow2` until after FMUL -> FMA has been
18703 // able to run.
18704 if (SDValue R = combineFMulOrFDivWithIntPow2(N))
18705 return R;
18706
18707 return SDValue();
18708}
18709
18710template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
18711 SDValue N0 = N->getOperand(Num: 0);
18712 SDValue N1 = N->getOperand(Num: 1);
18713 SDValue N2 = N->getOperand(Num: 2);
18714 ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(Val&: N0);
18715 ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(Val&: N1);
18716 ConstantFPSDNode *N2CFP = dyn_cast<ConstantFPSDNode>(Val&: N2);
18717 EVT VT = N->getValueType(ResNo: 0);
18718 SDLoc DL(N);
18719 // FMA nodes have flags that propagate to the created nodes.
18720 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
18721 MatchContextClass matcher(DAG, TLI, N);
18722
18723 // Constant fold FMA.
18724 if (SDValue C =
18725 DAG.FoldConstantArithmetic(Opcode: N->getOpcode(), DL, VT, Ops: {N0, N1, N2}))
18726 return C;
18727
18728 // (-N0 * -N1) + N2 --> (N0 * N1) + N2
18729 TargetLowering::NegatibleCost CostN0 =
18730 TargetLowering::NegatibleCost::Expensive;
18731 TargetLowering::NegatibleCost CostN1 =
18732 TargetLowering::NegatibleCost::Expensive;
18733 SDValue NegN0 =
18734 TLI.getNegatedExpression(Op: N0, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize, Cost&: CostN0);
18735 if (NegN0) {
18736 HandleSDNode NegN0Handle(NegN0);
18737 SDValue NegN1 =
18738 TLI.getNegatedExpression(Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize, Cost&: CostN1);
18739 if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
18740 CostN1 == TargetLowering::NegatibleCost::Cheaper))
18741 return matcher.getNode(ISD::FMA, DL, VT, NegN0, NegN1, N2);
18742 }
18743
18744 if (N->getFlags().hasNoNaNs() && N->getFlags().hasNoInfs()) {
18745 if (N->getFlags().hasNoSignedZeros() ||
18746 (N2CFP && !N2CFP->isExactlyValue(V: -0.0))) {
18747 if (N0CFP && N0CFP->isZero())
18748 return N2;
18749 if (N1CFP && N1CFP->isZero())
18750 return N2;
18751 }
18752 }
18753
18754 // FIXME: Support splat of constant.
18755 if (N0CFP && N0CFP->isExactlyValue(V: 1.0))
18756 return matcher.getNode(ISD::FADD, DL, VT, N1, N2);
18757 if (N1CFP && N1CFP->isExactlyValue(V: 1.0))
18758 return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
18759
18760 // Canonicalize (fma c, x, y) -> (fma x, c, y)
18761 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0) &&
18762 !DAG.isConstantFPBuildVectorOrConstantFP(N: N1))
18763 return matcher.getNode(ISD::FMA, DL, VT, N1, N0, N2);
18764
18765 bool CanReassociate = N->getFlags().hasAllowReassociation();
18766 if (CanReassociate) {
18767 // (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2)
18768 if (matcher.match(N2, ISD::FMUL) && N0 == N2.getOperand(i: 0) &&
18769 DAG.isConstantFPBuildVectorOrConstantFP(N: N1) &&
18770 DAG.isConstantFPBuildVectorOrConstantFP(N: N2.getOperand(i: 1))) {
18771 return matcher.getNode(
18772 ISD::FMUL, DL, VT, N0,
18773 matcher.getNode(ISD::FADD, DL, VT, N1, N2.getOperand(i: 1)));
18774 }
18775
18776 // (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y)
18777 if (matcher.match(N0, ISD::FMUL) &&
18778 DAG.isConstantFPBuildVectorOrConstantFP(N: N1) &&
18779 DAG.isConstantFPBuildVectorOrConstantFP(N: N0.getOperand(i: 1))) {
18780 return matcher.getNode(
18781 ISD::FMA, DL, VT, N0.getOperand(i: 0),
18782 matcher.getNode(ISD::FMUL, DL, VT, N1, N0.getOperand(i: 1)), N2);
18783 }
18784 }
18785
18786 // (fma x, -1, y) -> (fadd (fneg x), y)
18787 // FIXME: Support splat of constant.
18788 if (N1CFP) {
18789 if (N1CFP->isExactlyValue(V: 1.0))
18790 return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
18791
18792 if (N1CFP->isExactlyValue(V: -1.0) &&
18793 (!LegalOperations || TLI.isOperationLegal(Op: ISD::FNEG, VT))) {
18794 SDValue RHSNeg = matcher.getNode(ISD::FNEG, DL, VT, N0);
18795 AddToWorklist(N: RHSNeg.getNode());
18796 return matcher.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
18797 }
18798
18799 // fma (fneg x), K, y -> fma x -K, y
18800 if (matcher.match(N0, ISD::FNEG) &&
18801 (TLI.isOperationLegal(Op: ISD::ConstantFP, VT) ||
18802 (N1.hasOneUse() &&
18803 !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
18804 return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(i: 0),
18805 matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
18806 }
18807 }
18808
18809 // FIXME: Support splat of constant.
18810 if (CanReassociate) {
18811 // (fma x, c, x) -> (fmul x, (c+1))
18812 if (N1CFP && N0 == N2) {
18813 return matcher.getNode(ISD::FMUL, DL, VT, N0,
18814 matcher.getNode(ISD::FADD, DL, VT, N1,
18815 DAG.getConstantFP(Val: 1.0, DL, VT)));
18816 }
18817
18818 // (fma x, c, (fneg x)) -> (fmul x, (c-1))
18819 if (N1CFP && matcher.match(N2, ISD::FNEG) && N2.getOperand(i: 0) == N0) {
18820 return matcher.getNode(ISD::FMUL, DL, VT, N0,
18821 matcher.getNode(ISD::FADD, DL, VT, N1,
18822 DAG.getConstantFP(Val: -1.0, DL, VT)));
18823 }
18824 }
18825
18826 // fold ((fma (fneg X), Y, (fneg Z)) -> fneg (fma X, Y, Z))
18827 // fold ((fma X, (fneg Y), (fneg Z)) -> fneg (fma X, Y, Z))
18828 if (!TLI.isFNegFree(VT))
18829 if (SDValue Neg = TLI.getCheaperNegatedExpression(
18830 Op: SDValue(N, 0), DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize))
18831 return matcher.getNode(ISD::FNEG, DL, VT, Neg);
18832 return SDValue();
18833}
18834
18835SDValue DAGCombiner::visitFMAD(SDNode *N) {
18836 SDValue N0 = N->getOperand(Num: 0);
18837 SDValue N1 = N->getOperand(Num: 1);
18838 SDValue N2 = N->getOperand(Num: 2);
18839 EVT VT = N->getValueType(ResNo: 0);
18840 SDLoc DL(N);
18841
18842 // Constant fold FMAD.
18843 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FMAD, DL, VT, Ops: {N0, N1, N2}))
18844 return C;
18845
18846 return SDValue();
18847}
18848
18849SDValue DAGCombiner::visitFMULADD(SDNode *N) {
18850 SDValue N0 = N->getOperand(Num: 0);
18851 SDValue N1 = N->getOperand(Num: 1);
18852 SDValue N2 = N->getOperand(Num: 2);
18853 EVT VT = N->getValueType(ResNo: 0);
18854 SDLoc DL(N);
18855
18856 // Constant fold FMULADD.
18857 if (SDValue C =
18858 DAG.FoldConstantArithmetic(Opcode: ISD::FMULADD, DL, VT, Ops: {N0, N1, N2}))
18859 return C;
18860
18861 return SDValue();
18862}
18863
18864// Combine multiple FDIVs with the same divisor into multiple FMULs by the
18865// reciprocal.
18866// E.g., (a / D; b / D;) -> (recip = 1.0 / D; a * recip; b * recip)
18867// Notice that this is not always beneficial. One reason is different targets
18868// may have different costs for FDIV and FMUL, so sometimes the cost of two
18869// FDIVs may be lower than the cost of one FDIV and two FMULs. Another reason
18870// is the critical path is increased from "one FDIV" to "one FDIV + one FMUL".
18871SDValue DAGCombiner::combineRepeatedFPDivisors(SDNode *N) {
18872 // TODO: Limit this transform based on optsize/minsize - it always creates at
18873 // least 1 extra instruction. But the perf win may be substantial enough
18874 // that only minsize should restrict this.
18875 const SDNodeFlags Flags = N->getFlags();
18876 if (LegalDAG || !Flags.hasAllowReciprocal())
18877 return SDValue();
18878
18879 // Skip if current node is a reciprocal/fneg-reciprocal.
18880 SDValue N0 = N->getOperand(Num: 0), N1 = N->getOperand(Num: 1);
18881 ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N: N0, /* AllowUndefs */ true);
18882 if (N0CFP && (N0CFP->isExactlyValue(V: 1.0) || N0CFP->isExactlyValue(V: -1.0)))
18883 return SDValue();
18884
18885 // Exit early if the target does not want this transform or if there can't
18886 // possibly be enough uses of the divisor to make the transform worthwhile.
18887 unsigned MinUses = TLI.combineRepeatedFPDivisors();
18888
18889 // For splat vectors, scale the number of uses by the splat factor. If we can
18890 // convert the division into a scalar op, that will likely be much faster.
18891 unsigned NumElts = 1;
18892 EVT VT = N->getValueType(ResNo: 0);
18893 if (VT.isVector() && DAG.isSplatValue(V: N1))
18894 NumElts = VT.getVectorMinNumElements();
18895
18896 if (!MinUses || (N1->use_size() * NumElts) < MinUses)
18897 return SDValue();
18898
18899 // Find all FDIV users of the same divisor.
18900 // Use a set because duplicates may be present in the user list.
18901 SetVector<SDNode *> Users;
18902 for (auto *U : N1->users()) {
18903 if (U->getOpcode() == ISD::FDIV && U->getOperand(Num: 1) == N1) {
18904 // Skip X/sqrt(X) that has not been simplified to sqrt(X) yet.
18905 if (U->getOperand(Num: 1).getOpcode() == ISD::FSQRT &&
18906 U->getOperand(Num: 0) == U->getOperand(Num: 1).getOperand(i: 0) &&
18907 U->getFlags().hasAllowReassociation() &&
18908 U->getFlags().hasNoSignedZeros())
18909 continue;
18910
18911 // This division is eligible for optimization only if global unsafe math
18912 // is enabled or if this division allows reciprocal formation.
18913 if (U->getFlags().hasAllowReciprocal())
18914 Users.insert(X: U);
18915 }
18916 }
18917
18918 // Now that we have the actual number of divisor uses, make sure it meets
18919 // the minimum threshold specified by the target.
18920 if ((Users.size() * NumElts) < MinUses)
18921 return SDValue();
18922
18923 SDLoc DL(N);
18924 SDValue FPOne = DAG.getConstantFP(Val: 1.0, DL, VT);
18925 SDValue Reciprocal = DAG.getNode(Opcode: ISD::FDIV, DL, VT, N1: FPOne, N2: N1, Flags);
18926
18927 // Dividend / Divisor -> Dividend * Reciprocal
18928 for (auto *U : Users) {
18929 SDValue Dividend = U->getOperand(Num: 0);
18930 if (Dividend != FPOne) {
18931 SDValue NewNode = DAG.getNode(Opcode: ISD::FMUL, DL: SDLoc(U), VT, N1: Dividend,
18932 N2: Reciprocal, Flags);
18933 CombineTo(N: U, Res: NewNode);
18934 } else if (U != Reciprocal.getNode()) {
18935 // In the absence of fast-math-flags, this user node is always the
18936 // same node as Reciprocal, but with FMF they may be different nodes.
18937 CombineTo(N: U, Res: Reciprocal);
18938 }
18939 }
18940 return SDValue(N, 0); // N was replaced.
18941}
18942
18943SDValue DAGCombiner::visitFDIV(SDNode *N) {
18944 SDValue N0 = N->getOperand(Num: 0);
18945 SDValue N1 = N->getOperand(Num: 1);
18946 EVT VT = N->getValueType(ResNo: 0);
18947 SDLoc DL(N);
18948 SDNodeFlags Flags = N->getFlags();
18949 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
18950
18951 if (SDValue R = DAG.simplifyFPBinop(Opcode: N->getOpcode(), X: N0, Y: N1, Flags))
18952 return R;
18953
18954 // fold (fdiv c1, c2) -> c1/c2
18955 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FDIV, DL, VT, Ops: {N0, N1}))
18956 return C;
18957
18958 // fold vector ops
18959 if (VT.isVector())
18960 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
18961 return FoldedVOp;
18962
18963 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
18964 return NewSel;
18965
18966 if (SDValue V = combineRepeatedFPDivisors(N))
18967 return V;
18968
18969 // fold (fdiv X, c2) -> (fmul X, 1/c2) if there is no loss in precision, or
18970 // the loss is acceptable with AllowReciprocal.
18971 if (auto *N1CFP = isConstOrConstSplatFP(N: N1, AllowUndefs: true)) {
18972 // Compute the reciprocal 1.0 / c2.
18973 const APFloat &N1APF = N1CFP->getValueAPF();
18974 APFloat Recip = APFloat::getOne(Sem: N1APF.getSemantics());
18975 APFloat::opStatus st = Recip.divide(RHS: N1APF, RM: APFloat::rmNearestTiesToEven);
18976 // Only do the transform if the reciprocal is a legal fp immediate that
18977 // isn't too nasty (eg NaN, denormal, ...).
18978 if (((st == APFloat::opOK && !Recip.isDenormal()) ||
18979 (st == APFloat::opInexact && Flags.hasAllowReciprocal())) &&
18980 (!LegalOperations ||
18981 // FIXME: custom lowering of ConstantFP might fail (see e.g. ARM
18982 // backend)... we should handle this gracefully after Legalize.
18983 // TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT) ||
18984 TLI.isOperationLegal(Op: ISD::ConstantFP, VT) ||
18985 TLI.isFPImmLegal(Recip, VT, ForCodeSize)))
18986 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0,
18987 N2: DAG.getConstantFP(Val: Recip, DL, VT));
18988 }
18989
18990 if (Flags.hasAllowReciprocal()) {
18991 // If this FDIV is part of a reciprocal square root, it may be folded
18992 // into a target-specific square root estimate instruction.
18993 bool N1AllowReciprocal = N1->getFlags().hasAllowReciprocal();
18994 if (N1.getOpcode() == ISD::FSQRT) {
18995 if (SDValue RV = buildRsqrtEstimate(Op: N1.getOperand(i: 0)))
18996 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0, N2: RV);
18997 } else if (N1.getOpcode() == ISD::FP_EXTEND &&
18998 N1.getOperand(i: 0).getOpcode() == ISD::FSQRT &&
18999 N1AllowReciprocal) {
19000 if (SDValue RV = buildRsqrtEstimate(Op: N1.getOperand(i: 0).getOperand(i: 0))) {
19001 RV = DAG.getNode(Opcode: ISD::FP_EXTEND, DL: SDLoc(N1), VT, Operand: RV);
19002 AddToWorklist(N: RV.getNode());
19003 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0, N2: RV);
19004 }
19005 } else if (N1.getOpcode() == ISD::FP_ROUND &&
19006 N1.getOperand(i: 0).getOpcode() == ISD::FSQRT) {
19007 if (SDValue RV = buildRsqrtEstimate(Op: N1.getOperand(i: 0).getOperand(i: 0))) {
19008 RV = DAG.getNode(Opcode: ISD::FP_ROUND, DL: SDLoc(N1), VT, N1: RV, N2: N1.getOperand(i: 1));
19009 AddToWorklist(N: RV.getNode());
19010 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0, N2: RV);
19011 }
19012 } else if (N1.getOpcode() == ISD::FMUL) {
19013 // Look through an FMUL. Even though this won't remove the FDIV directly,
19014 // it's still worthwhile to get rid of the FSQRT if possible.
19015 SDValue Sqrt, Y;
19016 if (N1.getOperand(i: 0).getOpcode() == ISD::FSQRT) {
19017 Sqrt = N1.getOperand(i: 0);
19018 Y = N1.getOperand(i: 1);
19019 } else if (N1.getOperand(i: 1).getOpcode() == ISD::FSQRT) {
19020 Sqrt = N1.getOperand(i: 1);
19021 Y = N1.getOperand(i: 0);
19022 }
19023 if (Sqrt.getNode()) {
19024 // If the other multiply operand is known positive, pull it into the
19025 // sqrt. That will eliminate the division if we convert to an estimate.
19026 if (Flags.hasAllowReassociation() && N1.hasOneUse() &&
19027 N1->getFlags().hasAllowReassociation() && Sqrt.hasOneUse()) {
19028 SDValue A;
19029 if (Y.getOpcode() == ISD::FABS && Y.hasOneUse())
19030 A = Y.getOperand(i: 0);
19031 else if (Y == Sqrt.getOperand(i: 0))
19032 A = Y;
19033 if (A) {
19034 // X / (fabs(A) * sqrt(Z)) --> X / sqrt(A*A*Z) --> X * rsqrt(A*A*Z)
19035 // X / (A * sqrt(A)) --> X / sqrt(A*A*A) --> X * rsqrt(A*A*A)
19036 SDValue AA = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: A, N2: A);
19037 SDValue AAZ =
19038 DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: AA, N2: Sqrt.getOperand(i: 0));
19039 if (SDValue Rsqrt = buildRsqrtEstimate(Op: AAZ))
19040 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0, N2: Rsqrt);
19041
19042 // Estimate creation failed. Clean up speculatively created nodes.
19043 recursivelyDeleteUnusedNodes(N: AAZ.getNode());
19044 }
19045 }
19046
19047 // We found a FSQRT, so try to make this fold:
19048 // X / (Y * sqrt(Z)) -> X * (rsqrt(Z) / Y)
19049 if (SDValue Rsqrt = buildRsqrtEstimate(Op: Sqrt.getOperand(i: 0))) {
19050 SDValue Div = DAG.getNode(Opcode: ISD::FDIV, DL: SDLoc(N1), VT, N1: Rsqrt, N2: Y);
19051 AddToWorklist(N: Div.getNode());
19052 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0, N2: Div);
19053 }
19054 }
19055 }
19056
19057 // Fold into a reciprocal estimate and multiply instead of a real divide.
19058 if (Flags.hasNoInfs())
19059 if (SDValue RV = BuildDivEstimate(N: N0, Op: N1, Flags))
19060 return RV;
19061 }
19062
19063 // Fold X/Sqrt(X) -> Sqrt(X)
19064 if (DAG.canIgnoreSignBitOfZero(Op: SDValue(N, 0)) &&
19065 Flags.hasAllowReassociation())
19066 if (N1.getOpcode() == ISD::FSQRT && N0 == N1.getOperand(i: 0))
19067 return N1;
19068
19069 // (fdiv (fneg X), (fneg Y)) -> (fdiv X, Y)
19070 TargetLowering::NegatibleCost CostN0 =
19071 TargetLowering::NegatibleCost::Expensive;
19072 TargetLowering::NegatibleCost CostN1 =
19073 TargetLowering::NegatibleCost::Expensive;
19074 SDValue NegN0 =
19075 TLI.getNegatedExpression(Op: N0, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize, Cost&: CostN0);
19076 if (NegN0) {
19077 HandleSDNode NegN0Handle(NegN0);
19078 SDValue NegN1 =
19079 TLI.getNegatedExpression(Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize, Cost&: CostN1);
19080 if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
19081 CostN1 == TargetLowering::NegatibleCost::Cheaper))
19082 return DAG.getNode(Opcode: ISD::FDIV, DL, VT, N1: NegN0, N2: NegN1);
19083 }
19084
19085 if (SDValue R = combineFMulOrFDivWithIntPow2(N))
19086 return R;
19087
19088 return SDValue();
19089}
19090
19091SDValue DAGCombiner::visitFREM(SDNode *N) {
19092 SDValue N0 = N->getOperand(Num: 0);
19093 SDValue N1 = N->getOperand(Num: 1);
19094 EVT VT = N->getValueType(ResNo: 0);
19095 SDNodeFlags Flags = N->getFlags();
19096 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
19097 SDLoc DL(N);
19098
19099 if (SDValue R = DAG.simplifyFPBinop(Opcode: N->getOpcode(), X: N0, Y: N1, Flags))
19100 return R;
19101
19102 // fold (frem c1, c2) -> fmod(c1,c2)
19103 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FREM, DL, VT, Ops: {N0, N1}))
19104 return C;
19105
19106 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
19107 return NewSel;
19108
19109 // Lower frem N0, N1 => x - trunc(N0 / N1) * N1, providing N1 is an integer
19110 // power of 2.
19111 if (!TLI.isOperationLegal(Op: ISD::FREM, VT) &&
19112 TLI.isOperationLegalOrCustom(Op: ISD::FMUL, VT) &&
19113 TLI.isOperationLegalOrCustom(Op: ISD::FDIV, VT) &&
19114 TLI.isOperationLegalOrCustom(Op: ISD::FTRUNC, VT) &&
19115 DAG.isKnownToBeAPowerOfTwoFP(Val: N1)) {
19116 bool NeedsCopySign = !DAG.canIgnoreSignBitOfZero(Op: SDValue(N, 0)) &&
19117 !DAG.cannotBeOrderedNegativeFP(Op: N0);
19118 SDValue Div = DAG.getNode(Opcode: ISD::FDIV, DL, VT, N1: N0, N2: N1);
19119 SDValue Rnd = DAG.getNode(Opcode: ISD::FTRUNC, DL, VT, Operand: Div);
19120 SDValue MLA;
19121 if (TLI.isFMAFasterThanFMulAndFAdd(MF: DAG.getMachineFunction(), VT)) {
19122 MLA = DAG.getNode(Opcode: ISD::FMA, DL, VT, N1: DAG.getNode(Opcode: ISD::FNEG, DL, VT, Operand: Rnd),
19123 N2: N1, N3: N0);
19124 } else {
19125 SDValue Mul = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Rnd, N2: N1);
19126 MLA = DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1: N0, N2: Mul);
19127 }
19128 return NeedsCopySign ? DAG.getNode(Opcode: ISD::FCOPYSIGN, DL, VT, N1: MLA, N2: N0) : MLA;
19129 }
19130
19131 return SDValue();
19132}
19133
19134SDValue DAGCombiner::visitFSQRT(SDNode *N) {
19135 SDNodeFlags Flags = N->getFlags();
19136
19137 // Require 'ninf' flag since sqrt(+Inf) = +Inf, but the estimation goes as:
19138 // sqrt(+Inf) == rsqrt(+Inf) * +Inf = 0 * +Inf = NaN
19139 if (!Flags.hasApproximateFuncs() || !Flags.hasNoInfs())
19140 return SDValue();
19141
19142 SDValue N0 = N->getOperand(Num: 0);
19143 if (TLI.isFsqrtCheap(X: N0, DAG))
19144 return SDValue();
19145
19146 // FSQRT nodes have flags that propagate to the created nodes.
19147 SelectionDAG::FlagInserter FlagInserter(DAG, Flags);
19148 // TODO: If this is N0/sqrt(N0), and we reach this node before trying to
19149 // transform the fdiv, we may produce a sub-optimal estimate sequence
19150 // because the reciprocal calculation may not have to filter out a
19151 // 0.0 input.
19152 return buildSqrtEstimate(Op: N0);
19153}
19154
19155/// copysign(x, fp_extend(y)) -> copysign(x, y)
19156/// copysign(x, fp_round(y)) -> copysign(x, y)
19157/// Operands to the functions are the type of X and Y respectively.
19158static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(EVT XTy, EVT YTy) {
19159 // Always fold no-op FP casts.
19160 if (XTy == YTy)
19161 return true;
19162
19163 // Do not optimize out type conversion of f128 type yet.
19164 // For some targets like x86_64, configuration is changed to keep one f128
19165 // value in one SSE register, but instruction selection cannot handle
19166 // FCOPYSIGN on SSE registers yet.
19167 if (YTy == MVT::f128)
19168 return false;
19169
19170 // Avoid mismatched vector operand types, for better instruction selection.
19171 return !YTy.isVector();
19172}
19173
19174static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode *N) {
19175 SDValue N1 = N->getOperand(Num: 1);
19176 if (N1.getOpcode() != ISD::FP_EXTEND &&
19177 N1.getOpcode() != ISD::FP_ROUND)
19178 return false;
19179 EVT N1VT = N1->getValueType(ResNo: 0);
19180 EVT N1Op0VT = N1->getOperand(Num: 0).getValueType();
19181 return CanCombineFCOPYSIGN_EXTEND_ROUND(XTy: N1VT, YTy: N1Op0VT);
19182}
19183
19184SDValue DAGCombiner::visitFCOPYSIGN(SDNode *N) {
19185 SDValue N0 = N->getOperand(Num: 0);
19186 SDValue N1 = N->getOperand(Num: 1);
19187 EVT VT = N->getValueType(ResNo: 0);
19188 SDLoc DL(N);
19189
19190 // fold (fcopysign c1, c2) -> fcopysign(c1,c2)
19191 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FCOPYSIGN, DL, VT, Ops: {N0, N1}))
19192 return C;
19193
19194 // copysign(x, fp_extend(y)) -> copysign(x, y)
19195 // copysign(x, fp_round(y)) -> copysign(x, y)
19196 if (CanCombineFCOPYSIGN_EXTEND_ROUND(N))
19197 return DAG.getNode(Opcode: ISD::FCOPYSIGN, DL, VT, N1: N0, N2: N1.getOperand(i: 0));
19198
19199 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
19200 return SDValue(N, 0);
19201
19202 if (VT != N1.getValueType())
19203 return SDValue();
19204
19205 // If this is equivalent to a disjoint or, replace it with one. This can
19206 // happen if the sign operand is a sign mask (i.e., x << sign_bit_position).
19207 if (DAG.SignBitIsZeroFP(Op: N0) &&
19208 DAG.computeKnownBits(Op: N1).Zero.isMaxSignedValue()) {
19209 // TODO: Just directly match the shift pattern. computeKnownBits is heavy
19210 // for a such a narrowly targeted case.
19211 EVT IntVT = VT.changeTypeToInteger();
19212 // TODO: It appears to be profitable in some situations to unconditionally
19213 // emit a fabs(n0) to perform this combine.
19214 SDValue CastSrc0 = DAG.getNode(Opcode: ISD::BITCAST, DL, VT: IntVT, Operand: N0);
19215 SDValue CastSrc1 = DAG.getNode(Opcode: ISD::BITCAST, DL, VT: IntVT, Operand: N1);
19216
19217 SDValue SignOr = DAG.getNode(Opcode: ISD::OR, DL, VT: IntVT, N1: CastSrc0, N2: CastSrc1,
19218 Flags: SDNodeFlags::Disjoint);
19219 return DAG.getNode(Opcode: ISD::BITCAST, DL, VT, Operand: SignOr);
19220 }
19221
19222 return SDValue();
19223}
19224
19225SDValue DAGCombiner::visitFPOW(SDNode *N) {
19226 ConstantFPSDNode *ExponentC = isConstOrConstSplatFP(N: N->getOperand(Num: 1));
19227 if (!ExponentC)
19228 return SDValue();
19229 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
19230
19231 // Try to convert x ** (1/3) into cube root.
19232 // TODO: Handle the various flavors of long double.
19233 // TODO: Since we're approximating, we don't need an exact 1/3 exponent.
19234 // Some range near 1/3 should be fine.
19235 EVT VT = N->getValueType(ResNo: 0);
19236 if ((VT == MVT::f32 && ExponentC->getValueAPF().isExactlyValue(V: 1.0f/3.0f)) ||
19237 (VT == MVT::f64 && ExponentC->getValueAPF().isExactlyValue(V: 1.0/3.0))) {
19238 // pow(-0.0, 1/3) = +0.0; cbrt(-0.0) = -0.0.
19239 // pow(-inf, 1/3) = +inf; cbrt(-inf) = -inf.
19240 // pow(-val, 1/3) = nan; cbrt(-val) = -num.
19241 // For regular numbers, rounding may cause the results to differ.
19242 // Therefore, we require { nsz ninf nnan afn } for this transform.
19243 // TODO: We could select out the special cases if we don't have nsz/ninf.
19244 SDNodeFlags Flags = N->getFlags();
19245 if (!Flags.hasNoSignedZeros() || !Flags.hasNoInfs() || !Flags.hasNoNaNs() ||
19246 !Flags.hasApproximateFuncs())
19247 return SDValue();
19248
19249 // Do not create a cbrt() libcall if the target does not have it, and do not
19250 // turn a pow that has lowering support into a cbrt() libcall.
19251 if (!DAG.getLibInfo().has(F: LibFunc_cbrt) ||
19252 (!DAG.getTargetLoweringInfo().isOperationExpand(Op: ISD::FPOW, VT) &&
19253 DAG.getTargetLoweringInfo().isOperationExpand(Op: ISD::FCBRT, VT)))
19254 return SDValue();
19255
19256 return DAG.getNode(Opcode: ISD::FCBRT, DL: SDLoc(N), VT, Operand: N->getOperand(Num: 0));
19257 }
19258
19259 // Try to convert x ** (1/4) and x ** (3/4) into square roots.
19260 // x ** (1/2) is canonicalized to sqrt, so we do not bother with that case.
19261 // TODO: This could be extended (using a target hook) to handle smaller
19262 // power-of-2 fractional exponents.
19263 bool ExponentIs025 = ExponentC->getValueAPF().isExactlyValue(V: 0.25);
19264 bool ExponentIs075 = ExponentC->getValueAPF().isExactlyValue(V: 0.75);
19265 if (ExponentIs025 || ExponentIs075) {
19266 // pow(-0.0, 0.25) = +0.0; sqrt(sqrt(-0.0)) = -0.0.
19267 // pow(-inf, 0.25) = +inf; sqrt(sqrt(-inf)) = NaN.
19268 // pow(-0.0, 0.75) = +0.0; sqrt(-0.0) * sqrt(sqrt(-0.0)) = +0.0.
19269 // pow(-inf, 0.75) = +inf; sqrt(-inf) * sqrt(sqrt(-inf)) = NaN.
19270 // For regular numbers, rounding may cause the results to differ.
19271 // Therefore, we require { nsz ninf afn } for this transform.
19272 // TODO: We could select out the special cases if we don't have nsz/ninf.
19273 SDNodeFlags Flags = N->getFlags();
19274
19275 // We only need no signed zeros for the 0.25 case.
19276 if ((!Flags.hasNoSignedZeros() && ExponentIs025) || !Flags.hasNoInfs() ||
19277 !Flags.hasApproximateFuncs())
19278 return SDValue();
19279
19280 // Don't double the number of libcalls. We are trying to inline fast code.
19281 if (!DAG.getTargetLoweringInfo().isOperationLegalOrCustom(Op: ISD::FSQRT, VT))
19282 return SDValue();
19283
19284 // Assume that libcalls are the smallest code.
19285 // TODO: This restriction should probably be lifted for vectors.
19286 if (ForCodeSize)
19287 return SDValue();
19288
19289 // pow(X, 0.25) --> sqrt(sqrt(X))
19290 SDLoc DL(N);
19291 SDValue Sqrt = DAG.getNode(Opcode: ISD::FSQRT, DL, VT, Operand: N->getOperand(Num: 0));
19292 SDValue SqrtSqrt = DAG.getNode(Opcode: ISD::FSQRT, DL, VT, Operand: Sqrt);
19293 if (ExponentIs025)
19294 return SqrtSqrt;
19295 // pow(X, 0.75) --> sqrt(X) * sqrt(sqrt(X))
19296 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Sqrt, N2: SqrtSqrt);
19297 }
19298
19299 return SDValue();
19300}
19301
19302static SDValue foldFPToIntToFP(SDNode *N, const SDLoc &DL, SelectionDAG &DAG,
19303 const TargetLowering &TLI) {
19304 // We can fold the fpto[us]i -> [us]itofp pattern into a single ftrunc.
19305 // Additionally, if there are clamps ([us]min or [us]max) around
19306 // the fpto[us]i, we can fold those into fminnum/fmaxnum around the ftrunc.
19307 // If NoSignedZerosFPMath is enabled, this is a direct replacement.
19308 // Otherwise, for strict math, we must handle edge cases:
19309 // 1. For unsigned conversions, use FABS to handle negative cases. Take -0.0
19310 // as example, it first becomes integer 0, and is converted back to +0.0.
19311 // FTRUNC on its own could produce -0.0.
19312
19313 // FIXME: We should be able to use node-level FMF here.
19314 EVT VT = N->getValueType(ResNo: 0);
19315 if (!TLI.isOperationLegal(Op: ISD::FTRUNC, VT))
19316 return SDValue();
19317
19318 bool IsUnsigned = N->getOpcode() == ISD::UINT_TO_FP;
19319 bool IsSigned = N->getOpcode() == ISD::SINT_TO_FP;
19320 assert(IsSigned || IsUnsigned);
19321
19322 bool IsSignedZeroSafe = DAG.getTarget().Options.NoSignedZerosFPMath ||
19323 DAG.canIgnoreSignBitOfZero(Op: SDValue(N, 0));
19324 // For signed conversions: The optimization changes signed zero behavior.
19325 if (IsSigned && !IsSignedZeroSafe)
19326 return SDValue();
19327 // For unsigned conversions, we need FABS to canonicalize -0.0 to +0.0
19328 // (unless outputting a signed zero is OK).
19329 if (IsUnsigned && !IsSignedZeroSafe && !TLI.isFAbsFree(VT))
19330 return SDValue();
19331
19332 // Collect potential clamp operations (outermost to innermost) and peel.
19333 struct ClampInfo {
19334 bool IsMin;
19335 SDValue Constant;
19336 };
19337 constexpr unsigned MaxClamps = 2;
19338 SmallVector<ClampInfo, MaxClamps> Clamps;
19339 unsigned MinOp = IsUnsigned ? ISD::UMIN : ISD::SMIN;
19340 unsigned MaxOp = IsUnsigned ? ISD::UMAX : ISD::SMAX;
19341 SDValue IntVal = N->getOperand(Num: 0);
19342 for (unsigned Level = 0; Level < MaxClamps; ++Level) {
19343 if (!IntVal.hasOneUse() ||
19344 (IntVal.getOpcode() != MinOp && IntVal.getOpcode() != MaxOp))
19345 break;
19346 SDValue RHS = IntVal.getOperand(i: 1);
19347 APInt IntConst;
19348 if (auto *IntConstNode = dyn_cast<ConstantSDNode>(Val&: RHS))
19349 IntConst = IntConstNode->getAPIntValue();
19350 else if (!ISD::isConstantSplatVector(N: RHS.getNode(), SplatValue&: IntConst))
19351 return SDValue();
19352 APFloat FPConst(VT.getFltSemantics());
19353 FPConst.convertFromAPInt(Input: IntConst, IsSigned, RM: APFloat::rmNearestTiesToEven);
19354 // Verify roundtrip exactness.
19355 APSInt RoundTrip(IntConst.getBitWidth(), IsUnsigned);
19356 bool IsExact;
19357 if (FPConst.convertToInteger(Result&: RoundTrip, RM: APFloat::rmTowardZero, IsExact: &IsExact) !=
19358 APFloat::opOK ||
19359 !IsExact || static_cast<const APInt &>(RoundTrip) != IntConst)
19360 return SDValue();
19361 bool IsMin = IntVal.getOpcode() == MinOp;
19362 Clamps.push_back(Elt: {.IsMin: IsMin, .Constant: DAG.getConstantFP(Val: FPConst, DL, VT)});
19363 IntVal = IntVal.getOperand(i: 0);
19364 }
19365
19366 // Check that the sequence ends with the correct kind of fpto[us]i.
19367 unsigned FPToIntOp = IsUnsigned ? ISD::FP_TO_UINT : ISD::FP_TO_SINT;
19368 if (IntVal.getOpcode() != FPToIntOp ||
19369 IntVal.getOperand(i: 0).getValueType() != VT)
19370 return SDValue();
19371
19372 SDValue Result = IntVal.getOperand(i: 0);
19373 if (IsUnsigned && !IsSignedZeroSafe && TLI.isFAbsFree(VT))
19374 Result = DAG.getNode(Opcode: ISD::FABS, DL, VT, Operand: Result);
19375 Result = DAG.getNode(Opcode: ISD::FTRUNC, DL, VT, Operand: Result);
19376 // Apply clamps, if any, in reverse order (innermost first).
19377 for (const ClampInfo &Clamp : reverse(C&: Clamps)) {
19378 unsigned FPClampOp =
19379 getMinMaxOpcodeForClamp(IsMin: Clamp.IsMin, Operand1: Result, Operand2: Clamp.Constant, DAG, TLI);
19380 if (FPClampOp == ISD::DELETED_NODE)
19381 return SDValue();
19382 Result = DAG.getNode(Opcode: FPClampOp, DL, VT, N1: Result, N2: Clamp.Constant);
19383 }
19384 return Result;
19385}
19386
19387SDValue DAGCombiner::visitSINT_TO_FP(SDNode *N) {
19388 SDValue N0 = N->getOperand(Num: 0);
19389 EVT VT = N->getValueType(ResNo: 0);
19390 EVT OpVT = N0.getValueType();
19391 SDLoc DL(N);
19392
19393 // [us]itofp(undef) = 0, because the result value is bounded.
19394 if (N0.isUndef())
19395 return DAG.getConstantFP(Val: 0.0, DL, VT);
19396
19397 // fold (sint_to_fp c1) -> c1fp
19398 // ...but only if the target supports immediate floating-point values
19399 if ((!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::ConstantFP, VT)))
19400 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::SINT_TO_FP, DL, VT, Ops: {N0}))
19401 return C;
19402
19403 // If the input is a legal type, and SINT_TO_FP is not legal on this target,
19404 // but UINT_TO_FP is legal on this target, try to convert.
19405 if (!hasOperation(Opcode: ISD::SINT_TO_FP, VT: OpVT) &&
19406 hasOperation(Opcode: ISD::UINT_TO_FP, VT: OpVT)) {
19407 // If the sign bit is known to be zero, we can change this to UINT_TO_FP.
19408 if (DAG.SignBitIsZero(Op: N0))
19409 return DAG.getNode(Opcode: ISD::UINT_TO_FP, DL, VT, Operand: N0);
19410 }
19411
19412 // The next optimizations are desirable only if SELECT_CC can be lowered.
19413 // fold (sint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), -1.0, 0.0)
19414 if (N0.getOpcode() == ISD::SETCC && N0.getValueType() == MVT::i1 &&
19415 !VT.isVector() &&
19416 (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::ConstantFP, VT)))
19417 return DAG.getSelect(DL, VT, Cond: N0, LHS: DAG.getConstantFP(Val: -1.0, DL, VT),
19418 RHS: DAG.getConstantFP(Val: 0.0, DL, VT));
19419
19420 // fold (sint_to_fp (zext (setcc x, y, cc))) ->
19421 // (select (setcc x, y, cc), 1.0, 0.0)
19422 if (N0.getOpcode() == ISD::ZERO_EXTEND &&
19423 N0.getOperand(i: 0).getOpcode() == ISD::SETCC && !VT.isVector() &&
19424 (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::ConstantFP, VT)))
19425 return DAG.getSelect(DL, VT, Cond: N0.getOperand(i: 0),
19426 LHS: DAG.getConstantFP(Val: 1.0, DL, VT),
19427 RHS: DAG.getConstantFP(Val: 0.0, DL, VT));
19428
19429 if (SDValue FTrunc = foldFPToIntToFP(N, DL, DAG, TLI))
19430 return FTrunc;
19431
19432 // fold (sint_to_fp (trunc nsw x)) -> (sint_to_fp x)
19433 if (N0.getOpcode() == ISD::TRUNCATE && N0->getFlags().hasNoSignedWrap() &&
19434 TLI.isTypeDesirableForOp(ISD::SINT_TO_FP,
19435 VT: N0.getOperand(i: 0).getValueType()))
19436 return DAG.getNode(Opcode: ISD::SINT_TO_FP, DL, VT, Operand: N0.getOperand(i: 0));
19437
19438 return SDValue();
19439}
19440
19441SDValue DAGCombiner::visitUINT_TO_FP(SDNode *N) {
19442 SDValue N0 = N->getOperand(Num: 0);
19443 EVT VT = N->getValueType(ResNo: 0);
19444 EVT OpVT = N0.getValueType();
19445 SDLoc DL(N);
19446
19447 // [us]itofp(undef) = 0, because the result value is bounded.
19448 if (N0.isUndef())
19449 return DAG.getConstantFP(Val: 0.0, DL, VT);
19450
19451 // fold (uint_to_fp c1) -> c1fp
19452 // ...but only if the target supports immediate floating-point values
19453 if ((!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::ConstantFP, VT)))
19454 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::UINT_TO_FP, DL, VT, Ops: {N0}))
19455 return C;
19456
19457 // If the input is a legal type, and UINT_TO_FP is not legal on this target,
19458 // but SINT_TO_FP is legal on this target, try to convert.
19459 if (!hasOperation(Opcode: ISD::UINT_TO_FP, VT: OpVT) &&
19460 hasOperation(Opcode: ISD::SINT_TO_FP, VT: OpVT)) {
19461 // If the sign bit is known to be zero, we can change this to SINT_TO_FP.
19462 if (DAG.SignBitIsZero(Op: N0))
19463 return DAG.getNode(Opcode: ISD::SINT_TO_FP, DL, VT, Operand: N0);
19464 }
19465
19466 // fold (uint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), 1.0, 0.0)
19467 if (N0.getOpcode() == ISD::SETCC && !VT.isVector() &&
19468 (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::ConstantFP, VT)))
19469 return DAG.getSelect(DL, VT, Cond: N0, LHS: DAG.getConstantFP(Val: 1.0, DL, VT),
19470 RHS: DAG.getConstantFP(Val: 0.0, DL, VT));
19471
19472 if (SDValue FTrunc = foldFPToIntToFP(N, DL, DAG, TLI))
19473 return FTrunc;
19474
19475 // fold (uint_to_fp (trunc nuw x)) -> (uint_to_fp x)
19476 if (N0.getOpcode() == ISD::TRUNCATE && N0->getFlags().hasNoUnsignedWrap() &&
19477 TLI.isTypeDesirableForOp(ISD::UINT_TO_FP,
19478 VT: N0.getOperand(i: 0).getValueType()))
19479 return DAG.getNode(Opcode: ISD::UINT_TO_FP, DL, VT, Operand: N0.getOperand(i: 0));
19480
19481 return SDValue();
19482}
19483
19484// Fold (fp_to_{s/u}int ({s/u}int_to_fpx)) -> zext x, sext x, trunc x, or x
19485static SDValue FoldIntToFPToInt(SDNode *N, const SDLoc &DL, SelectionDAG &DAG) {
19486 SDValue N0 = N->getOperand(Num: 0);
19487 EVT VT = N->getValueType(ResNo: 0);
19488
19489 if (N0.getOpcode() != ISD::UINT_TO_FP && N0.getOpcode() != ISD::SINT_TO_FP)
19490 return SDValue();
19491
19492 SDValue Src = N0.getOperand(i: 0);
19493 EVT SrcVT = Src.getValueType();
19494 bool IsInputSigned = N0.getOpcode() == ISD::SINT_TO_FP;
19495 bool IsOutputSigned = N->getOpcode() == ISD::FP_TO_SINT;
19496
19497 // We can safely assume the conversion won't overflow the output range,
19498 // because (for example) (uint8_t)18293.f is undefined behavior.
19499
19500 // Since we can assume the conversion won't overflow, our decision as to
19501 // whether the input will fit in the float should depend on the minimum
19502 // of the input range and output range.
19503
19504 // This means this is also safe for a signed input and unsigned output, since
19505 // a negative input would lead to undefined behavior.
19506 unsigned InputSize = (int)SrcVT.getScalarSizeInBits() - IsInputSigned;
19507 unsigned OutputSize = (int)VT.getScalarSizeInBits();
19508 unsigned ActualSize = std::min(a: InputSize, b: OutputSize);
19509 const fltSemantics &Sem = N0.getValueType().getFltSemantics();
19510
19511 // We can only fold away the float conversion if the input range can be
19512 // represented exactly in the float range.
19513 if (APFloat::semanticsPrecision(Sem) >= ActualSize) {
19514 if (VT.getScalarSizeInBits() > SrcVT.getScalarSizeInBits()) {
19515 unsigned ExtOp =
19516 IsInputSigned && IsOutputSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
19517 return DAG.getNode(Opcode: ExtOp, DL, VT, Operand: Src);
19518 }
19519 if (VT.getScalarSizeInBits() < SrcVT.getScalarSizeInBits())
19520 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: Src);
19521 return DAG.getBitcast(VT, V: Src);
19522 }
19523 return SDValue();
19524}
19525
19526SDValue DAGCombiner::visitFP_TO_SINT(SDNode *N) {
19527 SDValue N0 = N->getOperand(Num: 0);
19528 EVT VT = N->getValueType(ResNo: 0);
19529 SDLoc DL(N);
19530
19531 // fold (fp_to_sint undef) -> undef
19532 if (N0.isUndef())
19533 return DAG.getUNDEF(VT);
19534
19535 // fold (fp_to_sint c1fp) -> c1
19536 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FP_TO_SINT, DL, VT, Ops: {N0}))
19537 return C;
19538
19539 return FoldIntToFPToInt(N, DL, DAG);
19540}
19541
19542SDValue DAGCombiner::visitFP_TO_UINT(SDNode *N) {
19543 SDValue N0 = N->getOperand(Num: 0);
19544 EVT VT = N->getValueType(ResNo: 0);
19545 SDLoc DL(N);
19546
19547 // fold (fp_to_uint undef) -> undef
19548 if (N0.isUndef())
19549 return DAG.getUNDEF(VT);
19550
19551 // fold (fp_to_uint c1fp) -> c1
19552 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FP_TO_UINT, DL, VT, Ops: {N0}))
19553 return C;
19554
19555 return FoldIntToFPToInt(N, DL, DAG);
19556}
19557
19558SDValue DAGCombiner::visitXROUND(SDNode *N) {
19559 SDValue N0 = N->getOperand(Num: 0);
19560 EVT VT = N->getValueType(ResNo: 0);
19561
19562 // fold (lrint|llrint undef) -> undef
19563 // fold (lround|llround undef) -> undef
19564 if (N0.isUndef())
19565 return DAG.getUNDEF(VT);
19566
19567 // fold (lrint|llrint c1fp) -> c1
19568 // fold (lround|llround c1fp) -> c1
19569 if (SDValue C =
19570 DAG.FoldConstantArithmetic(Opcode: N->getOpcode(), DL: SDLoc(N), VT, Ops: {N0}))
19571 return C;
19572
19573 return SDValue();
19574}
19575
19576SDValue DAGCombiner::visitFP_ROUND(SDNode *N) {
19577 SDValue N0 = N->getOperand(Num: 0);
19578 SDValue N1 = N->getOperand(Num: 1);
19579 EVT VT = N->getValueType(ResNo: 0);
19580 SDLoc DL(N);
19581
19582 // fold (fp_round c1fp) -> c1fp
19583 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FP_ROUND, DL, VT, Ops: {N0, N1}))
19584 return C;
19585
19586 // fold (fp_round (fp_extend x)) -> x
19587 if (N0.getOpcode() == ISD::FP_EXTEND && VT == N0.getOperand(i: 0).getValueType())
19588 return N0.getOperand(i: 0);
19589
19590 // fold (fp_round (fp_round x)) -> (fp_round x)
19591 if (N0.getOpcode() == ISD::FP_ROUND) {
19592 const bool NIsTrunc = N->getConstantOperandVal(Num: 1) == 1;
19593 const bool N0IsTrunc = N0.getConstantOperandVal(i: 1) == 1;
19594
19595 // Avoid folding legal fp_rounds into non-legal ones.
19596 if (!hasOperation(Opcode: ISD::FP_ROUND, VT))
19597 return SDValue();
19598
19599 // Skip this folding if it results in an fp_round from f80 to f16.
19600 //
19601 // f80 to f16 always generates an expensive (and as yet, unimplemented)
19602 // libcall to __truncxfhf2 instead of selecting native f16 conversion
19603 // instructions from f32 or f64. Moreover, the first (value-preserving)
19604 // fp_round from f80 to either f32 or f64 may become a NOP in platforms like
19605 // x86.
19606 if (N0.getOperand(i: 0).getValueType() == MVT::f80 && VT == MVT::f16)
19607 return SDValue();
19608
19609 // If the first fp_round isn't a value preserving truncation, it might
19610 // introduce a tie in the second fp_round, that wouldn't occur in the
19611 // single-step fp_round we want to fold to.
19612 // In other words, double rounding isn't the same as rounding.
19613 // Also, this is a value preserving truncation iff both fp_round's are.
19614 if ((N->getFlags().hasAllowContract() &&
19615 N0->getFlags().hasAllowContract()) ||
19616 N0IsTrunc)
19617 return DAG.getNode(
19618 Opcode: ISD::FP_ROUND, DL, VT, N1: N0.getOperand(i: 0),
19619 N2: DAG.getIntPtrConstant(Val: NIsTrunc && N0IsTrunc, DL, /*isTarget=*/true));
19620 }
19621
19622 // fold (fp_round (copysign X, Y)) -> (copysign (fp_round X), Y)
19623 // Note: From a legality perspective, this is a two step transform. First,
19624 // we duplicate the fp_round to the arguments of the copysign, then we
19625 // eliminate the fp_round on Y. The second step requires an additional
19626 // predicate to match the implementation above.
19627 if (N0.getOpcode() == ISD::FCOPYSIGN && N0->hasOneUse() &&
19628 CanCombineFCOPYSIGN_EXTEND_ROUND(XTy: VT,
19629 YTy: N0.getValueType())) {
19630 SDValue Tmp = DAG.getNode(Opcode: ISD::FP_ROUND, DL: SDLoc(N0), VT,
19631 N1: N0.getOperand(i: 0), N2: N1);
19632 AddToWorklist(N: Tmp.getNode());
19633 return DAG.getNode(Opcode: ISD::FCOPYSIGN, DL, VT, N1: Tmp, N2: N0.getOperand(i: 1));
19634 }
19635
19636 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(Cast: N))
19637 return NewVSel;
19638
19639 return SDValue();
19640}
19641
19642// Eliminate a floating-point widening of a narrowed value if the fast math
19643// flags allow it.
19644static SDValue eliminateFPCastPair(SDNode *N) {
19645 SDValue N0 = N->getOperand(Num: 0);
19646 EVT VT = N->getValueType(ResNo: 0);
19647
19648 unsigned NarrowingOp;
19649 switch (N->getOpcode()) {
19650 case ISD::FP16_TO_FP:
19651 NarrowingOp = ISD::FP_TO_FP16;
19652 break;
19653 case ISD::BF16_TO_FP:
19654 NarrowingOp = ISD::FP_TO_BF16;
19655 break;
19656 case ISD::FP_EXTEND:
19657 NarrowingOp = ISD::FP_ROUND;
19658 break;
19659 default:
19660 llvm_unreachable("Expected widening FP cast");
19661 }
19662
19663 if (N0.getOpcode() == NarrowingOp && N0.getOperand(i: 0).getValueType() == VT) {
19664 const SDNodeFlags NarrowFlags = N0->getFlags();
19665 const SDNodeFlags WidenFlags = N->getFlags();
19666 // Narrowing can introduce inf and change the encoding of a nan, so the
19667 // widen must have the nnan and ninf flags to indicate that we don't need to
19668 // care about that. We are also removing a rounding step, and that requires
19669 // both the narrow and widen to allow contraction.
19670 if (WidenFlags.hasNoNaNs() && WidenFlags.hasNoInfs() &&
19671 NarrowFlags.hasAllowContract() && WidenFlags.hasAllowContract()) {
19672 return N0.getOperand(i: 0);
19673 }
19674 }
19675
19676 return SDValue();
19677}
19678
19679SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) {
19680 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
19681 SDValue N0 = N->getOperand(Num: 0);
19682 EVT VT = N->getValueType(ResNo: 0);
19683 SDLoc DL(N);
19684
19685 if (VT.isVector())
19686 if (SDValue FoldedVOp = SimplifyVCastOp(N, DL))
19687 return FoldedVOp;
19688
19689 // If this is fp_round(fpextend), don't fold it, allow ourselves to be folded.
19690 if (N->hasOneUse() && N->user_begin()->getOpcode() == ISD::FP_ROUND)
19691 return SDValue();
19692
19693 // fold (fp_extend c1fp) -> c1fp
19694 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FP_EXTEND, DL, VT, Ops: {N0}))
19695 return C;
19696
19697 // fold (fp_extend (fp16_to_fp op)) -> (fp16_to_fp op)
19698 if (N0.getOpcode() == ISD::FP16_TO_FP &&
19699 TLI.getOperationAction(Op: ISD::FP16_TO_FP, VT) == TargetLowering::Legal)
19700 return DAG.getNode(Opcode: ISD::FP16_TO_FP, DL, VT, Operand: N0.getOperand(i: 0));
19701
19702 // Turn fp_extend(fp_round(X, 1)) -> x since the fp_round doesn't affect the
19703 // value of X.
19704 if (N0.getOpcode() == ISD::FP_ROUND && N0.getConstantOperandVal(i: 1) == 1) {
19705 SDValue In = N0.getOperand(i: 0);
19706 if (In.getValueType() == VT) return In;
19707 if (VT.bitsLT(VT: In.getValueType()))
19708 return DAG.getNode(Opcode: ISD::FP_ROUND, DL, VT, N1: In, N2: N0.getOperand(i: 1));
19709 return DAG.getNode(Opcode: ISD::FP_EXTEND, DL, VT, Operand: In);
19710 }
19711
19712 // fold (fpext (load x)) -> (fpext (fptrunc (extload x)))
19713 if (ISD::isNormalLoad(N: N0.getNode()) && N0.hasOneUse() &&
19714 TLI.isLoadExtLegalOrCustom(ExtType: ISD::EXTLOAD, ValVT: VT, MemVT: N0.getValueType())) {
19715 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
19716 SDValue ExtLoad = DAG.getExtLoad(ExtType: ISD::EXTLOAD, dl: DL, VT,
19717 Chain: LN0->getChain(),
19718 Ptr: LN0->getBasePtr(), MemVT: N0.getValueType(),
19719 MMO: LN0->getMemOperand());
19720 CombineTo(N, Res: ExtLoad);
19721 CombineTo(
19722 N: N0.getNode(),
19723 Res0: DAG.getNode(Opcode: ISD::FP_ROUND, DL: SDLoc(N0), VT: N0.getValueType(), N1: ExtLoad,
19724 N2: DAG.getIntPtrConstant(Val: 1, DL: SDLoc(N0), /*isTarget=*/true)),
19725 Res1: ExtLoad.getValue(R: 1));
19726 return SDValue(N, 0); // Return N so it doesn't get rechecked!
19727 }
19728
19729 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(Cast: N))
19730 return NewVSel;
19731
19732 if (SDValue CastEliminated = eliminateFPCastPair(N))
19733 return CastEliminated;
19734
19735 return SDValue();
19736}
19737
19738SDValue DAGCombiner::visitFCEIL(SDNode *N) {
19739 SDValue N0 = N->getOperand(Num: 0);
19740 EVT VT = N->getValueType(ResNo: 0);
19741
19742 // fold (fceil c1) -> fceil(c1)
19743 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FCEIL, DL: SDLoc(N), VT, Ops: {N0}))
19744 return C;
19745
19746 return SDValue();
19747}
19748
19749SDValue DAGCombiner::visitFTRUNC(SDNode *N) {
19750 SDValue N0 = N->getOperand(Num: 0);
19751 EVT VT = N->getValueType(ResNo: 0);
19752
19753 // fold (ftrunc c1) -> ftrunc(c1)
19754 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FTRUNC, DL: SDLoc(N), VT, Ops: {N0}))
19755 return C;
19756
19757 // fold ftrunc (known rounded int x) -> x
19758 // ftrunc is a part of fptosi/fptoui expansion on some targets, so this is
19759 // likely to be generated to extract integer from a rounded floating value.
19760 switch (N0.getOpcode()) {
19761 default: break;
19762 case ISD::FRINT:
19763 case ISD::FTRUNC:
19764 case ISD::FNEARBYINT:
19765 case ISD::FROUNDEVEN:
19766 case ISD::FFLOOR:
19767 case ISD::FCEIL:
19768 return N0;
19769 }
19770
19771 return SDValue();
19772}
19773
19774SDValue DAGCombiner::visitFFREXP(SDNode *N) {
19775 SDValue N0 = N->getOperand(Num: 0);
19776
19777 // fold (ffrexp c1) -> ffrexp(c1)
19778 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0))
19779 return DAG.getNode(Opcode: ISD::FFREXP, DL: SDLoc(N), VTList: N->getVTList(), N: N0);
19780 return SDValue();
19781}
19782
19783SDValue DAGCombiner::visitFFLOOR(SDNode *N) {
19784 SDValue N0 = N->getOperand(Num: 0);
19785 EVT VT = N->getValueType(ResNo: 0);
19786
19787 // fold (ffloor c1) -> ffloor(c1)
19788 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FFLOOR, DL: SDLoc(N), VT, Ops: {N0}))
19789 return C;
19790
19791 return SDValue();
19792}
19793
19794SDValue DAGCombiner::visitFNEG(SDNode *N) {
19795 SDValue N0 = N->getOperand(Num: 0);
19796 EVT VT = N->getValueType(ResNo: 0);
19797 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
19798
19799 // Constant fold FNEG.
19800 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FNEG, DL: SDLoc(N), VT, Ops: {N0}))
19801 return C;
19802
19803 if (SDValue NegN0 =
19804 TLI.getNegatedExpression(Op: N0, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize))
19805 return NegN0;
19806
19807 // -(X-Y) -> (Y-X) is unsafe because when X==Y, -0.0 != +0.0
19808 // FIXME: This is duplicated in getNegatibleCost, but getNegatibleCost doesn't
19809 // know it was called from a context with a nsz flag if the input fsub does
19810 // not.
19811 if (N0.getOpcode() == ISD::FSUB && N->getFlags().hasNoSignedZeros() &&
19812 N0.hasOneUse()) {
19813 return DAG.getNode(Opcode: ISD::FSUB, DL: SDLoc(N), VT, N1: N0.getOperand(i: 1),
19814 N2: N0.getOperand(i: 0));
19815 }
19816
19817 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
19818 return SDValue(N, 0);
19819
19820 if (SDValue Cast = foldSignChangeInBitcast(N))
19821 return Cast;
19822
19823 return SDValue();
19824}
19825
19826SDValue DAGCombiner::visitFMinMax(SDNode *N) {
19827 SDValue N0 = N->getOperand(Num: 0);
19828 SDValue N1 = N->getOperand(Num: 1);
19829 EVT VT = N->getValueType(ResNo: 0);
19830 const SDNodeFlags Flags = N->getFlags();
19831 unsigned Opc = N->getOpcode();
19832 bool PropAllNaNsToQNaNs = Opc == ISD::FMINIMUM || Opc == ISD::FMAXIMUM;
19833 bool ReturnsOtherForAllNaNs =
19834 Opc == ISD::FMINIMUMNUM || Opc == ISD::FMAXIMUMNUM;
19835 bool IsMin =
19836 Opc == ISD::FMINNUM || Opc == ISD::FMINIMUM || Opc == ISD::FMINIMUMNUM;
19837 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
19838
19839 // Constant fold.
19840 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: Opc, DL: SDLoc(N), VT, Ops: {N0, N1}))
19841 return C;
19842
19843 // Canonicalize to constant on RHS.
19844 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0) &&
19845 !DAG.isConstantFPBuildVectorOrConstantFP(N: N1))
19846 return DAG.getNode(Opcode: N->getOpcode(), DL: SDLoc(N), VT, N1, N2: N0);
19847
19848 if (const ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N: N1)) {
19849 const APFloat &AF = N1CFP->getValueAPF();
19850
19851 // minnum(X, qnan) -> X
19852 // maxnum(X, qnan) -> X
19853 // minimum(X, nan) -> qnan
19854 // maximum(X, nan) -> qnan
19855 // minimumnum(X, nan) -> X
19856 // maximumnum(X, nan) -> X
19857 if (AF.isNaN()) {
19858 if (PropAllNaNsToQNaNs) {
19859 if (AF.isSignaling())
19860 return DAG.getConstantFP(Val: AF.makeQuiet(), DL: SDLoc(N), VT);
19861 return N->getOperand(Num: 1);
19862 } else if (ReturnsOtherForAllNaNs || !AF.isSignaling()) {
19863 return N->getOperand(Num: 0);
19864 }
19865 return SDValue();
19866 }
19867
19868 // In the following folds, inf can be replaced with the largest finite
19869 // float, if the ninf flag is set.
19870 if (AF.isInfinity() || (Flags.hasNoInfs() && AF.isLargest())) {
19871 // minimum(X, -inf) -> -inf if nnan
19872 // maximum(X, +inf) -> +inf if nnan
19873 // minimumnum(X, -inf) -> -inf
19874 // maximumnum(X, +inf) -> +inf
19875 if (IsMin == AF.isNegative() &&
19876 (ReturnsOtherForAllNaNs || Flags.hasNoNaNs()))
19877 return N->getOperand(Num: 1);
19878
19879 // minnum(X, +inf) -> X if nnan
19880 // maxnum(X, -inf) -> X if nnan
19881 // minimum(X, +inf) -> X (ignoring quieting of sNaNs)
19882 // maximum(X, -inf) -> X (ignoring quieting of sNaNs)
19883 // minimumnum(X, +inf) -> X if nnan
19884 // maximumnum(X, -inf) -> X if nnan
19885 if (IsMin != AF.isNegative() && (PropAllNaNsToQNaNs || Flags.hasNoNaNs()))
19886 return N->getOperand(Num: 0);
19887 }
19888 }
19889
19890 // There are no VECREDUCE variants of FMINIMUMNUM or FMAXIMUMNUM
19891 if (Opc == ISD::FMINIMUMNUM || Opc == ISD::FMAXIMUMNUM)
19892 return SDValue();
19893
19894 if (SDValue SD = reassociateReduction(
19895 RedOpc: PropAllNaNsToQNaNs
19896 ? (IsMin ? ISD::VECREDUCE_FMINIMUM : ISD::VECREDUCE_FMAXIMUM)
19897 : (IsMin ? ISD::VECREDUCE_FMIN : ISD::VECREDUCE_FMAX),
19898 Opc, DL: SDLoc(N), VT, N0, N1, Flags))
19899 return SD;
19900
19901 return SDValue();
19902}
19903
19904SDValue DAGCombiner::visitFABS(SDNode *N) {
19905 SDValue N0 = N->getOperand(Num: 0);
19906 EVT VT = N->getValueType(ResNo: 0);
19907 SDLoc DL(N);
19908
19909 // fold (fabs c1) -> fabs(c1)
19910 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FABS, DL, VT, Ops: {N0}))
19911 return C;
19912
19913 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
19914 return SDValue(N, 0);
19915
19916 if (SDValue Cast = foldSignChangeInBitcast(N))
19917 return Cast;
19918
19919 return SDValue();
19920}
19921
19922SDValue DAGCombiner::visitBRCOND(SDNode *N) {
19923 SDValue Chain = N->getOperand(Num: 0);
19924 SDValue N1 = N->getOperand(Num: 1);
19925 SDValue N2 = N->getOperand(Num: 2);
19926
19927 // BRCOND(FREEZE(cond)) is equivalent to BRCOND(cond) (both are
19928 // nondeterministic jumps).
19929 if (N1->getOpcode() == ISD::FREEZE && N1.hasOneUse()) {
19930 return DAG.getNode(Opcode: ISD::BRCOND, DL: SDLoc(N), VT: MVT::Other, N1: Chain,
19931 N2: N1->getOperand(Num: 0), N3: N2, Flags: N->getFlags());
19932 }
19933
19934 // Variant of the previous fold where there is a SETCC in between:
19935 // BRCOND(SETCC(FREEZE(X), CONST, Cond))
19936 // =>
19937 // BRCOND(FREEZE(SETCC(X, CONST, Cond)))
19938 // =>
19939 // BRCOND(SETCC(X, CONST, Cond))
19940 // This is correct if FREEZE(X) has one use and SETCC(FREEZE(X), CONST, Cond)
19941 // isn't equivalent to true or false.
19942 // For example, SETCC(FREEZE(X), -128, SETULT) cannot be folded to
19943 // FREEZE(SETCC(X, -128, SETULT)) because X can be poison.
19944 if (N1->getOpcode() == ISD::SETCC && N1.hasOneUse()) {
19945 SDValue S0 = N1->getOperand(Num: 0), S1 = N1->getOperand(Num: 1);
19946 ISD::CondCode Cond = cast<CondCodeSDNode>(Val: N1->getOperand(Num: 2))->get();
19947 ConstantSDNode *S0C = dyn_cast<ConstantSDNode>(Val&: S0);
19948 ConstantSDNode *S1C = dyn_cast<ConstantSDNode>(Val&: S1);
19949 bool Updated = false;
19950
19951 // Is 'X Cond C' always true or false?
19952 auto IsAlwaysTrueOrFalse = [](ISD::CondCode Cond, ConstantSDNode *C) {
19953 bool False = (Cond == ISD::SETULT && C->isZero()) ||
19954 (Cond == ISD::SETLT && C->isMinSignedValue()) ||
19955 (Cond == ISD::SETUGT && C->isAllOnes()) ||
19956 (Cond == ISD::SETGT && C->isMaxSignedValue());
19957 bool True = (Cond == ISD::SETULE && C->isAllOnes()) ||
19958 (Cond == ISD::SETLE && C->isMaxSignedValue()) ||
19959 (Cond == ISD::SETUGE && C->isZero()) ||
19960 (Cond == ISD::SETGE && C->isMinSignedValue());
19961 return True || False;
19962 };
19963
19964 if (S0->getOpcode() == ISD::FREEZE && S0.hasOneUse() && S1C) {
19965 if (!IsAlwaysTrueOrFalse(Cond, S1C)) {
19966 S0 = S0->getOperand(Num: 0);
19967 Updated = true;
19968 }
19969 }
19970 if (S1->getOpcode() == ISD::FREEZE && S1.hasOneUse() && S0C) {
19971 if (!IsAlwaysTrueOrFalse(ISD::getSetCCSwappedOperands(Operation: Cond), S0C)) {
19972 S1 = S1->getOperand(Num: 0);
19973 Updated = true;
19974 }
19975 }
19976
19977 if (Updated)
19978 return DAG.getNode(
19979 Opcode: ISD::BRCOND, DL: SDLoc(N), VT: MVT::Other, N1: Chain,
19980 N2: DAG.getSetCC(DL: SDLoc(N1), VT: N1->getValueType(ResNo: 0), LHS: S0, RHS: S1, Cond), N3: N2,
19981 Flags: N->getFlags());
19982 }
19983
19984 // If N is a constant we could fold this into a fallthrough or unconditional
19985 // branch. However that doesn't happen very often in normal code, because
19986 // Instcombine/SimplifyCFG should have handled the available opportunities.
19987 // If we did this folding here, it would be necessary to update the
19988 // MachineBasicBlock CFG, which is awkward.
19989
19990 // fold a brcond with a setcc condition into a BR_CC node if BR_CC is legal
19991 // on the target, also copy fast math flags.
19992 if (N1.getOpcode() == ISD::SETCC &&
19993 TLI.isOperationLegalOrCustom(Op: ISD::BR_CC,
19994 VT: N1.getOperand(i: 0).getValueType())) {
19995 return DAG.getNode(Opcode: ISD::BR_CC, DL: SDLoc(N), VT: MVT::Other, N1: Chain,
19996 N2: N1.getOperand(i: 2), N3: N1.getOperand(i: 0), N4: N1.getOperand(i: 1), N5: N2,
19997 Flags: N1->getFlags());
19998 }
19999
20000 if (N1.hasOneUse()) {
20001 // rebuildSetCC calls visitXor which may change the Chain when there is a
20002 // STRICT_FSETCC/STRICT_FSETCCS involved. Use a handle to track changes.
20003 HandleSDNode ChainHandle(Chain);
20004 if (SDValue NewN1 = rebuildSetCC(N: N1))
20005 return DAG.getNode(Opcode: ISD::BRCOND, DL: SDLoc(N), VT: MVT::Other,
20006 N1: ChainHandle.getValue(), N2: NewN1, N3: N2, Flags: N->getFlags());
20007 }
20008
20009 return SDValue();
20010}
20011
20012SDValue DAGCombiner::rebuildSetCC(SDValue N) {
20013 if (N.getOpcode() == ISD::SRL ||
20014 (N.getOpcode() == ISD::TRUNCATE &&
20015 (N.getOperand(i: 0).hasOneUse() &&
20016 N.getOperand(i: 0).getOpcode() == ISD::SRL))) {
20017 // Look pass the truncate.
20018 if (N.getOpcode() == ISD::TRUNCATE)
20019 N = N.getOperand(i: 0);
20020
20021 // Match this pattern so that we can generate simpler code:
20022 //
20023 // %a = ...
20024 // %b = and i32 %a, 2
20025 // %c = srl i32 %b, 1
20026 // brcond i32 %c ...
20027 //
20028 // into
20029 //
20030 // %a = ...
20031 // %b = and i32 %a, 2
20032 // %c = setcc eq %b, 0
20033 // brcond %c ...
20034 //
20035 // This applies only when the AND constant value has one bit set and the
20036 // SRL constant is equal to the log2 of the AND constant. The back-end is
20037 // smart enough to convert the result into a TEST/JMP sequence.
20038 SDValue Op0 = N.getOperand(i: 0);
20039 SDValue Op1 = N.getOperand(i: 1);
20040
20041 if (Op0.getOpcode() == ISD::AND && Op1.getOpcode() == ISD::Constant) {
20042 SDValue AndOp1 = Op0.getOperand(i: 1);
20043
20044 if (AndOp1.getOpcode() == ISD::Constant) {
20045 const APInt &AndConst = AndOp1->getAsAPIntVal();
20046
20047 if (AndConst.isPowerOf2() &&
20048 Op1->getAsAPIntVal() == AndConst.logBase2()) {
20049 SDLoc DL(N);
20050 return DAG.getSetCC(DL, VT: getSetCCResultType(VT: Op0.getValueType()),
20051 LHS: Op0, RHS: DAG.getConstant(Val: 0, DL, VT: Op0.getValueType()),
20052 Cond: ISD::SETNE);
20053 }
20054 }
20055 }
20056 }
20057
20058 // Transform (brcond (xor x, y)) -> (brcond (setcc, x, y, ne))
20059 // Transform (brcond (xor (xor x, y), -1)) -> (brcond (setcc, x, y, eq))
20060 if (N.getOpcode() == ISD::XOR) {
20061 // Because we may call this on a speculatively constructed
20062 // SimplifiedSetCC Node, we need to simplify this node first.
20063 // Ideally this should be folded into SimplifySetCC and not
20064 // here. For now, grab a handle to N so we don't lose it from
20065 // replacements interal to the visit.
20066 HandleSDNode XORHandle(N);
20067 while (N.getOpcode() == ISD::XOR) {
20068 SDValue Tmp = visitXOR(N: N.getNode());
20069 // No simplification done.
20070 if (!Tmp.getNode())
20071 break;
20072 // Returning N is form in-visit replacement that may invalidated
20073 // N. Grab value from Handle.
20074 if (Tmp.getNode() == N.getNode())
20075 N = XORHandle.getValue();
20076 else // Node simplified. Try simplifying again.
20077 N = Tmp;
20078 }
20079
20080 if (N.getOpcode() != ISD::XOR)
20081 return N;
20082
20083 SDValue Op0 = N->getOperand(Num: 0);
20084 SDValue Op1 = N->getOperand(Num: 1);
20085
20086 if (Op0.getOpcode() != ISD::SETCC && Op1.getOpcode() != ISD::SETCC) {
20087 bool Equal = false;
20088 // (brcond (xor (xor x, y), -1)) -> (brcond (setcc x, y, eq))
20089 if (isBitwiseNot(V: N) && Op0.hasOneUse() && Op0.getOpcode() == ISD::XOR &&
20090 Op0.getValueType() == MVT::i1) {
20091 N = Op0;
20092 Op0 = N->getOperand(Num: 0);
20093 Op1 = N->getOperand(Num: 1);
20094 Equal = true;
20095 }
20096
20097 EVT SetCCVT = N.getValueType();
20098 if (LegalTypes)
20099 SetCCVT = getSetCCResultType(VT: SetCCVT);
20100 // Replace the uses of XOR with SETCC. Note, avoid this transformation if
20101 // it would introduce illegal operations post-legalization as this can
20102 // result in infinite looping between converting xor->setcc here, and
20103 // expanding setcc->xor in LegalizeSetCCCondCode if requested.
20104 const ISD::CondCode CC = Equal ? ISD::SETEQ : ISD::SETNE;
20105 if (!LegalOperations || TLI.isCondCodeLegal(CC, VT: Op0.getSimpleValueType()))
20106 return DAG.getSetCC(DL: SDLoc(N), VT: SetCCVT, LHS: Op0, RHS: Op1, Cond: CC);
20107 }
20108 }
20109
20110 return SDValue();
20111}
20112
20113// Operand List for BR_CC: Chain, CondCC, CondLHS, CondRHS, DestBB.
20114//
20115SDValue DAGCombiner::visitBR_CC(SDNode *N) {
20116 CondCodeSDNode *CC = cast<CondCodeSDNode>(Val: N->getOperand(Num: 1));
20117 SDValue CondLHS = N->getOperand(Num: 2), CondRHS = N->getOperand(Num: 3);
20118
20119 // If N is a constant we could fold this into a fallthrough or unconditional
20120 // branch. However that doesn't happen very often in normal code, because
20121 // Instcombine/SimplifyCFG should have handled the available opportunities.
20122 // If we did this folding here, it would be necessary to update the
20123 // MachineBasicBlock CFG, which is awkward.
20124
20125 // Use SimplifySetCC to simplify SETCC's.
20126 SDValue Simp = SimplifySetCC(VT: getSetCCResultType(VT: CondLHS.getValueType()),
20127 N0: CondLHS, N1: CondRHS, Cond: CC->get(), DL: SDLoc(N),
20128 foldBooleans: false);
20129 if (Simp.getNode()) AddToWorklist(N: Simp.getNode());
20130
20131 // fold to a simpler setcc
20132 if (Simp.getNode() && Simp.getOpcode() == ISD::SETCC)
20133 return DAG.getNode(Opcode: ISD::BR_CC, DL: SDLoc(N), VT: MVT::Other,
20134 N1: N->getOperand(Num: 0), N2: Simp.getOperand(i: 2),
20135 N3: Simp.getOperand(i: 0), N4: Simp.getOperand(i: 1),
20136 N5: N->getOperand(Num: 4));
20137
20138 return SDValue();
20139}
20140
20141static bool getCombineLoadStoreParts(SDNode *N, unsigned Inc, unsigned Dec,
20142 bool &IsLoad, bool &IsMasked, SDValue &Ptr,
20143 const TargetLowering &TLI) {
20144 if (LoadSDNode *LD = dyn_cast<LoadSDNode>(Val: N)) {
20145 if (LD->isIndexed())
20146 return false;
20147 EVT VT = LD->getMemoryVT();
20148 if (!TLI.isIndexedLoadLegal(IdxMode: Inc, VT) && !TLI.isIndexedLoadLegal(IdxMode: Dec, VT))
20149 return false;
20150 Ptr = LD->getBasePtr();
20151 } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(Val: N)) {
20152 if (ST->isIndexed())
20153 return false;
20154 EVT VT = ST->getMemoryVT();
20155 if (!TLI.isIndexedStoreLegal(IdxMode: Inc, VT) && !TLI.isIndexedStoreLegal(IdxMode: Dec, VT))
20156 return false;
20157 Ptr = ST->getBasePtr();
20158 IsLoad = false;
20159 } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(Val: N)) {
20160 if (LD->isIndexed())
20161 return false;
20162 EVT VT = LD->getMemoryVT();
20163 if (!TLI.isIndexedMaskedLoadLegal(IdxMode: Inc, VT) &&
20164 !TLI.isIndexedMaskedLoadLegal(IdxMode: Dec, VT))
20165 return false;
20166 Ptr = LD->getBasePtr();
20167 IsMasked = true;
20168 } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(Val: N)) {
20169 if (ST->isIndexed())
20170 return false;
20171 EVT VT = ST->getMemoryVT();
20172 if (!TLI.isIndexedMaskedStoreLegal(IdxMode: Inc, VT) &&
20173 !TLI.isIndexedMaskedStoreLegal(IdxMode: Dec, VT))
20174 return false;
20175 Ptr = ST->getBasePtr();
20176 IsLoad = false;
20177 IsMasked = true;
20178 } else {
20179 return false;
20180 }
20181 return true;
20182}
20183
20184/// Try turning a load/store into a pre-indexed load/store when the base
20185/// pointer is an add or subtract and it has other uses besides the load/store.
20186/// After the transformation, the new indexed load/store has effectively folded
20187/// the add/subtract in and all of its other uses are redirected to the
20188/// new load/store.
20189bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) {
20190 if (Level < AfterLegalizeDAG)
20191 return false;
20192
20193 bool IsLoad = true;
20194 bool IsMasked = false;
20195 SDValue Ptr;
20196 if (!getCombineLoadStoreParts(N, Inc: ISD::PRE_INC, Dec: ISD::PRE_DEC, IsLoad, IsMasked,
20197 Ptr, TLI))
20198 return false;
20199
20200 // If the pointer is not an add/sub, or if it doesn't have multiple uses, bail
20201 // out. There is no reason to make this a preinc/predec.
20202 if ((Ptr.getOpcode() != ISD::ADD && Ptr.getOpcode() != ISD::SUB) ||
20203 Ptr->hasOneUse())
20204 return false;
20205
20206 // Ask the target to do addressing mode selection.
20207 SDValue BasePtr;
20208 SDValue Offset;
20209 ISD::MemIndexedMode AM = ISD::UNINDEXED;
20210 if (!TLI.getPreIndexedAddressParts(N, BasePtr, Offset, AM, DAG))
20211 return false;
20212
20213 // Backends without true r+i pre-indexed forms may need to pass a
20214 // constant base with a variable offset so that constant coercion
20215 // will work with the patterns in canonical form.
20216 bool Swapped = false;
20217 if (isa<ConstantSDNode>(Val: BasePtr)) {
20218 std::swap(a&: BasePtr, b&: Offset);
20219 Swapped = true;
20220 }
20221
20222 // Don't create a indexed load / store with zero offset.
20223 if (isNullConstant(V: Offset))
20224 return false;
20225
20226 // Try turning it into a pre-indexed load / store except when:
20227 // 1) The new base ptr is a frame index.
20228 // 2) If N is a store and the new base ptr is either the same as or is a
20229 // predecessor of the value being stored.
20230 // 3) Another use of old base ptr is a predecessor of N. If ptr is folded
20231 // that would create a cycle.
20232 // 4) All uses are load / store ops that use it as old base ptr.
20233
20234 // Check #1. Preinc'ing a frame index would require copying the stack pointer
20235 // (plus the implicit offset) to a register to preinc anyway.
20236 if (isa<FrameIndexSDNode>(Val: BasePtr) || isa<RegisterSDNode>(Val: BasePtr))
20237 return false;
20238
20239 // Check #2.
20240 if (!IsLoad) {
20241 SDValue Val = IsMasked ? cast<MaskedStoreSDNode>(Val: N)->getValue()
20242 : cast<StoreSDNode>(Val: N)->getValue();
20243
20244 // Would require a copy.
20245 if (Val == BasePtr)
20246 return false;
20247
20248 // Would create a cycle.
20249 if (Val == Ptr || Ptr->isPredecessorOf(N: Val.getNode()))
20250 return false;
20251 }
20252
20253 // Caches for hasPredecessorHelper.
20254 SmallPtrSet<const SDNode *, 32> Visited;
20255 SmallVector<const SDNode *, 16> Worklist;
20256 Worklist.push_back(Elt: N);
20257
20258 // If the offset is a constant, there may be other adds of constants that
20259 // can be folded with this one. We should do this to avoid having to keep
20260 // a copy of the original base pointer.
20261 SmallVector<SDNode *, 16> OtherUses;
20262 unsigned MaxSteps = SelectionDAG::getHasPredecessorMaxSteps();
20263 if (isa<ConstantSDNode>(Val: Offset))
20264 for (SDUse &Use : BasePtr->uses()) {
20265 // Skip the use that is Ptr and uses of other results from BasePtr's
20266 // node (important for nodes that return multiple results).
20267 if (Use.getUser() == Ptr.getNode() || Use != BasePtr)
20268 continue;
20269
20270 if (SDNode::hasPredecessorHelper(N: Use.getUser(), Visited, Worklist,
20271 MaxSteps))
20272 continue;
20273
20274 if (Use.getUser()->getOpcode() != ISD::ADD &&
20275 Use.getUser()->getOpcode() != ISD::SUB) {
20276 OtherUses.clear();
20277 break;
20278 }
20279
20280 SDValue Op1 = Use.getUser()->getOperand(Num: (Use.getOperandNo() + 1) & 1);
20281 if (!isa<ConstantSDNode>(Val: Op1)) {
20282 OtherUses.clear();
20283 break;
20284 }
20285
20286 // FIXME: In some cases, we can be smarter about this.
20287 if (Op1.getValueType() != Offset.getValueType()) {
20288 OtherUses.clear();
20289 break;
20290 }
20291
20292 OtherUses.push_back(Elt: Use.getUser());
20293 }
20294
20295 if (Swapped)
20296 std::swap(a&: BasePtr, b&: Offset);
20297
20298 // Now check for #3 and #4.
20299 bool RealUse = false;
20300
20301 for (SDNode *User : Ptr->users()) {
20302 if (User == N)
20303 continue;
20304 if (SDNode::hasPredecessorHelper(N: User, Visited, Worklist, MaxSteps))
20305 return false;
20306
20307 // If Ptr may be folded in addressing mode of other use, then it's
20308 // not profitable to do this transformation.
20309 if (!canFoldInAddressingMode(N: Ptr.getNode(), Use: User, DAG, TLI))
20310 RealUse = true;
20311 }
20312
20313 if (!RealUse)
20314 return false;
20315
20316 SDValue Result;
20317 if (!IsMasked) {
20318 if (IsLoad)
20319 Result = DAG.getIndexedLoad(OrigLoad: SDValue(N, 0), dl: SDLoc(N), Base: BasePtr, Offset, AM);
20320 else
20321 Result =
20322 DAG.getIndexedStore(OrigStore: SDValue(N, 0), dl: SDLoc(N), Base: BasePtr, Offset, AM);
20323 } else {
20324 if (IsLoad)
20325 Result = DAG.getIndexedMaskedLoad(OrigLoad: SDValue(N, 0), dl: SDLoc(N), Base: BasePtr,
20326 Offset, AM);
20327 else
20328 Result = DAG.getIndexedMaskedStore(OrigStore: SDValue(N, 0), dl: SDLoc(N), Base: BasePtr,
20329 Offset, AM);
20330 }
20331 ++PreIndexedNodes;
20332 ++NodesCombined;
20333 LLVM_DEBUG(dbgs() << "\nReplacing.4 "; N->dump(&DAG); dbgs() << "\nWith: ";
20334 Result.dump(&DAG); dbgs() << '\n');
20335 WorklistRemover DeadNodes(*this);
20336 if (IsLoad) {
20337 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Result.getValue(R: 0));
20338 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 1), To: Result.getValue(R: 2));
20339 } else {
20340 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Result.getValue(R: 1));
20341 }
20342
20343 // Finally, since the node is now dead, remove it from the graph.
20344 deleteAndRecombine(N);
20345
20346 if (Swapped)
20347 std::swap(a&: BasePtr, b&: Offset);
20348
20349 // Replace other uses of BasePtr that can be updated to use Ptr
20350 for (SDNode *OtherUse : OtherUses) {
20351 unsigned OffsetIdx = 1;
20352 if (OtherUse->getOperand(Num: OffsetIdx).getNode() == BasePtr.getNode())
20353 OffsetIdx = 0;
20354 assert(OtherUse->getOperand(!OffsetIdx).getNode() == BasePtr.getNode() &&
20355 "Expected BasePtr operand");
20356
20357 // We need to replace ptr0 in the following expression:
20358 // x0 * offset0 + y0 * ptr0 = t0
20359 // knowing that
20360 // x1 * offset1 + y1 * ptr0 = t1 (the indexed load/store)
20361 //
20362 // where x0, x1, y0 and y1 in {-1, 1} are given by the types of the
20363 // indexed load/store and the expression that needs to be re-written.
20364 //
20365 // Therefore, we have:
20366 // t0 = (x0 * offset0 - x1 * y0 * y1 *offset1) + (y0 * y1) * t1
20367
20368 auto *CN = cast<ConstantSDNode>(Val: OtherUse->getOperand(Num: OffsetIdx));
20369 const APInt &Offset0 = CN->getAPIntValue();
20370 const APInt &Offset1 = Offset->getAsAPIntVal();
20371 int X0 = (OtherUse->getOpcode() == ISD::SUB && OffsetIdx == 1) ? -1 : 1;
20372 int Y0 = (OtherUse->getOpcode() == ISD::SUB && OffsetIdx == 0) ? -1 : 1;
20373 int X1 = (AM == ISD::PRE_DEC && !Swapped) ? -1 : 1;
20374 int Y1 = (AM == ISD::PRE_DEC && Swapped) ? -1 : 1;
20375
20376 unsigned Opcode = (Y0 * Y1 < 0) ? ISD::SUB : ISD::ADD;
20377
20378 APInt CNV = Offset0;
20379 if (X0 < 0) CNV = -CNV;
20380 if (X1 * Y0 * Y1 < 0) CNV = CNV + Offset1;
20381 else CNV = CNV - Offset1;
20382
20383 SDLoc DL(OtherUse);
20384
20385 // We can now generate the new expression.
20386 SDValue NewOp1 = DAG.getConstant(Val: CNV, DL, VT: CN->getValueType(ResNo: 0));
20387 SDValue NewOp2 = Result.getValue(R: IsLoad ? 1 : 0);
20388
20389 SDValue NewUse =
20390 DAG.getNode(Opcode, DL, VT: OtherUse->getValueType(ResNo: 0), N1: NewOp1, N2: NewOp2);
20391 DAG.ReplaceAllUsesOfValueWith(From: SDValue(OtherUse, 0), To: NewUse);
20392 deleteAndRecombine(N: OtherUse);
20393 }
20394
20395 // Replace the uses of Ptr with uses of the updated base value.
20396 DAG.ReplaceAllUsesOfValueWith(From: Ptr, To: Result.getValue(R: IsLoad ? 1 : 0));
20397 deleteAndRecombine(N: Ptr.getNode());
20398 AddToWorklist(N: Result.getNode());
20399
20400 return true;
20401}
20402
20403static bool shouldCombineToPostInc(SDNode *N, SDValue Ptr, SDNode *PtrUse,
20404 SDValue &BasePtr, SDValue &Offset,
20405 ISD::MemIndexedMode &AM,
20406 SelectionDAG &DAG,
20407 const TargetLowering &TLI) {
20408 if (PtrUse == N ||
20409 (PtrUse->getOpcode() != ISD::ADD && PtrUse->getOpcode() != ISD::SUB))
20410 return false;
20411
20412 if (!TLI.getPostIndexedAddressParts(N, PtrUse, BasePtr, Offset, AM, DAG))
20413 return false;
20414
20415 // Don't create a indexed load / store with zero offset.
20416 if (isNullConstant(V: Offset))
20417 return false;
20418
20419 if (isa<FrameIndexSDNode>(Val: BasePtr) || isa<RegisterSDNode>(Val: BasePtr))
20420 return false;
20421
20422 SmallPtrSet<const SDNode *, 32> Visited;
20423 unsigned MaxSteps = SelectionDAG::getHasPredecessorMaxSteps();
20424 for (SDNode *User : BasePtr->users()) {
20425 if (User == Ptr.getNode())
20426 continue;
20427
20428 // No if there's a later user which could perform the index instead.
20429 if (isa<MemSDNode>(Val: User)) {
20430 bool IsLoad = true;
20431 bool IsMasked = false;
20432 SDValue OtherPtr;
20433 if (getCombineLoadStoreParts(N: User, Inc: ISD::POST_INC, Dec: ISD::POST_DEC, IsLoad,
20434 IsMasked, Ptr&: OtherPtr, TLI)) {
20435 SmallVector<const SDNode *, 2> Worklist;
20436 Worklist.push_back(Elt: User);
20437 if (SDNode::hasPredecessorHelper(N, Visited, Worklist, MaxSteps))
20438 return false;
20439 }
20440 }
20441
20442 // If all the uses are load / store addresses, then don't do the
20443 // transformation.
20444 if (User->getOpcode() == ISD::ADD || User->getOpcode() == ISD::SUB) {
20445 for (SDNode *UserUser : User->users())
20446 if (canFoldInAddressingMode(N: User, Use: UserUser, DAG, TLI))
20447 return false;
20448 }
20449 }
20450 return true;
20451}
20452
20453static SDNode *getPostIndexedLoadStoreOp(SDNode *N, bool &IsLoad,
20454 bool &IsMasked, SDValue &Ptr,
20455 SDValue &BasePtr, SDValue &Offset,
20456 ISD::MemIndexedMode &AM,
20457 SelectionDAG &DAG,
20458 const TargetLowering &TLI) {
20459 if (!getCombineLoadStoreParts(N, Inc: ISD::POST_INC, Dec: ISD::POST_DEC, IsLoad,
20460 IsMasked, Ptr, TLI) ||
20461 Ptr->hasOneUse())
20462 return nullptr;
20463
20464 // Try turning it into a post-indexed load / store except when
20465 // 1) All uses are load / store ops that use it as base ptr (and
20466 // it may be folded as addressing mmode).
20467 // 2) Op must be independent of N, i.e. Op is neither a predecessor
20468 // nor a successor of N. Otherwise, if Op is folded that would
20469 // create a cycle.
20470 unsigned MaxSteps = SelectionDAG::getHasPredecessorMaxSteps();
20471 for (SDUse &U : Ptr->uses()) {
20472 if (U.getResNo() != Ptr.getResNo())
20473 continue;
20474
20475 // Check for #1.
20476 SDNode *Op = U.getUser();
20477 if (!shouldCombineToPostInc(N, Ptr, PtrUse: Op, BasePtr, Offset, AM, DAG, TLI))
20478 continue;
20479
20480 // Check for #2.
20481 SmallPtrSet<const SDNode *, 32> Visited;
20482 SmallVector<const SDNode *, 8> Worklist;
20483 // Ptr is predecessor to both N and Op.
20484 Visited.insert(Ptr: Ptr.getNode());
20485 Worklist.push_back(Elt: N);
20486 Worklist.push_back(Elt: Op);
20487 if (!SDNode::hasPredecessorHelper(N, Visited, Worklist, MaxSteps) &&
20488 !SDNode::hasPredecessorHelper(N: Op, Visited, Worklist, MaxSteps))
20489 return Op;
20490 }
20491 return nullptr;
20492}
20493
20494/// Try to combine a load/store with a add/sub of the base pointer node into a
20495/// post-indexed load/store. The transformation folded the add/subtract into the
20496/// new indexed load/store effectively and all of its uses are redirected to the
20497/// new load/store.
20498bool DAGCombiner::CombineToPostIndexedLoadStore(SDNode *N) {
20499 if (Level < AfterLegalizeDAG)
20500 return false;
20501
20502 bool IsLoad = true;
20503 bool IsMasked = false;
20504 SDValue Ptr;
20505 SDValue BasePtr;
20506 SDValue Offset;
20507 ISD::MemIndexedMode AM = ISD::UNINDEXED;
20508 SDNode *Op = getPostIndexedLoadStoreOp(N, IsLoad, IsMasked, Ptr, BasePtr,
20509 Offset, AM, DAG, TLI);
20510 if (!Op)
20511 return false;
20512
20513 SDValue Result;
20514 if (!IsMasked)
20515 Result = IsLoad ? DAG.getIndexedLoad(OrigLoad: SDValue(N, 0), dl: SDLoc(N), Base: BasePtr,
20516 Offset, AM)
20517 : DAG.getIndexedStore(OrigStore: SDValue(N, 0), dl: SDLoc(N),
20518 Base: BasePtr, Offset, AM);
20519 else
20520 Result = IsLoad ? DAG.getIndexedMaskedLoad(OrigLoad: SDValue(N, 0), dl: SDLoc(N),
20521 Base: BasePtr, Offset, AM)
20522 : DAG.getIndexedMaskedStore(OrigStore: SDValue(N, 0), dl: SDLoc(N),
20523 Base: BasePtr, Offset, AM);
20524 ++PostIndexedNodes;
20525 ++NodesCombined;
20526 LLVM_DEBUG(dbgs() << "\nReplacing.5 "; N->dump(&DAG); dbgs() << "\nWith: ";
20527 Result.dump(&DAG); dbgs() << '\n');
20528 WorklistRemover DeadNodes(*this);
20529 if (IsLoad) {
20530 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Result.getValue(R: 0));
20531 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 1), To: Result.getValue(R: 2));
20532 } else {
20533 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Result.getValue(R: 1));
20534 }
20535
20536 // Finally, since the node is now dead, remove it from the graph.
20537 deleteAndRecombine(N);
20538
20539 // Replace the uses of Use with uses of the updated base value.
20540 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Op, 0),
20541 To: Result.getValue(R: IsLoad ? 1 : 0));
20542 deleteAndRecombine(N: Op);
20543 return true;
20544}
20545
20546/// Return the base-pointer arithmetic from an indexed \p LD.
20547SDValue DAGCombiner::SplitIndexingFromLoad(LoadSDNode *LD) {
20548 ISD::MemIndexedMode AM = LD->getAddressingMode();
20549 assert(AM != ISD::UNINDEXED);
20550 SDValue BP = LD->getOperand(Num: 1);
20551 SDValue Inc = LD->getOperand(Num: 2);
20552
20553 // Some backends use TargetConstants for load offsets, but don't expect
20554 // TargetConstants in general ADD nodes. We can convert these constants into
20555 // regular Constants (if the constant is not opaque).
20556 assert((Inc.getOpcode() != ISD::TargetConstant ||
20557 !cast<ConstantSDNode>(Inc)->isOpaque()) &&
20558 "Cannot split out indexing using opaque target constants");
20559 if (Inc.getOpcode() == ISD::TargetConstant) {
20560 ConstantSDNode *ConstInc = cast<ConstantSDNode>(Val&: Inc);
20561 Inc = DAG.getConstant(Val: *ConstInc->getConstantIntValue(), DL: SDLoc(Inc),
20562 VT: ConstInc->getValueType(ResNo: 0));
20563 }
20564
20565 unsigned Opc =
20566 (AM == ISD::PRE_INC || AM == ISD::POST_INC ? ISD::ADD : ISD::SUB);
20567 return DAG.getNode(Opcode: Opc, DL: SDLoc(LD), VT: BP.getSimpleValueType(), N1: BP, N2: Inc);
20568}
20569
20570static inline ElementCount numVectorEltsOrZero(EVT T) {
20571 return T.isVector() ? T.getVectorElementCount() : ElementCount::getFixed(MinVal: 0);
20572}
20573
20574bool DAGCombiner::getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val) {
20575 EVT STType = Val.getValueType();
20576 EVT STMemType = ST->getMemoryVT();
20577 if (STType == STMemType)
20578 return true;
20579 if (isTypeLegal(VT: STMemType))
20580 return false; // fail.
20581 if (STType.isFloatingPoint() && STMemType.isFloatingPoint() &&
20582 TLI.isOperationLegal(Op: ISD::FTRUNC, VT: STMemType)) {
20583 Val = DAG.getNode(Opcode: ISD::FTRUNC, DL: SDLoc(ST), VT: STMemType, Operand: Val);
20584 return true;
20585 }
20586 if (numVectorEltsOrZero(T: STType) == numVectorEltsOrZero(T: STMemType) &&
20587 STType.isInteger() && STMemType.isInteger()) {
20588 Val = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(ST), VT: STMemType, Operand: Val);
20589 return true;
20590 }
20591 if (STType.getSizeInBits() == STMemType.getSizeInBits()) {
20592 Val = DAG.getBitcast(VT: STMemType, V: Val);
20593 return true;
20594 }
20595 return false; // fail.
20596}
20597
20598bool DAGCombiner::extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val) {
20599 EVT LDMemType = LD->getMemoryVT();
20600 EVT LDType = LD->getValueType(ResNo: 0);
20601 assert(Val.getValueType() == LDMemType &&
20602 "Attempting to extend value of non-matching type");
20603 if (LDType == LDMemType)
20604 return true;
20605 if (LDMemType.isInteger() && LDType.isInteger()) {
20606 switch (LD->getExtensionType()) {
20607 case ISD::NON_EXTLOAD:
20608 Val = DAG.getBitcast(VT: LDType, V: Val);
20609 return true;
20610 case ISD::EXTLOAD:
20611 Val = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL: SDLoc(LD), VT: LDType, Operand: Val);
20612 return true;
20613 case ISD::SEXTLOAD:
20614 Val = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL: SDLoc(LD), VT: LDType, Operand: Val);
20615 return true;
20616 case ISD::ZEXTLOAD:
20617 Val = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL: SDLoc(LD), VT: LDType, Operand: Val);
20618 return true;
20619 }
20620 }
20621 return false;
20622}
20623
20624StoreSDNode *DAGCombiner::getUniqueStoreFeeding(LoadSDNode *LD,
20625 int64_t &Offset) {
20626 SDValue Chain = LD->getOperand(Num: 0);
20627
20628 // Look through CALLSEQ_START.
20629 if (Chain.getOpcode() == ISD::CALLSEQ_START)
20630 Chain = Chain->getOperand(Num: 0);
20631
20632 StoreSDNode *ST = nullptr;
20633 SmallVector<SDValue, 8> Aliases;
20634 if (Chain.getOpcode() == ISD::TokenFactor) {
20635 // Look for unique store within the TokenFactor.
20636 for (SDValue Op : Chain->ops()) {
20637 StoreSDNode *Store = dyn_cast<StoreSDNode>(Val: Op.getNode());
20638 if (!Store)
20639 continue;
20640 BaseIndexOffset BasePtrLD = BaseIndexOffset::match(N: LD, DAG);
20641 BaseIndexOffset BasePtrST = BaseIndexOffset::match(N: Store, DAG);
20642 if (!BasePtrST.equalBaseIndex(Other: BasePtrLD, DAG, Off&: Offset))
20643 continue;
20644 // Make sure the store is not aliased with any nodes in TokenFactor.
20645 GatherAllAliases(N: Store, OriginalChain: Chain, Aliases);
20646 if (Aliases.empty() ||
20647 (Aliases.size() == 1 && Aliases.front().getNode() == Store))
20648 ST = Store;
20649 break;
20650 }
20651 } else {
20652 StoreSDNode *Store = dyn_cast<StoreSDNode>(Val: Chain.getNode());
20653 if (Store) {
20654 BaseIndexOffset BasePtrLD = BaseIndexOffset::match(N: LD, DAG);
20655 BaseIndexOffset BasePtrST = BaseIndexOffset::match(N: Store, DAG);
20656 if (BasePtrST.equalBaseIndex(Other: BasePtrLD, DAG, Off&: Offset))
20657 ST = Store;
20658 }
20659 }
20660
20661 return ST;
20662}
20663
20664SDValue DAGCombiner::ForwardStoreValueToDirectLoad(LoadSDNode *LD) {
20665 if (OptLevel == CodeGenOptLevel::None || !LD->isSimple())
20666 return SDValue();
20667 SDValue Chain = LD->getOperand(Num: 0);
20668 int64_t Offset;
20669
20670 StoreSDNode *ST = getUniqueStoreFeeding(LD, Offset);
20671 // TODO: Relax this restriction for unordered atomics (see D66309)
20672 if (!ST || !ST->isSimple() || ST->getAddressSpace() != LD->getAddressSpace())
20673 return SDValue();
20674
20675 EVT LDType = LD->getValueType(ResNo: 0);
20676 EVT LDMemType = LD->getMemoryVT();
20677 EVT STMemType = ST->getMemoryVT();
20678 EVT STType = ST->getValue().getValueType();
20679
20680 // There are two cases to consider here:
20681 // 1. The store is fixed width and the load is scalable. In this case we
20682 // don't know at compile time if the store completely envelops the load
20683 // so we abandon the optimisation.
20684 // 2. The store is scalable and the load is fixed width. We could
20685 // potentially support a limited number of cases here, but there has been
20686 // no cost-benefit analysis to prove it's worth it.
20687 bool LdStScalable = LDMemType.isScalableVT();
20688 if (LdStScalable != STMemType.isScalableVT())
20689 return SDValue();
20690
20691 // If we are dealing with scalable vectors on a big endian platform the
20692 // calculation of offsets below becomes trickier, since we do not know at
20693 // compile time the absolute size of the vector. Until we've done more
20694 // analysis on big-endian platforms it seems better to bail out for now.
20695 if (LdStScalable && DAG.getDataLayout().isBigEndian())
20696 return SDValue();
20697
20698 // Normalize for Endianness. After this Offset=0 will denote that the least
20699 // significant bit in the loaded value maps to the least significant bit in
20700 // the stored value). With Offset=n (for n > 0) the loaded value starts at the
20701 // n:th least significant byte of the stored value.
20702 int64_t OrigOffset = Offset;
20703 if (DAG.getDataLayout().isBigEndian())
20704 Offset = ((int64_t)STMemType.getStoreSizeInBits().getFixedValue() -
20705 (int64_t)LDMemType.getStoreSizeInBits().getFixedValue()) /
20706 8 -
20707 Offset;
20708
20709 // Check that the stored value cover all bits that are loaded.
20710 bool STCoversLD;
20711
20712 TypeSize LdMemSize = LDMemType.getSizeInBits();
20713 TypeSize StMemSize = STMemType.getSizeInBits();
20714 if (LdStScalable)
20715 STCoversLD = (Offset == 0) && LdMemSize == StMemSize;
20716 else
20717 STCoversLD = (Offset >= 0) && (Offset * 8 + LdMemSize.getFixedValue() <=
20718 StMemSize.getFixedValue());
20719
20720 auto ReplaceLd = [&](LoadSDNode *LD, SDValue Val, SDValue Chain) -> SDValue {
20721 if (LD->isIndexed()) {
20722 // Cannot handle opaque target constants and we must respect the user's
20723 // request not to split indexes from loads.
20724 if (!canSplitIdx(LD))
20725 return SDValue();
20726 SDValue Idx = SplitIndexingFromLoad(LD);
20727 SDValue Ops[] = {Val, Idx, Chain};
20728 return CombineTo(N: LD, To: Ops, NumTo: 3);
20729 }
20730 return CombineTo(N: LD, Res0: Val, Res1: Chain);
20731 };
20732
20733 if (!STCoversLD)
20734 return SDValue();
20735
20736 // Memory as copy space (potentially masked).
20737 if (Offset == 0 && LDType == STType && STMemType == LDMemType) {
20738 // Simple case: Direct non-truncating forwarding
20739 if (LDType.getSizeInBits() == LdMemSize)
20740 return ReplaceLd(LD, ST->getValue(), Chain);
20741 // Can we model the truncate and extension with an and mask?
20742 if (STType.isInteger() && LDMemType.isInteger() && !STType.isVector() &&
20743 !LDMemType.isVector() && LD->getExtensionType() != ISD::SEXTLOAD) {
20744 // Mask to size of LDMemType
20745 auto Mask =
20746 DAG.getConstant(Val: APInt::getLowBitsSet(numBits: STType.getFixedSizeInBits(),
20747 loBitsSet: StMemSize.getFixedValue()),
20748 DL: SDLoc(ST), VT: STType);
20749 auto Val = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(LD), VT: LDType, N1: ST->getValue(), N2: Mask);
20750 return ReplaceLd(LD, Val, Chain);
20751 }
20752 }
20753
20754 // Handle some cases for big-endian that would be Offset 0 and handled for
20755 // little-endian.
20756 SDValue Val = ST->getValue();
20757 if (DAG.getDataLayout().isBigEndian() && Offset > 0 && OrigOffset == 0) {
20758 if (STType.isInteger() && !STType.isVector() && LDType.isInteger() &&
20759 !LDType.isVector() && isTypeLegal(VT: STType) &&
20760 TLI.isOperationLegal(Op: ISD::SRL, VT: STType)) {
20761 Val = DAG.getNode(
20762 Opcode: ISD::SRL, DL: SDLoc(LD), VT: STType, N1: Val,
20763 N2: DAG.getShiftAmountConstant(Val: Offset * 8, VT: STType, DL: SDLoc(LD)));
20764 Offset = 0;
20765 }
20766 }
20767
20768 // TODO: Deal with nonzero offset.
20769 if (LD->getBasePtr().isUndef() || Offset != 0)
20770 return SDValue();
20771 // Model necessary truncations / extenstions.
20772 // Truncate Value To Stored Memory Size.
20773 do {
20774 if (!getTruncatedStoreValue(ST, Val))
20775 break;
20776 if (!isTypeLegal(VT: LDMemType))
20777 break;
20778 if (STMemType != LDMemType) {
20779 if (LdMemSize == StMemSize) {
20780 if (TLI.isOperationLegal(Op: ISD::BITCAST, VT: LDMemType) &&
20781 isTypeLegal(VT: LDMemType) &&
20782 TLI.isOperationLegal(Op: ISD::BITCAST, VT: STMemType) &&
20783 isTypeLegal(VT: STMemType) &&
20784 TLI.isLoadBitCastBeneficial(LoadVT: LDMemType, BitcastVT: STMemType, DAG,
20785 MMO: *LD->getMemOperand()))
20786 Val = DAG.getBitcast(VT: LDMemType, V: Val);
20787 else
20788 break;
20789 } else if (LDMemType.isVector() && isTypeLegal(VT: STMemType)) {
20790 EVT EltVT = LDMemType.getVectorElementType();
20791 TypeSize EltSize = EltVT.getSizeInBits();
20792
20793 if (!StMemSize.isKnownMultipleOf(RHS: EltSize))
20794 break;
20795
20796 EVT InterVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: EltVT,
20797 NumElements: StMemSize.divideCoefficientBy(RHS: EltSize));
20798 if (!TLI.isOperationLegalOrCustom(Op: ISD::EXTRACT_SUBVECTOR, VT: InterVT))
20799 break;
20800
20801 Val = DAG.getExtractSubvector(DL: SDLoc(LD), VT: LDMemType,
20802 Vec: DAG.getBitcast(VT: InterVT, V: Val), Idx: 0);
20803 } else if (!STMemType.isVector() && !LDMemType.isVector() &&
20804 STMemType.isInteger() && LDMemType.isInteger())
20805 Val = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(LD), VT: LDMemType, Operand: Val);
20806 else
20807 break;
20808 }
20809 if (!extendLoadedValueToExtension(LD, Val))
20810 break;
20811 return ReplaceLd(LD, Val, Chain);
20812 } while (false);
20813
20814 // On failure, cleanup dead nodes we may have created.
20815 if (Val->use_empty())
20816 deleteAndRecombine(N: Val.getNode());
20817 return SDValue();
20818}
20819
20820SDValue DAGCombiner::visitLOAD(SDNode *N) {
20821 LoadSDNode *LD = cast<LoadSDNode>(Val: N);
20822 SDValue Chain = LD->getChain();
20823 SDValue Ptr = LD->getBasePtr();
20824
20825 // If load is not volatile and there are no uses of the loaded value (and
20826 // the updated indexed value in case of indexed loads), change uses of the
20827 // chain value into uses of the chain input (i.e. delete the dead load).
20828 // TODO: Allow this for unordered atomics (see D66309)
20829 if (LD->isSimple()) {
20830 if (N->getValueType(ResNo: 1) == MVT::Other) {
20831 // Unindexed loads.
20832 if (!N->hasAnyUseOfValue(Value: 0)) {
20833 // It's not safe to use the two value CombineTo variant here. e.g.
20834 // v1, chain2 = load chain1, loc
20835 // v2, chain3 = load chain2, loc
20836 // v3 = add v2, c
20837 // Now we replace use of chain2 with chain1. This makes the second load
20838 // isomorphic to the one we are deleting, and thus makes this load live.
20839 LLVM_DEBUG(dbgs() << "\nReplacing.6 "; N->dump(&DAG);
20840 dbgs() << "\nWith chain: "; Chain.dump(&DAG);
20841 dbgs() << "\n");
20842 WorklistRemover DeadNodes(*this);
20843 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 1), To: Chain);
20844 AddUsersToWorklist(N: Chain.getNode());
20845 if (N->use_empty())
20846 deleteAndRecombine(N);
20847
20848 return SDValue(N, 0); // Return N so it doesn't get rechecked!
20849 }
20850 } else {
20851 // Indexed loads.
20852 assert(N->getValueType(2) == MVT::Other && "Malformed indexed loads?");
20853
20854 // If this load has an opaque TargetConstant offset, then we cannot split
20855 // the indexing into an add/sub directly (that TargetConstant may not be
20856 // valid for a different type of node, and we cannot convert an opaque
20857 // target constant into a regular constant).
20858 bool CanSplitIdx = canSplitIdx(LD);
20859
20860 if (!N->hasAnyUseOfValue(Value: 0) && (CanSplitIdx || !N->hasAnyUseOfValue(Value: 1))) {
20861 SDValue Undef = DAG.getUNDEF(VT: N->getValueType(ResNo: 0));
20862 SDValue Index;
20863 if (N->hasAnyUseOfValue(Value: 1) && CanSplitIdx) {
20864 Index = SplitIndexingFromLoad(LD);
20865 // Try to fold the base pointer arithmetic into subsequent loads and
20866 // stores.
20867 AddUsersToWorklist(N);
20868 } else
20869 Index = DAG.getUNDEF(VT: N->getValueType(ResNo: 1));
20870 LLVM_DEBUG(dbgs() << "\nReplacing.7 "; N->dump(&DAG);
20871 dbgs() << "\nWith: "; Undef.dump(&DAG);
20872 dbgs() << " and 2 other values\n");
20873 WorklistRemover DeadNodes(*this);
20874 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Undef);
20875 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 1), To: Index);
20876 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 2), To: Chain);
20877 deleteAndRecombine(N);
20878 return SDValue(N, 0); // Return N so it doesn't get rechecked!
20879 }
20880 }
20881 }
20882
20883 // If this load is directly stored, replace the load value with the stored
20884 // value.
20885 if (auto V = ForwardStoreValueToDirectLoad(LD))
20886 return V;
20887
20888 // Try to infer better alignment information than the load already has.
20889 if (OptLevel != CodeGenOptLevel::None && LD->isUnindexed() &&
20890 !LD->isAtomic()) {
20891 if (MaybeAlign Alignment = DAG.InferPtrAlign(Ptr)) {
20892 if (*Alignment > LD->getAlign() &&
20893 isAligned(Lhs: *Alignment, SizeInBytes: LD->getSrcValueOffset())) {
20894 SDValue NewLoad = DAG.getExtLoad(
20895 ExtType: LD->getExtensionType(), dl: SDLoc(N), VT: LD->getValueType(ResNo: 0), Chain, Ptr,
20896 PtrInfo: LD->getPointerInfo(), MemVT: LD->getMemoryVT(), Alignment: *Alignment,
20897 MMOFlags: LD->getMemOperand()->getFlags(), AAInfo: LD->getAAInfo());
20898 // NewLoad will always be N as we are only refining the alignment
20899 assert(NewLoad.getNode() == N);
20900 (void)NewLoad;
20901 }
20902 }
20903 }
20904
20905 if (LD->isUnindexed()) {
20906 // Walk up chain skipping non-aliasing memory nodes.
20907 SDValue BetterChain = FindBetterChain(N: LD, Chain);
20908
20909 // If there is a better chain.
20910 if (Chain != BetterChain) {
20911 SDValue ReplLoad;
20912
20913 // Replace the chain to void dependency.
20914 if (LD->getExtensionType() == ISD::NON_EXTLOAD) {
20915 ReplLoad = DAG.getLoad(VT: N->getValueType(ResNo: 0), dl: SDLoc(LD),
20916 Chain: BetterChain, Ptr, MMO: LD->getMemOperand());
20917 } else {
20918 ReplLoad = DAG.getExtLoad(ExtType: LD->getExtensionType(), dl: SDLoc(LD),
20919 VT: LD->getValueType(ResNo: 0),
20920 Chain: BetterChain, Ptr, MemVT: LD->getMemoryVT(),
20921 MMO: LD->getMemOperand());
20922 }
20923
20924 // Create token factor to keep old chain connected.
20925 SDValue Token = DAG.getNode(Opcode: ISD::TokenFactor, DL: SDLoc(N),
20926 VT: MVT::Other, N1: Chain, N2: ReplLoad.getValue(R: 1));
20927
20928 // Replace uses with load result and token factor
20929 return CombineTo(N, Res0: ReplLoad.getValue(R: 0), Res1: Token);
20930 }
20931 }
20932
20933 // Try transforming N to an indexed load.
20934 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
20935 return SDValue(N, 0);
20936
20937 // Try to slice up N to more direct loads if the slices are mapped to
20938 // different register banks or pairing can take place.
20939 if (SliceUpLoad(N))
20940 return SDValue(N, 0);
20941
20942 return SDValue();
20943}
20944
20945namespace {
20946
20947/// Helper structure used to slice a load in smaller loads.
20948/// Basically a slice is obtained from the following sequence:
20949/// Origin = load Ty1, Base
20950/// Shift = srl Ty1 Origin, CstTy Amount
20951/// Inst = trunc Shift to Ty2
20952///
20953/// Then, it will be rewritten into:
20954/// Slice = load SliceTy, Base + SliceOffset
20955/// [Inst = zext Slice to Ty2], only if SliceTy <> Ty2
20956///
20957/// SliceTy is deduced from the number of bits that are actually used to
20958/// build Inst.
20959struct LoadedSlice {
20960 /// Helper structure used to compute the cost of a slice.
20961 struct Cost {
20962 /// Are we optimizing for code size.
20963 bool ForCodeSize = false;
20964
20965 /// Various cost.
20966 unsigned Loads = 0;
20967 unsigned Truncates = 0;
20968 unsigned CrossRegisterBanksCopies = 0;
20969 unsigned ZExts = 0;
20970 unsigned Shift = 0;
20971
20972 explicit Cost(bool ForCodeSize) : ForCodeSize(ForCodeSize) {}
20973
20974 /// Get the cost of one isolated slice.
20975 Cost(const LoadedSlice &LS, bool ForCodeSize)
20976 : ForCodeSize(ForCodeSize), Loads(1) {
20977 EVT TruncType = LS.Inst->getValueType(ResNo: 0);
20978 EVT LoadedType = LS.getLoadedType();
20979 if (TruncType != LoadedType &&
20980 !LS.DAG->getTargetLoweringInfo().isZExtFree(FromTy: LoadedType, ToTy: TruncType))
20981 ZExts = 1;
20982 }
20983
20984 /// Account for slicing gain in the current cost.
20985 /// Slicing provide a few gains like removing a shift or a
20986 /// truncate. This method allows to grow the cost of the original
20987 /// load with the gain from this slice.
20988 void addSliceGain(const LoadedSlice &LS) {
20989 // Each slice saves a truncate.
20990 const TargetLowering &TLI = LS.DAG->getTargetLoweringInfo();
20991 if (!TLI.isTruncateFree(Val: LS.Inst->getOperand(Num: 0), VT2: LS.Inst->getValueType(ResNo: 0)))
20992 ++Truncates;
20993 // If there is a shift amount, this slice gets rid of it.
20994 if (LS.Shift)
20995 ++Shift;
20996 // If this slice can merge a cross register bank copy, account for it.
20997 if (LS.canMergeExpensiveCrossRegisterBankCopy())
20998 ++CrossRegisterBanksCopies;
20999 }
21000
21001 Cost &operator+=(const Cost &RHS) {
21002 Loads += RHS.Loads;
21003 Truncates += RHS.Truncates;
21004 CrossRegisterBanksCopies += RHS.CrossRegisterBanksCopies;
21005 ZExts += RHS.ZExts;
21006 Shift += RHS.Shift;
21007 return *this;
21008 }
21009
21010 bool operator==(const Cost &RHS) const {
21011 return Loads == RHS.Loads && Truncates == RHS.Truncates &&
21012 CrossRegisterBanksCopies == RHS.CrossRegisterBanksCopies &&
21013 ZExts == RHS.ZExts && Shift == RHS.Shift;
21014 }
21015
21016 bool operator!=(const Cost &RHS) const { return !(*this == RHS); }
21017
21018 bool operator<(const Cost &RHS) const {
21019 // Assume cross register banks copies are as expensive as loads.
21020 // FIXME: Do we want some more target hooks?
21021 unsigned ExpensiveOpsLHS = Loads + CrossRegisterBanksCopies;
21022 unsigned ExpensiveOpsRHS = RHS.Loads + RHS.CrossRegisterBanksCopies;
21023 // Unless we are optimizing for code size, consider the
21024 // expensive operation first.
21025 if (!ForCodeSize && ExpensiveOpsLHS != ExpensiveOpsRHS)
21026 return ExpensiveOpsLHS < ExpensiveOpsRHS;
21027 return (Truncates + ZExts + Shift + ExpensiveOpsLHS) <
21028 (RHS.Truncates + RHS.ZExts + RHS.Shift + ExpensiveOpsRHS);
21029 }
21030
21031 bool operator>(const Cost &RHS) const { return RHS < *this; }
21032
21033 bool operator<=(const Cost &RHS) const { return !(RHS < *this); }
21034
21035 bool operator>=(const Cost &RHS) const { return !(*this < RHS); }
21036 };
21037
21038 // The last instruction that represent the slice. This should be a
21039 // truncate instruction.
21040 SDNode *Inst;
21041
21042 // The original load instruction.
21043 LoadSDNode *Origin;
21044
21045 // The right shift amount in bits from the original load.
21046 unsigned Shift;
21047
21048 // The DAG from which Origin came from.
21049 // This is used to get some contextual information about legal types, etc.
21050 SelectionDAG *DAG;
21051
21052 LoadedSlice(SDNode *Inst = nullptr, LoadSDNode *Origin = nullptr,
21053 unsigned Shift = 0, SelectionDAG *DAG = nullptr)
21054 : Inst(Inst), Origin(Origin), Shift(Shift), DAG(DAG) {}
21055
21056 /// Get the bits used in a chunk of bits \p BitWidth large.
21057 /// \return Result is \p BitWidth and has used bits set to 1 and
21058 /// not used bits set to 0.
21059 APInt getUsedBits() const {
21060 // Reproduce the trunc(lshr) sequence:
21061 // - Start from the truncated value.
21062 // - Zero extend to the desired bit width.
21063 // - Shift left.
21064 assert(Origin && "No original load to compare against.");
21065 unsigned BitWidth = Origin->getValueSizeInBits(ResNo: 0);
21066 assert(Inst && "This slice is not bound to an instruction");
21067 assert(Inst->getValueSizeInBits(0) <= BitWidth &&
21068 "Extracted slice is bigger than the whole type!");
21069 APInt UsedBits(Inst->getValueSizeInBits(ResNo: 0), 0);
21070 UsedBits.setAllBits();
21071 UsedBits = UsedBits.zext(width: BitWidth);
21072 UsedBits <<= Shift;
21073 return UsedBits;
21074 }
21075
21076 /// Get the size of the slice to be loaded in bytes.
21077 unsigned getLoadedSize() const {
21078 unsigned SliceSize = getUsedBits().popcount();
21079 assert(!(SliceSize & 0x7) && "Size is not a multiple of a byte.");
21080 return SliceSize / 8;
21081 }
21082
21083 /// Get the type that will be loaded for this slice.
21084 /// Note: This may not be the final type for the slice.
21085 EVT getLoadedType() const {
21086 assert(DAG && "Missing context");
21087 LLVMContext &Ctxt = *DAG->getContext();
21088 return EVT::getIntegerVT(Context&: Ctxt, BitWidth: getLoadedSize() * 8);
21089 }
21090
21091 /// Get the alignment of the load used for this slice.
21092 Align getAlign() const {
21093 Align Alignment = Origin->getAlign();
21094 uint64_t Offset = getOffsetFromBase();
21095 if (Offset != 0)
21096 Alignment = commonAlignment(A: Alignment, Offset: Alignment.value() + Offset);
21097 return Alignment;
21098 }
21099
21100 /// Check if this slice can be rewritten with legal operations.
21101 bool isLegal() const {
21102 // An invalid slice is not legal.
21103 if (!Origin || !Inst || !DAG)
21104 return false;
21105
21106 // Offsets are for indexed load only, we do not handle that.
21107 if (!Origin->getOffset().isUndef())
21108 return false;
21109
21110 const TargetLowering &TLI = DAG->getTargetLoweringInfo();
21111
21112 // Check that the type is legal.
21113 EVT SliceType = getLoadedType();
21114 if (!TLI.isTypeLegal(VT: SliceType))
21115 return false;
21116
21117 // Check that the load is legal for this type.
21118 if (!TLI.isOperationLegal(Op: ISD::LOAD, VT: SliceType))
21119 return false;
21120
21121 // Check that the offset can be computed.
21122 // 1. Check its type.
21123 EVT PtrType = Origin->getBasePtr().getValueType();
21124 if (PtrType == MVT::Untyped || PtrType.isExtended())
21125 return false;
21126
21127 // 2. Check that it fits in the immediate.
21128 if (!TLI.isLegalAddImmediate(getOffsetFromBase()))
21129 return false;
21130
21131 // 3. Check that the computation is legal.
21132 if (!TLI.isOperationLegal(Op: ISD::ADD, VT: PtrType))
21133 return false;
21134
21135 // Check that the zext is legal if it needs one.
21136 EVT TruncateType = Inst->getValueType(ResNo: 0);
21137 if (TruncateType != SliceType &&
21138 !TLI.isOperationLegal(Op: ISD::ZERO_EXTEND, VT: TruncateType))
21139 return false;
21140
21141 return true;
21142 }
21143
21144 /// Get the offset in bytes of this slice in the original chunk of
21145 /// bits.
21146 /// \pre DAG != nullptr.
21147 uint64_t getOffsetFromBase() const {
21148 assert(DAG && "Missing context.");
21149 bool IsBigEndian = DAG->getDataLayout().isBigEndian();
21150 assert(!(Shift & 0x7) && "Shifts not aligned on Bytes are not supported.");
21151 uint64_t Offset = Shift / 8;
21152 unsigned TySizeInBytes = Origin->getValueSizeInBits(ResNo: 0) / 8;
21153 assert(!(Origin->getValueSizeInBits(0) & 0x7) &&
21154 "The size of the original loaded type is not a multiple of a"
21155 " byte.");
21156 // If Offset is bigger than TySizeInBytes, it means we are loading all
21157 // zeros. This should have been optimized before in the process.
21158 assert(TySizeInBytes > Offset &&
21159 "Invalid shift amount for given loaded size");
21160 if (IsBigEndian)
21161 Offset = TySizeInBytes - Offset - getLoadedSize();
21162 return Offset;
21163 }
21164
21165 /// Generate the sequence of instructions to load the slice
21166 /// represented by this object and redirect the uses of this slice to
21167 /// this new sequence of instructions.
21168 /// \pre this->Inst && this->Origin are valid Instructions and this
21169 /// object passed the legal check: LoadedSlice::isLegal returned true.
21170 /// \return The last instruction of the sequence used to load the slice.
21171 SDValue loadSlice() const {
21172 assert(Inst && Origin && "Unable to replace a non-existing slice.");
21173 const SDValue &OldBaseAddr = Origin->getBasePtr();
21174 SDValue BaseAddr = OldBaseAddr;
21175 // Get the offset in that chunk of bytes w.r.t. the endianness.
21176 int64_t Offset = static_cast<int64_t>(getOffsetFromBase());
21177 assert(Offset >= 0 && "Offset too big to fit in int64_t!");
21178 if (Offset) {
21179 // BaseAddr = BaseAddr + Offset.
21180 EVT ArithType = BaseAddr.getValueType();
21181 SDLoc DL(Origin);
21182 BaseAddr = DAG->getNode(Opcode: ISD::ADD, DL, VT: ArithType, N1: BaseAddr,
21183 N2: DAG->getConstant(Val: Offset, DL, VT: ArithType));
21184 }
21185
21186 // Create the type of the loaded slice according to its size.
21187 EVT SliceType = getLoadedType();
21188
21189 // Create the load for the slice.
21190 SDValue LastInst =
21191 DAG->getLoad(VT: SliceType, dl: SDLoc(Origin), Chain: Origin->getChain(), Ptr: BaseAddr,
21192 PtrInfo: Origin->getPointerInfo().getWithOffset(O: Offset), Alignment: getAlign(),
21193 MMOFlags: Origin->getMemOperand()->getFlags());
21194 // If the final type is not the same as the loaded type, this means that
21195 // we have to pad with zero. Create a zero extend for that.
21196 EVT FinalType = Inst->getValueType(ResNo: 0);
21197 if (SliceType != FinalType)
21198 LastInst =
21199 DAG->getNode(Opcode: ISD::ZERO_EXTEND, DL: SDLoc(LastInst), VT: FinalType, Operand: LastInst);
21200 return LastInst;
21201 }
21202
21203 /// Check if this slice can be merged with an expensive cross register
21204 /// bank copy. E.g.,
21205 /// i = load i32
21206 /// f = bitcast i32 i to float
21207 bool canMergeExpensiveCrossRegisterBankCopy() const {
21208 if (!Inst || !Inst->hasOneUse())
21209 return false;
21210 SDNode *User = *Inst->user_begin();
21211 if (User->getOpcode() != ISD::BITCAST)
21212 return false;
21213 assert(DAG && "Missing context");
21214 const TargetLowering &TLI = DAG->getTargetLoweringInfo();
21215 EVT ResVT = User->getValueType(ResNo: 0);
21216 const TargetRegisterClass *ResRC =
21217 TLI.getRegClassFor(VT: ResVT.getSimpleVT(), isDivergent: User->isDivergent());
21218 const TargetRegisterClass *ArgRC =
21219 TLI.getRegClassFor(VT: User->getOperand(Num: 0).getValueType().getSimpleVT(),
21220 isDivergent: User->getOperand(Num: 0)->isDivergent());
21221 if (ArgRC == ResRC || !TLI.isOperationLegal(Op: ISD::LOAD, VT: ResVT))
21222 return false;
21223
21224 // At this point, we know that we perform a cross-register-bank copy.
21225 // Check if it is expensive.
21226 const TargetRegisterInfo *TRI = DAG->getSubtarget().getRegisterInfo();
21227 // Assume bitcasts are cheap, unless both register classes do not
21228 // explicitly share a common sub class.
21229 if (!TRI || TRI->getCommonSubClass(A: ArgRC, B: ResRC))
21230 return false;
21231
21232 // Check if it will be merged with the load.
21233 // 1. Check the alignment / fast memory access constraint.
21234 unsigned IsFast = 0;
21235 if (!TLI.allowsMemoryAccess(Context&: *DAG->getContext(), DL: DAG->getDataLayout(), VT: ResVT,
21236 AddrSpace: Origin->getAddressSpace(), Alignment: getAlign(),
21237 Flags: Origin->getMemOperand()->getFlags(), Fast: &IsFast) ||
21238 !IsFast)
21239 return false;
21240
21241 // 2. Check that the load is a legal operation for that type.
21242 if (!TLI.isOperationLegal(Op: ISD::LOAD, VT: ResVT))
21243 return false;
21244
21245 // 3. Check that we do not have a zext in the way.
21246 if (Inst->getValueType(ResNo: 0) != getLoadedType())
21247 return false;
21248
21249 return true;
21250 }
21251};
21252
21253} // end anonymous namespace
21254
21255/// Check that all bits set in \p UsedBits form a dense region, i.e.,
21256/// \p UsedBits looks like 0..0 1..1 0..0.
21257static bool areUsedBitsDense(const APInt &UsedBits) {
21258 // If all the bits are one, this is dense!
21259 if (UsedBits.isAllOnes())
21260 return true;
21261
21262 // Get rid of the unused bits on the right.
21263 APInt NarrowedUsedBits = UsedBits.lshr(shiftAmt: UsedBits.countr_zero());
21264 // Get rid of the unused bits on the left.
21265 if (NarrowedUsedBits.countl_zero())
21266 NarrowedUsedBits = NarrowedUsedBits.trunc(width: NarrowedUsedBits.getActiveBits());
21267 // Check that the chunk of bits is completely used.
21268 return NarrowedUsedBits.isAllOnes();
21269}
21270
21271/// Check whether or not \p First and \p Second are next to each other
21272/// in memory. This means that there is no hole between the bits loaded
21273/// by \p First and the bits loaded by \p Second.
21274static bool areSlicesNextToEachOther(const LoadedSlice &First,
21275 const LoadedSlice &Second) {
21276 assert(First.Origin == Second.Origin && First.Origin &&
21277 "Unable to match different memory origins.");
21278 APInt UsedBits = First.getUsedBits();
21279 assert((UsedBits & Second.getUsedBits()) == 0 &&
21280 "Slices are not supposed to overlap.");
21281 UsedBits |= Second.getUsedBits();
21282 return areUsedBitsDense(UsedBits);
21283}
21284
21285/// Adjust the \p GlobalLSCost according to the target
21286/// paring capabilities and the layout of the slices.
21287/// \pre \p GlobalLSCost should account for at least as many loads as
21288/// there is in the slices in \p LoadedSlices.
21289static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices,
21290 LoadedSlice::Cost &GlobalLSCost) {
21291 unsigned NumberOfSlices = LoadedSlices.size();
21292 // If there is less than 2 elements, no pairing is possible.
21293 if (NumberOfSlices < 2)
21294 return;
21295
21296 // Sort the slices so that elements that are likely to be next to each
21297 // other in memory are next to each other in the list.
21298 llvm::sort(C&: LoadedSlices, Comp: [](const LoadedSlice &LHS, const LoadedSlice &RHS) {
21299 assert(LHS.Origin == RHS.Origin && "Different bases not implemented.");
21300 return LHS.getOffsetFromBase() < RHS.getOffsetFromBase();
21301 });
21302 const TargetLowering &TLI = LoadedSlices[0].DAG->getTargetLoweringInfo();
21303 // First (resp. Second) is the first (resp. Second) potentially candidate
21304 // to be placed in a paired load.
21305 const LoadedSlice *First = nullptr;
21306 const LoadedSlice *Second = nullptr;
21307 for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice,
21308 // Set the beginning of the pair.
21309 First = Second) {
21310 Second = &LoadedSlices[CurrSlice];
21311
21312 // If First is NULL, it means we start a new pair.
21313 // Get to the next slice.
21314 if (!First)
21315 continue;
21316
21317 EVT LoadedType = First->getLoadedType();
21318
21319 // If the types of the slices are different, we cannot pair them.
21320 if (LoadedType != Second->getLoadedType())
21321 continue;
21322
21323 // Check if the target supplies paired loads for this type.
21324 Align RequiredAlignment;
21325 if (!TLI.hasPairedLoad(LoadedType, RequiredAlignment)) {
21326 // move to the next pair, this type is hopeless.
21327 Second = nullptr;
21328 continue;
21329 }
21330 // Check if we meet the alignment requirement.
21331 if (First->getAlign() < RequiredAlignment)
21332 continue;
21333
21334 // Check that both loads are next to each other in memory.
21335 if (!areSlicesNextToEachOther(First: *First, Second: *Second))
21336 continue;
21337
21338 assert(GlobalLSCost.Loads > 0 && "We save more loads than we created!");
21339 --GlobalLSCost.Loads;
21340 // Move to the next pair.
21341 Second = nullptr;
21342 }
21343}
21344
21345/// Check the profitability of all involved LoadedSlice.
21346/// Currently, it is considered profitable if there is exactly two
21347/// involved slices (1) which are (2) next to each other in memory, and
21348/// whose cost (\see LoadedSlice::Cost) is smaller than the original load (3).
21349///
21350/// Note: The order of the elements in \p LoadedSlices may be modified, but not
21351/// the elements themselves.
21352///
21353/// FIXME: When the cost model will be mature enough, we can relax
21354/// constraints (1) and (2).
21355static bool isSlicingProfitable(SmallVectorImpl<LoadedSlice> &LoadedSlices,
21356 const APInt &UsedBits, bool ForCodeSize) {
21357 unsigned NumberOfSlices = LoadedSlices.size();
21358 if (StressLoadSlicing)
21359 return NumberOfSlices > 1;
21360
21361 // Check (1).
21362 if (NumberOfSlices != 2)
21363 return false;
21364
21365 // Check (2).
21366 if (!areUsedBitsDense(UsedBits))
21367 return false;
21368
21369 // Check (3).
21370 LoadedSlice::Cost OrigCost(ForCodeSize), GlobalSlicingCost(ForCodeSize);
21371 // The original code has one big load.
21372 OrigCost.Loads = 1;
21373 for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice) {
21374 const LoadedSlice &LS = LoadedSlices[CurrSlice];
21375 // Accumulate the cost of all the slices.
21376 LoadedSlice::Cost SliceCost(LS, ForCodeSize);
21377 GlobalSlicingCost += SliceCost;
21378
21379 // Account as cost in the original configuration the gain obtained
21380 // with the current slices.
21381 OrigCost.addSliceGain(LS);
21382 }
21383
21384 // If the target supports paired load, adjust the cost accordingly.
21385 adjustCostForPairing(LoadedSlices, GlobalLSCost&: GlobalSlicingCost);
21386 return OrigCost > GlobalSlicingCost;
21387}
21388
21389/// If the given load, \p LI, is used only by trunc or trunc(lshr)
21390/// operations, split it in the various pieces being extracted.
21391///
21392/// This sort of thing is introduced by SROA.
21393/// This slicing takes care not to insert overlapping loads.
21394/// \pre LI is a simple load (i.e., not an atomic or volatile load).
21395bool DAGCombiner::SliceUpLoad(SDNode *N) {
21396 if (Level < AfterLegalizeDAG)
21397 return false;
21398
21399 LoadSDNode *LD = cast<LoadSDNode>(Val: N);
21400 if (!LD->isSimple() || !ISD::isNormalLoad(N: LD) ||
21401 !LD->getValueType(ResNo: 0).isInteger())
21402 return false;
21403
21404 // The algorithm to split up a load of a scalable vector into individual
21405 // elements currently requires knowing the length of the loaded type,
21406 // so will need adjusting to work on scalable vectors.
21407 if (LD->getValueType(ResNo: 0).isScalableVector())
21408 return false;
21409
21410 // Keep track of already used bits to detect overlapping values.
21411 // In that case, we will just abort the transformation.
21412 APInt UsedBits(LD->getValueSizeInBits(ResNo: 0), 0);
21413
21414 SmallVector<LoadedSlice, 4> LoadedSlices;
21415
21416 // Check if this load is used as several smaller chunks of bits.
21417 // Basically, look for uses in trunc or trunc(lshr) and record a new chain
21418 // of computation for each trunc.
21419 for (SDUse &U : LD->uses()) {
21420 // Skip the uses of the chain.
21421 if (U.getResNo() != 0)
21422 continue;
21423
21424 SDNode *User = U.getUser();
21425 unsigned Shift = 0;
21426
21427 // Check if this is a trunc(lshr).
21428 if (User->getOpcode() == ISD::SRL && User->hasOneUse() &&
21429 isa<ConstantSDNode>(Val: User->getOperand(Num: 1))) {
21430 Shift = User->getConstantOperandVal(Num: 1);
21431 User = *User->user_begin();
21432 }
21433
21434 // At this point, User is a Truncate, iff we encountered, trunc or
21435 // trunc(lshr).
21436 if (User->getOpcode() != ISD::TRUNCATE)
21437 return false;
21438
21439 // The width of the type must be a power of 2 and greater than 8-bits.
21440 // Otherwise the load cannot be represented in LLVM IR.
21441 // Moreover, if we shifted with a non-8-bits multiple, the slice
21442 // will be across several bytes. We do not support that.
21443 unsigned Width = User->getValueSizeInBits(ResNo: 0);
21444 if (Width < 8 || !isPowerOf2_32(Value: Width) || (Shift & 0x7))
21445 return false;
21446
21447 // Build the slice for this chain of computations.
21448 LoadedSlice LS(User, LD, Shift, &DAG);
21449 APInt CurrentUsedBits = LS.getUsedBits();
21450
21451 // Check if this slice overlaps with another.
21452 if ((CurrentUsedBits & UsedBits) != 0)
21453 return false;
21454 // Update the bits used globally.
21455 UsedBits |= CurrentUsedBits;
21456
21457 // Check if the new slice would be legal.
21458 if (!LS.isLegal())
21459 return false;
21460
21461 // Record the slice.
21462 LoadedSlices.push_back(Elt: LS);
21463 }
21464
21465 // Abort slicing if it does not seem to be profitable.
21466 if (!isSlicingProfitable(LoadedSlices, UsedBits, ForCodeSize))
21467 return false;
21468
21469 ++SlicedLoads;
21470
21471 // Rewrite each chain to use an independent load.
21472 // By construction, each chain can be represented by a unique load.
21473
21474 // Prepare the argument for the new token factor for all the slices.
21475 SmallVector<SDValue, 8> ArgChains;
21476 for (const LoadedSlice &LS : LoadedSlices) {
21477 SDValue SliceInst = LS.loadSlice();
21478 CombineTo(N: LS.Inst, Res: SliceInst, AddTo: true);
21479 if (SliceInst.getOpcode() != ISD::LOAD)
21480 SliceInst = SliceInst.getOperand(i: 0);
21481 assert(SliceInst->getOpcode() == ISD::LOAD &&
21482 "It takes more than a zext to get to the loaded slice!!");
21483 ArgChains.push_back(Elt: SliceInst.getValue(R: 1));
21484 }
21485
21486 SDValue Chain = DAG.getNode(Opcode: ISD::TokenFactor, DL: SDLoc(LD), VT: MVT::Other,
21487 Ops: ArgChains);
21488 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 1), To: Chain);
21489 AddToWorklist(N: Chain.getNode());
21490 return true;
21491}
21492
21493/// Check to see if V is (and load (ptr), imm), where the load is having
21494/// specific bytes cleared out. If so, return the byte size being masked out
21495/// and the shift amount.
21496static std::pair<unsigned, unsigned>
21497CheckForMaskedLoad(SDValue V, SDValue Ptr, SDValue Chain) {
21498 std::pair<unsigned, unsigned> Result(0, 0);
21499
21500 // Check for the structure we're looking for.
21501 if (V->getOpcode() != ISD::AND ||
21502 !isa<ConstantSDNode>(Val: V->getOperand(Num: 1)) ||
21503 !ISD::isNormalLoad(N: V->getOperand(Num: 0).getNode()))
21504 return Result;
21505
21506 // Check the chain and pointer.
21507 LoadSDNode *LD = cast<LoadSDNode>(Val: V->getOperand(Num: 0));
21508 if (LD->getBasePtr() != Ptr) return Result; // Not from same pointer.
21509
21510 // This only handles simple types.
21511 if (V.getValueType() != MVT::i16 &&
21512 V.getValueType() != MVT::i32 &&
21513 V.getValueType() != MVT::i64)
21514 return Result;
21515
21516 // Check the constant mask. Invert it so that the bits being masked out are
21517 // 0 and the bits being kept are 1. Use getSExtValue so that leading bits
21518 // follow the sign bit for uniformity.
21519 uint64_t NotMask = ~cast<ConstantSDNode>(Val: V->getOperand(Num: 1))->getSExtValue();
21520 unsigned NotMaskLZ = llvm::countl_zero(Val: NotMask);
21521 if (NotMaskLZ & 7) return Result; // Must be multiple of a byte.
21522 unsigned NotMaskTZ = llvm::countr_zero(Val: NotMask);
21523 if (NotMaskTZ & 7) return Result; // Must be multiple of a byte.
21524 if (NotMaskLZ == 64) return Result; // All zero mask.
21525
21526 // See if we have a continuous run of bits. If so, we have 0*1+0*
21527 if (llvm::countr_one(Value: NotMask >> NotMaskTZ) + NotMaskTZ + NotMaskLZ != 64)
21528 return Result;
21529
21530 // Adjust NotMaskLZ down to be from the actual size of the int instead of i64.
21531 if (V.getValueType() != MVT::i64 && NotMaskLZ)
21532 NotMaskLZ -= 64-V.getValueSizeInBits();
21533
21534 unsigned MaskedBytes = (V.getValueSizeInBits()-NotMaskLZ-NotMaskTZ)/8;
21535 switch (MaskedBytes) {
21536 case 1:
21537 case 2:
21538 case 4: break;
21539 default: return Result; // All one mask, or 5-byte mask.
21540 }
21541
21542 // Verify that the first bit starts at a multiple of mask so that the access
21543 // is aligned the same as the access width.
21544 if (NotMaskTZ && NotMaskTZ/8 % MaskedBytes) return Result;
21545
21546 // For narrowing to be valid, it must be the case that the load the
21547 // immediately preceding memory operation before the store.
21548 if (LD == Chain.getNode())
21549 ; // ok.
21550 else if (Chain->getOpcode() == ISD::TokenFactor &&
21551 SDValue(LD, 1).hasOneUse()) {
21552 // LD has only 1 chain use so they are no indirect dependencies.
21553 if (!LD->isOperandOf(N: Chain.getNode()))
21554 return Result;
21555 } else
21556 return Result; // Fail.
21557
21558 Result.first = MaskedBytes;
21559 Result.second = NotMaskTZ/8;
21560 return Result;
21561}
21562
21563/// Check to see if IVal is something that provides a value as specified by
21564/// MaskInfo. If so, replace the specified store with a narrower store of
21565/// truncated IVal.
21566static SDValue
21567ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned, unsigned> &MaskInfo,
21568 SDValue IVal, StoreSDNode *St,
21569 DAGCombiner *DC) {
21570 unsigned NumBytes = MaskInfo.first;
21571 unsigned ByteShift = MaskInfo.second;
21572 SelectionDAG &DAG = DC->getDAG();
21573
21574 // Check to see if IVal is all zeros in the part being masked in by the 'or'
21575 // that uses this. If not, this is not a replacement.
21576 APInt Mask = ~APInt::getBitsSet(numBits: IVal.getValueSizeInBits(),
21577 loBit: ByteShift*8, hiBit: (ByteShift+NumBytes)*8);
21578 if (!DAG.MaskedValueIsZero(Op: IVal, Mask)) return SDValue();
21579
21580 // Check that it is legal on the target to do this. It is legal if the new
21581 // VT we're shrinking to (i8/i16/i32) is legal or we're still before type
21582 // legalization. If the source type is legal, but the store type isn't, see
21583 // if we can use a truncating store.
21584 MVT VT = MVT::getIntegerVT(BitWidth: NumBytes * 8);
21585 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
21586 bool UseTruncStore;
21587 if (DC->isTypeLegal(VT))
21588 UseTruncStore = false;
21589 else if (TLI.isTypeLegal(VT: IVal.getValueType()) &&
21590 TLI.isTruncStoreLegal(ValVT: IVal.getValueType(), MemVT: VT))
21591 UseTruncStore = true;
21592 else
21593 return SDValue();
21594
21595 // Can't do this for indexed stores.
21596 if (St->isIndexed())
21597 return SDValue();
21598
21599 // Check that the target doesn't think this is a bad idea.
21600 if (St->getMemOperand() &&
21601 !TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT,
21602 MMO: *St->getMemOperand()))
21603 return SDValue();
21604
21605 // Okay, we can do this! Replace the 'St' store with a store of IVal that is
21606 // shifted by ByteShift and truncated down to NumBytes.
21607 if (ByteShift) {
21608 SDLoc DL(IVal);
21609 IVal = DAG.getNode(
21610 Opcode: ISD::SRL, DL, VT: IVal.getValueType(), N1: IVal,
21611 N2: DAG.getShiftAmountConstant(Val: ByteShift * 8, VT: IVal.getValueType(), DL));
21612 }
21613
21614 // Figure out the offset for the store and the alignment of the access.
21615 unsigned StOffset;
21616 if (DAG.getDataLayout().isLittleEndian())
21617 StOffset = ByteShift;
21618 else
21619 StOffset = IVal.getValueType().getStoreSize() - ByteShift - NumBytes;
21620
21621 SDValue Ptr = St->getBasePtr();
21622 if (StOffset) {
21623 SDLoc DL(IVal);
21624 Ptr = DAG.getMemBasePlusOffset(Base: Ptr, Offset: TypeSize::getFixed(ExactSize: StOffset), DL);
21625 }
21626
21627 ++OpsNarrowed;
21628 if (UseTruncStore)
21629 return DAG.getTruncStore(Chain: St->getChain(), dl: SDLoc(St), Val: IVal, Ptr,
21630 PtrInfo: St->getPointerInfo().getWithOffset(O: StOffset), SVT: VT,
21631 Alignment: St->getBaseAlign());
21632
21633 // Truncate down to the new size.
21634 IVal = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(IVal), VT, Operand: IVal);
21635
21636 return DAG.getStore(Chain: St->getChain(), dl: SDLoc(St), Val: IVal, Ptr,
21637 PtrInfo: St->getPointerInfo().getWithOffset(O: StOffset),
21638 Alignment: St->getBaseAlign());
21639}
21640
21641/// Look for sequence of load / op / store where op is one of 'or', 'xor', and
21642/// 'and' of immediates. If 'op' is only touching some of the loaded bits, try
21643/// narrowing the load and store if it would end up being a win for performance
21644/// or code size.
21645SDValue DAGCombiner::ReduceLoadOpStoreWidth(SDNode *N) {
21646 StoreSDNode *ST = cast<StoreSDNode>(Val: N);
21647 if (!ST->isSimple())
21648 return SDValue();
21649
21650 SDValue Chain = ST->getChain();
21651 SDValue Value = ST->getValue();
21652 SDValue Ptr = ST->getBasePtr();
21653 EVT VT = Value.getValueType();
21654
21655 if (ST->isTruncatingStore() || VT.isVector())
21656 return SDValue();
21657
21658 unsigned Opc = Value.getOpcode();
21659
21660 if ((Opc != ISD::OR && Opc != ISD::XOR && Opc != ISD::AND) ||
21661 !Value.hasOneUse())
21662 return SDValue();
21663
21664 // If this is "store (or X, Y), P" and X is "(and (load P), cst)", where cst
21665 // is a byte mask indicating a consecutive number of bytes, check to see if
21666 // Y is known to provide just those bytes. If so, we try to replace the
21667 // load + replace + store sequence with a single (narrower) store, which makes
21668 // the load dead.
21669 if (Opc == ISD::OR && EnableShrinkLoadReplaceStoreWithStore) {
21670 std::pair<unsigned, unsigned> MaskedLoad;
21671 MaskedLoad = CheckForMaskedLoad(V: Value.getOperand(i: 0), Ptr, Chain);
21672 if (MaskedLoad.first)
21673 if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskInfo: MaskedLoad,
21674 IVal: Value.getOperand(i: 1), St: ST,DC: this))
21675 return NewST;
21676
21677 // Or is commutative, so try swapping X and Y.
21678 MaskedLoad = CheckForMaskedLoad(V: Value.getOperand(i: 1), Ptr, Chain);
21679 if (MaskedLoad.first)
21680 if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskInfo: MaskedLoad,
21681 IVal: Value.getOperand(i: 0), St: ST,DC: this))
21682 return NewST;
21683 }
21684
21685 if (!EnableReduceLoadOpStoreWidth)
21686 return SDValue();
21687
21688 if (Value.getOperand(i: 1).getOpcode() != ISD::Constant)
21689 return SDValue();
21690
21691 SDValue N0 = Value.getOperand(i: 0);
21692 if (ISD::isNormalLoad(N: N0.getNode()) && N0.hasOneUse() &&
21693 Chain == SDValue(N0.getNode(), 1)) {
21694 LoadSDNode *LD = cast<LoadSDNode>(Val&: N0);
21695 if (LD->getBasePtr() != Ptr ||
21696 LD->getPointerInfo().getAddrSpace() !=
21697 ST->getPointerInfo().getAddrSpace())
21698 return SDValue();
21699
21700 // Find the type NewVT to narrow the load / op / store to.
21701 SDValue N1 = Value.getOperand(i: 1);
21702 unsigned BitWidth = N1.getValueSizeInBits();
21703 APInt Imm = N1->getAsAPIntVal();
21704 if (Opc == ISD::AND)
21705 Imm.flipAllBits();
21706 if (Imm == 0 || Imm.isAllOnes())
21707 return SDValue();
21708 // Find least/most significant bit that need to be part of the narrowed
21709 // operation. We assume target will need to address/access full bytes, so
21710 // we make sure to align LSB and MSB at byte boundaries.
21711 unsigned BitsPerByteMask = 7u;
21712 unsigned LSB = Imm.countr_zero() & ~BitsPerByteMask;
21713 unsigned MSB = (Imm.getActiveBits() - 1) | BitsPerByteMask;
21714 unsigned NewBW = NextPowerOf2(A: MSB - LSB);
21715 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: NewBW);
21716 // The narrowing should be profitable, the load/store operation should be
21717 // legal (or custom) and the store size should be equal to the NewVT width.
21718 while (NewBW < BitWidth &&
21719 (NewVT.getStoreSizeInBits() != NewBW ||
21720 !TLI.isOperationLegalOrCustom(Op: Opc, VT: NewVT) ||
21721 (!ReduceLoadOpStoreWidthForceNarrowingProfitable &&
21722 !TLI.isNarrowingProfitable(N, SrcVT: VT, DestVT: NewVT)))) {
21723 NewBW = NextPowerOf2(A: NewBW);
21724 NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: NewBW);
21725 }
21726 if (NewBW >= BitWidth)
21727 return SDValue();
21728
21729 // If we come this far NewVT/NewBW reflect a power-of-2 sized type that is
21730 // large enough to cover all bits that should be modified. This type might
21731 // however be larger than really needed (such as i32 while we actually only
21732 // need to modify one byte). Now we need to find our how to align the memory
21733 // accesses to satisfy preferred alignments as well as avoiding to access
21734 // memory outside the store size of the orignal access.
21735
21736 unsigned VTStoreSize = VT.getStoreSizeInBits().getFixedValue();
21737
21738 // Let ShAmt denote amount of bits to skip, counted from the least
21739 // significant bits of Imm. And let PtrOff how much the pointer needs to be
21740 // offsetted (in bytes) for the new access.
21741 unsigned ShAmt = 0;
21742 uint64_t PtrOff = 0;
21743 for (; ShAmt + NewBW <= VTStoreSize; ShAmt += 8) {
21744 // Make sure the range [ShAmt, ShAmt+NewBW) cover both LSB and MSB.
21745 if (ShAmt > LSB)
21746 return SDValue();
21747 if (ShAmt + NewBW < MSB)
21748 continue;
21749
21750 // Calculate PtrOff.
21751 unsigned PtrAdjustmentInBits = DAG.getDataLayout().isBigEndian()
21752 ? VTStoreSize - NewBW - ShAmt
21753 : ShAmt;
21754 PtrOff = PtrAdjustmentInBits / 8;
21755
21756 // Now check if narrow access is allowed and fast, considering alignments.
21757 unsigned IsFast = 0;
21758 Align NewAlign = commonAlignment(A: LD->getAlign(), Offset: PtrOff);
21759 if (TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT: NewVT,
21760 AddrSpace: LD->getAddressSpace(), Alignment: NewAlign,
21761 Flags: LD->getMemOperand()->getFlags(), Fast: &IsFast) &&
21762 IsFast)
21763 break;
21764 }
21765 // If loop above did not find any accepted ShAmt we need to exit here.
21766 if (ShAmt + NewBW > VTStoreSize)
21767 return SDValue();
21768
21769 APInt NewImm = Imm.lshr(shiftAmt: ShAmt).trunc(width: NewBW);
21770 if (Opc == ISD::AND)
21771 NewImm.flipAllBits();
21772 Align NewAlign = commonAlignment(A: LD->getAlign(), Offset: PtrOff);
21773 SDValue NewPtr =
21774 DAG.getMemBasePlusOffset(Base: Ptr, Offset: TypeSize::getFixed(ExactSize: PtrOff), DL: SDLoc(LD));
21775 SDValue NewLD =
21776 DAG.getLoad(VT: NewVT, dl: SDLoc(N0), Chain: LD->getChain(), Ptr: NewPtr,
21777 PtrInfo: LD->getPointerInfo().getWithOffset(O: PtrOff), Alignment: NewAlign,
21778 MMOFlags: LD->getMemOperand()->getFlags(), AAInfo: LD->getAAInfo());
21779 SDValue NewVal = DAG.getNode(Opcode: Opc, DL: SDLoc(Value), VT: NewVT, N1: NewLD,
21780 N2: DAG.getConstant(Val: NewImm, DL: SDLoc(Value), VT: NewVT));
21781 SDValue NewST =
21782 DAG.getStore(Chain, dl: SDLoc(N), Val: NewVal, Ptr: NewPtr,
21783 PtrInfo: ST->getPointerInfo().getWithOffset(O: PtrOff), Alignment: NewAlign);
21784
21785 AddToWorklist(N: NewPtr.getNode());
21786 AddToWorklist(N: NewLD.getNode());
21787 AddToWorklist(N: NewVal.getNode());
21788 WorklistRemover DeadNodes(*this);
21789 DAG.ReplaceAllUsesOfValueWith(From: N0.getValue(R: 1), To: NewLD.getValue(R: 1));
21790 ++OpsNarrowed;
21791 return NewST;
21792 }
21793
21794 return SDValue();
21795}
21796
21797/// For a given floating point load / store pair, if the load value isn't used
21798/// by any other operations, then consider transforming the pair to integer
21799/// load / store operations if the target deems the transformation profitable.
21800SDValue DAGCombiner::TransformFPLoadStorePair(SDNode *N) {
21801 StoreSDNode *ST = cast<StoreSDNode>(Val: N);
21802 SDValue Value = ST->getValue();
21803 if (ISD::isNormalStore(N: ST) && ISD::isNormalLoad(N: Value.getNode()) &&
21804 Value.hasOneUse()) {
21805 LoadSDNode *LD = cast<LoadSDNode>(Val&: Value);
21806 EVT VT = LD->getMemoryVT();
21807 if (!VT.isSimple() || !VT.isFloatingPoint() || VT != ST->getMemoryVT() ||
21808 LD->isNonTemporal() || ST->isNonTemporal() ||
21809 LD->getPointerInfo().getAddrSpace() != 0 ||
21810 ST->getPointerInfo().getAddrSpace() != 0)
21811 return SDValue();
21812
21813 TypeSize VTSize = VT.getSizeInBits();
21814
21815 // We don't know the size of scalable types at compile time so we cannot
21816 // create an integer of the equivalent size.
21817 if (VTSize.isScalable())
21818 return SDValue();
21819
21820 unsigned FastLD = 0, FastST = 0;
21821 EVT IntVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: VTSize.getFixedValue());
21822 if (!TLI.isOperationLegal(Op: ISD::LOAD, VT: IntVT) ||
21823 !TLI.isOperationLegal(Op: ISD::STORE, VT: IntVT) ||
21824 !TLI.isDesirableToTransformToIntegerOp(ISD::LOAD, VT) ||
21825 !TLI.isDesirableToTransformToIntegerOp(ISD::STORE, VT) ||
21826 !TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT: IntVT,
21827 MMO: *LD->getMemOperand(), Fast: &FastLD) ||
21828 !TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT: IntVT,
21829 MMO: *ST->getMemOperand(), Fast: &FastST) ||
21830 !FastLD || !FastST)
21831 return SDValue();
21832
21833 SDValue NewLD = DAG.getLoad(VT: IntVT, dl: SDLoc(Value), Chain: LD->getChain(),
21834 Ptr: LD->getBasePtr(), MMO: LD->getMemOperand());
21835
21836 SDValue NewST = DAG.getStore(Chain: ST->getChain(), dl: SDLoc(N), Val: NewLD,
21837 Ptr: ST->getBasePtr(), MMO: ST->getMemOperand());
21838
21839 AddToWorklist(N: NewLD.getNode());
21840 AddToWorklist(N: NewST.getNode());
21841 WorklistRemover DeadNodes(*this);
21842 DAG.ReplaceAllUsesOfValueWith(From: Value.getValue(R: 1), To: NewLD.getValue(R: 1));
21843 ++LdStFP2Int;
21844 return NewST;
21845 }
21846
21847 return SDValue();
21848}
21849
21850// This is a helper function for visitMUL to check the profitability
21851// of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
21852// MulNode is the original multiply, AddNode is (add x, c1),
21853// and ConstNode is c2.
21854//
21855// If the (add x, c1) has multiple uses, we could increase
21856// the number of adds if we make this transformation.
21857// It would only be worth doing this if we can remove a
21858// multiply in the process. Check for that here.
21859// To illustrate:
21860// (A + c1) * c3
21861// (A + c2) * c3
21862// We're checking for cases where we have common "c3 * A" expressions.
21863bool DAGCombiner::isMulAddWithConstProfitable(SDNode *MulNode, SDValue AddNode,
21864 SDValue ConstNode) {
21865 // If the add only has one use, and the target thinks the folding is
21866 // profitable or does not lead to worse code, this would be OK to do.
21867 if (AddNode->hasOneUse() &&
21868 TLI.isMulAddWithConstProfitable(AddNode, ConstNode))
21869 return true;
21870
21871 // Walk all the users of the constant with which we're multiplying.
21872 for (SDNode *User : ConstNode->users()) {
21873 if (User == MulNode) // This use is the one we're on right now. Skip it.
21874 continue;
21875
21876 if (User->getOpcode() == ISD::MUL) { // We have another multiply use.
21877 SDNode *OtherOp;
21878 SDNode *MulVar = AddNode.getOperand(i: 0).getNode();
21879
21880 // OtherOp is what we're multiplying against the constant.
21881 if (User->getOperand(Num: 0) == ConstNode)
21882 OtherOp = User->getOperand(Num: 1).getNode();
21883 else
21884 OtherOp = User->getOperand(Num: 0).getNode();
21885
21886 // Check to see if multiply is with the same operand of our "add".
21887 //
21888 // ConstNode = CONST
21889 // User = ConstNode * A <-- visiting User. OtherOp is A.
21890 // ...
21891 // AddNode = (A + c1) <-- MulVar is A.
21892 // = AddNode * ConstNode <-- current visiting instruction.
21893 //
21894 // If we make this transformation, we will have a common
21895 // multiply (ConstNode * A) that we can save.
21896 if (OtherOp == MulVar)
21897 return true;
21898
21899 // Now check to see if a future expansion will give us a common
21900 // multiply.
21901 //
21902 // ConstNode = CONST
21903 // AddNode = (A + c1)
21904 // ... = AddNode * ConstNode <-- current visiting instruction.
21905 // ...
21906 // OtherOp = (A + c2)
21907 // User = OtherOp * ConstNode <-- visiting User.
21908 //
21909 // If we make this transformation, we will have a common
21910 // multiply (CONST * A) after we also do the same transformation
21911 // to the "t2" instruction.
21912 if (OtherOp->getOpcode() == ISD::ADD &&
21913 DAG.isConstantIntBuildVectorOrConstantInt(N: OtherOp->getOperand(Num: 1)) &&
21914 OtherOp->getOperand(Num: 0).getNode() == MulVar)
21915 return true;
21916 }
21917 }
21918
21919 // Didn't find a case where this would be profitable.
21920 return false;
21921}
21922
21923SDValue DAGCombiner::getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
21924 unsigned NumStores) {
21925 SmallVector<SDValue, 8> Chains;
21926 SmallPtrSet<const SDNode *, 8> Visited;
21927 SDLoc StoreDL(StoreNodes[0].MemNode);
21928
21929 for (unsigned i = 0; i < NumStores; ++i) {
21930 Visited.insert(Ptr: StoreNodes[i].MemNode);
21931 }
21932
21933 // don't include nodes that are children or repeated nodes.
21934 for (unsigned i = 0; i < NumStores; ++i) {
21935 if (Visited.insert(Ptr: StoreNodes[i].MemNode->getChain().getNode()).second)
21936 Chains.push_back(Elt: StoreNodes[i].MemNode->getChain());
21937 }
21938
21939 assert(!Chains.empty() && "Chain should have generated a chain");
21940 return DAG.getTokenFactor(DL: StoreDL, Vals&: Chains);
21941}
21942
21943bool DAGCombiner::hasSameUnderlyingObj(ArrayRef<MemOpLink> StoreNodes) {
21944 const Value *UnderlyingObj = nullptr;
21945 for (const auto &MemOp : StoreNodes) {
21946 const MachineMemOperand *MMO = MemOp.MemNode->getMemOperand();
21947 // Pseudo value like stack frame has its own frame index and size, should
21948 // not use the first store's frame index for other frames.
21949 if (MMO->getPseudoValue())
21950 return false;
21951
21952 if (!MMO->getValue())
21953 return false;
21954
21955 const Value *Obj = getUnderlyingObject(V: MMO->getValue());
21956
21957 if (UnderlyingObj && UnderlyingObj != Obj)
21958 return false;
21959
21960 if (!UnderlyingObj)
21961 UnderlyingObj = Obj;
21962 }
21963
21964 return true;
21965}
21966
21967bool DAGCombiner::mergeStoresOfConstantsOrVecElts(
21968 SmallVectorImpl<MemOpLink> &StoreNodes, EVT MemVT, unsigned NumStores,
21969 bool IsConstantSrc, bool UseVector, bool UseTrunc) {
21970 // Make sure we have something to merge.
21971 if (NumStores < 2)
21972 return false;
21973
21974 assert((!UseTrunc || !UseVector) &&
21975 "This optimization cannot emit a vector truncating store");
21976
21977 // The latest Node in the DAG.
21978 SDLoc DL(StoreNodes[0].MemNode);
21979
21980 TypeSize ElementSizeBits = MemVT.getStoreSizeInBits();
21981 unsigned SizeInBits = NumStores * ElementSizeBits;
21982 unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
21983
21984 std::optional<MachineMemOperand::Flags> Flags;
21985 AAMDNodes AAInfo;
21986 for (unsigned I = 0; I != NumStores; ++I) {
21987 StoreSDNode *St = cast<StoreSDNode>(Val: StoreNodes[I].MemNode);
21988 if (!Flags) {
21989 Flags = St->getMemOperand()->getFlags();
21990 AAInfo = St->getAAInfo();
21991 continue;
21992 }
21993 // Skip merging if there's an inconsistent flag.
21994 if (Flags != St->getMemOperand()->getFlags())
21995 return false;
21996 // Concatenate AA metadata.
21997 AAInfo = AAInfo.concat(Other: St->getAAInfo());
21998 }
21999
22000 EVT StoreTy;
22001 if (UseVector) {
22002 unsigned Elts = NumStores * NumMemElts;
22003 // Get the type for the merged vector store.
22004 StoreTy = EVT::getVectorVT(Context&: *DAG.getContext(), VT: MemVT.getScalarType(), NumElements: Elts);
22005 } else
22006 StoreTy = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: SizeInBits);
22007
22008 SDValue StoredVal;
22009 if (UseVector) {
22010 if (IsConstantSrc) {
22011 SmallVector<SDValue, 8> BuildVector;
22012 for (unsigned I = 0; I != NumStores; ++I) {
22013 StoreSDNode *St = cast<StoreSDNode>(Val: StoreNodes[I].MemNode);
22014 SDValue Val = St->getValue();
22015 // If constant is of the wrong type, convert it now. This comes up
22016 // when one of our stores was truncating.
22017 if (MemVT != Val.getValueType()) {
22018 Val = peekThroughBitcasts(V: Val);
22019 // Deal with constants of wrong size.
22020 if (ElementSizeBits != Val.getValueSizeInBits()) {
22021 auto *C = dyn_cast<ConstantSDNode>(Val);
22022 if (!C)
22023 // Not clear how to truncate FP values.
22024 // TODO: Handle truncation of build_vector constants
22025 return false;
22026
22027 EVT IntMemVT =
22028 EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: MemVT.getSizeInBits());
22029 Val = DAG.getConstant(Val: C->getAPIntValue()
22030 .zextOrTrunc(width: Val.getValueSizeInBits())
22031 .zextOrTrunc(width: ElementSizeBits),
22032 DL: SDLoc(C), VT: IntMemVT);
22033 }
22034 // Make sure correctly size type is the correct type.
22035 Val = DAG.getBitcast(VT: MemVT, V: Val);
22036 }
22037 BuildVector.push_back(Elt: Val);
22038 }
22039 StoredVal = DAG.getNode(Opcode: MemVT.isVector() ? ISD::CONCAT_VECTORS
22040 : ISD::BUILD_VECTOR,
22041 DL, VT: StoreTy, Ops: BuildVector);
22042 } else {
22043 SmallVector<SDValue, 8> Ops;
22044 for (unsigned i = 0; i < NumStores; ++i) {
22045 StoreSDNode *St = cast<StoreSDNode>(Val: StoreNodes[i].MemNode);
22046 SDValue Val = peekThroughBitcasts(V: St->getValue());
22047 // All operands of BUILD_VECTOR / CONCAT_VECTOR must be of
22048 // type MemVT. If the underlying value is not the correct
22049 // type, but it is an extraction of an appropriate vector we
22050 // can recast Val to be of the correct type. This may require
22051 // converting between EXTRACT_VECTOR_ELT and
22052 // EXTRACT_SUBVECTOR.
22053 if ((MemVT != Val.getValueType()) &&
22054 (Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT ||
22055 Val.getOpcode() == ISD::EXTRACT_SUBVECTOR)) {
22056 EVT MemVTScalarTy = MemVT.getScalarType();
22057 // We may need to add a bitcast here to get types to line up.
22058 if (MemVTScalarTy != Val.getValueType().getScalarType()) {
22059 Val = DAG.getBitcast(VT: MemVT, V: Val);
22060 } else if (MemVT.isVector() &&
22061 Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
22062 Val = DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: MemVT, Operand: Val);
22063 } else {
22064 unsigned OpC = MemVT.isVector() ? ISD::EXTRACT_SUBVECTOR
22065 : ISD::EXTRACT_VECTOR_ELT;
22066 SDValue Vec = Val.getOperand(i: 0);
22067 SDValue Idx = Val.getOperand(i: 1);
22068 Val = DAG.getNode(Opcode: OpC, DL: SDLoc(Val), VT: MemVT, N1: Vec, N2: Idx);
22069 }
22070 }
22071 Ops.push_back(Elt: Val);
22072 }
22073
22074 // Build the extracted vector elements back into a vector.
22075 StoredVal = DAG.getNode(Opcode: MemVT.isVector() ? ISD::CONCAT_VECTORS
22076 : ISD::BUILD_VECTOR,
22077 DL, VT: StoreTy, Ops);
22078 }
22079 } else {
22080 // We should always use a vector store when merging extracted vector
22081 // elements, so this path implies a store of constants.
22082 assert(IsConstantSrc && "Merged vector elements should use vector store");
22083
22084 APInt StoreInt(SizeInBits, 0);
22085
22086 // Construct a single integer constant which is made of the smaller
22087 // constant inputs.
22088 bool IsLE = DAG.getDataLayout().isLittleEndian();
22089 for (unsigned i = 0; i < NumStores; ++i) {
22090 unsigned Idx = IsLE ? (NumStores - 1 - i) : i;
22091 StoreSDNode *St = cast<StoreSDNode>(Val: StoreNodes[Idx].MemNode);
22092
22093 SDValue Val = St->getValue();
22094 Val = peekThroughBitcasts(V: Val);
22095 StoreInt <<= ElementSizeBits;
22096 if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val)) {
22097 StoreInt |= C->getAPIntValue()
22098 .zextOrTrunc(width: ElementSizeBits)
22099 .zextOrTrunc(width: SizeInBits);
22100 } else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Val)) {
22101 StoreInt |= C->getValueAPF()
22102 .bitcastToAPInt()
22103 .zextOrTrunc(width: ElementSizeBits)
22104 .zextOrTrunc(width: SizeInBits);
22105 // If fp truncation is necessary give up for now.
22106 if (MemVT.getSizeInBits() != ElementSizeBits)
22107 return false;
22108 } else if (ISD::isBuildVectorOfConstantSDNodes(N: Val.getNode()) ||
22109 ISD::isBuildVectorOfConstantFPSDNodes(N: Val.getNode())) {
22110 // Not yet handled
22111 return false;
22112 } else {
22113 llvm_unreachable("Invalid constant element type");
22114 }
22115 }
22116
22117 // Create the new Load and Store operations.
22118 StoredVal = DAG.getConstant(Val: StoreInt, DL, VT: StoreTy);
22119 }
22120
22121 LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
22122 SDValue NewChain = getMergeStoreChains(StoreNodes, NumStores);
22123 bool CanReusePtrInfo = hasSameUnderlyingObj(StoreNodes);
22124
22125 // make sure we use trunc store if it's necessary to be legal.
22126 // When generate the new widen store, if the first store's pointer info can
22127 // not be reused, discard the pointer info except the address space because
22128 // now the widen store can not be represented by the original pointer info
22129 // which is for the narrow memory object.
22130 SDValue NewStore;
22131 if (!UseTrunc) {
22132 NewStore = DAG.getStore(
22133 Chain: NewChain, dl: DL, Val: StoredVal, Ptr: FirstInChain->getBasePtr(),
22134 PtrInfo: CanReusePtrInfo
22135 ? FirstInChain->getPointerInfo()
22136 : MachinePointerInfo(FirstInChain->getPointerInfo().getAddrSpace()),
22137 Alignment: FirstInChain->getAlign(), MMOFlags: *Flags, AAInfo);
22138 } else { // Must be realized as a trunc store
22139 EVT LegalizedStoredValTy =
22140 TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT: StoredVal.getValueType());
22141 unsigned LegalizedStoreSize = LegalizedStoredValTy.getSizeInBits();
22142 ConstantSDNode *C = cast<ConstantSDNode>(Val&: StoredVal);
22143 SDValue ExtendedStoreVal =
22144 DAG.getConstant(Val: C->getAPIntValue().zextOrTrunc(width: LegalizedStoreSize), DL,
22145 VT: LegalizedStoredValTy);
22146 NewStore = DAG.getTruncStore(
22147 Chain: NewChain, dl: DL, Val: ExtendedStoreVal, Ptr: FirstInChain->getBasePtr(),
22148 PtrInfo: CanReusePtrInfo
22149 ? FirstInChain->getPointerInfo()
22150 : MachinePointerInfo(FirstInChain->getPointerInfo().getAddrSpace()),
22151 SVT: StoredVal.getValueType() /*TVT*/, Alignment: FirstInChain->getAlign(), MMOFlags: *Flags,
22152 AAInfo);
22153 }
22154
22155 // Replace all merged stores with the new store.
22156 for (unsigned i = 0; i < NumStores; ++i)
22157 CombineTo(N: StoreNodes[i].MemNode, Res: NewStore);
22158
22159 AddToWorklist(N: NewChain.getNode());
22160 return true;
22161}
22162
22163SDNode *
22164DAGCombiner::getStoreMergeCandidates(StoreSDNode *St,
22165 SmallVectorImpl<MemOpLink> &StoreNodes) {
22166 // This holds the base pointer, index, and the offset in bytes from the base
22167 // pointer. We must have a base and an offset. Do not handle stores to undef
22168 // base pointers.
22169 BaseIndexOffset BasePtr = BaseIndexOffset::match(N: St, DAG);
22170 if (!BasePtr.getBase().getNode() || BasePtr.getBase().isUndef())
22171 return nullptr;
22172
22173 SDValue Val = peekThroughBitcasts(V: St->getValue());
22174 StoreSource StoreSrc = getStoreSource(StoreVal: Val);
22175 assert(StoreSrc != StoreSource::Unknown && "Expected known source for store");
22176
22177 // Match on loadbaseptr if relevant.
22178 EVT MemVT = St->getMemoryVT();
22179 BaseIndexOffset LBasePtr;
22180 EVT LoadVT;
22181 if (StoreSrc == StoreSource::Load) {
22182 auto *Ld = cast<LoadSDNode>(Val);
22183 LBasePtr = BaseIndexOffset::match(N: Ld, DAG);
22184 LoadVT = Ld->getMemoryVT();
22185 // Load and store should be the same type.
22186 if (MemVT != LoadVT)
22187 return nullptr;
22188 // Loads must only have one use.
22189 if (!Ld->hasNUsesOfValue(NUses: 1, Value: 0))
22190 return nullptr;
22191 // The memory operands must not be volatile/indexed/atomic.
22192 // TODO: May be able to relax for unordered atomics (see D66309)
22193 if (!Ld->isSimple() || Ld->isIndexed())
22194 return nullptr;
22195 }
22196 auto CandidateMatch = [&](StoreSDNode *Other, BaseIndexOffset &Ptr,
22197 int64_t &Offset) -> bool {
22198 // The memory operands must not be volatile/indexed/atomic.
22199 // TODO: May be able to relax for unordered atomics (see D66309)
22200 if (!Other->isSimple() || Other->isIndexed())
22201 return false;
22202 // Don't mix temporal stores with non-temporal stores.
22203 if (St->isNonTemporal() != Other->isNonTemporal())
22204 return false;
22205 if (!TLI.areTwoSDNodeTargetMMOFlagsMergeable(NodeX: *St, NodeY: *Other))
22206 return false;
22207 SDValue OtherBC = peekThroughBitcasts(V: Other->getValue());
22208 // Allow merging constants of different types as integers.
22209 bool NoTypeMatch = (MemVT.isInteger()) ? !MemVT.bitsEq(VT: Other->getMemoryVT())
22210 : Other->getMemoryVT() != MemVT;
22211 switch (StoreSrc) {
22212 case StoreSource::Load: {
22213 if (NoTypeMatch)
22214 return false;
22215 // The Load's Base Ptr must also match.
22216 auto *OtherLd = dyn_cast<LoadSDNode>(Val&: OtherBC);
22217 if (!OtherLd)
22218 return false;
22219 BaseIndexOffset LPtr = BaseIndexOffset::match(N: OtherLd, DAG);
22220 if (LoadVT != OtherLd->getMemoryVT())
22221 return false;
22222 // Loads must only have one use.
22223 if (!OtherLd->hasNUsesOfValue(NUses: 1, Value: 0))
22224 return false;
22225 // The memory operands must not be volatile/indexed/atomic.
22226 // TODO: May be able to relax for unordered atomics (see D66309)
22227 if (!OtherLd->isSimple() || OtherLd->isIndexed())
22228 return false;
22229 // Don't mix temporal loads with non-temporal loads.
22230 if (cast<LoadSDNode>(Val)->isNonTemporal() != OtherLd->isNonTemporal())
22231 return false;
22232 if (!TLI.areTwoSDNodeTargetMMOFlagsMergeable(NodeX: *cast<LoadSDNode>(Val),
22233 NodeY: *OtherLd))
22234 return false;
22235 if (!(LBasePtr.equalBaseIndex(Other: LPtr, DAG)))
22236 return false;
22237 break;
22238 }
22239 case StoreSource::Constant:
22240 if (NoTypeMatch)
22241 return false;
22242 if (getStoreSource(StoreVal: OtherBC) != StoreSource::Constant)
22243 return false;
22244 break;
22245 case StoreSource::Extract:
22246 // Do not merge truncated stores here.
22247 if (Other->isTruncatingStore())
22248 return false;
22249 if (!MemVT.bitsEq(VT: OtherBC.getValueType()))
22250 return false;
22251 if (OtherBC.getOpcode() != ISD::EXTRACT_VECTOR_ELT &&
22252 OtherBC.getOpcode() != ISD::EXTRACT_SUBVECTOR)
22253 return false;
22254 break;
22255 default:
22256 llvm_unreachable("Unhandled store source for merging");
22257 }
22258 Ptr = BaseIndexOffset::match(N: Other, DAG);
22259 return (BasePtr.equalBaseIndex(Other: Ptr, DAG, Off&: Offset));
22260 };
22261
22262 // We are looking for a root node which is an ancestor to all mergable
22263 // stores. We search up through a load, to our root and then down
22264 // through all children. For instance we will find Store{1,2,3} if
22265 // St is Store1, Store2. or Store3 where the root is not a load
22266 // which always true for nonvolatile ops. TODO: Expand
22267 // the search to find all valid candidates through multiple layers of loads.
22268 //
22269 // Root
22270 // |-------|-------|
22271 // Load Load Store3
22272 // | |
22273 // Store1 Store2
22274 //
22275 // FIXME: We should be able to climb and
22276 // descend TokenFactors to find candidates as well.
22277
22278 SDNode *RootNode = St->getChain().getNode();
22279 // Bail out if we already analyzed this root node and found nothing.
22280 if (ChainsWithoutMergeableStores.contains(Ptr: RootNode))
22281 return nullptr;
22282
22283 // Check if the pair of StoreNode and the RootNode already bail out many
22284 // times which is over the limit in dependence check.
22285 auto OverLimitInDependenceCheck = [&](SDNode *StoreNode,
22286 SDNode *RootNode) -> bool {
22287 auto RootCount = StoreRootCountMap.find(Val: StoreNode);
22288 return RootCount != StoreRootCountMap.end() &&
22289 RootCount->second.first == RootNode &&
22290 RootCount->second.second > StoreMergeDependenceLimit;
22291 };
22292
22293 auto TryToAddCandidate = [&](SDUse &Use) {
22294 // This must be a chain use.
22295 if (Use.getOperandNo() != 0)
22296 return;
22297 if (auto *OtherStore = dyn_cast<StoreSDNode>(Val: Use.getUser())) {
22298 BaseIndexOffset Ptr;
22299 int64_t PtrDiff;
22300 if (CandidateMatch(OtherStore, Ptr, PtrDiff) &&
22301 !OverLimitInDependenceCheck(OtherStore, RootNode))
22302 StoreNodes.push_back(Elt: MemOpLink(OtherStore, PtrDiff));
22303 }
22304 };
22305
22306 unsigned NumNodesExplored = 0;
22307 const unsigned MaxSearchNodes = 1024;
22308 if (auto *Ldn = dyn_cast<LoadSDNode>(Val: RootNode)) {
22309 RootNode = Ldn->getChain().getNode();
22310 // Bail out if we already analyzed this root node and found nothing.
22311 if (ChainsWithoutMergeableStores.contains(Ptr: RootNode))
22312 return nullptr;
22313 for (auto I = RootNode->use_begin(), E = RootNode->use_end();
22314 I != E && NumNodesExplored < MaxSearchNodes; ++I, ++NumNodesExplored) {
22315 SDNode *User = I->getUser();
22316 if (I->getOperandNo() == 0 && isa<LoadSDNode>(Val: User)) { // walk down chain
22317 for (SDUse &U2 : User->uses())
22318 TryToAddCandidate(U2);
22319 }
22320 // Check stores that depend on the root (e.g. Store 3 in the chart above).
22321 if (I->getOperandNo() == 0 && isa<StoreSDNode>(Val: User)) {
22322 TryToAddCandidate(*I);
22323 }
22324 }
22325 } else {
22326 for (auto I = RootNode->use_begin(), E = RootNode->use_end();
22327 I != E && NumNodesExplored < MaxSearchNodes; ++I, ++NumNodesExplored)
22328 TryToAddCandidate(*I);
22329 }
22330
22331 return RootNode;
22332}
22333
22334// We need to check that merging these stores does not cause a loop in the
22335// DAG. Any store candidate may depend on another candidate indirectly through
22336// its operands. Check in parallel by searching up from operands of candidates.
22337bool DAGCombiner::checkMergeStoreCandidatesForDependencies(
22338 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
22339 SDNode *RootNode) {
22340 // FIXME: We should be able to truncate a full search of
22341 // predecessors by doing a BFS and keeping tabs the originating
22342 // stores from which worklist nodes come from in a similar way to
22343 // TokenFactor simplfication.
22344
22345 SmallPtrSet<const SDNode *, 32> Visited;
22346 SmallVector<const SDNode *, 8> Worklist;
22347
22348 // RootNode is a predecessor to all candidates so we need not search
22349 // past it. Add RootNode (peeking through TokenFactors). Do not count
22350 // these towards size check.
22351
22352 Worklist.push_back(Elt: RootNode);
22353 while (!Worklist.empty()) {
22354 auto N = Worklist.pop_back_val();
22355 if (!Visited.insert(Ptr: N).second)
22356 continue; // Already present in Visited.
22357 if (N->getOpcode() == ISD::TokenFactor) {
22358 for (SDValue Op : N->ops())
22359 Worklist.push_back(Elt: Op.getNode());
22360 }
22361 }
22362
22363 // Don't count pruning nodes towards max.
22364 unsigned int Max = 1024 + Visited.size();
22365 // Search Ops of store candidates.
22366 for (unsigned i = 0; i < NumStores; ++i) {
22367 SDNode *N = StoreNodes[i].MemNode;
22368 // Of the 4 Store Operands:
22369 // * Chain (Op 0) -> We have already considered these
22370 // in candidate selection, but only by following the
22371 // chain dependencies. We could still have a chain
22372 // dependency to a load, that has a non-chain dep to
22373 // another load, that depends on a store, etc. So it is
22374 // possible to have dependencies that consist of a mix
22375 // of chain and non-chain deps, and we need to include
22376 // chain operands in the analysis here..
22377 // * Value (Op 1) -> Cycles may happen (e.g. through load chains)
22378 // * Address (Op 2) -> Merged addresses may only vary by a fixed constant,
22379 // but aren't necessarily fromt the same base node, so
22380 // cycles possible (e.g. via indexed store).
22381 // * (Op 3) -> Represents the pre or post-indexing offset (or undef for
22382 // non-indexed stores). Not constant on all targets (e.g. ARM)
22383 // and so can participate in a cycle.
22384 for (const SDValue &Op : N->op_values())
22385 Worklist.push_back(Elt: Op.getNode());
22386 }
22387 // Search through DAG. We can stop early if we find a store node.
22388 for (unsigned i = 0; i < NumStores; ++i)
22389 if (SDNode::hasPredecessorHelper(N: StoreNodes[i].MemNode, Visited, Worklist,
22390 MaxSteps: Max)) {
22391 // If the searching bail out, record the StoreNode and RootNode in the
22392 // StoreRootCountMap. If we have seen the pair many times over a limit,
22393 // we won't add the StoreNode into StoreNodes set again.
22394 if (Visited.size() >= Max) {
22395 auto &RootCount = StoreRootCountMap[StoreNodes[i].MemNode];
22396 if (RootCount.first == RootNode)
22397 RootCount.second++;
22398 else
22399 RootCount = {RootNode, 1};
22400 }
22401 return false;
22402 }
22403 return true;
22404}
22405
22406bool DAGCombiner::hasCallInLdStChain(StoreSDNode *St, LoadSDNode *Ld) {
22407 SmallPtrSet<const SDNode *, 32> Visited;
22408 SmallVector<std::pair<const SDNode *, bool>, 8> Worklist;
22409 Worklist.emplace_back(Args: St->getChain().getNode(), Args: false);
22410
22411 while (!Worklist.empty()) {
22412 auto [Node, FoundCall] = Worklist.pop_back_val();
22413 if (!Visited.insert(Ptr: Node).second || Node->getNumOperands() == 0)
22414 continue;
22415
22416 switch (Node->getOpcode()) {
22417 case ISD::CALLSEQ_END:
22418 Worklist.emplace_back(Args: Node->getOperand(Num: 0).getNode(), Args: true);
22419 break;
22420 case ISD::TokenFactor:
22421 for (SDValue Op : Node->ops())
22422 Worklist.emplace_back(Args: Op.getNode(), Args&: FoundCall);
22423 break;
22424 case ISD::LOAD:
22425 if (Node == Ld)
22426 return FoundCall;
22427 [[fallthrough]];
22428 default:
22429 assert(Node->getOperand(0).getValueType() == MVT::Other &&
22430 "Invalid chain type");
22431 Worklist.emplace_back(Args: Node->getOperand(Num: 0).getNode(), Args&: FoundCall);
22432 break;
22433 }
22434 }
22435 return false;
22436}
22437
22438unsigned
22439DAGCombiner::getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
22440 int64_t ElementSizeBytes) const {
22441 while (true) {
22442 // Find a store past the width of the first store.
22443 size_t StartIdx = 0;
22444 while ((StartIdx + 1 < StoreNodes.size()) &&
22445 StoreNodes[StartIdx].OffsetFromBase + ElementSizeBytes !=
22446 StoreNodes[StartIdx + 1].OffsetFromBase)
22447 ++StartIdx;
22448
22449 // Bail if we don't have enough candidates to merge.
22450 if (StartIdx + 1 >= StoreNodes.size())
22451 return 0;
22452
22453 // Trim stores that overlapped with the first store.
22454 if (StartIdx)
22455 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + StartIdx);
22456
22457 // Scan the memory operations on the chain and find the first
22458 // non-consecutive store memory address.
22459 unsigned NumConsecutiveStores = 1;
22460 int64_t StartAddress = StoreNodes[0].OffsetFromBase;
22461 // Check that the addresses are consecutive starting from the second
22462 // element in the list of stores.
22463 for (unsigned i = 1, e = StoreNodes.size(); i < e; ++i) {
22464 int64_t CurrAddress = StoreNodes[i].OffsetFromBase;
22465 if (CurrAddress - StartAddress != (ElementSizeBytes * i))
22466 break;
22467 NumConsecutiveStores = i + 1;
22468 }
22469 if (NumConsecutiveStores > 1)
22470 return NumConsecutiveStores;
22471
22472 // There are no consecutive stores at the start of the list.
22473 // Remove the first store and try again.
22474 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + 1);
22475 }
22476}
22477
22478bool DAGCombiner::tryStoreMergeOfConstants(
22479 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores,
22480 EVT MemVT, SDNode *RootNode, bool AllowVectors) {
22481 LLVMContext &Context = *DAG.getContext();
22482 const DataLayout &DL = DAG.getDataLayout();
22483 int64_t ElementSizeBytes = MemVT.getStoreSize();
22484 unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
22485 bool MadeChange = false;
22486
22487 // Store the constants into memory as one consecutive store.
22488 while (NumConsecutiveStores >= 2) {
22489 LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
22490 unsigned FirstStoreAS = FirstInChain->getAddressSpace();
22491 Align FirstStoreAlign = FirstInChain->getAlign();
22492 unsigned LastLegalType = 1;
22493 unsigned LastLegalVectorType = 1;
22494 bool LastIntegerTrunc = false;
22495 bool NonZero = false;
22496 unsigned FirstZeroAfterNonZero = NumConsecutiveStores;
22497 for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
22498 StoreSDNode *ST = cast<StoreSDNode>(Val: StoreNodes[i].MemNode);
22499 SDValue StoredVal = ST->getValue();
22500 bool IsElementZero = false;
22501 if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val&: StoredVal))
22502 IsElementZero = C->isZero();
22503 else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Val&: StoredVal))
22504 IsElementZero = C->getConstantFPValue()->isNullValue();
22505 else if (ISD::isBuildVectorAllZeros(N: StoredVal.getNode()))
22506 IsElementZero = true;
22507 if (IsElementZero) {
22508 if (NonZero && FirstZeroAfterNonZero == NumConsecutiveStores)
22509 FirstZeroAfterNonZero = i;
22510 }
22511 NonZero |= !IsElementZero;
22512
22513 // Find a legal type for the constant store.
22514 unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
22515 EVT StoreTy = EVT::getIntegerVT(Context, BitWidth: SizeInBits);
22516 unsigned IsFast = 0;
22517
22518 // Break early when size is too large to be legal.
22519 if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
22520 break;
22521
22522 if (TLI.isTypeLegal(VT: StoreTy) &&
22523 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: StoreTy,
22524 MF: DAG.getMachineFunction()) &&
22525 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
22526 MMO: *FirstInChain->getMemOperand(), Fast: &IsFast) &&
22527 IsFast) {
22528 LastIntegerTrunc = false;
22529 LastLegalType = i + 1;
22530 // Or check whether a truncstore is legal.
22531 } else if (TLI.getTypeAction(Context, VT: StoreTy) ==
22532 TargetLowering::TypePromoteInteger) {
22533 EVT LegalizedStoredValTy =
22534 TLI.getTypeToTransformTo(Context, VT: StoredVal.getValueType());
22535 if (TLI.isTruncStoreLegal(ValVT: LegalizedStoredValTy, MemVT: StoreTy) &&
22536 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: LegalizedStoredValTy,
22537 MF: DAG.getMachineFunction()) &&
22538 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
22539 MMO: *FirstInChain->getMemOperand(), Fast: &IsFast) &&
22540 IsFast) {
22541 LastIntegerTrunc = true;
22542 LastLegalType = i + 1;
22543 }
22544 }
22545
22546 // We only use vectors if the target allows it and the function is not
22547 // marked with the noimplicitfloat attribute.
22548 if (TLI.storeOfVectorConstantIsCheap(IsZero: !NonZero, MemVT, NumElem: i + 1, AddrSpace: FirstStoreAS) &&
22549 AllowVectors) {
22550 // Find a legal type for the vector store.
22551 unsigned Elts = (i + 1) * NumMemElts;
22552 EVT Ty = EVT::getVectorVT(Context, VT: MemVT.getScalarType(), NumElements: Elts);
22553 if (TLI.isTypeLegal(VT: Ty) && TLI.isTypeLegal(VT: MemVT) &&
22554 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: Ty, MF: DAG.getMachineFunction()) &&
22555 TLI.allowsMemoryAccess(Context, DL, VT: Ty,
22556 MMO: *FirstInChain->getMemOperand(), Fast: &IsFast) &&
22557 IsFast)
22558 LastLegalVectorType = i + 1;
22559 }
22560 }
22561
22562 bool UseVector = (LastLegalVectorType > LastLegalType) && AllowVectors;
22563 unsigned NumElem = (UseVector) ? LastLegalVectorType : LastLegalType;
22564 bool UseTrunc = LastIntegerTrunc && !UseVector;
22565
22566 // Check if we found a legal integer type that creates a meaningful
22567 // merge.
22568 if (NumElem < 2) {
22569 // We know that candidate stores are in order and of correct
22570 // shape. While there is no mergeable sequence from the
22571 // beginning one may start later in the sequence. The only
22572 // reason a merge of size N could have failed where another of
22573 // the same size would not have, is if the alignment has
22574 // improved or we've dropped a non-zero value. Drop as many
22575 // candidates as we can here.
22576 unsigned NumSkip = 1;
22577 while ((NumSkip < NumConsecutiveStores) &&
22578 (NumSkip < FirstZeroAfterNonZero) &&
22579 (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
22580 NumSkip++;
22581
22582 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumSkip);
22583 NumConsecutiveStores -= NumSkip;
22584 continue;
22585 }
22586
22587 // Check that we can merge these candidates without causing a cycle.
22588 if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumStores: NumElem,
22589 RootNode)) {
22590 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumElem);
22591 NumConsecutiveStores -= NumElem;
22592 continue;
22593 }
22594
22595 MadeChange |= mergeStoresOfConstantsOrVecElts(StoreNodes, MemVT, NumStores: NumElem,
22596 /*IsConstantSrc*/ true,
22597 UseVector, UseTrunc);
22598
22599 // Remove merged stores for next iteration.
22600 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumElem);
22601 NumConsecutiveStores -= NumElem;
22602 }
22603 return MadeChange;
22604}
22605
22606bool DAGCombiner::tryStoreMergeOfExtracts(
22607 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores,
22608 EVT MemVT, SDNode *RootNode) {
22609 LLVMContext &Context = *DAG.getContext();
22610 const DataLayout &DL = DAG.getDataLayout();
22611 unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
22612 bool MadeChange = false;
22613
22614 // Loop on Consecutive Stores on success.
22615 while (NumConsecutiveStores >= 2) {
22616 LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
22617 unsigned FirstStoreAS = FirstInChain->getAddressSpace();
22618 Align FirstStoreAlign = FirstInChain->getAlign();
22619 unsigned NumStoresToMerge = 1;
22620 for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
22621 // Find a legal type for the vector store.
22622 unsigned Elts = (i + 1) * NumMemElts;
22623 EVT Ty = EVT::getVectorVT(Context&: *DAG.getContext(), VT: MemVT.getScalarType(), NumElements: Elts);
22624 unsigned IsFast = 0;
22625
22626 // Break early when size is too large to be legal.
22627 if (Ty.getSizeInBits() > MaximumLegalStoreInBits)
22628 break;
22629
22630 if (TLI.isTypeLegal(VT: Ty) &&
22631 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: Ty, MF: DAG.getMachineFunction()) &&
22632 TLI.allowsMemoryAccess(Context, DL, VT: Ty,
22633 MMO: *FirstInChain->getMemOperand(), Fast: &IsFast) &&
22634 IsFast)
22635 NumStoresToMerge = i + 1;
22636 }
22637
22638 // Check if we found a legal integer type creating a meaningful
22639 // merge.
22640 if (NumStoresToMerge < 2) {
22641 // We know that candidate stores are in order and of correct
22642 // shape. While there is no mergeable sequence from the
22643 // beginning one may start later in the sequence. The only
22644 // reason a merge of size N could have failed where another of
22645 // the same size would not have, is if the alignment has
22646 // improved. Drop as many candidates as we can here.
22647 unsigned NumSkip = 1;
22648 while ((NumSkip < NumConsecutiveStores) &&
22649 (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
22650 NumSkip++;
22651
22652 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumSkip);
22653 NumConsecutiveStores -= NumSkip;
22654 continue;
22655 }
22656
22657 // Check that we can merge these candidates without causing a cycle.
22658 if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumStores: NumStoresToMerge,
22659 RootNode)) {
22660 StoreNodes.erase(CS: StoreNodes.begin(),
22661 CE: StoreNodes.begin() + NumStoresToMerge);
22662 NumConsecutiveStores -= NumStoresToMerge;
22663 continue;
22664 }
22665
22666 MadeChange |= mergeStoresOfConstantsOrVecElts(
22667 StoreNodes, MemVT, NumStores: NumStoresToMerge, /*IsConstantSrc*/ false,
22668 /*UseVector*/ true, /*UseTrunc*/ false);
22669
22670 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumStoresToMerge);
22671 NumConsecutiveStores -= NumStoresToMerge;
22672 }
22673 return MadeChange;
22674}
22675
22676bool DAGCombiner::tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
22677 unsigned NumConsecutiveStores, EVT MemVT,
22678 SDNode *RootNode, bool AllowVectors,
22679 bool IsNonTemporalStore,
22680 bool IsNonTemporalLoad) {
22681 LLVMContext &Context = *DAG.getContext();
22682 const DataLayout &DL = DAG.getDataLayout();
22683 int64_t ElementSizeBytes = MemVT.getStoreSize();
22684 unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
22685 bool MadeChange = false;
22686
22687 // Look for load nodes which are used by the stored values.
22688 SmallVector<MemOpLink, 8> LoadNodes;
22689
22690 // Find acceptable loads. Loads need to have the same chain (token factor),
22691 // must not be zext, volatile, indexed, and they must be consecutive.
22692 BaseIndexOffset LdBasePtr;
22693
22694 for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
22695 StoreSDNode *St = cast<StoreSDNode>(Val: StoreNodes[i].MemNode);
22696 SDValue Val = peekThroughBitcasts(V: St->getValue());
22697 LoadSDNode *Ld = cast<LoadSDNode>(Val);
22698
22699 BaseIndexOffset LdPtr = BaseIndexOffset::match(N: Ld, DAG);
22700 // If this is not the first ptr that we check.
22701 int64_t LdOffset = 0;
22702 if (LdBasePtr.getBase().getNode()) {
22703 // The base ptr must be the same.
22704 if (!LdBasePtr.equalBaseIndex(Other: LdPtr, DAG, Off&: LdOffset))
22705 break;
22706 } else {
22707 // Check that all other base pointers are the same as this one.
22708 LdBasePtr = LdPtr;
22709 }
22710
22711 // We found a potential memory operand to merge.
22712 LoadNodes.push_back(Elt: MemOpLink(Ld, LdOffset));
22713 }
22714
22715 while (NumConsecutiveStores >= 2 && LoadNodes.size() >= 2) {
22716 Align RequiredAlignment;
22717 bool NeedRotate = false;
22718 if (LoadNodes.size() == 2) {
22719 // If we have load/store pair instructions and we only have two values,
22720 // don't bother merging.
22721 if (TLI.hasPairedLoad(MemVT, RequiredAlignment) &&
22722 StoreNodes[0].MemNode->getAlign() >= RequiredAlignment) {
22723 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + 2);
22724 LoadNodes.erase(CS: LoadNodes.begin(), CE: LoadNodes.begin() + 2);
22725 break;
22726 }
22727 // If the loads are reversed, see if we can rotate the halves into place.
22728 int64_t Offset0 = LoadNodes[0].OffsetFromBase;
22729 int64_t Offset1 = LoadNodes[1].OffsetFromBase;
22730 EVT PairVT = EVT::getIntegerVT(Context, BitWidth: ElementSizeBytes * 8 * 2);
22731 if (Offset0 - Offset1 == ElementSizeBytes &&
22732 (hasOperation(Opcode: ISD::ROTL, VT: PairVT) ||
22733 hasOperation(Opcode: ISD::ROTR, VT: PairVT))) {
22734 std::swap(a&: LoadNodes[0], b&: LoadNodes[1]);
22735 NeedRotate = true;
22736 }
22737 }
22738 LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
22739 unsigned FirstStoreAS = FirstInChain->getAddressSpace();
22740 Align FirstStoreAlign = FirstInChain->getAlign();
22741 LoadSDNode *FirstLoad = cast<LoadSDNode>(Val: LoadNodes[0].MemNode);
22742
22743 // Scan the memory operations on the chain and find the first
22744 // non-consecutive load memory address. These variables hold the index in
22745 // the store node array.
22746
22747 unsigned LastConsecutiveLoad = 1;
22748
22749 // This variable refers to the size and not index in the array.
22750 unsigned LastLegalVectorType = 1;
22751 unsigned LastLegalIntegerType = 1;
22752 bool isDereferenceable = true;
22753 bool DoIntegerTruncate = false;
22754 int64_t StartAddress = LoadNodes[0].OffsetFromBase;
22755 SDValue LoadChain = FirstLoad->getChain();
22756 for (unsigned i = 1; i < LoadNodes.size(); ++i) {
22757 // All loads must share the same chain.
22758 if (LoadNodes[i].MemNode->getChain() != LoadChain)
22759 break;
22760
22761 int64_t CurrAddress = LoadNodes[i].OffsetFromBase;
22762 if (CurrAddress - StartAddress != (ElementSizeBytes * i))
22763 break;
22764 LastConsecutiveLoad = i;
22765
22766 if (isDereferenceable && !LoadNodes[i].MemNode->isDereferenceable())
22767 isDereferenceable = false;
22768
22769 // Find a legal type for the vector store.
22770 unsigned Elts = (i + 1) * NumMemElts;
22771 EVT StoreTy = EVT::getVectorVT(Context, VT: MemVT.getScalarType(), NumElements: Elts);
22772
22773 // Break early when size is too large to be legal.
22774 if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
22775 break;
22776
22777 unsigned IsFastSt = 0;
22778 unsigned IsFastLd = 0;
22779 // Don't try vector types if we need a rotate. We may still fail the
22780 // legality checks for the integer type, but we can't handle the rotate
22781 // case with vectors.
22782 // FIXME: We could use a shuffle in place of the rotate.
22783 if (!NeedRotate && TLI.isTypeLegal(VT: StoreTy) &&
22784 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: StoreTy,
22785 MF: DAG.getMachineFunction()) &&
22786 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
22787 MMO: *FirstInChain->getMemOperand(), Fast: &IsFastSt) &&
22788 IsFastSt &&
22789 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
22790 MMO: *FirstLoad->getMemOperand(), Fast: &IsFastLd) &&
22791 IsFastLd) {
22792 LastLegalVectorType = i + 1;
22793 }
22794
22795 // Find a legal type for the integer store.
22796 unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
22797 StoreTy = EVT::getIntegerVT(Context, BitWidth: SizeInBits);
22798 if (TLI.isTypeLegal(VT: StoreTy) &&
22799 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: StoreTy,
22800 MF: DAG.getMachineFunction()) &&
22801 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
22802 MMO: *FirstInChain->getMemOperand(), Fast: &IsFastSt) &&
22803 IsFastSt &&
22804 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
22805 MMO: *FirstLoad->getMemOperand(), Fast: &IsFastLd) &&
22806 IsFastLd) {
22807 LastLegalIntegerType = i + 1;
22808 DoIntegerTruncate = false;
22809 // Or check whether a truncstore and extload is legal.
22810 } else if (TLI.getTypeAction(Context, VT: StoreTy) ==
22811 TargetLowering::TypePromoteInteger) {
22812 EVT LegalizedStoredValTy = TLI.getTypeToTransformTo(Context, VT: StoreTy);
22813 if (TLI.isTruncStoreLegal(ValVT: LegalizedStoredValTy, MemVT: StoreTy) &&
22814 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: LegalizedStoredValTy,
22815 MF: DAG.getMachineFunction()) &&
22816 TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: LegalizedStoredValTy, MemVT: StoreTy) &&
22817 TLI.isLoadExtLegal(ExtType: ISD::SEXTLOAD, ValVT: LegalizedStoredValTy, MemVT: StoreTy) &&
22818 TLI.isLoadExtLegal(ExtType: ISD::EXTLOAD, ValVT: LegalizedStoredValTy, MemVT: StoreTy) &&
22819 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
22820 MMO: *FirstInChain->getMemOperand(), Fast: &IsFastSt) &&
22821 IsFastSt &&
22822 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
22823 MMO: *FirstLoad->getMemOperand(), Fast: &IsFastLd) &&
22824 IsFastLd) {
22825 LastLegalIntegerType = i + 1;
22826 DoIntegerTruncate = true;
22827 }
22828 }
22829 }
22830
22831 // Only use vector types if the vector type is larger than the integer
22832 // type. If they are the same, use integers.
22833 bool UseVectorTy =
22834 LastLegalVectorType > LastLegalIntegerType && AllowVectors;
22835 unsigned LastLegalType =
22836 std::max(a: LastLegalVectorType, b: LastLegalIntegerType);
22837
22838 // We add +1 here because the LastXXX variables refer to location while
22839 // the NumElem refers to array/index size.
22840 unsigned NumElem = std::min(a: NumConsecutiveStores, b: LastConsecutiveLoad + 1);
22841 NumElem = std::min(a: LastLegalType, b: NumElem);
22842 Align FirstLoadAlign = FirstLoad->getAlign();
22843
22844 if (NumElem < 2) {
22845 // We know that candidate stores are in order and of correct
22846 // shape. While there is no mergeable sequence from the
22847 // beginning one may start later in the sequence. The only
22848 // reason a merge of size N could have failed where another of
22849 // the same size would not have is if the alignment or either
22850 // the load or store has improved. Drop as many candidates as we
22851 // can here.
22852 unsigned NumSkip = 1;
22853 while ((NumSkip < LoadNodes.size()) &&
22854 (LoadNodes[NumSkip].MemNode->getAlign() <= FirstLoadAlign) &&
22855 (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
22856 NumSkip++;
22857 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumSkip);
22858 LoadNodes.erase(CS: LoadNodes.begin(), CE: LoadNodes.begin() + NumSkip);
22859 NumConsecutiveStores -= NumSkip;
22860 continue;
22861 }
22862
22863 // Check that we can merge these candidates without causing a cycle.
22864 if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumStores: NumElem,
22865 RootNode)) {
22866 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumElem);
22867 LoadNodes.erase(CS: LoadNodes.begin(), CE: LoadNodes.begin() + NumElem);
22868 NumConsecutiveStores -= NumElem;
22869 continue;
22870 }
22871
22872 // Find if it is better to use vectors or integers to load and store
22873 // to memory.
22874 EVT JointMemOpVT;
22875 if (UseVectorTy) {
22876 // Find a legal type for the vector store.
22877 unsigned Elts = NumElem * NumMemElts;
22878 JointMemOpVT = EVT::getVectorVT(Context, VT: MemVT.getScalarType(), NumElements: Elts);
22879 } else {
22880 unsigned SizeInBits = NumElem * ElementSizeBytes * 8;
22881 JointMemOpVT = EVT::getIntegerVT(Context, BitWidth: SizeInBits);
22882 }
22883
22884 // Check if there is a call in the load/store chain.
22885 if (!TLI.shouldMergeStoreOfLoadsOverCall(MemVT, JointMemOpVT) &&
22886 hasCallInLdStChain(St: cast<StoreSDNode>(Val: StoreNodes[0].MemNode),
22887 Ld: cast<LoadSDNode>(Val: LoadNodes[0].MemNode))) {
22888 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumElem);
22889 LoadNodes.erase(CS: LoadNodes.begin(), CE: LoadNodes.begin() + NumElem);
22890 NumConsecutiveStores -= NumElem;
22891 continue;
22892 }
22893
22894 SDLoc LoadDL(LoadNodes[0].MemNode);
22895 SDLoc StoreDL(StoreNodes[0].MemNode);
22896
22897 // The merged loads are required to have the same incoming chain, so
22898 // using the first's chain is acceptable.
22899
22900 SDValue NewStoreChain = getMergeStoreChains(StoreNodes, NumStores: NumElem);
22901 bool CanReusePtrInfo = hasSameUnderlyingObj(StoreNodes);
22902 AddToWorklist(N: NewStoreChain.getNode());
22903
22904 MachineMemOperand::Flags LdMMOFlags =
22905 isDereferenceable ? MachineMemOperand::MODereferenceable
22906 : MachineMemOperand::MONone;
22907 if (IsNonTemporalLoad)
22908 LdMMOFlags |= MachineMemOperand::MONonTemporal;
22909
22910 LdMMOFlags |= TLI.getTargetMMOFlags(Node: *FirstLoad);
22911
22912 MachineMemOperand::Flags StMMOFlags = IsNonTemporalStore
22913 ? MachineMemOperand::MONonTemporal
22914 : MachineMemOperand::MONone;
22915
22916 StMMOFlags |= TLI.getTargetMMOFlags(Node: *StoreNodes[0].MemNode);
22917
22918 SDValue NewLoad, NewStore;
22919 if (UseVectorTy || !DoIntegerTruncate) {
22920 NewLoad = DAG.getLoad(
22921 VT: JointMemOpVT, dl: LoadDL, Chain: FirstLoad->getChain(), Ptr: FirstLoad->getBasePtr(),
22922 PtrInfo: FirstLoad->getPointerInfo(), Alignment: FirstLoadAlign, MMOFlags: LdMMOFlags);
22923 SDValue StoreOp = NewLoad;
22924 if (NeedRotate) {
22925 unsigned LoadWidth = ElementSizeBytes * 8 * 2;
22926 assert(JointMemOpVT == EVT::getIntegerVT(Context, LoadWidth) &&
22927 "Unexpected type for rotate-able load pair");
22928 SDValue RotAmt =
22929 DAG.getShiftAmountConstant(Val: LoadWidth / 2, VT: JointMemOpVT, DL: LoadDL);
22930 // Target can convert to the identical ROTR if it does not have ROTL.
22931 StoreOp = DAG.getNode(Opcode: ISD::ROTL, DL: LoadDL, VT: JointMemOpVT, N1: NewLoad, N2: RotAmt);
22932 }
22933 NewStore = DAG.getStore(
22934 Chain: NewStoreChain, dl: StoreDL, Val: StoreOp, Ptr: FirstInChain->getBasePtr(),
22935 PtrInfo: CanReusePtrInfo ? FirstInChain->getPointerInfo()
22936 : MachinePointerInfo(FirstStoreAS),
22937 Alignment: FirstStoreAlign, MMOFlags: StMMOFlags);
22938 } else { // This must be the truncstore/extload case
22939 EVT ExtendedTy =
22940 TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT: JointMemOpVT);
22941 NewLoad = DAG.getExtLoad(ExtType: ISD::EXTLOAD, dl: LoadDL, VT: ExtendedTy,
22942 Chain: FirstLoad->getChain(), Ptr: FirstLoad->getBasePtr(),
22943 PtrInfo: FirstLoad->getPointerInfo(), MemVT: JointMemOpVT,
22944 Alignment: FirstLoadAlign, MMOFlags: LdMMOFlags);
22945 NewStore = DAG.getTruncStore(
22946 Chain: NewStoreChain, dl: StoreDL, Val: NewLoad, Ptr: FirstInChain->getBasePtr(),
22947 PtrInfo: CanReusePtrInfo ? FirstInChain->getPointerInfo()
22948 : MachinePointerInfo(FirstStoreAS),
22949 SVT: JointMemOpVT, Alignment: FirstInChain->getAlign(),
22950 MMOFlags: FirstInChain->getMemOperand()->getFlags());
22951 }
22952
22953 // Transfer chain users from old loads to the new load.
22954 for (unsigned i = 0; i < NumElem; ++i) {
22955 LoadSDNode *Ld = cast<LoadSDNode>(Val: LoadNodes[i].MemNode);
22956 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Ld, 1),
22957 To: SDValue(NewLoad.getNode(), 1));
22958 }
22959
22960 // Replace all stores with the new store. Recursively remove corresponding
22961 // values if they are no longer used.
22962 for (unsigned i = 0; i < NumElem; ++i) {
22963 SDValue Val = StoreNodes[i].MemNode->getOperand(Num: 1);
22964 CombineTo(N: StoreNodes[i].MemNode, Res: NewStore);
22965 if (Val->use_empty())
22966 recursivelyDeleteUnusedNodes(N: Val.getNode());
22967 }
22968
22969 MadeChange = true;
22970 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumElem);
22971 LoadNodes.erase(CS: LoadNodes.begin(), CE: LoadNodes.begin() + NumElem);
22972 NumConsecutiveStores -= NumElem;
22973 }
22974 return MadeChange;
22975}
22976
22977bool DAGCombiner::mergeConsecutiveStores(StoreSDNode *St) {
22978 if (OptLevel == CodeGenOptLevel::None || !EnableStoreMerging)
22979 return false;
22980
22981 // TODO: Extend this function to merge stores of scalable vectors.
22982 // (i.e. two <vscale x 8 x i8> stores can be merged to one <vscale x 16 x i8>
22983 // store since we know <vscale x 16 x i8> is exactly twice as large as
22984 // <vscale x 8 x i8>). Until then, bail out for scalable vectors.
22985 EVT MemVT = St->getMemoryVT();
22986 if (MemVT.isScalableVT())
22987 return false;
22988 if (!MemVT.isSimple() || MemVT.getSizeInBits() * 2 > MaximumLegalStoreInBits)
22989 return false;
22990
22991 // This function cannot currently deal with non-byte-sized memory sizes.
22992 int64_t ElementSizeBytes = MemVT.getStoreSize();
22993 if (ElementSizeBytes * 8 != (int64_t)MemVT.getSizeInBits())
22994 return false;
22995
22996 // Do not bother looking at stored values that are not constants, loads, or
22997 // extracted vector elements.
22998 SDValue StoredVal = peekThroughBitcasts(V: St->getValue());
22999 const StoreSource StoreSrc = getStoreSource(StoreVal: StoredVal);
23000 if (StoreSrc == StoreSource::Unknown)
23001 return false;
23002
23003 SmallVector<MemOpLink, 8> StoreNodes;
23004 // Find potential store merge candidates by searching through chain sub-DAG
23005 SDNode *RootNode = getStoreMergeCandidates(St, StoreNodes);
23006
23007 // Check if there is anything to merge.
23008 if (StoreNodes.size() < 2)
23009 return false;
23010
23011 // Sort the memory operands according to their distance from the
23012 // base pointer.
23013 llvm::sort(C&: StoreNodes, Comp: [](MemOpLink LHS, MemOpLink RHS) {
23014 return LHS.OffsetFromBase < RHS.OffsetFromBase;
23015 });
23016
23017 bool AllowVectors = !DAG.getMachineFunction().getFunction().hasFnAttribute(
23018 Kind: Attribute::NoImplicitFloat);
23019 bool IsNonTemporalStore = St->isNonTemporal();
23020 bool IsNonTemporalLoad = StoreSrc == StoreSource::Load &&
23021 cast<LoadSDNode>(Val&: StoredVal)->isNonTemporal();
23022
23023 // Store Merge attempts to merge the lowest stores. This generally
23024 // works out as if successful, as the remaining stores are checked
23025 // after the first collection of stores is merged. However, in the
23026 // case that a non-mergeable store is found first, e.g., {p[-2],
23027 // p[0], p[1], p[2], p[3]}, we would fail and miss the subsequent
23028 // mergeable cases. To prevent this, we prune such stores from the
23029 // front of StoreNodes here.
23030 bool MadeChange = false;
23031 while (StoreNodes.size() > 1) {
23032 unsigned NumConsecutiveStores =
23033 getConsecutiveStores(StoreNodes, ElementSizeBytes);
23034 // There are no more stores in the list to examine.
23035 if (NumConsecutiveStores == 0)
23036 return MadeChange;
23037
23038 // We have at least 2 consecutive stores. Try to merge them.
23039 assert(NumConsecutiveStores >= 2 && "Expected at least 2 stores");
23040 switch (StoreSrc) {
23041 case StoreSource::Constant:
23042 MadeChange |= tryStoreMergeOfConstants(StoreNodes, NumConsecutiveStores,
23043 MemVT, RootNode, AllowVectors);
23044 break;
23045
23046 case StoreSource::Extract:
23047 MadeChange |= tryStoreMergeOfExtracts(StoreNodes, NumConsecutiveStores,
23048 MemVT, RootNode);
23049 break;
23050
23051 case StoreSource::Load:
23052 MadeChange |= tryStoreMergeOfLoads(StoreNodes, NumConsecutiveStores,
23053 MemVT, RootNode, AllowVectors,
23054 IsNonTemporalStore, IsNonTemporalLoad);
23055 break;
23056
23057 default:
23058 llvm_unreachable("Unhandled store source type");
23059 }
23060 }
23061
23062 // Remember if we failed to optimize, to save compile time.
23063 if (!MadeChange)
23064 ChainsWithoutMergeableStores.insert(Ptr: RootNode);
23065
23066 return MadeChange;
23067}
23068
23069SDValue DAGCombiner::replaceStoreChain(StoreSDNode *ST, SDValue BetterChain) {
23070 SDLoc SL(ST);
23071 SDValue ReplStore;
23072
23073 // Replace the chain to avoid dependency.
23074 if (ST->isTruncatingStore()) {
23075 ReplStore = DAG.getTruncStore(Chain: BetterChain, dl: SL, Val: ST->getValue(),
23076 Ptr: ST->getBasePtr(), SVT: ST->getMemoryVT(),
23077 MMO: ST->getMemOperand());
23078 } else {
23079 ReplStore = DAG.getStore(Chain: BetterChain, dl: SL, Val: ST->getValue(), Ptr: ST->getBasePtr(),
23080 MMO: ST->getMemOperand());
23081 }
23082
23083 // Create token to keep both nodes around.
23084 SDValue Token = DAG.getNode(Opcode: ISD::TokenFactor, DL: SL,
23085 VT: MVT::Other, N1: ST->getChain(), N2: ReplStore);
23086
23087 // Make sure the new and old chains are cleaned up.
23088 AddToWorklist(N: Token.getNode());
23089
23090 // Don't add users to work list.
23091 return CombineTo(N: ST, Res: Token, AddTo: false);
23092}
23093
23094SDValue DAGCombiner::replaceStoreOfFPConstant(StoreSDNode *ST) {
23095 SDValue Value = ST->getValue();
23096 if (Value.getOpcode() == ISD::TargetConstantFP)
23097 return SDValue();
23098
23099 if (!ISD::isNormalStore(N: ST))
23100 return SDValue();
23101
23102 SDLoc DL(ST);
23103
23104 SDValue Chain = ST->getChain();
23105 SDValue Ptr = ST->getBasePtr();
23106
23107 const ConstantFPSDNode *CFP = cast<ConstantFPSDNode>(Val&: Value);
23108
23109 // NOTE: If the original store is volatile, this transform must not increase
23110 // the number of stores. For example, on x86-32 an f64 can be stored in one
23111 // processor operation but an i64 (which is not legal) requires two. So the
23112 // transform should not be done in this case.
23113
23114 SDValue Tmp;
23115 switch (CFP->getSimpleValueType(ResNo: 0).SimpleTy) {
23116 default:
23117 llvm_unreachable("Unknown FP type");
23118 case MVT::f16: // We don't do this for these yet.
23119 case MVT::bf16:
23120 case MVT::f80:
23121 case MVT::f128:
23122 case MVT::ppcf128:
23123 return SDValue();
23124 case MVT::f32:
23125 if ((isTypeLegal(VT: MVT::i32) && !LegalOperations && ST->isSimple()) ||
23126 TLI.isOperationLegalOrCustom(Op: ISD::STORE, VT: MVT::i32)) {
23127 Tmp = DAG.getConstant(Val: (uint32_t)CFP->getValueAPF().
23128 bitcastToAPInt().getZExtValue(), DL: SDLoc(CFP),
23129 VT: MVT::i32);
23130 return DAG.getStore(Chain, dl: DL, Val: Tmp, Ptr, MMO: ST->getMemOperand());
23131 }
23132
23133 return SDValue();
23134 case MVT::f64:
23135 if ((TLI.isTypeLegal(VT: MVT::i64) && !LegalOperations &&
23136 ST->isSimple()) ||
23137 TLI.isOperationLegalOrCustom(Op: ISD::STORE, VT: MVT::i64)) {
23138 Tmp = DAG.getConstant(Val: CFP->getValueAPF().bitcastToAPInt().
23139 getZExtValue(), DL: SDLoc(CFP), VT: MVT::i64);
23140 return DAG.getStore(Chain, dl: DL, Val: Tmp,
23141 Ptr, MMO: ST->getMemOperand());
23142 }
23143
23144 if (ST->isSimple() && TLI.isOperationLegalOrCustom(Op: ISD::STORE, VT: MVT::i32) &&
23145 !TLI.isFPImmLegal(CFP->getValueAPF(), MVT::f64)) {
23146 // Many FP stores are not made apparent until after legalize, e.g. for
23147 // argument passing. Since this is so common, custom legalize the
23148 // 64-bit integer store into two 32-bit stores.
23149 uint64_t Val = CFP->getValueAPF().bitcastToAPInt().getZExtValue();
23150 SDValue Lo = DAG.getConstant(Val: Val & 0xFFFFFFFF, DL: SDLoc(CFP), VT: MVT::i32);
23151 SDValue Hi = DAG.getConstant(Val: Val >> 32, DL: SDLoc(CFP), VT: MVT::i32);
23152 if (DAG.getDataLayout().isBigEndian())
23153 std::swap(a&: Lo, b&: Hi);
23154
23155 MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
23156 AAMDNodes AAInfo = ST->getAAInfo();
23157
23158 SDValue St0 = DAG.getStore(Chain, dl: DL, Val: Lo, Ptr, PtrInfo: ST->getPointerInfo(),
23159 Alignment: ST->getBaseAlign(), MMOFlags, AAInfo);
23160 Ptr = DAG.getMemBasePlusOffset(Base: Ptr, Offset: TypeSize::getFixed(ExactSize: 4), DL);
23161 SDValue St1 = DAG.getStore(Chain, dl: DL, Val: Hi, Ptr,
23162 PtrInfo: ST->getPointerInfo().getWithOffset(O: 4),
23163 Alignment: ST->getBaseAlign(), MMOFlags, AAInfo);
23164 return DAG.getNode(Opcode: ISD::TokenFactor, DL, VT: MVT::Other,
23165 N1: St0, N2: St1);
23166 }
23167
23168 return SDValue();
23169 }
23170}
23171
23172// (store (insert_vector_elt (load p), x, i), p) -> (store x, p+offset)
23173//
23174// If a store of a load with an element inserted into it has no other
23175// uses in between the chain, then we can consider the vector store
23176// dead and replace it with just the single scalar element store.
23177SDValue DAGCombiner::replaceStoreOfInsertLoad(StoreSDNode *ST) {
23178 SDLoc DL(ST);
23179 SDValue Value = ST->getValue();
23180 SDValue Ptr = ST->getBasePtr();
23181 SDValue Chain = ST->getChain();
23182 if (Value.getOpcode() != ISD::INSERT_VECTOR_ELT || !Value.hasOneUse())
23183 return SDValue();
23184
23185 SDValue Elt = Value.getOperand(i: 1);
23186 SDValue Idx = Value.getOperand(i: 2);
23187
23188 // If the element isn't byte sized or is implicitly truncated then we can't
23189 // compute an offset.
23190 EVT EltVT = Elt.getValueType();
23191 if (!EltVT.isByteSized() ||
23192 EltVT != Value.getOperand(i: 0).getValueType().getVectorElementType())
23193 return SDValue();
23194
23195 auto *Ld = dyn_cast<LoadSDNode>(Val: Value.getOperand(i: 0));
23196 if (!Ld || Ld->getBasePtr() != Ptr ||
23197 ST->getMemoryVT() != Ld->getMemoryVT() || !ST->isSimple() ||
23198 !ISD::isNormalStore(N: ST) ||
23199 Ld->getAddressSpace() != ST->getAddressSpace() ||
23200 !Chain.reachesChainWithoutSideEffects(Dest: SDValue(Ld, 1)))
23201 return SDValue();
23202
23203 unsigned IsFast;
23204 if (!TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(),
23205 VT: Elt.getValueType(), AddrSpace: ST->getAddressSpace(),
23206 Alignment: ST->getAlign(), Flags: ST->getMemOperand()->getFlags(),
23207 Fast: &IsFast) ||
23208 !IsFast)
23209 return SDValue();
23210
23211 MachinePointerInfo PointerInfo(ST->getAddressSpace());
23212
23213 // If the offset is a known constant then try to recover the pointer
23214 // info
23215 SDValue NewPtr;
23216 if (auto *CIdx = dyn_cast<ConstantSDNode>(Val&: Idx)) {
23217 unsigned COffset = CIdx->getSExtValue() * EltVT.getSizeInBits() / 8;
23218 NewPtr = DAG.getMemBasePlusOffset(Base: Ptr, Offset: TypeSize::getFixed(ExactSize: COffset), DL);
23219 PointerInfo = ST->getPointerInfo().getWithOffset(O: COffset);
23220 } else {
23221 // The original DAG loaded the entire vector from memory, so arithmetic
23222 // within it must be inbounds.
23223 NewPtr = TLI.getInboundsVectorElementPointer(DAG, VecPtr: Ptr, VecVT: Value.getValueType(),
23224 Index: Idx);
23225 }
23226
23227 return DAG.getStore(Chain, dl: DL, Val: Elt, Ptr: NewPtr, PtrInfo: PointerInfo, Alignment: ST->getAlign(),
23228 MMOFlags: ST->getMemOperand()->getFlags());
23229}
23230
23231SDValue DAGCombiner::visitATOMIC_STORE(SDNode *N) {
23232 AtomicSDNode *ST = cast<AtomicSDNode>(Val: N);
23233 SDValue Val = ST->getVal();
23234 EVT VT = Val.getValueType();
23235 EVT MemVT = ST->getMemoryVT();
23236
23237 if (MemVT.bitsLT(VT)) { // Is truncating store
23238 APInt TruncDemandedBits = APInt::getLowBitsSet(numBits: VT.getScalarSizeInBits(),
23239 loBitsSet: MemVT.getScalarSizeInBits());
23240 // See if we can simplify the operation with SimplifyDemandedBits, which
23241 // only works if the value has a single use.
23242 if (SimplifyDemandedBits(Op: Val, DemandedBits: TruncDemandedBits))
23243 return SDValue(N, 0);
23244 }
23245
23246 return SDValue();
23247}
23248
23249static SDValue foldToMaskedStore(StoreSDNode *Store, SelectionDAG &DAG,
23250 const SDLoc &Dl) {
23251 if (!Store->isSimple() || !ISD::isNormalStore(N: Store))
23252 return SDValue();
23253
23254 SDValue StoredVal = Store->getValue();
23255 SDValue StorePtr = Store->getBasePtr();
23256 SDValue StoreOffset = Store->getOffset();
23257 EVT VT = Store->getMemoryVT();
23258
23259 // Skip this combine for non-vector types and for <1 x ty> vectors, as they
23260 // will be scalarized later.
23261 if (!VT.isVector() || VT.isScalableVector() || VT.getVectorNumElements() == 1)
23262 return SDValue();
23263
23264 unsigned AddrSpace = Store->getAddressSpace();
23265 Align Alignment = Store->getAlign();
23266 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
23267
23268 if (!TLI.isOperationLegalOrCustom(Op: ISD::MSTORE, VT) ||
23269 !TLI.allowsMisalignedMemoryAccesses(VT, AddrSpace, Alignment))
23270 return SDValue();
23271
23272 SDValue Mask, OtherVec, LoadCh;
23273 unsigned LoadPos;
23274 if (sd_match(N: StoredVal,
23275 P: m_VSelect(Cond: m_Value(N&: Mask), T: m_Value(N&: OtherVec),
23276 F: m_Load(Ch: m_Value(N&: LoadCh), Ptr: m_Specific(N: StorePtr),
23277 Offset: m_Specific(N: StoreOffset))))) {
23278 LoadPos = 2;
23279 } else if (sd_match(N: StoredVal,
23280 P: m_VSelect(Cond: m_Value(N&: Mask),
23281 T: m_Load(Ch: m_Value(N&: LoadCh), Ptr: m_Specific(N: StorePtr),
23282 Offset: m_Specific(N: StoreOffset)),
23283 F: m_Value(N&: OtherVec)))) {
23284 LoadPos = 1;
23285 } else {
23286 return SDValue();
23287 }
23288
23289 auto *Load = cast<LoadSDNode>(Val: StoredVal.getOperand(i: LoadPos));
23290 if (!Load->isSimple() || !ISD::isNormalLoad(N: Load) ||
23291 Load->getAddressSpace() != AddrSpace)
23292 return SDValue();
23293
23294 if (!Store->getChain().reachesChainWithoutSideEffects(Dest: LoadCh))
23295 return SDValue();
23296
23297 if (LoadPos == 1)
23298 Mask = DAG.getNOT(DL: Dl, Val: Mask, VT: Mask.getValueType());
23299
23300 return DAG.getMaskedStore(Chain: Store->getChain(), dl: Dl, Val: OtherVec, Base: StorePtr,
23301 Offset: StoreOffset, Mask, MemVT: VT, MMO: Store->getMemOperand(),
23302 AM: Store->getAddressingMode());
23303}
23304
23305SDValue DAGCombiner::visitSTORE(SDNode *N) {
23306 StoreSDNode *ST = cast<StoreSDNode>(Val: N);
23307 SDValue Chain = ST->getChain();
23308 SDValue Value = ST->getValue();
23309 SDValue Ptr = ST->getBasePtr();
23310
23311 // If this is a store of a bit convert, store the input value if the
23312 // resultant store does not need a higher alignment than the original.
23313 if (Value.getOpcode() == ISD::BITCAST && !ST->isTruncatingStore() &&
23314 ST->isUnindexed()) {
23315 EVT SVT = Value.getOperand(i: 0).getValueType();
23316 // If the store is volatile, we only want to change the store type if the
23317 // resulting store is legal. Otherwise we might increase the number of
23318 // memory accesses. We don't care if the original type was legal or not
23319 // as we assume software couldn't rely on the number of accesses of an
23320 // illegal type.
23321 // TODO: May be able to relax for unordered atomics (see D66309)
23322 if (((!LegalOperations && ST->isSimple()) ||
23323 TLI.isOperationLegal(Op: ISD::STORE, VT: SVT)) &&
23324 TLI.isStoreBitCastBeneficial(StoreVT: Value.getValueType(), BitcastVT: SVT,
23325 DAG, MMO: *ST->getMemOperand())) {
23326 return DAG.getStore(Chain, dl: SDLoc(N), Val: Value.getOperand(i: 0), Ptr,
23327 MMO: ST->getMemOperand());
23328 }
23329 }
23330
23331 // Turn 'store undef, Ptr' -> nothing.
23332 if (Value.isUndef() && ST->isUnindexed() && !ST->isVolatile())
23333 return Chain;
23334
23335 // Try to infer better alignment information than the store already has.
23336 if (OptLevel != CodeGenOptLevel::None && ST->isUnindexed() &&
23337 !ST->isAtomic()) {
23338 if (MaybeAlign Alignment = DAG.InferPtrAlign(Ptr)) {
23339 if (*Alignment > ST->getAlign() &&
23340 isAligned(Lhs: *Alignment, SizeInBytes: ST->getSrcValueOffset())) {
23341 SDValue NewStore =
23342 DAG.getTruncStore(Chain, dl: SDLoc(N), Val: Value, Ptr, PtrInfo: ST->getPointerInfo(),
23343 SVT: ST->getMemoryVT(), Alignment: *Alignment,
23344 MMOFlags: ST->getMemOperand()->getFlags(), AAInfo: ST->getAAInfo());
23345 // NewStore will always be N as we are only refining the alignment
23346 assert(NewStore.getNode() == N);
23347 (void)NewStore;
23348 }
23349 }
23350 }
23351
23352 // Try transforming a pair floating point load / store ops to integer
23353 // load / store ops.
23354 if (SDValue NewST = TransformFPLoadStorePair(N))
23355 return NewST;
23356
23357 // Try transforming several stores into STORE (BSWAP).
23358 if (SDValue Store = mergeTruncStores(N: ST))
23359 return Store;
23360
23361 if (ST->isUnindexed()) {
23362 // Walk up chain skipping non-aliasing memory nodes, on this store and any
23363 // adjacent stores.
23364 if (findBetterNeighborChains(St: ST)) {
23365 // replaceStoreChain uses CombineTo, which handled all of the worklist
23366 // manipulation. Return the original node to not do anything else.
23367 return SDValue(ST, 0);
23368 }
23369 Chain = ST->getChain();
23370 }
23371
23372 // FIXME: is there such a thing as a truncating indexed store?
23373 if (ST->isTruncatingStore() && ST->isUnindexed() &&
23374 Value.getValueType().isInteger() &&
23375 (!isa<ConstantSDNode>(Val: Value) ||
23376 !cast<ConstantSDNode>(Val&: Value)->isOpaque())) {
23377 // Convert a truncating store of a extension into a standard store.
23378 if ((Value.getOpcode() == ISD::ZERO_EXTEND ||
23379 Value.getOpcode() == ISD::SIGN_EXTEND ||
23380 Value.getOpcode() == ISD::ANY_EXTEND) &&
23381 Value.getOperand(i: 0).getValueType() == ST->getMemoryVT() &&
23382 TLI.isOperationLegalOrCustom(Op: ISD::STORE, VT: ST->getMemoryVT()))
23383 return DAG.getStore(Chain, dl: SDLoc(N), Val: Value.getOperand(i: 0), Ptr,
23384 MMO: ST->getMemOperand());
23385
23386 APInt TruncDemandedBits =
23387 APInt::getLowBitsSet(numBits: Value.getScalarValueSizeInBits(),
23388 loBitsSet: ST->getMemoryVT().getScalarSizeInBits());
23389
23390 // See if we can simplify the operation with SimplifyDemandedBits, which
23391 // only works if the value has a single use.
23392 AddToWorklist(N: Value.getNode());
23393 if (SimplifyDemandedBits(Op: Value, DemandedBits: TruncDemandedBits)) {
23394 // Re-visit the store if anything changed and the store hasn't been merged
23395 // with another node (N is deleted) SimplifyDemandedBits will add Value's
23396 // node back to the worklist if necessary, but we also need to re-visit
23397 // the Store node itself.
23398 if (N->getOpcode() != ISD::DELETED_NODE)
23399 AddToWorklist(N);
23400 return SDValue(N, 0);
23401 }
23402
23403 // Otherwise, see if we can simplify the input to this truncstore with
23404 // knowledge that only the low bits are being used. For example:
23405 // "truncstore (or (shl x, 8), y), i8" -> "truncstore y, i8"
23406 if (SDValue Shorter =
23407 TLI.SimplifyMultipleUseDemandedBits(Op: Value, DemandedBits: TruncDemandedBits, DAG))
23408 return DAG.getTruncStore(Chain, dl: SDLoc(N), Val: Shorter, Ptr, SVT: ST->getMemoryVT(),
23409 MMO: ST->getMemOperand());
23410
23411 // If we're storing a truncated constant, see if we can simplify it.
23412 // TODO: Move this to targetShrinkDemandedConstant?
23413 if (auto *Cst = dyn_cast<ConstantSDNode>(Val&: Value))
23414 if (!Cst->isOpaque()) {
23415 const APInt &CValue = Cst->getAPIntValue();
23416 APInt NewVal = CValue & TruncDemandedBits;
23417 if (NewVal != CValue) {
23418 SDValue Shorter =
23419 DAG.getConstant(Val: NewVal, DL: SDLoc(N), VT: Value.getValueType());
23420 return DAG.getTruncStore(Chain, dl: SDLoc(N), Val: Shorter, Ptr,
23421 SVT: ST->getMemoryVT(), MMO: ST->getMemOperand());
23422 }
23423 }
23424 }
23425
23426 // If this is a load followed by a store to the same location, then the store
23427 // is dead/noop. Peek through any truncates if canCombineTruncStore failed.
23428 // TODO: Add big-endian truncate support with test coverage.
23429 // TODO: Can relax for unordered atomics (see D66309)
23430 SDValue TruncVal = DAG.getDataLayout().isLittleEndian()
23431 ? peekThroughTruncates(V: Value)
23432 : Value;
23433 if (auto *Ld = dyn_cast<LoadSDNode>(Val&: TruncVal)) {
23434 if (Ld->getBasePtr() == Ptr && ST->getMemoryVT() == Ld->getMemoryVT() &&
23435 ST->isUnindexed() && ST->isSimple() &&
23436 Ld->getAddressSpace() == ST->getAddressSpace() &&
23437 // There can't be any side effects between the load and store, such as
23438 // a call or store.
23439 Chain.reachesChainWithoutSideEffects(Dest: SDValue(Ld, 1))) {
23440 // The store is dead, remove it.
23441 return Chain;
23442 }
23443 }
23444
23445 // Try scalarizing vector stores of loads where we only change one element
23446 if (SDValue NewST = replaceStoreOfInsertLoad(ST))
23447 return NewST;
23448
23449 // TODO: Can relax for unordered atomics (see D66309)
23450 if (StoreSDNode *ST1 = dyn_cast<StoreSDNode>(Val&: Chain)) {
23451 if (ST->isUnindexed() && ST->isSimple() &&
23452 ST1->isUnindexed() && ST1->isSimple()) {
23453 if (OptLevel != CodeGenOptLevel::None && ST1->getBasePtr() == Ptr &&
23454 ST1->getValue() == Value && ST->getMemoryVT() == ST1->getMemoryVT() &&
23455 ST->getAddressSpace() == ST1->getAddressSpace()) {
23456 // If this is a store followed by a store with the same value to the
23457 // same location, then the store is dead/noop.
23458 return Chain;
23459 }
23460
23461 if (OptLevel != CodeGenOptLevel::None && ST1->hasOneUse() &&
23462 !ST1->getBasePtr().isUndef() &&
23463 ST->getAddressSpace() == ST1->getAddressSpace()) {
23464 // If we consider two stores and one smaller in size is a scalable
23465 // vector type and another one a bigger size store with a fixed type,
23466 // then we could not allow the scalable store removal because we don't
23467 // know its final size in the end.
23468 if (ST->getMemoryVT().isScalableVector() ||
23469 ST1->getMemoryVT().isScalableVector()) {
23470 if (ST1->getBasePtr() == Ptr &&
23471 TypeSize::isKnownLE(LHS: ST1->getMemoryVT().getStoreSize(),
23472 RHS: ST->getMemoryVT().getStoreSize())) {
23473 CombineTo(N: ST1, Res: ST1->getChain());
23474 return SDValue(N, 0);
23475 }
23476 } else {
23477 const BaseIndexOffset STBase = BaseIndexOffset::match(N: ST, DAG);
23478 const BaseIndexOffset ChainBase = BaseIndexOffset::match(N: ST1, DAG);
23479 // If this is a store who's preceding store to a subset of the current
23480 // location and no one other node is chained to that store we can
23481 // effectively drop the store. Do not remove stores to undef as they
23482 // may be used as data sinks.
23483 if (STBase.contains(DAG, BitSize: ST->getMemoryVT().getFixedSizeInBits(),
23484 Other: ChainBase,
23485 OtherBitSize: ST1->getMemoryVT().getFixedSizeInBits())) {
23486 CombineTo(N: ST1, Res: ST1->getChain());
23487 return SDValue(N, 0);
23488 }
23489 }
23490 }
23491 }
23492 }
23493
23494 // If this is an FP_ROUND or TRUNC followed by a store, fold this into a
23495 // truncating store. We can do this even if this is already a truncstore.
23496 if ((Value.getOpcode() == ISD::FP_ROUND ||
23497 Value.getOpcode() == ISD::TRUNCATE) &&
23498 Value->hasOneUse() && ST->isUnindexed() &&
23499 TLI.canCombineTruncStore(ValVT: Value.getOperand(i: 0).getValueType(),
23500 MemVT: ST->getMemoryVT(), LegalOnly: LegalOperations)) {
23501 return DAG.getTruncStore(Chain, dl: SDLoc(N), Val: Value.getOperand(i: 0),
23502 Ptr, SVT: ST->getMemoryVT(), MMO: ST->getMemOperand());
23503 }
23504
23505 // Always perform this optimization before types are legal. If the target
23506 // prefers, also try this after legalization to catch stores that were created
23507 // by intrinsics or other nodes.
23508 if (!LegalTypes || (TLI.mergeStoresAfterLegalization(MemVT: ST->getMemoryVT()))) {
23509 while (true) {
23510 // There can be multiple store sequences on the same chain.
23511 // Keep trying to merge store sequences until we are unable to do so
23512 // or until we merge the last store on the chain.
23513 bool Changed = mergeConsecutiveStores(St: ST);
23514 if (!Changed) break;
23515 // Return N as merge only uses CombineTo and no worklist clean
23516 // up is necessary.
23517 if (N->getOpcode() == ISD::DELETED_NODE || !isa<StoreSDNode>(Val: N))
23518 return SDValue(N, 0);
23519 }
23520 }
23521
23522 // Try transforming N to an indexed store.
23523 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
23524 return SDValue(N, 0);
23525
23526 // Turn 'store float 1.0, Ptr' -> 'store int 0x12345678, Ptr'
23527 //
23528 // Make sure to do this only after attempting to merge stores in order to
23529 // avoid changing the types of some subset of stores due to visit order,
23530 // preventing their merging.
23531 if (isa<ConstantFPSDNode>(Val: ST->getValue())) {
23532 if (SDValue NewSt = replaceStoreOfFPConstant(ST))
23533 return NewSt;
23534 }
23535
23536 if (SDValue NewSt = splitMergedValStore(ST))
23537 return NewSt;
23538
23539 if (SDValue MaskedStore = foldToMaskedStore(Store: ST, DAG, Dl: SDLoc(N)))
23540 return MaskedStore;
23541
23542 return ReduceLoadOpStoreWidth(N);
23543}
23544
23545SDValue DAGCombiner::visitLIFETIME_END(SDNode *N) {
23546 const auto *LifetimeEnd = cast<LifetimeSDNode>(Val: N);
23547 const BaseIndexOffset LifetimeEndBase(N->getOperand(Num: 1), SDValue(), 0, false);
23548
23549 // We walk up the chains to find stores.
23550 SmallVector<SDValue, 8> Chains = {N->getOperand(Num: 0)};
23551 while (!Chains.empty()) {
23552 SDValue Chain = Chains.pop_back_val();
23553 if (!Chain.hasOneUse())
23554 continue;
23555 switch (Chain.getOpcode()) {
23556 case ISD::TokenFactor:
23557 for (unsigned Nops = Chain.getNumOperands(); Nops;)
23558 Chains.push_back(Elt: Chain.getOperand(i: --Nops));
23559 break;
23560 case ISD::LIFETIME_START:
23561 case ISD::LIFETIME_END:
23562 // We can forward past any lifetime start/end that can be proven not to
23563 // alias the node.
23564 if (!mayAlias(Op0: Chain.getNode(), Op1: N))
23565 Chains.push_back(Elt: Chain.getOperand(i: 0));
23566 break;
23567 case ISD::STORE: {
23568 StoreSDNode *ST = dyn_cast<StoreSDNode>(Val&: Chain);
23569 // TODO: Can relax for unordered atomics (see D66309)
23570 if (!ST->isSimple() || ST->isIndexed())
23571 continue;
23572 const TypeSize StoreSize = ST->getMemoryVT().getStoreSize();
23573 // The bounds of a scalable store are not known until runtime, so this
23574 // store cannot be elided.
23575 if (StoreSize.isScalable())
23576 continue;
23577 const BaseIndexOffset StoreBase = BaseIndexOffset::match(N: ST, DAG);
23578 // If we store purely within object bounds just before its lifetime ends,
23579 // we can remove the store.
23580 MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo();
23581 if (LifetimeEndBase.contains(
23582 DAG, BitSize: MFI.getObjectSize(ObjectIdx: LifetimeEnd->getFrameIndex()) * 8,
23583 Other: StoreBase, OtherBitSize: StoreSize.getFixedValue() * 8)) {
23584 LLVM_DEBUG(dbgs() << "\nRemoving store:"; StoreBase.dump();
23585 dbgs() << "\nwithin LIFETIME_END of : ";
23586 LifetimeEndBase.dump(); dbgs() << "\n");
23587 CombineTo(N: ST, Res: ST->getChain());
23588 return SDValue(N, 0);
23589 }
23590 }
23591 }
23592 }
23593 return SDValue();
23594}
23595
23596/// For the instruction sequence of store below, F and I values
23597/// are bundled together as an i64 value before being stored into memory.
23598/// Sometimes it is more efficent to generate separate stores for F and I,
23599/// which can remove the bitwise instructions or sink them to colder places.
23600///
23601/// (store (or (zext (bitcast F to i32) to i64),
23602/// (shl (zext I to i64), 32)), addr) -->
23603/// (store F, addr) and (store I, addr+4)
23604///
23605/// Similarly, splitting for other merged store can also be beneficial, like:
23606/// For pair of {i32, i32}, i64 store --> two i32 stores.
23607/// For pair of {i32, i16}, i64 store --> two i32 stores.
23608/// For pair of {i16, i16}, i32 store --> two i16 stores.
23609/// For pair of {i16, i8}, i32 store --> two i16 stores.
23610/// For pair of {i8, i8}, i16 store --> two i8 stores.
23611///
23612/// We allow each target to determine specifically which kind of splitting is
23613/// supported.
23614///
23615/// The store patterns are commonly seen from the simple code snippet below
23616/// if only std::make_pair(...) is sroa transformed before inlined into hoo.
23617/// void goo(const std::pair<int, float> &);
23618/// hoo() {
23619/// ...
23620/// goo(std::make_pair(tmp, ftmp));
23621/// ...
23622/// }
23623///
23624SDValue DAGCombiner::splitMergedValStore(StoreSDNode *ST) {
23625 if (OptLevel == CodeGenOptLevel::None)
23626 return SDValue();
23627
23628 // Can't change the number of memory accesses for a volatile store or break
23629 // atomicity for an atomic one.
23630 if (!ST->isSimple())
23631 return SDValue();
23632
23633 SDValue Val = ST->getValue();
23634 SDLoc DL(ST);
23635
23636 // Match OR operand.
23637 if (!Val.getValueType().isScalarInteger() || Val.getOpcode() != ISD::OR)
23638 return SDValue();
23639
23640 // Match SHL operand and get Lower and Higher parts of Val.
23641 SDValue Op1 = Val.getOperand(i: 0);
23642 SDValue Op2 = Val.getOperand(i: 1);
23643 SDValue Lo, Hi;
23644 if (Op1.getOpcode() != ISD::SHL) {
23645 std::swap(a&: Op1, b&: Op2);
23646 if (Op1.getOpcode() != ISD::SHL)
23647 return SDValue();
23648 }
23649 Lo = Op2;
23650 Hi = Op1.getOperand(i: 0);
23651 if (!Op1.hasOneUse())
23652 return SDValue();
23653
23654 // Match shift amount to HalfValBitSize.
23655 unsigned HalfValBitSize = Val.getValueSizeInBits() / 2;
23656 ConstantSDNode *ShAmt = dyn_cast<ConstantSDNode>(Val: Op1.getOperand(i: 1));
23657 if (!ShAmt || ShAmt->getAPIntValue() != HalfValBitSize)
23658 return SDValue();
23659
23660 // Lo and Hi are zero-extended from int with size less equal than 32
23661 // to i64.
23662 if (Lo.getOpcode() != ISD::ZERO_EXTEND || !Lo.hasOneUse() ||
23663 !Lo.getOperand(i: 0).getValueType().isScalarInteger() ||
23664 Lo.getOperand(i: 0).getValueSizeInBits() > HalfValBitSize ||
23665 Hi.getOpcode() != ISD::ZERO_EXTEND || !Hi.hasOneUse() ||
23666 !Hi.getOperand(i: 0).getValueType().isScalarInteger() ||
23667 Hi.getOperand(i: 0).getValueSizeInBits() > HalfValBitSize)
23668 return SDValue();
23669
23670 // Use the EVT of low and high parts before bitcast as the input
23671 // of target query.
23672 EVT LowTy = (Lo.getOperand(i: 0).getOpcode() == ISD::BITCAST)
23673 ? Lo.getOperand(i: 0).getValueType()
23674 : Lo.getValueType();
23675 EVT HighTy = (Hi.getOperand(i: 0).getOpcode() == ISD::BITCAST)
23676 ? Hi.getOperand(i: 0).getValueType()
23677 : Hi.getValueType();
23678 if (!TLI.isMultiStoresCheaperThanBitsMerge(LTy: LowTy, HTy: HighTy))
23679 return SDValue();
23680
23681 // Start to split store.
23682 MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
23683 AAMDNodes AAInfo = ST->getAAInfo();
23684
23685 // Change the sizes of Lo and Hi's value types to HalfValBitSize.
23686 EVT VT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: HalfValBitSize);
23687 Lo = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: Lo.getOperand(i: 0));
23688 Hi = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: Hi.getOperand(i: 0));
23689
23690 SDValue Chain = ST->getChain();
23691 SDValue Ptr = ST->getBasePtr();
23692 // Lower value store.
23693 SDValue St0 = DAG.getStore(Chain, dl: DL, Val: Lo, Ptr, PtrInfo: ST->getPointerInfo(),
23694 Alignment: ST->getBaseAlign(), MMOFlags, AAInfo);
23695 Ptr =
23696 DAG.getMemBasePlusOffset(Base: Ptr, Offset: TypeSize::getFixed(ExactSize: HalfValBitSize / 8), DL);
23697 // Higher value store.
23698 SDValue St1 = DAG.getStore(
23699 Chain: St0, dl: DL, Val: Hi, Ptr, PtrInfo: ST->getPointerInfo().getWithOffset(O: HalfValBitSize / 8),
23700 Alignment: ST->getBaseAlign(), MMOFlags, AAInfo);
23701 return St1;
23702}
23703
23704// Merge an insertion into an existing shuffle:
23705// (insert_vector_elt (vector_shuffle X, Y, Mask),
23706// .(extract_vector_elt X, N), InsIndex)
23707// --> (vector_shuffle X, Y, NewMask)
23708// and variations where shuffle operands may be CONCAT_VECTORS.
23709static bool mergeEltWithShuffle(SDValue &X, SDValue &Y, ArrayRef<int> Mask,
23710 SmallVectorImpl<int> &NewMask, SDValue Elt,
23711 unsigned InsIndex) {
23712 if (Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
23713 !isa<ConstantSDNode>(Val: Elt.getOperand(i: 1)))
23714 return false;
23715
23716 // Vec's operand 0 is using indices from 0 to N-1 and
23717 // operand 1 from N to 2N - 1, where N is the number of
23718 // elements in the vectors.
23719 SDValue InsertVal0 = Elt.getOperand(i: 0);
23720 int ElementOffset = -1;
23721
23722 // We explore the inputs of the shuffle in order to see if we find the
23723 // source of the extract_vector_elt. If so, we can use it to modify the
23724 // shuffle rather than perform an insert_vector_elt.
23725 SmallVector<std::pair<int, SDValue>, 8> ArgWorkList;
23726 ArgWorkList.emplace_back(Args: Mask.size(), Args&: Y);
23727 ArgWorkList.emplace_back(Args: 0, Args&: X);
23728
23729 while (!ArgWorkList.empty()) {
23730 int ArgOffset;
23731 SDValue ArgVal;
23732 std::tie(args&: ArgOffset, args&: ArgVal) = ArgWorkList.pop_back_val();
23733
23734 if (ArgVal == InsertVal0) {
23735 ElementOffset = ArgOffset;
23736 break;
23737 }
23738
23739 // Peek through concat_vector.
23740 if (ArgVal.getOpcode() == ISD::CONCAT_VECTORS) {
23741 int CurrentArgOffset =
23742 ArgOffset + ArgVal.getValueType().getVectorNumElements();
23743 int Step = ArgVal.getOperand(i: 0).getValueType().getVectorNumElements();
23744 for (SDValue Op : reverse(C: ArgVal->ops())) {
23745 CurrentArgOffset -= Step;
23746 ArgWorkList.emplace_back(Args&: CurrentArgOffset, Args&: Op);
23747 }
23748
23749 // Make sure we went through all the elements and did not screw up index
23750 // computation.
23751 assert(CurrentArgOffset == ArgOffset);
23752 }
23753 }
23754
23755 // If we failed to find a match, see if we can replace an UNDEF shuffle
23756 // operand.
23757 if (ElementOffset == -1) {
23758 if (!Y.isUndef() || InsertVal0.getValueType() != Y.getValueType())
23759 return false;
23760 ElementOffset = Mask.size();
23761 Y = InsertVal0;
23762 }
23763
23764 NewMask.assign(in_start: Mask.begin(), in_end: Mask.end());
23765 NewMask[InsIndex] = ElementOffset + Elt.getConstantOperandVal(i: 1);
23766 assert(NewMask[InsIndex] < (int)(2 * Mask.size()) && NewMask[InsIndex] >= 0 &&
23767 "NewMask[InsIndex] is out of bound");
23768 return true;
23769}
23770
23771// Merge an insertion into an existing shuffle:
23772// (insert_vector_elt (vector_shuffle X, Y), (extract_vector_elt X, N),
23773// InsIndex)
23774// --> (vector_shuffle X, Y) and variations where shuffle operands may be
23775// CONCAT_VECTORS.
23776SDValue DAGCombiner::mergeInsertEltWithShuffle(SDNode *N, unsigned InsIndex) {
23777 assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT &&
23778 "Expected extract_vector_elt");
23779 SDValue InsertVal = N->getOperand(Num: 1);
23780 SDValue Vec = N->getOperand(Num: 0);
23781
23782 auto *SVN = dyn_cast<ShuffleVectorSDNode>(Val&: Vec);
23783 if (!SVN || !Vec.hasOneUse())
23784 return SDValue();
23785
23786 ArrayRef<int> Mask = SVN->getMask();
23787 SDValue X = Vec.getOperand(i: 0);
23788 SDValue Y = Vec.getOperand(i: 1);
23789
23790 SmallVector<int, 16> NewMask(Mask);
23791 if (mergeEltWithShuffle(X, Y, Mask, NewMask, Elt: InsertVal, InsIndex)) {
23792 SDValue LegalShuffle = TLI.buildLegalVectorShuffle(
23793 VT: Vec.getValueType(), DL: SDLoc(N), N0: X, N1: Y, Mask: NewMask, DAG);
23794 if (LegalShuffle)
23795 return LegalShuffle;
23796 }
23797
23798 return SDValue();
23799}
23800
23801// Convert a disguised subvector insertion into a shuffle:
23802// insert_vector_elt V, (bitcast X from vector type), IdxC -->
23803// bitcast(shuffle (bitcast V), (extended X), Mask)
23804// Note: We do not use an insert_subvector node because that requires a
23805// legal subvector type.
23806SDValue DAGCombiner::combineInsertEltToShuffle(SDNode *N, unsigned InsIndex) {
23807 assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT &&
23808 "Expected extract_vector_elt");
23809 SDValue InsertVal = N->getOperand(Num: 1);
23810
23811 if (InsertVal.getOpcode() != ISD::BITCAST || !InsertVal.hasOneUse() ||
23812 !InsertVal.getOperand(i: 0).getValueType().isVector())
23813 return SDValue();
23814
23815 SDValue SubVec = InsertVal.getOperand(i: 0);
23816 SDValue DestVec = N->getOperand(Num: 0);
23817 EVT SubVecVT = SubVec.getValueType();
23818 EVT VT = DestVec.getValueType();
23819 unsigned NumSrcElts = SubVecVT.getVectorNumElements();
23820 // Bail out if the inserted value is larger than the vector element, as
23821 // insert_vector_elt performs an implicit truncation in this case.
23822 if (InsertVal.getValueType() != VT.getVectorElementType())
23823 return SDValue();
23824 // If the source only has a single vector element, the cost of creating adding
23825 // it to a vector is likely to exceed the cost of a insert_vector_elt.
23826 if (NumSrcElts == 1)
23827 return SDValue();
23828 unsigned ExtendRatio = VT.getSizeInBits() / SubVecVT.getSizeInBits();
23829 unsigned NumMaskVals = ExtendRatio * NumSrcElts;
23830
23831 // Step 1: Create a shuffle mask that implements this insert operation. The
23832 // vector that we are inserting into will be operand 0 of the shuffle, so
23833 // those elements are just 'i'. The inserted subvector is in the first
23834 // positions of operand 1 of the shuffle. Example:
23835 // insert v4i32 V, (v2i16 X), 2 --> shuffle v8i16 V', X', {0,1,2,3,8,9,6,7}
23836 SmallVector<int, 16> Mask(NumMaskVals);
23837 for (unsigned i = 0; i != NumMaskVals; ++i) {
23838 if (i / NumSrcElts == InsIndex)
23839 Mask[i] = (i % NumSrcElts) + NumMaskVals;
23840 else
23841 Mask[i] = i;
23842 }
23843
23844 // Bail out if the target can not handle the shuffle we want to create.
23845 EVT SubVecEltVT = SubVecVT.getVectorElementType();
23846 EVT ShufVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: SubVecEltVT, NumElements: NumMaskVals);
23847 if (!TLI.isShuffleMaskLegal(Mask, ShufVT))
23848 return SDValue();
23849
23850 // Step 2: Create a wide vector from the inserted source vector by appending
23851 // poison elements. This is the same size as our destination vector.
23852 SDLoc DL(N);
23853 SmallVector<SDValue, 8> ConcatOps(ExtendRatio, DAG.getPOISON(VT: SubVecVT));
23854 ConcatOps[0] = SubVec;
23855 SDValue PaddedSubV = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT: ShufVT, Ops: ConcatOps);
23856
23857 // Step 3: Shuffle in the padded subvector.
23858 SDValue DestVecBC = DAG.getBitcast(VT: ShufVT, V: DestVec);
23859 SDValue Shuf = DAG.getVectorShuffle(VT: ShufVT, dl: DL, N1: DestVecBC, N2: PaddedSubV, Mask);
23860 AddToWorklist(N: PaddedSubV.getNode());
23861 AddToWorklist(N: DestVecBC.getNode());
23862 AddToWorklist(N: Shuf.getNode());
23863 return DAG.getBitcast(VT, V: Shuf);
23864}
23865
23866// Combine insert(shuffle(load, <u,0,1,2>), load, 0) into a single load if
23867// possible and the new load will be quick. We use more loads but less shuffles
23868// and inserts.
23869SDValue DAGCombiner::combineInsertEltToLoad(SDNode *N, unsigned InsIndex) {
23870 EVT VT = N->getValueType(ResNo: 0);
23871
23872 // InsIndex is expected to be the first of last lane.
23873 if (!VT.isFixedLengthVector() ||
23874 (InsIndex != 0 && InsIndex != VT.getVectorNumElements() - 1))
23875 return SDValue();
23876
23877 // Look for a shuffle with the mask u,0,1,2,3,4,5,6 or 1,2,3,4,5,6,7,u
23878 // depending on the InsIndex.
23879 auto *Shuffle = dyn_cast<ShuffleVectorSDNode>(Val: N->getOperand(Num: 0));
23880 SDValue Scalar = N->getOperand(Num: 1);
23881 if (!Shuffle || !all_of(Range: enumerate(First: Shuffle->getMask()), P: [&](auto P) {
23882 return InsIndex == P.index() || P.value() < 0 ||
23883 (InsIndex == 0 && P.value() == (int)P.index() - 1) ||
23884 (InsIndex == VT.getVectorNumElements() - 1 &&
23885 P.value() == (int)P.index() + 1);
23886 }))
23887 return SDValue();
23888
23889 // We optionally skip over an extend so long as both loads are extended in the
23890 // same way from the same type.
23891 unsigned Extend = 0;
23892 if (Scalar.getOpcode() == ISD::ZERO_EXTEND ||
23893 Scalar.getOpcode() == ISD::SIGN_EXTEND ||
23894 Scalar.getOpcode() == ISD::ANY_EXTEND) {
23895 Extend = Scalar.getOpcode();
23896 Scalar = Scalar.getOperand(i: 0);
23897 }
23898
23899 auto *ScalarLoad = dyn_cast<LoadSDNode>(Val&: Scalar);
23900 if (!ScalarLoad)
23901 return SDValue();
23902
23903 SDValue Vec = Shuffle->getOperand(Num: 0);
23904 if (Extend) {
23905 if (Vec.getOpcode() != Extend)
23906 return SDValue();
23907 Vec = Vec.getOperand(i: 0);
23908 }
23909 auto *VecLoad = dyn_cast<LoadSDNode>(Val&: Vec);
23910 if (!VecLoad || Vec.getValueType().getScalarType() != Scalar.getValueType())
23911 return SDValue();
23912
23913 int EltSize = ScalarLoad->getValueType(ResNo: 0).getScalarSizeInBits();
23914 if (EltSize == 0 || EltSize % 8 != 0 || !ScalarLoad->isSimple() ||
23915 !VecLoad->isSimple() || VecLoad->getExtensionType() != ISD::NON_EXTLOAD ||
23916 ScalarLoad->getExtensionType() != ISD::NON_EXTLOAD ||
23917 ScalarLoad->getAddressSpace() != VecLoad->getAddressSpace())
23918 return SDValue();
23919
23920 // Check that the offset between the pointers to produce a single continuous
23921 // load.
23922 if (InsIndex == 0) {
23923 if (!DAG.areNonVolatileConsecutiveLoads(LD: ScalarLoad, Base: VecLoad, Bytes: EltSize / 8,
23924 Dist: -1))
23925 return SDValue();
23926 } else {
23927 if (!DAG.areNonVolatileConsecutiveLoads(
23928 LD: VecLoad, Base: ScalarLoad, Bytes: VT.getVectorNumElements() * EltSize / 8, Dist: -1))
23929 return SDValue();
23930 }
23931
23932 // And that the new unaligned load will be fast.
23933 unsigned IsFast = 0;
23934 Align NewAlign = commonAlignment(A: VecLoad->getAlign(), Offset: EltSize / 8);
23935 if (!TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(),
23936 VT: Vec.getValueType(), AddrSpace: VecLoad->getAddressSpace(),
23937 Alignment: NewAlign, Flags: VecLoad->getMemOperand()->getFlags(),
23938 Fast: &IsFast) ||
23939 !IsFast)
23940 return SDValue();
23941
23942 // Calculate the new Ptr and create the new load.
23943 SDLoc DL(N);
23944 SDValue Ptr = ScalarLoad->getBasePtr();
23945 if (InsIndex != 0)
23946 Ptr = DAG.getNode(Opcode: ISD::ADD, DL, VT: Ptr.getValueType(), N1: VecLoad->getBasePtr(),
23947 N2: DAG.getConstant(Val: EltSize / 8, DL, VT: Ptr.getValueType()));
23948 MachinePointerInfo PtrInfo =
23949 InsIndex == 0 ? ScalarLoad->getPointerInfo()
23950 : VecLoad->getPointerInfo().getWithOffset(O: EltSize / 8);
23951
23952 SDValue Load = DAG.getLoad(VT: VecLoad->getValueType(ResNo: 0), dl: DL,
23953 Chain: ScalarLoad->getChain(), Ptr, PtrInfo, Alignment: NewAlign);
23954 DAG.makeEquivalentMemoryOrdering(OldLoad: ScalarLoad, NewMemOp: Load.getValue(R: 1));
23955 DAG.makeEquivalentMemoryOrdering(OldLoad: VecLoad, NewMemOp: Load.getValue(R: 1));
23956 return Extend ? DAG.getNode(Opcode: Extend, DL, VT, Operand: Load) : Load;
23957}
23958
23959SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
23960 SDValue InVec = N->getOperand(Num: 0);
23961 SDValue InVal = N->getOperand(Num: 1);
23962 SDValue EltNo = N->getOperand(Num: 2);
23963 SDLoc DL(N);
23964
23965 EVT VT = InVec.getValueType();
23966 auto *IndexC = dyn_cast<ConstantSDNode>(Val&: EltNo);
23967
23968 // Insert into out-of-bounds element is poison.
23969 if (IndexC && VT.isFixedLengthVector() &&
23970 IndexC->getZExtValue() >= VT.getVectorNumElements())
23971 return DAG.getPOISON(VT);
23972
23973 // Remove redundant insertions:
23974 // (insert_vector_elt x (extract_vector_elt x idx) idx) -> x
23975 if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
23976 InVec == InVal.getOperand(i: 0) && EltNo == InVal.getOperand(i: 1))
23977 return InVec;
23978
23979 // Remove insert of UNDEF/POISON elements.
23980 if (InVal.isUndef()) {
23981 if (InVal.getOpcode() == ISD::POISON || InVec.getOpcode() == ISD::UNDEF)
23982 return InVec;
23983 return DAG.getFreeze(V: InVec);
23984 }
23985
23986 if (!IndexC) {
23987 // If this is variable insert to undef vector, it might be better to splat:
23988 // inselt undef, InVal, EltNo --> build_vector < InVal, InVal, ... >
23989 if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT))
23990 return DAG.getSplat(VT, DL, Op: InVal);
23991
23992 // Extend this type to be byte-addressable
23993 EVT OldVT = VT;
23994 EVT EltVT = VT.getVectorElementType();
23995 bool IsByteSized = EltVT.isByteSized();
23996 if (!IsByteSized) {
23997 EltVT =
23998 EltVT.changeTypeToInteger().getRoundIntegerType(Context&: *DAG.getContext());
23999 VT = VT.changeElementType(Context&: *DAG.getContext(), EltVT);
24000 }
24001
24002 // Check if this operation will be handled the default way for its type.
24003 auto IsTypeDefaultHandled = [this](EVT VT) {
24004 return TLI.getTypeAction(Context&: *DAG.getContext(), VT) ==
24005 TargetLowering::TypeSplitVector ||
24006 TLI.isOperationExpand(Op: ISD::INSERT_VECTOR_ELT, VT);
24007 };
24008
24009 // Check if this operation is illegal and will be handled the default way,
24010 // even after extending the type to be byte-addressable.
24011 if (IsTypeDefaultHandled(OldVT) && IsTypeDefaultHandled(VT)) {
24012 // For each dynamic insertelt, the default way will save the vector to
24013 // the stack, store at an offset, and load the modified vector. This can
24014 // dramatically increase code size if we have a chain of insertelts on a
24015 // large vector: requiring O(V*C) stores/loads where V = length of
24016 // vector and C is length of chain. If each insertelt is only fed into the
24017 // next, the vector is write-only across this chain, and we can just
24018 // save once before the chain and load after in O(V + C) operations.
24019 SmallVector<SDNode *> Seq{N};
24020 unsigned NumDynamic = 1;
24021 while (true) {
24022 SDValue InVec = Seq.back()->getOperand(Num: 0);
24023 if (InVec.getOpcode() != ISD::INSERT_VECTOR_ELT)
24024 break;
24025 Seq.push_back(Elt: InVec.getNode());
24026 NumDynamic += !isa<ConstantSDNode>(Val: InVec.getOperand(i: 2));
24027 }
24028
24029 // It always and only makes sense to lower this sequence when we have more
24030 // than one dynamic insertelt, since we will not have more than V constant
24031 // insertelts, so we will be reducing the total number of stores+loads.
24032 if (NumDynamic > 1) {
24033 // In cases where the vector is illegal it will be broken down into
24034 // parts and stored in parts - we should use the alignment for the
24035 // smallest part.
24036 Align SmallestAlign = DAG.getReducedAlign(VT, /*UseABI=*/false);
24037 SDValue StackPtr =
24038 DAG.CreateStackTemporary(Bytes: VT.getStoreSize(), Alignment: SmallestAlign);
24039 auto &MF = DAG.getMachineFunction();
24040 int FrameIndex = cast<FrameIndexSDNode>(Val: StackPtr.getNode())->getIndex();
24041 auto PtrInfo = MachinePointerInfo::getFixedStack(MF, FI: FrameIndex);
24042
24043 // Save the vector to the stack
24044 SDValue InVec = Seq.back()->getOperand(Num: 0);
24045 if (!IsByteSized)
24046 InVec = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT, Operand: InVec);
24047 SDValue Store = DAG.getStore(Chain: DAG.getEntryNode(), dl: DL, Val: InVec, Ptr: StackPtr,
24048 PtrInfo, Alignment: SmallestAlign);
24049
24050 // Lower each dynamic insertelt to a store
24051 for (SDNode *N : reverse(C&: Seq)) {
24052 SDValue Elmnt = N->getOperand(Num: 1);
24053 SDValue Index = N->getOperand(Num: 2);
24054
24055 // Check if we have to extend the element type
24056 if (!IsByteSized && Elmnt.getValueType().bitsLT(VT: EltVT))
24057 Elmnt = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT: EltVT, Operand: Elmnt);
24058
24059 // Store the new element. This may be larger than the vector element
24060 // type, so use a truncating store.
24061 SDValue EltPtr =
24062 TLI.getVectorElementPointer(DAG, VecPtr: StackPtr, VecVT: VT, Index);
24063 EVT EltVT = Elmnt.getValueType();
24064 Store = DAG.getTruncStore(
24065 Chain: Store, dl: DL, Val: Elmnt, Ptr: EltPtr, PtrInfo: MachinePointerInfo::getUnknownStack(MF),
24066 SVT: EltVT,
24067 Alignment: commonAlignment(A: SmallestAlign, Offset: EltVT.getFixedSizeInBits() / 8));
24068 }
24069
24070 // Load the saved vector from the stack
24071 SDValue Load =
24072 DAG.getLoad(VT, dl: DL, Chain: Store, Ptr: StackPtr, PtrInfo, Alignment: SmallestAlign);
24073 SDValue LoadV = Load.getValue(R: 0);
24074 return IsByteSized ? LoadV : DAG.getAnyExtOrTrunc(Op: LoadV, DL, VT: OldVT);
24075 }
24076 }
24077
24078 return SDValue();
24079 }
24080
24081 if (VT.isScalableVector())
24082 return SDValue();
24083
24084 unsigned NumElts = VT.getVectorNumElements();
24085
24086 // We must know which element is being inserted for folds below here.
24087 unsigned Elt = IndexC->getZExtValue();
24088
24089 // Handle <1 x ???> vector insertion special cases.
24090 if (NumElts == 1) {
24091 // insert_vector_elt(x, extract_vector_elt(y, 0), 0) -> y
24092 if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
24093 InVal.getOperand(i: 0).getValueType() == VT &&
24094 isNullConstant(V: InVal.getOperand(i: 1)))
24095 return InVal.getOperand(i: 0);
24096 }
24097
24098 // Canonicalize insert_vector_elt dag nodes.
24099 // Example:
24100 // (insert_vector_elt (insert_vector_elt A, Idx0), Idx1)
24101 // -> (insert_vector_elt (insert_vector_elt A, Idx1), Idx0)
24102 //
24103 // Do this only if the child insert_vector node has one use; also
24104 // do this only if indices are both constants and Idx1 < Idx0.
24105 if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT && InVec.hasOneUse()
24106 && isa<ConstantSDNode>(Val: InVec.getOperand(i: 2))) {
24107 unsigned OtherElt = InVec.getConstantOperandVal(i: 2);
24108 if (Elt < OtherElt) {
24109 // Swap nodes.
24110 SDValue NewOp = DAG.getNode(Opcode: ISD::INSERT_VECTOR_ELT, DL, VT,
24111 N1: InVec.getOperand(i: 0), N2: InVal, N3: EltNo);
24112 AddToWorklist(N: NewOp.getNode());
24113 return DAG.getNode(Opcode: ISD::INSERT_VECTOR_ELT, DL: SDLoc(InVec.getNode()),
24114 VT, N1: NewOp, N2: InVec.getOperand(i: 1), N3: InVec.getOperand(i: 2));
24115 }
24116 }
24117
24118 if (SDValue Shuf = mergeInsertEltWithShuffle(N, InsIndex: Elt))
24119 return Shuf;
24120
24121 if (SDValue Shuf = combineInsertEltToShuffle(N, InsIndex: Elt))
24122 return Shuf;
24123
24124 if (SDValue Shuf = combineInsertEltToLoad(N, InsIndex: Elt))
24125 return Shuf;
24126
24127 // Attempt to convert an insert_vector_elt chain into a legal build_vector.
24128 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::BUILD_VECTOR, VT)) {
24129 // vXi1 vector - we don't need to recurse.
24130 if (NumElts == 1)
24131 return DAG.getBuildVector(VT, DL, Ops: {InVal});
24132
24133 // If we haven't already collected the element, insert into the op list.
24134 EVT MaxEltVT = InVal.getValueType();
24135 auto AddBuildVectorOp = [&](SmallVectorImpl<SDValue> &Ops, SDValue Elt,
24136 unsigned Idx) {
24137 if (!Ops[Idx]) {
24138 Ops[Idx] = Elt;
24139 if (VT.isInteger()) {
24140 EVT EltVT = Elt.getValueType();
24141 MaxEltVT = MaxEltVT.bitsGE(VT: EltVT) ? MaxEltVT : EltVT;
24142 }
24143 }
24144 };
24145
24146 // Ensure all the operands are the same value type, fill any missing
24147 // operands with UNDEF and create the BUILD_VECTOR.
24148 auto CanonicalizeBuildVector = [&](SmallVectorImpl<SDValue> &Ops,
24149 bool FreezeUndef = false) {
24150 assert(Ops.size() == NumElts && "Unexpected vector size");
24151 SDValue UndefOp = FreezeUndef ? DAG.getFreeze(V: DAG.getUNDEF(VT: MaxEltVT))
24152 : DAG.getUNDEF(VT: MaxEltVT);
24153 for (SDValue &Op : Ops) {
24154 if (Op)
24155 Op = VT.isInteger() ? DAG.getAnyExtOrTrunc(Op, DL, VT: MaxEltVT) : Op;
24156 else
24157 Op = UndefOp;
24158 }
24159 return DAG.getBuildVector(VT, DL, Ops);
24160 };
24161
24162 SmallVector<SDValue, 8> Ops(NumElts, SDValue());
24163 Ops[Elt] = InVal;
24164
24165 // Recurse up a INSERT_VECTOR_ELT chain to build a BUILD_VECTOR.
24166 for (SDValue CurVec = InVec; CurVec;) {
24167 // UNDEF - build new BUILD_VECTOR from already inserted operands.
24168 if (CurVec.isUndef())
24169 return CanonicalizeBuildVector(Ops);
24170
24171 // FREEZE(UNDEF) - build new BUILD_VECTOR from already inserted operands.
24172 if (ISD::isFreezeUndef(N: CurVec.getNode()) && CurVec.hasOneUse())
24173 return CanonicalizeBuildVector(Ops, /*FreezeUndef=*/true);
24174
24175 // BUILD_VECTOR - insert unused operands and build new BUILD_VECTOR.
24176 if (CurVec.getOpcode() == ISD::BUILD_VECTOR && CurVec.hasOneUse()) {
24177 for (unsigned I = 0; I != NumElts; ++I)
24178 AddBuildVectorOp(Ops, CurVec.getOperand(i: I), I);
24179 return CanonicalizeBuildVector(Ops);
24180 }
24181
24182 // SCALAR_TO_VECTOR - insert unused scalar and build new BUILD_VECTOR.
24183 if (CurVec.getOpcode() == ISD::SCALAR_TO_VECTOR && CurVec.hasOneUse()) {
24184 AddBuildVectorOp(Ops, CurVec.getOperand(i: 0), 0);
24185 return CanonicalizeBuildVector(Ops);
24186 }
24187
24188 // INSERT_VECTOR_ELT - insert operand and continue up the chain.
24189 if (CurVec.getOpcode() == ISD::INSERT_VECTOR_ELT && CurVec.hasOneUse())
24190 if (auto *CurIdx = dyn_cast<ConstantSDNode>(Val: CurVec.getOperand(i: 2)))
24191 if (CurIdx->getAPIntValue().ult(RHS: NumElts)) {
24192 unsigned Idx = CurIdx->getZExtValue();
24193 AddBuildVectorOp(Ops, CurVec.getOperand(i: 1), Idx);
24194
24195 // Found entire BUILD_VECTOR.
24196 if (all_of(Range&: Ops, P: [](SDValue Op) { return !!Op; }))
24197 return CanonicalizeBuildVector(Ops);
24198
24199 CurVec = CurVec->getOperand(Num: 0);
24200 continue;
24201 }
24202
24203 // VECTOR_SHUFFLE - if all the operands match the shuffle's sources,
24204 // update the shuffle mask (and second operand if we started with unary
24205 // shuffle) and create a new legal shuffle.
24206 if (CurVec.getOpcode() == ISD::VECTOR_SHUFFLE && CurVec.hasOneUse()) {
24207 auto *SVN = cast<ShuffleVectorSDNode>(Val&: CurVec);
24208 SDValue LHS = SVN->getOperand(Num: 0);
24209 SDValue RHS = SVN->getOperand(Num: 1);
24210 SmallVector<int, 16> Mask(SVN->getMask());
24211 bool Merged = true;
24212 for (auto I : enumerate(First&: Ops)) {
24213 SDValue &Op = I.value();
24214 if (Op) {
24215 SmallVector<int, 16> NewMask;
24216 if (!mergeEltWithShuffle(X&: LHS, Y&: RHS, Mask, NewMask, Elt: Op, InsIndex: I.index())) {
24217 Merged = false;
24218 break;
24219 }
24220 Mask = std::move(NewMask);
24221 }
24222 }
24223 if (Merged)
24224 if (SDValue NewShuffle =
24225 TLI.buildLegalVectorShuffle(VT, DL, N0: LHS, N1: RHS, Mask, DAG))
24226 return NewShuffle;
24227 }
24228
24229 if (!LegalOperations) {
24230 bool IsNull = llvm::isNullConstant(V: InVal);
24231 // We can convert to AND/OR mask if all insertions are zero or -1
24232 // respectively.
24233 if ((IsNull || llvm::isAllOnesConstant(V: InVal)) &&
24234 all_of(Range&: Ops, P: [InVal](SDValue Op) { return !Op || Op == InVal; }) &&
24235 count_if(Range&: Ops, P: [InVal](SDValue Op) { return Op == InVal; }) >= 2) {
24236 SDValue Zero = DAG.getConstant(Val: 0, DL, VT: MaxEltVT);
24237 SDValue AllOnes = DAG.getAllOnesConstant(DL, VT: MaxEltVT);
24238 SmallVector<SDValue, 8> Mask(NumElts);
24239
24240 // Build the mask and return the corresponding DAG node.
24241 auto BuildMaskAndNode = [&](SDValue TrueVal, SDValue FalseVal,
24242 unsigned MaskOpcode) {
24243 for (unsigned I = 0; I != NumElts; ++I)
24244 Mask[I] = Ops[I] ? TrueVal : FalseVal;
24245 return DAG.getNode(Opcode: MaskOpcode, DL, VT, N1: CurVec,
24246 N2: DAG.getBuildVector(VT, DL, Ops: Mask));
24247 };
24248
24249 // If all elements are zero, we can use AND with all ones.
24250 if (IsNull)
24251 return BuildMaskAndNode(Zero, AllOnes, ISD::AND);
24252
24253 // If all elements are -1, we can use OR with zero.
24254 return BuildMaskAndNode(AllOnes, Zero, ISD::OR);
24255 }
24256 }
24257
24258 // Failed to find a match in the chain - bail.
24259 break;
24260 }
24261
24262 // See if we can fill in the missing constant elements as zeros.
24263 // TODO: Should we do this for any constant?
24264 APInt DemandedZeroElts = APInt::getZero(numBits: NumElts);
24265 for (unsigned I = 0; I != NumElts; ++I)
24266 if (!Ops[I])
24267 DemandedZeroElts.setBit(I);
24268
24269 if (DAG.MaskedVectorIsZero(Op: InVec, DemandedElts: DemandedZeroElts)) {
24270 SDValue Zero = VT.isInteger() ? DAG.getConstant(Val: 0, DL, VT: MaxEltVT)
24271 : DAG.getConstantFP(Val: 0, DL, VT: MaxEltVT);
24272 for (unsigned I = 0; I != NumElts; ++I)
24273 if (!Ops[I])
24274 Ops[I] = Zero;
24275
24276 return CanonicalizeBuildVector(Ops);
24277 }
24278 }
24279
24280 return SDValue();
24281}
24282
24283/// Transform a vector binary operation into a scalar binary operation by moving
24284/// the math/logic after an extract element of a vector.
24285static SDValue scalarizeExtractedBinOp(SDNode *ExtElt, SelectionDAG &DAG,
24286 const SDLoc &DL, bool LegalTypes) {
24287 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24288 SDValue Vec = ExtElt->getOperand(Num: 0);
24289 SDValue Index = ExtElt->getOperand(Num: 1);
24290 auto *IndexC = dyn_cast<ConstantSDNode>(Val&: Index);
24291 unsigned Opc = Vec.getOpcode();
24292 if (!IndexC || !Vec.hasOneUse() || (!TLI.isBinOp(Opcode: Opc) && Opc != ISD::SETCC) ||
24293 Vec->getNumValues() != 1)
24294 return SDValue();
24295
24296 // Targets may want to avoid this to prevent an expensive register transfer.
24297 if (!TLI.shouldScalarizeBinop(VecOp: Vec))
24298 return SDValue();
24299
24300 EVT ResVT = ExtElt->getValueType(ResNo: 0);
24301 if (Opc == ISD::SETCC &&
24302 (ResVT != Vec.getValueType().getVectorElementType() || LegalTypes))
24303 return SDValue();
24304
24305 // Extracting an element of a vector constant is constant-folded, so this
24306 // transform is just replacing a vector op with a scalar op while moving the
24307 // extract.
24308 SDValue Op0 = Vec.getOperand(i: 0);
24309 SDValue Op1 = Vec.getOperand(i: 1);
24310 APInt SplatVal;
24311 if (!isAnyConstantBuildVector(V: Op0, NoOpaques: true) &&
24312 !ISD::isConstantSplatVector(N: Op0.getNode(), SplatValue&: SplatVal) &&
24313 !isAnyConstantBuildVector(V: Op1, NoOpaques: true) &&
24314 !ISD::isConstantSplatVector(N: Op1.getNode(), SplatValue&: SplatVal))
24315 return SDValue();
24316
24317 // extractelt (op X, C), IndexC --> op (extractelt X, IndexC), C'
24318 // extractelt (op C, X), IndexC --> op C', (extractelt X, IndexC)
24319 if (Opc == ISD::SETCC) {
24320 EVT OpVT = Op0.getValueType().getVectorElementType();
24321 Op0 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: OpVT, N1: Op0, N2: Index);
24322 Op1 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: OpVT, N1: Op1, N2: Index);
24323 SDValue NewVal = DAG.getSetCC(
24324 DL, VT: ResVT, LHS: Op0, RHS: Op1, Cond: cast<CondCodeSDNode>(Val: Vec->getOperand(Num: 2))->get());
24325 // We may need to sign- or zero-extend the result to match the same
24326 // behaviour as the vector version of SETCC.
24327 unsigned VecBoolContents = TLI.getBooleanContents(Type: Vec.getValueType());
24328 if (ResVT != MVT::i1 &&
24329 VecBoolContents != TargetLowering::UndefinedBooleanContent &&
24330 VecBoolContents != TLI.getBooleanContents(Type: ResVT)) {
24331 if (VecBoolContents == TargetLowering::ZeroOrNegativeOneBooleanContent)
24332 NewVal = DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL, VT: ResVT, N1: NewVal,
24333 N2: DAG.getValueType(MVT::i1));
24334 else
24335 NewVal = DAG.getZeroExtendInReg(Op: NewVal, DL, VT: MVT::i1);
24336 }
24337 return NewVal;
24338 }
24339 Op0 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: ResVT, N1: Op0, N2: Index);
24340 Op1 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: ResVT, N1: Op1, N2: Index);
24341 return DAG.getNode(Opcode: Opc, DL, VT: ResVT, N1: Op0, N2: Op1);
24342}
24343
24344// Given a ISD::EXTRACT_VECTOR_ELT, which is a glorified bit sequence extract,
24345// recursively analyse all of it's users. and try to model themselves as
24346// bit sequence extractions. If all of them agree on the new, narrower element
24347// type, and all of them can be modelled as ISD::EXTRACT_VECTOR_ELT's of that
24348// new element type, do so now.
24349// This is mainly useful to recover from legalization that scalarized
24350// the vector as wide elements, but tries to rebuild it with narrower elements.
24351//
24352// Some more nodes could be modelled if that helps cover interesting patterns.
24353bool DAGCombiner::refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(
24354 SDNode *N) {
24355 // We perform this optimization post type-legalization because
24356 // the type-legalizer often scalarizes integer-promoted vectors.
24357 // Performing this optimization before may cause legalizaton cycles.
24358 if (Level != AfterLegalizeVectorOps && Level != AfterLegalizeTypes)
24359 return false;
24360
24361 // TODO: Add support for big-endian.
24362 if (DAG.getDataLayout().isBigEndian())
24363 return false;
24364
24365 SDValue VecOp = N->getOperand(Num: 0);
24366 EVT VecVT = VecOp.getValueType();
24367 assert(!VecVT.isScalableVector() && "Only for fixed vectors.");
24368
24369 // We must start with a constant extraction index.
24370 auto *IndexC = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
24371 if (!IndexC)
24372 return false;
24373
24374 assert(IndexC->getZExtValue() < VecVT.getVectorNumElements() &&
24375 "Original ISD::EXTRACT_VECTOR_ELT is undefinend?");
24376
24377 // TODO: deal with the case of implicit anyext of the extraction.
24378 unsigned VecEltBitWidth = VecVT.getScalarSizeInBits();
24379 EVT ScalarVT = N->getValueType(ResNo: 0);
24380 if (VecVT.getScalarType() != ScalarVT)
24381 return false;
24382
24383 // TODO: deal with the cases other than everything being integer-typed.
24384 if (!ScalarVT.isScalarInteger())
24385 return false;
24386
24387 struct Entry {
24388 SDNode *Producer;
24389
24390 // Which bits of VecOp does it contain?
24391 unsigned BitPos;
24392 int NumBits;
24393 // NOTE: the actual width of \p Producer may be wider than NumBits!
24394
24395 Entry(Entry &&) = default;
24396 Entry(SDNode *Producer_, unsigned BitPos_, int NumBits_)
24397 : Producer(Producer_), BitPos(BitPos_), NumBits(NumBits_) {}
24398
24399 Entry() = delete;
24400 Entry(const Entry &) = delete;
24401 Entry &operator=(const Entry &) = delete;
24402 Entry &operator=(Entry &&) = delete;
24403 };
24404 SmallVector<Entry, 32> Worklist;
24405 SmallVector<Entry, 32> Leafs;
24406
24407 // We start at the "root" ISD::EXTRACT_VECTOR_ELT.
24408 Worklist.emplace_back(Args&: N, /*BitPos=*/Args: VecEltBitWidth * IndexC->getZExtValue(),
24409 /*NumBits=*/Args&: VecEltBitWidth);
24410
24411 while (!Worklist.empty()) {
24412 Entry E = Worklist.pop_back_val();
24413 // Does the node not even use any of the VecOp bits?
24414 if (!(E.NumBits > 0 && E.BitPos < VecVT.getSizeInBits() &&
24415 E.BitPos + E.NumBits <= VecVT.getSizeInBits()))
24416 return false; // Let's allow the other combines clean this up first.
24417 // Did we fail to model any of the users of the Producer?
24418 bool ProducerIsLeaf = false;
24419 // Look at each user of this Producer.
24420 for (SDNode *User : E.Producer->users()) {
24421 switch (User->getOpcode()) {
24422 // TODO: support ISD::BITCAST
24423 // TODO: support ISD::ANY_EXTEND
24424 // TODO: support ISD::ZERO_EXTEND
24425 // TODO: support ISD::SIGN_EXTEND
24426 case ISD::TRUNCATE:
24427 // Truncation simply means we keep position, but extract less bits.
24428 Worklist.emplace_back(Args&: User, Args&: E.BitPos,
24429 /*NumBits=*/Args: User->getValueSizeInBits(ResNo: 0));
24430 break;
24431 // TODO: support ISD::SRA
24432 // TODO: support ISD::SHL
24433 case ISD::SRL:
24434 // We should be shifting the Producer by a constant amount.
24435 if (auto *ShAmtC = dyn_cast<ConstantSDNode>(Val: User->getOperand(Num: 1));
24436 User->getOperand(Num: 0).getNode() == E.Producer && ShAmtC) {
24437 // Logical right-shift means that we start extraction later,
24438 // but stop it at the same position we did previously.
24439 unsigned ShAmt = ShAmtC->getZExtValue();
24440 Worklist.emplace_back(Args&: User, Args: E.BitPos + ShAmt, Args: E.NumBits - ShAmt);
24441 break;
24442 }
24443 [[fallthrough]];
24444 default:
24445 // We can not model this user of the Producer.
24446 // Which means the current Producer will be a ISD::EXTRACT_VECTOR_ELT.
24447 ProducerIsLeaf = true;
24448 // Profitability check: all users that we can not model
24449 // must be ISD::BUILD_VECTOR's.
24450 if (User->getOpcode() != ISD::BUILD_VECTOR)
24451 return false;
24452 break;
24453 }
24454 }
24455 if (ProducerIsLeaf)
24456 Leafs.emplace_back(Args: std::move(E));
24457 }
24458
24459 unsigned NewVecEltBitWidth = Leafs.front().NumBits;
24460
24461 // If we are still at the same element granularity, give up,
24462 if (NewVecEltBitWidth == VecEltBitWidth)
24463 return false;
24464
24465 // The vector width must be a multiple of the new element width.
24466 if (VecVT.getSizeInBits() % NewVecEltBitWidth != 0)
24467 return false;
24468
24469 // All leafs must agree on the new element width.
24470 // All leafs must not expect any "padding" bits ontop of that width.
24471 // All leafs must start extraction from multiple of that width.
24472 if (!all_of(Range&: Leafs, P: [NewVecEltBitWidth](const Entry &E) {
24473 return (unsigned)E.NumBits == NewVecEltBitWidth &&
24474 E.Producer->getValueSizeInBits(ResNo: 0) == NewVecEltBitWidth &&
24475 E.BitPos % NewVecEltBitWidth == 0;
24476 }))
24477 return false;
24478
24479 EVT NewScalarVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: NewVecEltBitWidth);
24480 EVT NewVecVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: NewScalarVT,
24481 NumElements: VecVT.getSizeInBits() / NewVecEltBitWidth);
24482
24483 if (LegalTypes &&
24484 !(TLI.isTypeLegal(VT: NewScalarVT) && TLI.isTypeLegal(VT: NewVecVT)))
24485 return false;
24486
24487 if (LegalOperations &&
24488 !(TLI.isOperationLegalOrCustom(Op: ISD::BITCAST, VT: NewVecVT) &&
24489 TLI.isOperationLegalOrCustom(Op: ISD::EXTRACT_VECTOR_ELT, VT: NewVecVT)))
24490 return false;
24491
24492 SDValue NewVecOp = DAG.getBitcast(VT: NewVecVT, V: VecOp);
24493 for (const Entry &E : Leafs) {
24494 SDLoc DL(E.Producer);
24495 unsigned NewIndex = E.BitPos / NewVecEltBitWidth;
24496 assert(NewIndex < NewVecVT.getVectorNumElements() &&
24497 "Creating out-of-bounds ISD::EXTRACT_VECTOR_ELT?");
24498 SDValue V = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: NewScalarVT, N1: NewVecOp,
24499 N2: DAG.getVectorIdxConstant(Val: NewIndex, DL));
24500 CombineTo(N: E.Producer, Res: V);
24501 }
24502
24503 return true;
24504}
24505
24506SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
24507 SDValue VecOp = N->getOperand(Num: 0);
24508 SDValue Index = N->getOperand(Num: 1);
24509 EVT ScalarVT = N->getValueType(ResNo: 0);
24510 EVT VecVT = VecOp.getValueType();
24511 if (VecOp.isUndef())
24512 return DAG.getUNDEF(VT: ScalarVT);
24513
24514 // extract_vector_elt (insert_vector_elt vec, val, idx), idx) -> val
24515 //
24516 // This only really matters if the index is non-constant since other combines
24517 // on the constant elements already work.
24518 SDLoc DL(N);
24519 if (VecOp.getOpcode() == ISD::INSERT_VECTOR_ELT &&
24520 Index == VecOp.getOperand(i: 2)) {
24521 SDValue Elt = VecOp.getOperand(i: 1);
24522 AddUsersToWorklist(N: VecOp.getNode());
24523 return VecVT.isInteger() ? DAG.getAnyExtOrTrunc(Op: Elt, DL, VT: ScalarVT) : Elt;
24524 }
24525
24526 // (vextract (scalar_to_vector val, 0) -> val
24527 if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR) {
24528 // Only 0'th element of SCALAR_TO_VECTOR is defined.
24529 if (DAG.isKnownNeverZero(Op: Index))
24530 return DAG.getPOISON(VT: ScalarVT);
24531
24532 // Check if the result type doesn't match the inserted element type.
24533 // The inserted element and extracted element may have mismatched bitwidth.
24534 // As a result, EXTRACT_VECTOR_ELT may extend or truncate the extracted vector.
24535 SDValue InOp = VecOp.getOperand(i: 0);
24536 if (InOp.getValueType() != ScalarVT) {
24537 assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
24538 if (InOp.getValueType().bitsGT(VT: ScalarVT))
24539 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: ScalarVT, Operand: InOp);
24540 return DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT: ScalarVT, Operand: InOp);
24541 }
24542 return InOp;
24543 }
24544
24545 // extract_vector_elt of out-of-bounds element -> UNDEF
24546 auto *IndexC = dyn_cast<ConstantSDNode>(Val&: Index);
24547 if (IndexC && VecVT.isFixedLengthVector() &&
24548 IndexC->getAPIntValue().uge(RHS: VecVT.getVectorNumElements()))
24549 return DAG.getUNDEF(VT: ScalarVT);
24550
24551 // extract_vector_elt (build_vector x, y), 1 -> y
24552 if (((IndexC && VecOp.getOpcode() == ISD::BUILD_VECTOR) ||
24553 VecOp.getOpcode() == ISD::SPLAT_VECTOR) &&
24554 TLI.isTypeLegal(VT: VecVT)) {
24555 assert((VecOp.getOpcode() != ISD::BUILD_VECTOR ||
24556 VecVT.isFixedLengthVector()) &&
24557 "BUILD_VECTOR used for scalable vectors");
24558 unsigned IndexVal =
24559 VecOp.getOpcode() == ISD::BUILD_VECTOR ? IndexC->getZExtValue() : 0;
24560 SDValue Elt = VecOp.getOperand(i: IndexVal);
24561 EVT InEltVT = Elt.getValueType();
24562
24563 if (VecOp.hasOneUse() || TLI.aggressivelyPreferBuildVectorSources(VecVT) ||
24564 isNullConstant(V: Elt)) {
24565 // Sometimes build_vector's scalar input types do not match result type.
24566 if (ScalarVT == InEltVT)
24567 return Elt;
24568
24569 // TODO: It may be useful to truncate if free if the build_vector
24570 // implicitly converts.
24571 }
24572 }
24573
24574 if (SDValue BO = scalarizeExtractedBinOp(ExtElt: N, DAG, DL, LegalTypes))
24575 return BO;
24576
24577 if (VecVT.isScalableVector())
24578 return SDValue();
24579
24580 // All the code from this point onwards assumes fixed width vectors, but it's
24581 // possible that some of the combinations could be made to work for scalable
24582 // vectors too.
24583 unsigned NumElts = VecVT.getVectorNumElements();
24584 unsigned VecEltBitWidth = VecVT.getScalarSizeInBits();
24585
24586 // See if the extracted element is constant, in which case fold it if its
24587 // a legal fp immediate.
24588 if (IndexC && ScalarVT.isFloatingPoint()) {
24589 APInt EltMask = APInt::getOneBitSet(numBits: NumElts, BitNo: IndexC->getZExtValue());
24590 KnownBits KnownElt = DAG.computeKnownBits(Op: VecOp, DemandedElts: EltMask);
24591 if (KnownElt.isConstant()) {
24592 APFloat CstFP =
24593 APFloat(ScalarVT.getFltSemantics(), KnownElt.getConstant());
24594 if (TLI.isFPImmLegal(CstFP, ScalarVT))
24595 return DAG.getConstantFP(Val: CstFP, DL, VT: ScalarVT);
24596 }
24597 }
24598
24599 // TODO: These transforms should not require the 'hasOneUse' restriction, but
24600 // there are regressions on multiple targets without it. We can end up with a
24601 // mess of scalar and vector code if we reduce only part of the DAG to scalar.
24602 if (IndexC && VecOp.getOpcode() == ISD::BITCAST && VecVT.isInteger() &&
24603 VecOp.hasOneUse()) {
24604 // The vector index of the LSBs of the source depend on the endian-ness.
24605 bool IsLE = DAG.getDataLayout().isLittleEndian();
24606 unsigned ExtractIndex = IndexC->getZExtValue();
24607 // extract_elt (v2i32 (bitcast i64:x)), BCTruncElt -> i32 (trunc i64:x)
24608 unsigned BCTruncElt = IsLE ? 0 : NumElts - 1;
24609 SDValue BCSrc = VecOp.getOperand(i: 0);
24610 if (ExtractIndex == BCTruncElt && BCSrc.getValueType().isScalarInteger())
24611 return DAG.getAnyExtOrTrunc(Op: BCSrc, DL, VT: ScalarVT);
24612
24613 // TODO: Add support for SCALAR_TO_VECTOR implicit truncation.
24614 if (LegalTypes && BCSrc.getValueType().isInteger() &&
24615 BCSrc.getOpcode() == ISD::SCALAR_TO_VECTOR &&
24616 BCSrc.getScalarValueSizeInBits() ==
24617 BCSrc.getOperand(i: 0).getScalarValueSizeInBits()) {
24618 // ext_elt (bitcast (scalar_to_vec i64 X to v2i64) to v4i32), TruncElt -->
24619 // trunc i64 X to i32
24620 SDValue X = BCSrc.getOperand(i: 0);
24621 EVT XVT = X.getValueType();
24622 assert(XVT.isScalarInteger() && ScalarVT.isScalarInteger() &&
24623 "Extract element and scalar to vector can't change element type "
24624 "from FP to integer.");
24625 unsigned XBitWidth = X.getValueSizeInBits();
24626 unsigned Scale = XBitWidth / VecEltBitWidth;
24627 BCTruncElt = IsLE ? 0 : Scale - 1;
24628
24629 // An extract element return value type can be wider than its vector
24630 // operand element type. In that case, the high bits are undefined, so
24631 // it's possible that we may need to extend rather than truncate.
24632 if (ExtractIndex < Scale && XBitWidth > VecEltBitWidth) {
24633 assert(XBitWidth % VecEltBitWidth == 0 &&
24634 "Scalar bitwidth must be a multiple of vector element bitwidth");
24635
24636 if (ExtractIndex != BCTruncElt) {
24637 unsigned ShiftIndex =
24638 IsLE ? ExtractIndex : (Scale - 1) - ExtractIndex;
24639 X = DAG.getNode(
24640 Opcode: ISD::SRL, DL, VT: XVT, N1: X,
24641 N2: DAG.getShiftAmountConstant(Val: ShiftIndex * VecEltBitWidth, VT: XVT, DL));
24642 }
24643
24644 return DAG.getAnyExtOrTrunc(Op: X, DL, VT: ScalarVT);
24645 }
24646 }
24647 }
24648
24649 // Transform: (EXTRACT_VECTOR_ELT( VECTOR_SHUFFLE )) -> EXTRACT_VECTOR_ELT.
24650 // We only perform this optimization before the op legalization phase because
24651 // we may introduce new vector instructions which are not backed by TD
24652 // patterns. For example on AVX, extracting elements from a wide vector
24653 // without using extract_subvector. However, if we can find an underlying
24654 // scalar value, then we can always use that.
24655 if (IndexC && VecOp.getOpcode() == ISD::VECTOR_SHUFFLE) {
24656 auto *Shuf = cast<ShuffleVectorSDNode>(Val&: VecOp);
24657 // Find the new index to extract from.
24658 int OrigElt = Shuf->getMaskElt(Idx: IndexC->getZExtValue());
24659
24660 // Extracting an undef index is undef.
24661 if (OrigElt == -1)
24662 return DAG.getUNDEF(VT: ScalarVT);
24663
24664 // Select the right vector half to extract from.
24665 SDValue SVInVec;
24666 if (OrigElt < (int)NumElts) {
24667 SVInVec = VecOp.getOperand(i: 0);
24668 } else {
24669 SVInVec = VecOp.getOperand(i: 1);
24670 OrigElt -= NumElts;
24671 }
24672
24673 if (SVInVec.getOpcode() == ISD::BUILD_VECTOR) {
24674 // TODO: Check if shuffle mask is legal?
24675 if (LegalOperations && TLI.isOperationLegal(Op: ISD::VECTOR_SHUFFLE, VT: VecVT) &&
24676 !VecOp.hasOneUse())
24677 return SDValue();
24678
24679 SDValue InOp = SVInVec.getOperand(i: OrigElt);
24680 if (InOp.getValueType() != ScalarVT) {
24681 assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
24682 InOp = DAG.getSExtOrTrunc(Op: InOp, DL, VT: ScalarVT);
24683 }
24684
24685 return InOp;
24686 }
24687
24688 // FIXME: We should handle recursing on other vector shuffles and
24689 // scalar_to_vector here as well.
24690
24691 if (!LegalOperations ||
24692 // FIXME: Should really be just isOperationLegalOrCustom.
24693 TLI.isOperationLegal(Op: ISD::EXTRACT_VECTOR_ELT, VT: VecVT) ||
24694 TLI.isOperationExpand(Op: ISD::VECTOR_SHUFFLE, VT: VecVT)) {
24695 return DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: ScalarVT, N1: SVInVec,
24696 N2: DAG.getVectorIdxConstant(Val: OrigElt, DL));
24697 }
24698 }
24699
24700 // If only EXTRACT_VECTOR_ELT nodes use the source vector we can
24701 // simplify it based on the (valid) extraction indices.
24702 if (llvm::all_of(Range: VecOp->users(), P: [&](SDNode *Use) {
24703 return Use->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
24704 Use->getOperand(Num: 0) == VecOp &&
24705 isa<ConstantSDNode>(Val: Use->getOperand(Num: 1));
24706 })) {
24707 APInt DemandedElts = APInt::getZero(numBits: NumElts);
24708 for (SDNode *User : VecOp->users()) {
24709 auto *CstElt = cast<ConstantSDNode>(Val: User->getOperand(Num: 1));
24710 if (CstElt->getAPIntValue().ult(RHS: NumElts))
24711 DemandedElts.setBit(CstElt->getZExtValue());
24712 }
24713 if (SimplifyDemandedVectorElts(Op: VecOp, DemandedElts, AssumeSingleUse: true)) {
24714 // We simplified the vector operand of this extract element. If this
24715 // extract is not dead, visit it again so it is folded properly.
24716 if (N->getOpcode() != ISD::DELETED_NODE)
24717 AddToWorklist(N);
24718 return SDValue(N, 0);
24719 }
24720 APInt DemandedBits = APInt::getAllOnes(numBits: VecEltBitWidth);
24721 if (SimplifyDemandedBits(Op: VecOp, DemandedBits, DemandedElts, AssumeSingleUse: true)) {
24722 // We simplified the vector operand of this extract element. If this
24723 // extract is not dead, visit it again so it is folded properly.
24724 if (N->getOpcode() != ISD::DELETED_NODE)
24725 AddToWorklist(N);
24726 return SDValue(N, 0);
24727 }
24728 }
24729
24730 if (refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(N))
24731 return SDValue(N, 0);
24732
24733 // Everything under here is trying to match an extract of a loaded value.
24734 // If the result of load has to be truncated, then it's not necessarily
24735 // profitable.
24736 bool BCNumEltsChanged = false;
24737 EVT ExtVT = VecVT.getVectorElementType();
24738 EVT LVT = ExtVT;
24739 if (ScalarVT.bitsLT(VT: LVT) && !TLI.isTruncateFree(FromVT: LVT, ToVT: ScalarVT))
24740 return SDValue();
24741
24742 if (VecOp.getOpcode() == ISD::BITCAST) {
24743 // Don't duplicate a load with other uses.
24744 if (!VecOp.hasOneUse())
24745 return SDValue();
24746
24747 EVT BCVT = VecOp.getOperand(i: 0).getValueType();
24748 if (!BCVT.isVector() || ExtVT.bitsGT(VT: BCVT.getVectorElementType()))
24749 return SDValue();
24750 if (NumElts != BCVT.getVectorNumElements())
24751 BCNumEltsChanged = true;
24752 VecOp = VecOp.getOperand(i: 0);
24753 ExtVT = BCVT.getVectorElementType();
24754 }
24755
24756 // extract (vector load $addr), i --> load $addr + i * size
24757 if (!LegalOperations && !IndexC && VecOp.hasOneUse() &&
24758 ISD::isNormalLoad(N: VecOp.getNode()) &&
24759 !Index->hasPredecessor(N: VecOp.getNode())) {
24760 auto *VecLoad = dyn_cast<LoadSDNode>(Val&: VecOp);
24761 if (VecLoad && VecLoad->isSimple()) {
24762 if (SDValue Scalarized = TLI.scalarizeExtractedVectorLoad(
24763 ResultVT: ScalarVT, DL: SDLoc(N), InVecVT: VecVT, EltNo: Index, OriginalLoad: VecLoad, DAG)) {
24764 ++OpsNarrowed;
24765 return Scalarized;
24766 }
24767 }
24768 }
24769
24770 // Perform only after legalization to ensure build_vector / vector_shuffle
24771 // optimizations have already been done.
24772 if (!LegalOperations || !IndexC)
24773 return SDValue();
24774
24775 // (vextract (v4f32 load $addr), c) -> (f32 load $addr+c*size)
24776 // (vextract (v4f32 s2v (f32 load $addr)), c) -> (f32 load $addr+c*size)
24777 // (vextract (v4f32 shuffle (load $addr), <1,u,u,u>), 0) -> (f32 load $addr)
24778 int Elt = IndexC->getZExtValue();
24779 LoadSDNode *LN0 = nullptr;
24780 if (ISD::isNormalLoad(N: VecOp.getNode())) {
24781 LN0 = cast<LoadSDNode>(Val&: VecOp);
24782 } else if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR &&
24783 VecOp.getOperand(i: 0).getValueType() == ExtVT &&
24784 ISD::isNormalLoad(N: VecOp.getOperand(i: 0).getNode())) {
24785 // Don't duplicate a load with other uses.
24786 if (!VecOp.hasOneUse())
24787 return SDValue();
24788
24789 LN0 = cast<LoadSDNode>(Val: VecOp.getOperand(i: 0));
24790 }
24791 if (auto *Shuf = dyn_cast<ShuffleVectorSDNode>(Val&: VecOp)) {
24792 // (vextract (vector_shuffle (load $addr), v2, <1, u, u, u>), 1)
24793 // =>
24794 // (load $addr+1*size)
24795
24796 // Don't duplicate a load with other uses.
24797 if (!VecOp.hasOneUse())
24798 return SDValue();
24799
24800 // If the bit convert changed the number of elements, it is unsafe
24801 // to examine the mask.
24802 if (BCNumEltsChanged)
24803 return SDValue();
24804
24805 // Select the input vector, guarding against out of range extract vector.
24806 int Idx = (Elt > (int)NumElts) ? -1 : Shuf->getMaskElt(Idx: Elt);
24807 VecOp = (Idx < (int)NumElts) ? VecOp.getOperand(i: 0) : VecOp.getOperand(i: 1);
24808
24809 if (VecOp.getOpcode() == ISD::BITCAST) {
24810 // Don't duplicate a load with other uses.
24811 if (!VecOp.hasOneUse())
24812 return SDValue();
24813
24814 VecOp = VecOp.getOperand(i: 0);
24815 }
24816 if (ISD::isNormalLoad(N: VecOp.getNode())) {
24817 LN0 = cast<LoadSDNode>(Val&: VecOp);
24818 Elt = (Idx < (int)NumElts) ? Idx : Idx - (int)NumElts;
24819 Index = DAG.getConstant(Val: Elt, DL, VT: Index.getValueType());
24820 }
24821 } else if (VecOp.getOpcode() == ISD::CONCAT_VECTORS && !BCNumEltsChanged &&
24822 VecVT.getVectorElementType() == ScalarVT &&
24823 (!LegalTypes ||
24824 TLI.isTypeLegal(
24825 VT: VecOp.getOperand(i: 0).getValueType().getVectorElementType()))) {
24826 // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 0
24827 // -> extract_vector_elt a, 0
24828 // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 1
24829 // -> extract_vector_elt a, 1
24830 // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 2
24831 // -> extract_vector_elt b, 0
24832 // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 3
24833 // -> extract_vector_elt b, 1
24834 EVT ConcatVT = VecOp.getOperand(i: 0).getValueType();
24835 unsigned ConcatNumElts = ConcatVT.getVectorNumElements();
24836 SDValue NewIdx = DAG.getConstant(Val: Elt % ConcatNumElts, DL,
24837 VT: Index.getValueType());
24838
24839 SDValue ConcatOp = VecOp.getOperand(i: Elt / ConcatNumElts);
24840 SDValue Elt = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL,
24841 VT: ConcatVT.getVectorElementType(),
24842 N1: ConcatOp, N2: NewIdx);
24843 return DAG.getNode(Opcode: ISD::BITCAST, DL, VT: ScalarVT, Operand: Elt);
24844 }
24845
24846 // Make sure we found a non-volatile load and the extractelement is
24847 // the only use.
24848 if (!LN0 || !LN0->hasNUsesOfValue(NUses: 1,Value: 0) || !LN0->isSimple())
24849 return SDValue();
24850
24851 // If Idx was -1 above, Elt is going to be -1, so just return undef.
24852 if (Elt == -1)
24853 return DAG.getUNDEF(VT: LVT);
24854
24855 if (SDValue Scalarized =
24856 TLI.scalarizeExtractedVectorLoad(ResultVT: LVT, DL, InVecVT: VecVT, EltNo: Index, OriginalLoad: LN0, DAG)) {
24857 ++OpsNarrowed;
24858 return Scalarized;
24859 }
24860
24861 return SDValue();
24862}
24863
24864// Simplify (build_vec (ext )) to (bitcast (build_vec ))
24865SDValue DAGCombiner::reduceBuildVecExtToExtBuildVec(SDNode *N) {
24866 // We perform this optimization post type-legalization because
24867 // the type-legalizer often scalarizes integer-promoted vectors.
24868 // Performing this optimization before may create bit-casts which
24869 // will be type-legalized to complex code sequences.
24870 // We perform this optimization only before the operation legalizer because we
24871 // may introduce illegal operations.
24872 if (Level != AfterLegalizeVectorOps && Level != AfterLegalizeTypes)
24873 return SDValue();
24874
24875 unsigned NumInScalars = N->getNumOperands();
24876 SDLoc DL(N);
24877 EVT VT = N->getValueType(ResNo: 0);
24878
24879 // Check to see if this is a BUILD_VECTOR of a bunch of values
24880 // which come from any_extend or zero_extend nodes. If so, we can create
24881 // a new BUILD_VECTOR using bit-casts which may enable other BUILD_VECTOR
24882 // optimizations. We do not handle sign-extend because we can't fill the sign
24883 // using shuffles.
24884 EVT SourceType = MVT::Other;
24885 bool AllAnyExt = true;
24886
24887 for (unsigned i = 0; i != NumInScalars; ++i) {
24888 SDValue In = N->getOperand(Num: i);
24889 // Ignore undef inputs.
24890 if (In.isUndef()) continue;
24891
24892 bool AnyExt = In.getOpcode() == ISD::ANY_EXTEND;
24893 bool ZeroExt = In.getOpcode() == ISD::ZERO_EXTEND;
24894
24895 // Abort if the element is not an extension.
24896 if (!ZeroExt && !AnyExt) {
24897 SourceType = MVT::Other;
24898 break;
24899 }
24900
24901 // The input is a ZeroExt or AnyExt. Check the original type.
24902 EVT InTy = In.getOperand(i: 0).getValueType();
24903
24904 // Check that all of the widened source types are the same.
24905 if (SourceType == MVT::Other)
24906 // First time.
24907 SourceType = InTy;
24908 else if (InTy != SourceType) {
24909 // Multiple income types. Abort.
24910 SourceType = MVT::Other;
24911 break;
24912 }
24913
24914 // Check if all of the extends are ANY_EXTENDs.
24915 AllAnyExt &= AnyExt;
24916 }
24917
24918 // In order to have valid types, all of the inputs must be extended from the
24919 // same source type and all of the inputs must be any or zero extend.
24920 // Scalar sizes must be a power of two.
24921 EVT OutScalarTy = VT.getScalarType();
24922 bool ValidTypes =
24923 SourceType != MVT::Other &&
24924 llvm::has_single_bit<uint32_t>(Value: OutScalarTy.getSizeInBits()) &&
24925 llvm::has_single_bit<uint32_t>(Value: SourceType.getSizeInBits());
24926
24927 // Create a new simpler BUILD_VECTOR sequence which other optimizations can
24928 // turn into a single shuffle instruction.
24929 if (!ValidTypes)
24930 return SDValue();
24931
24932 // If we already have a splat buildvector, then don't fold it if it means
24933 // introducing zeros.
24934 if (!AllAnyExt && DAG.isSplatValue(V: SDValue(N, 0), /*AllowUndefs*/ true))
24935 return SDValue();
24936
24937 bool isLE = DAG.getDataLayout().isLittleEndian();
24938 unsigned ElemRatio = OutScalarTy.getSizeInBits()/SourceType.getSizeInBits();
24939 assert(ElemRatio > 1 && "Invalid element size ratio");
24940 SDValue Filler = AllAnyExt ? DAG.getPOISON(VT: SourceType)
24941 : DAG.getConstant(Val: 0, DL, VT: SourceType);
24942
24943 unsigned NewBVElems = ElemRatio * VT.getVectorNumElements();
24944 SmallVector<SDValue, 8> Ops(NewBVElems, Filler);
24945
24946 // Populate the new build_vector
24947 for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
24948 SDValue Cast = N->getOperand(Num: i);
24949 assert((Cast.getOpcode() == ISD::ANY_EXTEND ||
24950 Cast.getOpcode() == ISD::ZERO_EXTEND ||
24951 Cast.isUndef()) && "Invalid cast opcode");
24952 SDValue In;
24953 if (Cast.isUndef())
24954 In = DAG.getUNDEF(VT: SourceType);
24955 else
24956 In = Cast->getOperand(Num: 0);
24957 unsigned Index = isLE ? (i * ElemRatio) :
24958 (i * ElemRatio + (ElemRatio - 1));
24959
24960 assert(Index < Ops.size() && "Invalid index");
24961 Ops[Index] = In;
24962 }
24963
24964 // The type of the new BUILD_VECTOR node.
24965 EVT VecVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: SourceType, NumElements: NewBVElems);
24966 assert(VecVT.getSizeInBits() == VT.getSizeInBits() &&
24967 "Invalid vector size");
24968 // Check if the new vector type is legal.
24969 if (!isTypeLegal(VT: VecVT) ||
24970 (!TLI.isOperationLegal(Op: ISD::BUILD_VECTOR, VT: VecVT) &&
24971 TLI.isOperationLegal(Op: ISD::BUILD_VECTOR, VT)))
24972 return SDValue();
24973
24974 // Make the new BUILD_VECTOR.
24975 SDValue BV = DAG.getBuildVector(VT: VecVT, DL, Ops);
24976
24977 // The new BUILD_VECTOR node has the potential to be further optimized.
24978 AddToWorklist(N: BV.getNode());
24979 // Bitcast to the desired type.
24980 return DAG.getBitcast(VT, V: BV);
24981}
24982
24983// Simplify (build_vec (trunc $1)
24984// (trunc (srl $1 half-width))
24985// (trunc (srl $1 (2 * half-width))))
24986// to (bitcast $1)
24987SDValue DAGCombiner::reduceBuildVecTruncToBitCast(SDNode *N) {
24988 assert(N->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
24989
24990 EVT VT = N->getValueType(ResNo: 0);
24991
24992 // Don't run this before LegalizeTypes if VT is legal.
24993 // Targets may have other preferences.
24994 if (Level < AfterLegalizeTypes && TLI.isTypeLegal(VT))
24995 return SDValue();
24996
24997 // Only for little endian
24998 if (!DAG.getDataLayout().isLittleEndian())
24999 return SDValue();
25000
25001 EVT OutScalarTy = VT.getScalarType();
25002 uint64_t ScalarTypeBitsize = OutScalarTy.getSizeInBits();
25003
25004 // Only for power of two types to be sure that bitcast works well
25005 if (!isPowerOf2_64(Value: ScalarTypeBitsize))
25006 return SDValue();
25007
25008 unsigned NumInScalars = N->getNumOperands();
25009
25010 // Look through bitcasts
25011 auto PeekThroughBitcast = [](SDValue Op) {
25012 if (Op.getOpcode() == ISD::BITCAST)
25013 return Op.getOperand(i: 0);
25014 return Op;
25015 };
25016
25017 // The source value where all the parts are extracted.
25018 SDValue Src;
25019 for (unsigned i = 0; i != NumInScalars; ++i) {
25020 SDValue In = PeekThroughBitcast(N->getOperand(Num: i));
25021 // Ignore undef inputs.
25022 if (In.isUndef()) continue;
25023
25024 if (In.getOpcode() != ISD::TRUNCATE)
25025 return SDValue();
25026
25027 In = PeekThroughBitcast(In.getOperand(i: 0));
25028
25029 if (In.getOpcode() != ISD::SRL) {
25030 // For now only build_vec without shuffling, handle shifts here in the
25031 // future.
25032 if (i != 0)
25033 return SDValue();
25034
25035 Src = In;
25036 } else {
25037 // In is SRL
25038 SDValue part = PeekThroughBitcast(In.getOperand(i: 0));
25039
25040 if (!Src) {
25041 Src = part;
25042 } else if (Src != part) {
25043 // Vector parts do not stem from the same variable
25044 return SDValue();
25045 }
25046
25047 SDValue ShiftAmtVal = In.getOperand(i: 1);
25048 if (!isa<ConstantSDNode>(Val: ShiftAmtVal))
25049 return SDValue();
25050
25051 uint64_t ShiftAmt = In.getConstantOperandVal(i: 1);
25052
25053 // The extracted value is not extracted at the right position
25054 if (ShiftAmt != i * ScalarTypeBitsize)
25055 return SDValue();
25056 }
25057 }
25058
25059 // Only cast if the size is the same
25060 if (!Src || Src.getValueType().getSizeInBits() != VT.getSizeInBits())
25061 return SDValue();
25062
25063 return DAG.getBitcast(VT, V: Src);
25064}
25065
25066SDValue DAGCombiner::createBuildVecShuffle(const SDLoc &DL, SDNode *N,
25067 ArrayRef<int> VectorMask,
25068 SDValue VecIn1, SDValue VecIn2,
25069 unsigned LeftIdx, bool DidSplitVec) {
25070 EVT VT = N->getValueType(ResNo: 0);
25071 EVT InVT1 = VecIn1.getValueType();
25072 EVT InVT2 = VecIn2.getNode() ? VecIn2.getValueType() : InVT1;
25073
25074 unsigned NumElems = VT.getVectorNumElements();
25075 unsigned ShuffleNumElems = NumElems;
25076
25077 // If we artificially split a vector in two already, then the offsets in the
25078 // operands will all be based off of VecIn1, even those in VecIn2.
25079 unsigned Vec2Offset = DidSplitVec ? 0 : InVT1.getVectorNumElements();
25080
25081 uint64_t VTSize = VT.getFixedSizeInBits();
25082 uint64_t InVT1Size = InVT1.getFixedSizeInBits();
25083 uint64_t InVT2Size = InVT2.getFixedSizeInBits();
25084
25085 assert(InVT2Size <= InVT1Size &&
25086 "Inputs must be sorted to be in non-increasing vector size order.");
25087
25088 // We can't generate a shuffle node with mismatched input and output types.
25089 // Try to make the types match the type of the output.
25090 if (InVT1 != VT || InVT2 != VT) {
25091 if ((VTSize % InVT1Size == 0) && InVT1 == InVT2) {
25092 // If the output vector length is a multiple of both input lengths,
25093 // we can concatenate them and pad the rest with poison.
25094 unsigned NumConcats = VTSize / InVT1Size;
25095 assert(NumConcats >= 2 && "Concat needs at least two inputs!");
25096 SmallVector<SDValue, 2> ConcatOps(NumConcats, DAG.getPOISON(VT: InVT1));
25097 ConcatOps[0] = VecIn1;
25098 ConcatOps[1] = VecIn2 ? VecIn2 : DAG.getPOISON(VT: InVT1);
25099 VecIn1 = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT, Ops: ConcatOps);
25100 VecIn2 = SDValue();
25101 } else if (InVT1Size == VTSize * 2) {
25102 if (!TLI.isExtractSubvectorCheap(ResVT: VT, SrcVT: InVT1, Index: NumElems))
25103 return SDValue();
25104
25105 if (!VecIn2.getNode()) {
25106 // If we only have one input vector, and it's twice the size of the
25107 // output, split it in two.
25108 VecIn2 = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT, N1: VecIn1,
25109 N2: DAG.getVectorIdxConstant(Val: NumElems, DL));
25110 VecIn1 = DAG.getExtractSubvector(DL, VT, Vec: VecIn1, Idx: 0);
25111 // Since we now have shorter input vectors, adjust the offset of the
25112 // second vector's start.
25113 Vec2Offset = NumElems;
25114 } else {
25115 assert(InVT2Size <= InVT1Size &&
25116 "Second input is not going to be larger than the first one.");
25117
25118 // VecIn1 is wider than the output, and we have another, possibly
25119 // smaller input. Pad the smaller input with undefs, shuffle at the
25120 // input vector width, and extract the output.
25121 // The shuffle type is different than VT, so check legality again.
25122 if (LegalOperations &&
25123 !TLI.isOperationLegal(Op: ISD::VECTOR_SHUFFLE, VT: InVT1))
25124 return SDValue();
25125
25126 // Legalizing INSERT_SUBVECTOR is tricky - you basically have to
25127 // lower it back into a BUILD_VECTOR. So if the inserted type is
25128 // illegal, don't even try.
25129 if (InVT1 != InVT2) {
25130 if (!TLI.isTypeLegal(VT: InVT2))
25131 return SDValue();
25132 VecIn2 = DAG.getInsertSubvector(DL, Vec: DAG.getPOISON(VT: InVT1), SubVec: VecIn2, Idx: 0);
25133 }
25134 ShuffleNumElems = NumElems * 2;
25135 }
25136 } else if (InVT2Size * 2 == VTSize && InVT1Size == VTSize) {
25137 SmallVector<SDValue, 2> ConcatOps(2, DAG.getPOISON(VT: InVT2));
25138 ConcatOps[0] = VecIn2;
25139 VecIn2 = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT, Ops: ConcatOps);
25140 } else if (InVT1Size / VTSize > 1 && InVT1Size % VTSize == 0) {
25141 if (!TLI.isExtractSubvectorCheap(ResVT: VT, SrcVT: InVT1, Index: NumElems) ||
25142 !TLI.isTypeLegal(VT: InVT1) || !TLI.isTypeLegal(VT: InVT2))
25143 return SDValue();
25144 // If dest vector has less than two elements, then use shuffle and extract
25145 // from larger regs will cost even more.
25146 if (VT.getVectorNumElements() <= 2 || !VecIn2.getNode())
25147 return SDValue();
25148 assert(InVT2Size <= InVT1Size &&
25149 "Second input is not going to be larger than the first one.");
25150
25151 // VecIn1 is wider than the output, and we have another, possibly
25152 // smaller input. Pad the smaller input with undefs, shuffle at the
25153 // input vector width, and extract the output.
25154 // The shuffle type is different than VT, so check legality again.
25155 if (LegalOperations && !TLI.isOperationLegal(Op: ISD::VECTOR_SHUFFLE, VT: InVT1))
25156 return SDValue();
25157
25158 if (InVT1 != InVT2) {
25159 VecIn2 = DAG.getInsertSubvector(DL, Vec: DAG.getPOISON(VT: InVT1), SubVec: VecIn2, Idx: 0);
25160 }
25161 ShuffleNumElems = InVT1Size / VTSize * NumElems;
25162 } else {
25163 // TODO: Support cases where the length mismatch isn't exactly by a
25164 // factor of 2.
25165 // TODO: Move this check upwards, so that if we have bad type
25166 // mismatches, we don't create any DAG nodes.
25167 return SDValue();
25168 }
25169 }
25170
25171 // Initialize mask to undef.
25172 SmallVector<int, 8> Mask(ShuffleNumElems, -1);
25173
25174 // Only need to run up to the number of elements actually used, not the
25175 // total number of elements in the shuffle - if we are shuffling a wider
25176 // vector, the high lanes should be set to undef.
25177 for (unsigned i = 0; i != NumElems; ++i) {
25178 if (VectorMask[i] <= 0)
25179 continue;
25180
25181 unsigned ExtIndex = N->getOperand(Num: i).getConstantOperandVal(i: 1);
25182 if (VectorMask[i] == (int)LeftIdx) {
25183 Mask[i] = ExtIndex;
25184 } else if (VectorMask[i] == (int)LeftIdx + 1) {
25185 Mask[i] = Vec2Offset + ExtIndex;
25186 }
25187 }
25188
25189 // The type the input vectors may have changed above.
25190 InVT1 = VecIn1.getValueType();
25191
25192 // If we already have a VecIn2, it should have the same type as VecIn1.
25193 // If we don't, get an poison/zero vector of the appropriate type.
25194 VecIn2 = VecIn2.getNode() ? VecIn2 : DAG.getPOISON(VT: InVT1);
25195 assert(InVT1 == VecIn2.getValueType() && "Unexpected second input type.");
25196
25197 SDValue Shuffle = DAG.getVectorShuffle(VT: InVT1, dl: DL, N1: VecIn1, N2: VecIn2, Mask);
25198 if (ShuffleNumElems > NumElems)
25199 Shuffle = DAG.getExtractSubvector(DL, VT, Vec: Shuffle, Idx: 0);
25200
25201 return Shuffle;
25202}
25203
25204static SDValue reduceBuildVecToShuffleWithZero(SDNode *BV, SelectionDAG &DAG) {
25205 assert(BV->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
25206
25207 // First, determine where the build vector is not undef.
25208 // TODO: We could extend this to handle zero elements as well as undefs.
25209 int NumBVOps = BV->getNumOperands();
25210 int ZextElt = -1;
25211 for (int i = 0; i != NumBVOps; ++i) {
25212 SDValue Op = BV->getOperand(Num: i);
25213 if (Op.isUndef())
25214 continue;
25215 if (ZextElt == -1)
25216 ZextElt = i;
25217 else
25218 return SDValue();
25219 }
25220 // Bail out if there's no non-undef element.
25221 if (ZextElt == -1)
25222 return SDValue();
25223
25224 // The build vector contains some number of undef elements and exactly
25225 // one other element. That other element must be a zero-extended scalar
25226 // extracted from a vector at a constant index to turn this into a shuffle.
25227 // Also, require that the build vector does not implicitly truncate/extend
25228 // its elements.
25229 // TODO: This could be enhanced to allow ANY_EXTEND as well as ZERO_EXTEND.
25230 EVT VT = BV->getValueType(ResNo: 0);
25231 SDValue Zext = BV->getOperand(Num: ZextElt);
25232 if (Zext.getOpcode() != ISD::ZERO_EXTEND || !Zext.hasOneUse() ||
25233 Zext.getOperand(i: 0).getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
25234 !isa<ConstantSDNode>(Val: Zext.getOperand(i: 0).getOperand(i: 1)) ||
25235 Zext.getValueSizeInBits() != VT.getScalarSizeInBits())
25236 return SDValue();
25237
25238 // The zero-extend must be a multiple of the source size, and we must be
25239 // building a vector of the same size as the source of the extract element.
25240 SDValue Extract = Zext.getOperand(i: 0);
25241 unsigned DestSize = Zext.getValueSizeInBits();
25242 unsigned SrcSize = Extract.getValueSizeInBits();
25243 if (DestSize % SrcSize != 0 ||
25244 Extract.getOperand(i: 0).getValueSizeInBits() != VT.getSizeInBits())
25245 return SDValue();
25246
25247 // Create a shuffle mask that will combine the extracted element with zeros
25248 // and undefs.
25249 int ZextRatio = DestSize / SrcSize;
25250 int NumMaskElts = NumBVOps * ZextRatio;
25251 SmallVector<int, 32> ShufMask(NumMaskElts, -1);
25252 for (int i = 0; i != NumMaskElts; ++i) {
25253 if (i / ZextRatio == ZextElt) {
25254 // The low bits of the (potentially translated) extracted element map to
25255 // the source vector. The high bits map to zero. We will use a zero vector
25256 // as the 2nd source operand of the shuffle, so use the 1st element of
25257 // that vector (mask value is number-of-elements) for the high bits.
25258 int Low = DAG.getDataLayout().isBigEndian() ? (ZextRatio - 1) : 0;
25259 ShufMask[i] = (i % ZextRatio == Low) ? Extract.getConstantOperandVal(i: 1)
25260 : NumMaskElts;
25261 }
25262
25263 // Undef elements of the build vector remain undef because we initialize
25264 // the shuffle mask with -1.
25265 }
25266
25267 // buildvec undef, ..., (zext (extractelt V, IndexC)), undef... -->
25268 // bitcast (shuffle V, ZeroVec, VectorMask)
25269 SDLoc DL(BV);
25270 EVT VecVT = Extract.getOperand(i: 0).getValueType();
25271 SDValue ZeroVec = DAG.getConstant(Val: 0, DL, VT: VecVT);
25272 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
25273 SDValue Shuf = TLI.buildLegalVectorShuffle(VT: VecVT, DL, N0: Extract.getOperand(i: 0),
25274 N1: ZeroVec, Mask: ShufMask, DAG);
25275 if (!Shuf)
25276 return SDValue();
25277 return DAG.getBitcast(VT, V: Shuf);
25278}
25279
25280// FIXME: promote to STLExtras.
25281template <typename R, typename T>
25282static auto getFirstIndexOf(R &&Range, const T &Val) {
25283 auto I = find(Range, Val);
25284 if (I == Range.end())
25285 return static_cast<decltype(std::distance(Range.begin(), I))>(-1);
25286 return std::distance(Range.begin(), I);
25287}
25288
25289// Check to see if this is a BUILD_VECTOR of a bunch of EXTRACT_VECTOR_ELT
25290// operations. If the types of the vectors we're extracting from allow it,
25291// turn this into a vector_shuffle node.
25292SDValue DAGCombiner::reduceBuildVecToShuffle(SDNode *N) {
25293 SDLoc DL(N);
25294 EVT VT = N->getValueType(ResNo: 0);
25295
25296 // Only type-legal BUILD_VECTOR nodes are converted to shuffle nodes.
25297 if (!isTypeLegal(VT))
25298 return SDValue();
25299
25300 if (SDValue V = reduceBuildVecToShuffleWithZero(BV: N, DAG))
25301 return V;
25302
25303 // May only combine to shuffle after legalize if shuffle is legal.
25304 if (LegalOperations && !TLI.isOperationLegal(Op: ISD::VECTOR_SHUFFLE, VT))
25305 return SDValue();
25306
25307 bool UsesZeroVector = false;
25308 unsigned NumElems = N->getNumOperands();
25309
25310 // Record, for each element of the newly built vector, which input vector
25311 // that element comes from. -1 stands for undef, 0 for the zero vector,
25312 // and positive values for the input vectors.
25313 // VectorMask maps each element to its vector number, and VecIn maps vector
25314 // numbers to their initial SDValues.
25315
25316 SmallVector<int, 8> VectorMask(NumElems, -1);
25317 SmallVector<SDValue, 8> VecIn;
25318 VecIn.push_back(Elt: SDValue());
25319
25320 // If we have a single extract_element with a constant index, track the index
25321 // value.
25322 unsigned OneConstExtractIndex = ~0u;
25323
25324 // Count the number of extract_vector_elt sources (i.e. non-constant or undef)
25325 unsigned NumExtracts = 0;
25326
25327 for (unsigned i = 0; i != NumElems; ++i) {
25328 SDValue Op = N->getOperand(Num: i);
25329
25330 if (Op.isUndef())
25331 continue;
25332
25333 // See if we can use a blend with a zero vector.
25334 // TODO: Should we generalize this to a blend with an arbitrary constant
25335 // vector?
25336 if (isNullConstant(V: Op) || isNullFPConstant(V: Op)) {
25337 UsesZeroVector = true;
25338 VectorMask[i] = 0;
25339 continue;
25340 }
25341
25342 // Not an undef or zero. If the input is something other than an
25343 // EXTRACT_VECTOR_ELT with an in-range constant index, bail out.
25344 if (Op.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
25345 return SDValue();
25346
25347 SDValue ExtractedFromVec = Op.getOperand(i: 0);
25348 if (ExtractedFromVec.getValueType().isScalableVector())
25349 return SDValue();
25350 auto *ExtractIdx = dyn_cast<ConstantSDNode>(Val: Op.getOperand(i: 1));
25351 if (!ExtractIdx)
25352 return SDValue();
25353
25354 if (ExtractIdx->getAsAPIntVal().uge(
25355 RHS: ExtractedFromVec.getValueType().getVectorNumElements()))
25356 return SDValue();
25357
25358 // All inputs must have the same element type as the output.
25359 if (VT.getVectorElementType() !=
25360 ExtractedFromVec.getValueType().getVectorElementType())
25361 return SDValue();
25362
25363 OneConstExtractIndex = ExtractIdx->getZExtValue();
25364 ++NumExtracts;
25365
25366 // Have we seen this input vector before?
25367 // The vectors are expected to be tiny (usually 1 or 2 elements), so using
25368 // a map back from SDValues to numbers isn't worth it.
25369 int Idx = getFirstIndexOf(Range&: VecIn, Val: ExtractedFromVec);
25370 if (Idx == -1) { // A new source vector?
25371 Idx = VecIn.size();
25372 VecIn.push_back(Elt: ExtractedFromVec);
25373 }
25374
25375 VectorMask[i] = Idx;
25376 }
25377
25378 // If we didn't find at least one input vector, bail out.
25379 if (VecIn.size() < 2)
25380 return SDValue();
25381
25382 // If all the Operands of BUILD_VECTOR extract from same
25383 // vector, then split the vector efficiently based on the maximum
25384 // vector access index and adjust the VectorMask and
25385 // VecIn accordingly.
25386 bool DidSplitVec = false;
25387 if (VecIn.size() == 2) {
25388 // If we only found a single constant indexed extract_vector_elt feeding the
25389 // build_vector, do not produce a more complicated shuffle if the extract is
25390 // cheap with other constant/undef elements. Skip broadcast patterns with
25391 // multiple uses in the build_vector.
25392
25393 // TODO: This should be more aggressive about skipping the shuffle
25394 // formation, particularly if VecIn[1].hasOneUse(), and regardless of the
25395 // index.
25396 if (NumExtracts == 1 &&
25397 TLI.isOperationLegalOrCustom(Op: ISD::EXTRACT_VECTOR_ELT, VT) &&
25398 TLI.isTypeLegal(VT: VT.getVectorElementType()) &&
25399 TLI.isExtractVecEltCheap(VT, Index: OneConstExtractIndex))
25400 return SDValue();
25401
25402 unsigned MaxIndex = 0;
25403 unsigned NearestPow2 = 0;
25404 SDValue Vec = VecIn.back();
25405 EVT InVT = Vec.getValueType();
25406 SmallVector<unsigned, 8> IndexVec(NumElems, 0);
25407
25408 for (unsigned i = 0; i < NumElems; i++) {
25409 if (VectorMask[i] <= 0)
25410 continue;
25411 unsigned Index = N->getOperand(Num: i).getConstantOperandVal(i: 1);
25412 IndexVec[i] = Index;
25413 MaxIndex = std::max(a: MaxIndex, b: Index);
25414 }
25415
25416 NearestPow2 = PowerOf2Ceil(A: MaxIndex);
25417 if (InVT.isSimple() && NearestPow2 > 2 && MaxIndex < NearestPow2 &&
25418 NumElems * 2 < NearestPow2) {
25419 unsigned SplitSize = NearestPow2 / 2;
25420 EVT SplitVT = EVT::getVectorVT(Context&: *DAG.getContext(),
25421 VT: InVT.getVectorElementType(), NumElements: SplitSize);
25422 if (TLI.isTypeLegal(VT: SplitVT) &&
25423 SplitSize + SplitVT.getVectorNumElements() <=
25424 InVT.getVectorNumElements()) {
25425 SDValue VecIn2 = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: SplitVT, N1: Vec,
25426 N2: DAG.getVectorIdxConstant(Val: SplitSize, DL));
25427 SDValue VecIn1 = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: SplitVT, N1: Vec,
25428 N2: DAG.getVectorIdxConstant(Val: 0, DL));
25429 VecIn.pop_back();
25430 VecIn.push_back(Elt: VecIn1);
25431 VecIn.push_back(Elt: VecIn2);
25432 DidSplitVec = true;
25433
25434 for (unsigned i = 0; i < NumElems; i++) {
25435 if (VectorMask[i] <= 0)
25436 continue;
25437 VectorMask[i] = (IndexVec[i] < SplitSize) ? 1 : 2;
25438 }
25439 }
25440 }
25441 }
25442
25443 // Sort input vectors by decreasing vector element count,
25444 // while preserving the relative order of equally-sized vectors.
25445 // Note that we keep the first "implicit zero vector as-is.
25446 SmallVector<SDValue, 8> SortedVecIn(VecIn);
25447 llvm::stable_sort(Range: MutableArrayRef<SDValue>(SortedVecIn).drop_front(),
25448 C: [](const SDValue &a, const SDValue &b) {
25449 return a.getValueType().getVectorNumElements() >
25450 b.getValueType().getVectorNumElements();
25451 });
25452
25453 // We now also need to rebuild the VectorMask, because it referenced element
25454 // order in VecIn, and we just sorted them.
25455 for (int &SourceVectorIndex : VectorMask) {
25456 if (SourceVectorIndex <= 0)
25457 continue;
25458 unsigned Idx = getFirstIndexOf(Range&: SortedVecIn, Val: VecIn[SourceVectorIndex]);
25459 assert(Idx > 0 && Idx < SortedVecIn.size() &&
25460 VecIn[SourceVectorIndex] == SortedVecIn[Idx] && "Remapping failure");
25461 SourceVectorIndex = Idx;
25462 }
25463
25464 VecIn = std::move(SortedVecIn);
25465
25466 // TODO: Should this fire if some of the input vectors has illegal type (like
25467 // it does now), or should we let legalization run its course first?
25468
25469 // Shuffle phase:
25470 // Take pairs of vectors, and shuffle them so that the result has elements
25471 // from these vectors in the correct places.
25472 // For example, given:
25473 // t10: i32 = extract_vector_elt t1, Constant:i64<0>
25474 // t11: i32 = extract_vector_elt t2, Constant:i64<0>
25475 // t12: i32 = extract_vector_elt t3, Constant:i64<0>
25476 // t13: i32 = extract_vector_elt t1, Constant:i64<1>
25477 // t14: v4i32 = BUILD_VECTOR t10, t11, t12, t13
25478 // We will generate:
25479 // t20: v4i32 = vector_shuffle<0,4,u,1> t1, t2
25480 // t21: v4i32 = vector_shuffle<u,u,0,u> t3, undef
25481 SmallVector<SDValue, 4> Shuffles;
25482 for (unsigned In = 0, Len = (VecIn.size() / 2); In < Len; ++In) {
25483 unsigned LeftIdx = 2 * In + 1;
25484 SDValue VecLeft = VecIn[LeftIdx];
25485 SDValue VecRight =
25486 (LeftIdx + 1) < VecIn.size() ? VecIn[LeftIdx + 1] : SDValue();
25487
25488 if (SDValue Shuffle = createBuildVecShuffle(DL, N, VectorMask, VecIn1: VecLeft,
25489 VecIn2: VecRight, LeftIdx, DidSplitVec))
25490 Shuffles.push_back(Elt: Shuffle);
25491 else
25492 return SDValue();
25493 }
25494
25495 // If we need the zero vector as an "ingredient" in the blend tree, add it
25496 // to the list of shuffles.
25497 if (UsesZeroVector)
25498 Shuffles.push_back(Elt: VT.isInteger() ? DAG.getConstant(Val: 0, DL, VT)
25499 : DAG.getConstantFP(Val: 0.0, DL, VT));
25500
25501 // If we only have one shuffle, we're done.
25502 if (Shuffles.size() == 1)
25503 return Shuffles[0];
25504
25505 // Update the vector mask to point to the post-shuffle vectors.
25506 for (int &Vec : VectorMask)
25507 if (Vec == 0)
25508 Vec = Shuffles.size() - 1;
25509 else
25510 Vec = (Vec - 1) / 2;
25511
25512 // More than one shuffle. Generate a binary tree of blends, e.g. if from
25513 // the previous step we got the set of shuffles t10, t11, t12, t13, we will
25514 // generate:
25515 // t10: v8i32 = vector_shuffle<0,8,u,u,u,u,u,u> t1, t2
25516 // t11: v8i32 = vector_shuffle<u,u,0,8,u,u,u,u> t3, t4
25517 // t12: v8i32 = vector_shuffle<u,u,u,u,0,8,u,u> t5, t6
25518 // t13: v8i32 = vector_shuffle<u,u,u,u,u,u,0,8> t7, t8
25519 // t20: v8i32 = vector_shuffle<0,1,10,11,u,u,u,u> t10, t11
25520 // t21: v8i32 = vector_shuffle<u,u,u,u,4,5,14,15> t12, t13
25521 // t30: v8i32 = vector_shuffle<0,1,2,3,12,13,14,15> t20, t21
25522
25523 // Make sure the initial size of the shuffle list is even.
25524 if (Shuffles.size() % 2)
25525 Shuffles.push_back(Elt: DAG.getPOISON(VT));
25526
25527 for (unsigned CurSize = Shuffles.size(); CurSize > 1; CurSize /= 2) {
25528 if (CurSize % 2) {
25529 Shuffles[CurSize] = DAG.getPOISON(VT);
25530 CurSize++;
25531 }
25532 for (unsigned In = 0, Len = CurSize / 2; In < Len; ++In) {
25533 int Left = 2 * In;
25534 int Right = 2 * In + 1;
25535 SmallVector<int, 8> Mask(NumElems, -1);
25536 SDValue L = Shuffles[Left];
25537 ArrayRef<int> LMask;
25538 bool IsLeftShuffle = L.getOpcode() == ISD::VECTOR_SHUFFLE &&
25539 L.use_empty() && L.getOperand(i: 1).isUndef() &&
25540 L.getOperand(i: 0).getValueType() == L.getValueType();
25541 if (IsLeftShuffle) {
25542 LMask = cast<ShuffleVectorSDNode>(Val: L.getNode())->getMask();
25543 L = L.getOperand(i: 0);
25544 }
25545 SDValue R = Shuffles[Right];
25546 ArrayRef<int> RMask;
25547 bool IsRightShuffle = R.getOpcode() == ISD::VECTOR_SHUFFLE &&
25548 R.use_empty() && R.getOperand(i: 1).isUndef() &&
25549 R.getOperand(i: 0).getValueType() == R.getValueType();
25550 if (IsRightShuffle) {
25551 RMask = cast<ShuffleVectorSDNode>(Val: R.getNode())->getMask();
25552 R = R.getOperand(i: 0);
25553 }
25554 for (unsigned I = 0; I != NumElems; ++I) {
25555 if (VectorMask[I] == Left) {
25556 Mask[I] = I;
25557 if (IsLeftShuffle)
25558 Mask[I] = LMask[I];
25559 VectorMask[I] = In;
25560 } else if (VectorMask[I] == Right) {
25561 Mask[I] = I + NumElems;
25562 if (IsRightShuffle)
25563 Mask[I] = RMask[I] + NumElems;
25564 VectorMask[I] = In;
25565 }
25566 }
25567
25568 Shuffles[In] = DAG.getVectorShuffle(VT, dl: DL, N1: L, N2: R, Mask);
25569 }
25570 }
25571 return Shuffles[0];
25572}
25573
25574// Try to turn a build vector of zero extends of extract vector elts into a
25575// a vector zero extend and possibly an extract subvector.
25576// TODO: Support sign extend?
25577// TODO: Allow undef elements?
25578SDValue DAGCombiner::convertBuildVecZextToZext(SDNode *N) {
25579 if (LegalOperations)
25580 return SDValue();
25581
25582 EVT VT = N->getValueType(ResNo: 0);
25583
25584 bool FoundZeroExtend = false;
25585 SDValue Op0 = N->getOperand(Num: 0);
25586 auto checkElem = [&](SDValue Op) -> int64_t {
25587 unsigned Opc = Op.getOpcode();
25588 FoundZeroExtend |= (Opc == ISD::ZERO_EXTEND);
25589 if ((Opc == ISD::ZERO_EXTEND || Opc == ISD::ANY_EXTEND) &&
25590 Op.getOperand(i: 0).getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
25591 Op0.getOperand(i: 0).getOperand(i: 0) == Op.getOperand(i: 0).getOperand(i: 0))
25592 if (auto *C = dyn_cast<ConstantSDNode>(Val: Op.getOperand(i: 0).getOperand(i: 1)))
25593 return C->getZExtValue();
25594 return -1;
25595 };
25596
25597 // Make sure the first element matches
25598 // (zext (extract_vector_elt X, C))
25599 // Offset must be a constant multiple of the
25600 // known-minimum vector length of the result type.
25601 int64_t Offset = checkElem(Op0);
25602 if (Offset < 0 || (Offset % VT.getVectorNumElements()) != 0)
25603 return SDValue();
25604
25605 unsigned NumElems = N->getNumOperands();
25606 SDValue In = Op0.getOperand(i: 0).getOperand(i: 0);
25607 EVT InSVT = In.getValueType().getScalarType();
25608 EVT InVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: InSVT, NumElements: NumElems);
25609
25610 // Don't create an illegal input type after type legalization.
25611 if (LegalTypes && !TLI.isTypeLegal(VT: InVT))
25612 return SDValue();
25613
25614 // Ensure all the elements come from the same vector and are adjacent.
25615 for (unsigned i = 1; i != NumElems; ++i) {
25616 if ((Offset + i) != checkElem(N->getOperand(Num: i)))
25617 return SDValue();
25618 }
25619
25620 SDLoc DL(N);
25621 In = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: InVT, N1: In,
25622 N2: Op0.getOperand(i: 0).getOperand(i: 1));
25623 return DAG.getNode(Opcode: FoundZeroExtend ? ISD::ZERO_EXTEND : ISD::ANY_EXTEND, DL,
25624 VT, Operand: In);
25625}
25626
25627// If this is a very simple BUILD_VECTOR with first element being a ZERO_EXTEND,
25628// and all other elements being constant zero's, granularize the BUILD_VECTOR's
25629// element width, absorbing the ZERO_EXTEND, turning it into a constant zero op.
25630// This patten can appear during legalization.
25631//
25632// NOTE: This can be generalized to allow more than a single
25633// non-constant-zero op, UNDEF's, and to be KnownBits-based,
25634SDValue DAGCombiner::convertBuildVecZextToBuildVecWithZeros(SDNode *N) {
25635 // Don't run this after legalization. Targets may have other preferences.
25636 if (Level >= AfterLegalizeDAG)
25637 return SDValue();
25638
25639 // FIXME: support big-endian.
25640 if (DAG.getDataLayout().isBigEndian())
25641 return SDValue();
25642
25643 EVT VT = N->getValueType(ResNo: 0);
25644 EVT OpVT = N->getOperand(Num: 0).getValueType();
25645 assert(!VT.isScalableVector() && "Encountered scalable BUILD_VECTOR?");
25646
25647 EVT OpIntVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: OpVT.getSizeInBits());
25648
25649 if (!TLI.isTypeLegal(VT: OpIntVT) ||
25650 (LegalOperations && !TLI.isOperationLegalOrCustom(Op: ISD::BITCAST, VT: OpIntVT)))
25651 return SDValue();
25652
25653 unsigned EltBitwidth = VT.getScalarSizeInBits();
25654 // NOTE: the actual width of operands may be wider than that!
25655
25656 // Analyze all operands of this BUILD_VECTOR. What is the largest number of
25657 // active bits they all have? We'll want to truncate them all to that width.
25658 unsigned ActiveBits = 0;
25659 APInt KnownZeroOps(VT.getVectorNumElements(), 0);
25660 for (auto I : enumerate(First: N->ops())) {
25661 SDValue Op = I.value();
25662 // FIXME: support UNDEF elements?
25663 if (auto *Cst = dyn_cast<ConstantSDNode>(Val&: Op)) {
25664 unsigned OpActiveBits =
25665 Cst->getAPIntValue().trunc(width: EltBitwidth).getActiveBits();
25666 if (OpActiveBits == 0) {
25667 KnownZeroOps.setBit(I.index());
25668 continue;
25669 }
25670 // Profitability check: don't allow non-zero constant operands.
25671 return SDValue();
25672 }
25673 // Profitability check: there must only be a single non-zero operand,
25674 // and it must be the first operand of the BUILD_VECTOR.
25675 if (I.index() != 0)
25676 return SDValue();
25677 // The operand must be a zero-extension itself.
25678 // FIXME: this could be generalized to known leading zeros check.
25679 if (Op.getOpcode() != ISD::ZERO_EXTEND)
25680 return SDValue();
25681 unsigned CurrActiveBits =
25682 Op.getOperand(i: 0).getValueSizeInBits().getFixedValue();
25683 assert(!ActiveBits && "Already encountered non-constant-zero operand?");
25684 ActiveBits = CurrActiveBits;
25685 // We want to at least halve the element size.
25686 if (2 * ActiveBits > EltBitwidth)
25687 return SDValue();
25688 }
25689
25690 // This BUILD_VECTOR must have at least one non-constant-zero operand.
25691 if (ActiveBits == 0)
25692 return SDValue();
25693
25694 // We have EltBitwidth bits, the *minimal* chunk size is ActiveBits,
25695 // into how many chunks can we split our element width?
25696 EVT NewScalarIntVT, NewIntVT;
25697 std::optional<unsigned> Factor;
25698 // We can split the element into at least two chunks, but not into more
25699 // than |_ EltBitwidth / ActiveBits _| chunks. Find a largest split factor
25700 // for which the element width is a multiple of it,
25701 // and the resulting types/operations on that chunk width are legal.
25702 assert(2 * ActiveBits <= EltBitwidth &&
25703 "We know that half or less bits of the element are active.");
25704 for (unsigned Scale = EltBitwidth / ActiveBits; Scale >= 2; --Scale) {
25705 if (EltBitwidth % Scale != 0)
25706 continue;
25707 unsigned ChunkBitwidth = EltBitwidth / Scale;
25708 assert(ChunkBitwidth >= ActiveBits && "As per starting point.");
25709 NewScalarIntVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: ChunkBitwidth);
25710 NewIntVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: NewScalarIntVT,
25711 NumElements: Scale * N->getNumOperands());
25712 if (!TLI.isTypeLegal(VT: NewScalarIntVT) || !TLI.isTypeLegal(VT: NewIntVT) ||
25713 (LegalOperations &&
25714 !(TLI.isOperationLegalOrCustom(Op: ISD::TRUNCATE, VT: NewScalarIntVT) &&
25715 TLI.isOperationLegalOrCustom(Op: ISD::BUILD_VECTOR, VT: NewIntVT))))
25716 continue;
25717 Factor = Scale;
25718 break;
25719 }
25720 if (!Factor)
25721 return SDValue();
25722
25723 SDLoc DL(N);
25724 SDValue ZeroOp = DAG.getConstant(Val: 0, DL, VT: NewScalarIntVT);
25725
25726 // Recreate the BUILD_VECTOR, with elements now being Factor times smaller.
25727 SmallVector<SDValue, 16> NewOps;
25728 NewOps.reserve(N: NewIntVT.getVectorNumElements());
25729 for (auto I : enumerate(First: N->ops())) {
25730 SDValue Op = I.value();
25731 assert(!Op.isUndef() && "FIXME: after allowing UNDEF's, handle them here.");
25732 unsigned SrcOpIdx = I.index();
25733 if (KnownZeroOps[SrcOpIdx]) {
25734 NewOps.append(NumInputs: *Factor, Elt: ZeroOp);
25735 continue;
25736 }
25737 Op = DAG.getBitcast(VT: OpIntVT, V: Op);
25738 Op = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: NewScalarIntVT, Operand: Op);
25739 NewOps.emplace_back(Args&: Op);
25740 NewOps.append(NumInputs: *Factor - 1, Elt: ZeroOp);
25741 }
25742 assert(NewOps.size() == NewIntVT.getVectorNumElements());
25743 SDValue NewBV = DAG.getBuildVector(VT: NewIntVT, DL, Ops: NewOps);
25744 NewBV = DAG.getBitcast(VT, V: NewBV);
25745 return NewBV;
25746}
25747
25748SDValue DAGCombiner::visitBUILD_VECTOR(SDNode *N) {
25749 EVT VT = N->getValueType(ResNo: 0);
25750
25751 // A vector built entirely of undefs is undef.
25752 if (ISD::allOperandsUndef(N))
25753 return DAG.getUNDEF(VT);
25754
25755 // If this is a splat of a bitcast from another vector, change to a
25756 // concat_vector.
25757 // For example:
25758 // (build_vector (i64 (bitcast (v2i32 X))), (i64 (bitcast (v2i32 X)))) ->
25759 // (v2i64 (bitcast (concat_vectors (v2i32 X), (v2i32 X))))
25760 //
25761 // If X is a build_vector itself, the concat can become a larger build_vector.
25762 // TODO: Maybe this is useful for non-splat too?
25763 if (!LegalOperations) {
25764 SDValue Splat = cast<BuildVectorSDNode>(Val: N)->getSplatValue();
25765 // Only change build_vector to a concat_vector if the splat value type is
25766 // same as the vector element type.
25767 if (Splat && Splat.getValueType() == VT.getVectorElementType()) {
25768 Splat = peekThroughBitcasts(V: Splat);
25769 EVT SrcVT = Splat.getValueType();
25770 if (SrcVT.isVector()) {
25771 unsigned NumElts = N->getNumOperands() * SrcVT.getVectorNumElements();
25772 EVT NewVT = EVT::getVectorVT(Context&: *DAG.getContext(),
25773 VT: SrcVT.getVectorElementType(), NumElements: NumElts);
25774 if (!LegalTypes || TLI.isTypeLegal(VT: NewVT)) {
25775 SmallVector<SDValue, 8> Ops(N->getNumOperands(), Splat);
25776 SDValue Concat =
25777 DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT: NewVT, Ops);
25778 return DAG.getBitcast(VT, V: Concat);
25779 }
25780 }
25781 }
25782 }
25783
25784 // Check if we can express BUILD VECTOR via subvector extract.
25785 if (!LegalTypes && (N->getNumOperands() > 1)) {
25786 SDValue Op0 = N->getOperand(Num: 0);
25787 auto checkElem = [&](SDValue Op) -> uint64_t {
25788 if ((Op.getOpcode() == ISD::EXTRACT_VECTOR_ELT) &&
25789 (Op0.getOperand(i: 0) == Op.getOperand(i: 0)))
25790 if (auto CNode = dyn_cast<ConstantSDNode>(Val: Op.getOperand(i: 1)))
25791 return CNode->getZExtValue();
25792 return -1;
25793 };
25794
25795 int Offset = checkElem(Op0);
25796 for (unsigned i = 0; i < N->getNumOperands(); ++i) {
25797 if (Offset + i != checkElem(N->getOperand(Num: i))) {
25798 Offset = -1;
25799 break;
25800 }
25801 }
25802
25803 if ((Offset == 0) &&
25804 (Op0.getOperand(i: 0).getValueType() == N->getValueType(ResNo: 0)))
25805 return Op0.getOperand(i: 0);
25806 if ((Offset != -1) &&
25807 ((Offset % N->getValueType(ResNo: 0).getVectorNumElements()) ==
25808 0)) // IDX must be multiple of output size.
25809 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL: SDLoc(N), VT: N->getValueType(ResNo: 0),
25810 N1: Op0.getOperand(i: 0), N2: Op0.getOperand(i: 1));
25811 }
25812
25813 if (SDValue V = convertBuildVecZextToZext(N))
25814 return V;
25815
25816 if (SDValue V = convertBuildVecZextToBuildVecWithZeros(N))
25817 return V;
25818
25819 if (SDValue V = reduceBuildVecExtToExtBuildVec(N))
25820 return V;
25821
25822 if (SDValue V = reduceBuildVecTruncToBitCast(N))
25823 return V;
25824
25825 if (SDValue V = reduceBuildVecToShuffle(N))
25826 return V;
25827
25828 // A splat of a single element is a SPLAT_VECTOR if supported on the target.
25829 // Do this late as some of the above may replace the splat.
25830 if (TLI.getOperationAction(Op: ISD::SPLAT_VECTOR, VT) != TargetLowering::Expand)
25831 if (SDValue V = cast<BuildVectorSDNode>(Val: N)->getSplatValue()) {
25832 assert(!V.isUndef() && "Splat of undef should have been handled earlier");
25833 return DAG.getNode(Opcode: ISD::SPLAT_VECTOR, DL: SDLoc(N), VT, Operand: V);
25834 }
25835
25836 return SDValue();
25837}
25838
25839static SDValue combineConcatVectorOfScalars(SDNode *N, SelectionDAG &DAG) {
25840 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
25841 EVT OpVT = N->getOperand(Num: 0).getValueType();
25842
25843 // If the operands are legal vectors, leave them alone.
25844 if (TLI.isTypeLegal(VT: OpVT) || OpVT.isScalableVector())
25845 return SDValue();
25846
25847 SDLoc DL(N);
25848 EVT VT = N->getValueType(ResNo: 0);
25849 SmallVector<SDValue, 8> Ops;
25850 EVT SVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: OpVT.getSizeInBits());
25851
25852 // Keep track of what we encounter.
25853 EVT AnyFPVT;
25854
25855 for (const SDValue &Op : N->ops()) {
25856 if (ISD::BITCAST == Op.getOpcode() &&
25857 !Op.getOperand(i: 0).getValueType().isVector())
25858 Ops.push_back(Elt: Op.getOperand(i: 0));
25859 else if (Op.isUndef())
25860 Ops.push_back(Elt: DAG.getNode(Opcode: Op.getOpcode(), DL, VT: SVT));
25861 else
25862 return SDValue();
25863
25864 // Note whether we encounter an integer or floating point scalar.
25865 // If it's neither, bail out, it could be something weird like x86mmx.
25866 EVT LastOpVT = Ops.back().getValueType();
25867 if (LastOpVT.isFloatingPoint())
25868 AnyFPVT = LastOpVT;
25869 else if (!LastOpVT.isInteger())
25870 return SDValue();
25871 }
25872
25873 // If any of the operands is a floating point scalar bitcast to a vector,
25874 // use floating point types throughout, and bitcast everything.
25875 // Replace UNDEFs by another scalar UNDEF node, of the final desired type.
25876 if (AnyFPVT != EVT()) {
25877 SVT = AnyFPVT;
25878 for (SDValue &Op : Ops) {
25879 if (Op.getValueType() == SVT)
25880 continue;
25881 if (Op.isUndef())
25882 Op = DAG.getNode(Opcode: Op.getOpcode(), DL, VT: SVT);
25883 else
25884 Op = DAG.getBitcast(VT: SVT, V: Op);
25885 }
25886 }
25887
25888 EVT VecVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: SVT,
25889 NumElements: VT.getSizeInBits() / SVT.getSizeInBits());
25890 return DAG.getBitcast(VT, V: DAG.getBuildVector(VT: VecVT, DL, Ops));
25891}
25892
25893// Attempt to merge nested concat_vectors/undefs.
25894// Fold concat_vectors(concat_vectors(x,y,z,w),u,u,concat_vectors(a,b,c,d))
25895// --> concat_vectors(x,y,z,w,u,u,u,u,u,u,u,u,a,b,c,d)
25896static SDValue combineConcatVectorOfConcatVectors(SDNode *N,
25897 SelectionDAG &DAG) {
25898 EVT VT = N->getValueType(ResNo: 0);
25899
25900 // Ensure we're concatenating UNDEF and CONCAT_VECTORS nodes of similar types.
25901 EVT SubVT;
25902 SDValue FirstConcat;
25903 for (const SDValue &Op : N->ops()) {
25904 if (Op.isUndef())
25905 continue;
25906 if (Op.getOpcode() != ISD::CONCAT_VECTORS)
25907 return SDValue();
25908 if (!FirstConcat) {
25909 SubVT = Op.getOperand(i: 0).getValueType();
25910 if (!DAG.getTargetLoweringInfo().isTypeLegal(VT: SubVT))
25911 return SDValue();
25912 FirstConcat = Op;
25913 continue;
25914 }
25915 if (SubVT != Op.getOperand(i: 0).getValueType())
25916 return SDValue();
25917 }
25918 assert(FirstConcat && "Concat of all-undefs found");
25919
25920 SmallVector<SDValue> ConcatOps;
25921 for (const SDValue &Op : N->ops()) {
25922 if (Op.isUndef()) {
25923 ConcatOps.append(NumInputs: FirstConcat->getNumOperands(), Elt: DAG.getPOISON(VT: SubVT));
25924 continue;
25925 }
25926 ConcatOps.append(in_start: Op->op_begin(), in_end: Op->op_end());
25927 }
25928 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT, Ops: ConcatOps);
25929}
25930
25931// Check to see if this is a CONCAT_VECTORS of a bunch of EXTRACT_SUBVECTOR
25932// operations. If so, and if the EXTRACT_SUBVECTOR vector inputs come from at
25933// most two distinct vectors the same size as the result, attempt to turn this
25934// into a legal shuffle.
25935static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) {
25936 EVT VT = N->getValueType(ResNo: 0);
25937 EVT OpVT = N->getOperand(Num: 0).getValueType();
25938
25939 // We currently can't generate an appropriate shuffle for a scalable vector.
25940 if (VT.isScalableVector())
25941 return SDValue();
25942
25943 int NumElts = VT.getVectorNumElements();
25944 int NumOpElts = OpVT.getVectorNumElements();
25945
25946 SDValue SV0 = DAG.getPOISON(VT), SV1 = DAG.getPOISON(VT);
25947 SmallVector<int, 8> Mask;
25948
25949 for (SDValue Op : N->ops()) {
25950 Op = peekThroughBitcasts(V: Op);
25951
25952 // UNDEF nodes convert to UNDEF shuffle mask values.
25953 if (Op.isUndef()) {
25954 Mask.append(NumInputs: (unsigned)NumOpElts, Elt: -1);
25955 continue;
25956 }
25957
25958 if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
25959 return SDValue();
25960
25961 // What vector are we extracting the subvector from and at what index?
25962 SDValue ExtVec = Op.getOperand(i: 0);
25963 int ExtIdx = Op.getConstantOperandVal(i: 1);
25964
25965 // We want the EVT of the original extraction to correctly scale the
25966 // extraction index.
25967 EVT ExtVT = ExtVec.getValueType();
25968 ExtVec = peekThroughBitcasts(V: ExtVec);
25969
25970 // UNDEF nodes convert to UNDEF shuffle mask values.
25971 if (ExtVec.isUndef()) {
25972 Mask.append(NumInputs: (unsigned)NumOpElts, Elt: -1);
25973 continue;
25974 }
25975
25976 // Ensure that we are extracting a subvector from a vector the same
25977 // size as the result.
25978 if (ExtVT.getSizeInBits() != VT.getSizeInBits())
25979 return SDValue();
25980
25981 // Scale the subvector index to account for any bitcast.
25982 int NumExtElts = ExtVT.getVectorNumElements();
25983 if (0 == (NumExtElts % NumElts))
25984 ExtIdx /= (NumExtElts / NumElts);
25985 else if (0 == (NumElts % NumExtElts))
25986 ExtIdx *= (NumElts / NumExtElts);
25987 else
25988 return SDValue();
25989
25990 // At most we can reference 2 inputs in the final shuffle.
25991 if (SV0.isUndef() || SV0 == ExtVec) {
25992 SV0 = ExtVec;
25993 for (int i = 0; i != NumOpElts; ++i)
25994 Mask.push_back(Elt: i + ExtIdx);
25995 } else if (SV1.isUndef() || SV1 == ExtVec) {
25996 SV1 = ExtVec;
25997 for (int i = 0; i != NumOpElts; ++i)
25998 Mask.push_back(Elt: i + ExtIdx + NumElts);
25999 } else {
26000 return SDValue();
26001 }
26002 }
26003
26004 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
26005 return TLI.buildLegalVectorShuffle(VT, DL: SDLoc(N), N0: DAG.getBitcast(VT, V: SV0),
26006 N1: DAG.getBitcast(VT, V: SV1), Mask, DAG);
26007}
26008
26009static SDValue combineConcatVectorOfCasts(SDNode *N, SelectionDAG &DAG) {
26010 unsigned CastOpcode = N->getOperand(Num: 0).getOpcode();
26011 switch (CastOpcode) {
26012 case ISD::SINT_TO_FP:
26013 case ISD::UINT_TO_FP:
26014 case ISD::FP_TO_SINT:
26015 case ISD::FP_TO_UINT:
26016 // TODO: Allow more opcodes?
26017 // case ISD::BITCAST:
26018 // case ISD::TRUNCATE:
26019 // case ISD::ZERO_EXTEND:
26020 // case ISD::SIGN_EXTEND:
26021 // case ISD::FP_EXTEND:
26022 break;
26023 default:
26024 return SDValue();
26025 }
26026
26027 EVT SrcVT = N->getOperand(Num: 0).getOperand(i: 0).getValueType();
26028 if (!SrcVT.isVector())
26029 return SDValue();
26030
26031 // All operands of the concat must be the same kind of cast from the same
26032 // source type.
26033 SmallVector<SDValue, 4> SrcOps;
26034 for (SDValue Op : N->ops()) {
26035 if (Op.getOpcode() != CastOpcode || !Op.hasOneUse() ||
26036 Op.getOperand(i: 0).getValueType() != SrcVT)
26037 return SDValue();
26038 SrcOps.push_back(Elt: Op.getOperand(i: 0));
26039 }
26040
26041 // The wider cast must be supported by the target. This is unusual because
26042 // the operation support type parameter depends on the opcode. In addition,
26043 // check the other type in the cast to make sure this is really legal.
26044 EVT VT = N->getValueType(ResNo: 0);
26045 EVT SrcEltVT = SrcVT.getVectorElementType();
26046 ElementCount NumElts = SrcVT.getVectorElementCount() * N->getNumOperands();
26047 EVT ConcatSrcVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: SrcEltVT, EC: NumElts);
26048 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
26049 switch (CastOpcode) {
26050 case ISD::SINT_TO_FP:
26051 case ISD::UINT_TO_FP:
26052 if (!TLI.isOperationLegalOrCustom(Op: CastOpcode, VT: ConcatSrcVT) ||
26053 !TLI.isTypeLegal(VT))
26054 return SDValue();
26055 break;
26056 case ISD::FP_TO_SINT:
26057 case ISD::FP_TO_UINT:
26058 if (!TLI.isOperationLegalOrCustom(Op: CastOpcode, VT) ||
26059 !TLI.isTypeLegal(VT: ConcatSrcVT))
26060 return SDValue();
26061 break;
26062 default:
26063 llvm_unreachable("Unexpected cast opcode");
26064 }
26065
26066 // concat (cast X), (cast Y)... -> cast (concat X, Y...)
26067 SDLoc DL(N);
26068 SDValue NewConcat = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT: ConcatSrcVT, Ops: SrcOps);
26069 return DAG.getNode(Opcode: CastOpcode, DL, VT, Operand: NewConcat);
26070}
26071
26072// See if this is a simple CONCAT_VECTORS with no UNDEF operands, and if one of
26073// the operands is a SHUFFLE_VECTOR, and all other operands are also operands
26074// to that SHUFFLE_VECTOR, create wider SHUFFLE_VECTOR.
26075static SDValue combineConcatVectorOfShuffleAndItsOperands(
26076 SDNode *N, SelectionDAG &DAG, const TargetLowering &TLI, bool LegalTypes,
26077 bool LegalOperations) {
26078 EVT VT = N->getValueType(ResNo: 0);
26079 EVT OpVT = N->getOperand(Num: 0).getValueType();
26080 if (VT.isScalableVector())
26081 return SDValue();
26082
26083 // For now, only allow simple 2-operand concatenations.
26084 if (N->getNumOperands() != 2)
26085 return SDValue();
26086
26087 // Don't create illegal types/shuffles when not allowed to.
26088 if ((LegalTypes && !TLI.isTypeLegal(VT)) ||
26089 (LegalOperations &&
26090 !TLI.isOperationLegalOrCustom(Op: ISD::VECTOR_SHUFFLE, VT)))
26091 return SDValue();
26092
26093 // Analyze all of the operands of the CONCAT_VECTORS. Out of all of them,
26094 // we want to find one that is: (1) a SHUFFLE_VECTOR (2) only used by us,
26095 // and (3) all operands of CONCAT_VECTORS must be either that SHUFFLE_VECTOR,
26096 // or one of the operands of that SHUFFLE_VECTOR (but not UNDEF!).
26097 // (4) and for now, the SHUFFLE_VECTOR must be unary.
26098 ShuffleVectorSDNode *SVN = nullptr;
26099 for (SDValue Op : N->ops()) {
26100 if (auto *CurSVN = dyn_cast<ShuffleVectorSDNode>(Val&: Op);
26101 CurSVN && CurSVN->getOperand(Num: 1).isUndef() && N->isOnlyUserOf(N: CurSVN) &&
26102 all_of(Range: N->ops(), P: [CurSVN](SDValue Op) {
26103 // FIXME: can we allow UNDEF operands?
26104 return !Op.isUndef() &&
26105 (Op.getNode() == CurSVN || is_contained(Range: CurSVN->ops(), Element: Op));
26106 })) {
26107 SVN = CurSVN;
26108 break;
26109 }
26110 }
26111 if (!SVN)
26112 return SDValue();
26113
26114 // We are going to pad the shuffle operands, so any indice, that was picking
26115 // from the second operand, must be adjusted.
26116 SmallVector<int, 16> AdjustedMask(SVN->getMask());
26117 assert(SVN->getOperand(1).isUndef() && "Expected unary shuffle!");
26118
26119 // Identity masks for the operands of the (padded) shuffle.
26120 SmallVector<int, 32> IdentityMask(2 * OpVT.getVectorNumElements());
26121 MutableArrayRef<int> FirstShufOpIdentityMask =
26122 MutableArrayRef<int>(IdentityMask)
26123 .take_front(N: OpVT.getVectorNumElements());
26124 MutableArrayRef<int> SecondShufOpIdentityMask =
26125 MutableArrayRef<int>(IdentityMask).take_back(N: OpVT.getVectorNumElements());
26126 std::iota(first: FirstShufOpIdentityMask.begin(), last: FirstShufOpIdentityMask.end(), value: 0);
26127 std::iota(first: SecondShufOpIdentityMask.begin(), last: SecondShufOpIdentityMask.end(),
26128 value: VT.getVectorNumElements());
26129
26130 // New combined shuffle mask.
26131 SmallVector<int, 32> Mask;
26132 Mask.reserve(N: VT.getVectorNumElements());
26133 for (SDValue Op : N->ops()) {
26134 assert(!Op.isUndef() && "Not expecting to concatenate UNDEF.");
26135 if (Op.getNode() == SVN) {
26136 append_range(C&: Mask, R&: AdjustedMask);
26137 continue;
26138 }
26139 if (Op == SVN->getOperand(Num: 0)) {
26140 append_range(C&: Mask, R&: FirstShufOpIdentityMask);
26141 continue;
26142 }
26143 if (Op == SVN->getOperand(Num: 1)) {
26144 append_range(C&: Mask, R&: SecondShufOpIdentityMask);
26145 continue;
26146 }
26147 llvm_unreachable("Unexpected operand!");
26148 }
26149
26150 // Don't create illegal shuffle masks.
26151 if (!TLI.isShuffleMaskLegal(Mask, VT))
26152 return SDValue();
26153
26154 // Pad the shuffle operands with poison.
26155 SDLoc dl(N);
26156 std::array<SDValue, 2> ShufOps;
26157 for (auto I : zip(t: SVN->ops(), u&: ShufOps)) {
26158 SDValue ShufOp = std::get<0>(t&: I);
26159 SDValue &NewShufOp = std::get<1>(t&: I);
26160 if (ShufOp.isUndef())
26161 NewShufOp = DAG.getPOISON(VT);
26162 else {
26163 SmallVector<SDValue, 2> ShufOpParts(N->getNumOperands(),
26164 DAG.getPOISON(VT: OpVT));
26165 ShufOpParts[0] = ShufOp;
26166 NewShufOp = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: dl, VT, Ops: ShufOpParts);
26167 }
26168 }
26169 // Finally, create the new wide shuffle.
26170 return DAG.getVectorShuffle(VT, dl, N1: ShufOps[0], N2: ShufOps[1], Mask);
26171}
26172
26173static SDValue combineConcatVectorOfSplats(SDNode *N, SelectionDAG &DAG,
26174 const TargetLowering &TLI,
26175 bool LegalTypes,
26176 bool LegalOperations) {
26177 EVT VT = N->getValueType(ResNo: 0);
26178
26179 // Post-legalization we can only create wider SPLAT_VECTOR operations if both
26180 // the type and operation is legal. The Hexagon target has custom
26181 // legalization for SPLAT_VECTOR that splits the operation into two parts and
26182 // concatenates them. Therefore, custom lowering must also be rejected in
26183 // order to avoid an infinite loop.
26184 if ((LegalTypes && !TLI.isTypeLegal(VT)) ||
26185 (LegalOperations && !TLI.isOperationLegal(Op: ISD::SPLAT_VECTOR, VT)))
26186 return SDValue();
26187
26188 SDValue Op0 = N->getOperand(Num: 0);
26189 if (!llvm::all_equal(Range: N->op_values()) || Op0.getOpcode() != ISD::SPLAT_VECTOR)
26190 return SDValue();
26191
26192 return DAG.getNode(Opcode: ISD::SPLAT_VECTOR, DL: SDLoc(N), VT, Operand: Op0.getOperand(i: 0));
26193}
26194
26195SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
26196 // If we only have one input vector, we don't need to do any concatenation.
26197 if (N->getNumOperands() == 1)
26198 return N->getOperand(Num: 0);
26199
26200 // Check if all of the operands are undefs.
26201 EVT VT = N->getValueType(ResNo: 0);
26202 if (ISD::allOperandsUndef(N))
26203 return DAG.getUNDEF(VT);
26204
26205 // Optimize concat_vectors where all but the first of the vectors are undef.
26206 if (all_of(Range: drop_begin(RangeOrContainer: N->ops()),
26207 P: [](const SDValue &Op) { return Op.isUndef(); })) {
26208 SDValue In = N->getOperand(Num: 0);
26209 assert(In.getValueType().isVector() && "Must concat vectors");
26210
26211 // If the input is a concat_vectors, just make a larger concat by padding
26212 // with smaller undefs.
26213 //
26214 // Legalizing in AArch64TargetLowering::LowerCONCAT_VECTORS() and combining
26215 // here could cause an infinite loop. That legalizing happens when LegalDAG
26216 // is true and input of AArch64TargetLowering::LowerCONCAT_VECTORS() is
26217 // scalable.
26218 if (In.getOpcode() == ISD::CONCAT_VECTORS && In.hasOneUse() &&
26219 !(LegalDAG && In.getValueType().isScalableVector())) {
26220 unsigned NumOps = N->getNumOperands() * In.getNumOperands();
26221 SmallVector<SDValue, 4> Ops(In->ops());
26222 Ops.resize(N: NumOps, NV: DAG.getPOISON(VT: Ops[0].getValueType()));
26223 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT, Ops);
26224 }
26225
26226 SDValue Scalar = peekThroughOneUseBitcasts(V: In);
26227
26228 // concat_vectors(scalar_to_vector(scalar), undef) ->
26229 // scalar_to_vector(scalar)
26230 if (!LegalOperations && Scalar.getOpcode() == ISD::SCALAR_TO_VECTOR &&
26231 Scalar.hasOneUse()) {
26232 EVT SVT = Scalar.getValueType().getVectorElementType();
26233 if (SVT == Scalar.getOperand(i: 0).getValueType())
26234 Scalar = Scalar.getOperand(i: 0);
26235 }
26236
26237 // concat_vectors(scalar, undef) -> scalar_to_vector(scalar)
26238 if (!Scalar.getValueType().isVector() && In.hasOneUse()) {
26239 // If the bitcast type isn't legal, it might be a trunc of a legal type;
26240 // look through the trunc so we can still do the transform:
26241 // concat_vectors(trunc(scalar), undef) -> scalar_to_vector(scalar)
26242 if (Scalar->getOpcode() == ISD::TRUNCATE &&
26243 !TLI.isTypeLegal(VT: Scalar.getValueType()) &&
26244 TLI.isTypeLegal(VT: Scalar->getOperand(Num: 0).getValueType()))
26245 Scalar = Scalar->getOperand(Num: 0);
26246
26247 EVT SclTy = Scalar.getValueType();
26248
26249 if (!SclTy.isFloatingPoint() && !SclTy.isInteger())
26250 return SDValue();
26251
26252 // Bail out if the vector size is not a multiple of the scalar size.
26253 if (VT.getSizeInBits() % SclTy.getSizeInBits())
26254 return SDValue();
26255
26256 unsigned VNTNumElms = VT.getSizeInBits() / SclTy.getSizeInBits();
26257 if (VNTNumElms < 2)
26258 return SDValue();
26259
26260 EVT NVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: SclTy, NumElements: VNTNumElms);
26261 if (!TLI.isTypeLegal(VT: NVT) || !TLI.isTypeLegal(VT: Scalar.getValueType()))
26262 return SDValue();
26263
26264 SDValue Res = DAG.getNode(Opcode: ISD::SCALAR_TO_VECTOR, DL: SDLoc(N), VT: NVT, Operand: Scalar);
26265 return DAG.getBitcast(VT, V: Res);
26266 }
26267 }
26268
26269 // Fold any combination of BUILD_VECTOR or UNDEF nodes into one BUILD_VECTOR.
26270 // We have already tested above for an UNDEF only concatenation.
26271 // fold (concat_vectors (BUILD_VECTOR A, B, ...), (BUILD_VECTOR C, D, ...))
26272 // -> (BUILD_VECTOR A, B, ..., C, D, ...)
26273 auto IsBuildVectorOrUndef = [](const SDValue &Op) {
26274 return Op.isUndef() || ISD::BUILD_VECTOR == Op.getOpcode();
26275 };
26276 if (llvm::all_of(Range: N->ops(), P: IsBuildVectorOrUndef)) {
26277 SmallVector<SDValue, 8> Opnds;
26278 EVT SVT = VT.getScalarType();
26279
26280 EVT MinVT = SVT;
26281 if (!SVT.isFloatingPoint()) {
26282 // If BUILD_VECTOR are from built from integer, they may have different
26283 // operand types. Get the smallest type and truncate all operands to it.
26284 bool FoundMinVT = false;
26285 for (const SDValue &Op : N->ops())
26286 if (ISD::BUILD_VECTOR == Op.getOpcode()) {
26287 EVT OpSVT = Op.getOperand(i: 0).getValueType();
26288 MinVT = (!FoundMinVT || OpSVT.bitsLE(VT: MinVT)) ? OpSVT : MinVT;
26289 FoundMinVT = true;
26290 }
26291 assert(FoundMinVT && "Concat vector type mismatch");
26292 }
26293
26294 for (const SDValue &Op : N->ops()) {
26295 EVT OpVT = Op.getValueType();
26296 unsigned NumElts = OpVT.getVectorNumElements();
26297
26298 if (Op.isUndef())
26299 Opnds.append(NumInputs: NumElts, Elt: DAG.getPOISON(VT: MinVT));
26300
26301 if (ISD::BUILD_VECTOR == Op.getOpcode()) {
26302 if (SVT.isFloatingPoint()) {
26303 assert(SVT == OpVT.getScalarType() && "Concat vector type mismatch");
26304 Opnds.append(in_start: Op->op_begin(), in_end: Op->op_begin() + NumElts);
26305 } else {
26306 for (unsigned i = 0; i != NumElts; ++i)
26307 Opnds.push_back(
26308 Elt: DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(N), VT: MinVT, Operand: Op.getOperand(i)));
26309 }
26310 }
26311 }
26312
26313 assert(VT.getVectorNumElements() == Opnds.size() &&
26314 "Concat vector type mismatch");
26315 return DAG.getBuildVector(VT, DL: SDLoc(N), Ops: Opnds);
26316 }
26317
26318 if (SDValue V =
26319 combineConcatVectorOfSplats(N, DAG, TLI, LegalTypes, LegalOperations))
26320 return V;
26321
26322 // Fold CONCAT_VECTORS of only bitcast scalars (or undef) to BUILD_VECTOR.
26323 // FIXME: Add support for concat_vectors(bitcast(vec0),bitcast(vec1),...).
26324 if (SDValue V = combineConcatVectorOfScalars(N, DAG))
26325 return V;
26326
26327 if (Level <= AfterLegalizeVectorOps && TLI.isTypeLegal(VT)) {
26328 // Fold CONCAT_VECTORS of CONCAT_VECTORS (or undef) to VECTOR_SHUFFLE.
26329 if (SDValue V = combineConcatVectorOfConcatVectors(N, DAG))
26330 return V;
26331
26332 // Fold CONCAT_VECTORS of EXTRACT_SUBVECTOR (or undef) to VECTOR_SHUFFLE.
26333 if (SDValue V = combineConcatVectorOfExtracts(N, DAG))
26334 return V;
26335 }
26336
26337 if (SDValue V = combineConcatVectorOfCasts(N, DAG))
26338 return V;
26339
26340 if (SDValue V = combineConcatVectorOfShuffleAndItsOperands(
26341 N, DAG, TLI, LegalTypes, LegalOperations))
26342 return V;
26343
26344 // Type legalization of vectors and DAG canonicalization of SHUFFLE_VECTOR
26345 // nodes often generate nop CONCAT_VECTOR nodes. Scan the CONCAT_VECTOR
26346 // operands and look for a CONCAT operations that place the incoming vectors
26347 // at the exact same location.
26348 //
26349 // For scalable vectors, EXTRACT_SUBVECTOR indexes are implicitly scaled.
26350 SDValue SingleSource = SDValue();
26351 unsigned PartNumElem =
26352 N->getOperand(Num: 0).getValueType().getVectorMinNumElements();
26353
26354 for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
26355 SDValue Op = N->getOperand(Num: i);
26356
26357 if (Op.isUndef())
26358 continue;
26359
26360 // Check if this is the identity extract:
26361 if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
26362 return SDValue();
26363
26364 // Find the single incoming vector for the extract_subvector.
26365 if (SingleSource.getNode()) {
26366 if (Op.getOperand(i: 0) != SingleSource)
26367 return SDValue();
26368 } else {
26369 SingleSource = Op.getOperand(i: 0);
26370
26371 // Check the source type is the same as the type of the result.
26372 // If not, this concat may extend the vector, so we can not
26373 // optimize it away.
26374 if (SingleSource.getValueType() != N->getValueType(ResNo: 0))
26375 return SDValue();
26376 }
26377
26378 // Check that we are reading from the identity index.
26379 unsigned IdentityIndex = i * PartNumElem;
26380 if (Op.getConstantOperandAPInt(i: 1) != IdentityIndex)
26381 return SDValue();
26382 }
26383
26384 if (SingleSource.getNode())
26385 return SingleSource;
26386
26387 return SDValue();
26388}
26389
26390SDValue DAGCombiner::visitVECTOR_INTERLEAVE(SDNode *N) {
26391 // Check to see if all operands are identical.
26392 if (!llvm::all_equal(Range: N->op_values()))
26393 return SDValue();
26394
26395 // Check to see if the identical operand is a splat.
26396 if (!DAG.isSplatValue(V: N->getOperand(Num: 0)))
26397 return SDValue();
26398
26399 // interleave splat(X), splat(X).... --> splat(X), splat(X)....
26400 SmallVector<SDValue, 4> Ops;
26401 Ops.append(in_start: N->op_values().begin(), in_end: N->op_values().end());
26402 return CombineTo(N, To: &Ops);
26403}
26404
26405// Helper that peeks through INSERT_SUBVECTOR/CONCAT_VECTORS to find
26406// if the subvector can be sourced for free.
26407static SDValue getSubVectorSrc(SDValue V, unsigned Index, EVT SubVT) {
26408 if (V.getOpcode() == ISD::INSERT_SUBVECTOR &&
26409 V.getOperand(i: 1).getValueType() == SubVT &&
26410 V.getConstantOperandAPInt(i: 2) == Index) {
26411 return V.getOperand(i: 1);
26412 }
26413 if (V.getOpcode() == ISD::CONCAT_VECTORS &&
26414 V.getOperand(i: 0).getValueType() == SubVT &&
26415 (Index % SubVT.getVectorMinNumElements()) == 0) {
26416 uint64_t SubIdx = Index / SubVT.getVectorMinNumElements();
26417 return V.getOperand(i: SubIdx);
26418 }
26419 return SDValue();
26420}
26421
26422static SDValue narrowInsertExtractVectorBinOp(EVT SubVT, SDValue BinOp,
26423 unsigned Index, const SDLoc &DL,
26424 SelectionDAG &DAG,
26425 bool LegalOperations) {
26426 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
26427 unsigned BinOpcode = BinOp.getOpcode();
26428 if (!TLI.isBinOp(Opcode: BinOpcode) || BinOp->getNumValues() != 1)
26429 return SDValue();
26430
26431 EVT VecVT = BinOp.getValueType();
26432 SDValue Bop0 = BinOp.getOperand(i: 0), Bop1 = BinOp.getOperand(i: 1);
26433 if (VecVT != Bop0.getValueType() || VecVT != Bop1.getValueType())
26434 return SDValue();
26435 if (!TLI.isOperationLegalOrCustom(Op: BinOpcode, VT: SubVT, LegalOnly: LegalOperations))
26436 return SDValue();
26437
26438 SDValue Sub0 = getSubVectorSrc(V: Bop0, Index, SubVT);
26439 SDValue Sub1 = getSubVectorSrc(V: Bop1, Index, SubVT);
26440
26441 // TODO: We could handle the case where only 1 operand is being inserted by
26442 // creating an extract of the other operand, but that requires checking
26443 // number of uses and/or costs.
26444 if (!Sub0 || !Sub1)
26445 return SDValue();
26446
26447 // We are inserting both operands of the wide binop only to extract back
26448 // to the narrow vector size. Eliminate all of the insert/extract:
26449 // ext (binop (ins ?, X, Index), (ins ?, Y, Index)), Index --> binop X, Y
26450 return DAG.getNode(Opcode: BinOpcode, DL, VT: SubVT, N1: Sub0, N2: Sub1, Flags: BinOp->getFlags());
26451}
26452
26453/// If we are extracting a subvector produced by a wide binary operator try
26454/// to use a narrow binary operator and/or avoid concatenation and extraction.
26455static SDValue narrowExtractedVectorBinOp(EVT VT, SDValue Src, unsigned Index,
26456 const SDLoc &DL, SelectionDAG &DAG,
26457 bool LegalOperations) {
26458 // TODO: Refactor with the caller (visitEXTRACT_SUBVECTOR), so we can share
26459 // some of these bailouts with other transforms.
26460
26461 if (SDValue V = narrowInsertExtractVectorBinOp(SubVT: VT, BinOp: Src, Index, DL, DAG,
26462 LegalOperations))
26463 return V;
26464
26465 // We are looking for an optionally bitcasted wide vector binary operator
26466 // feeding an extract subvector.
26467 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
26468 SDValue BinOp = peekThroughBitcasts(V: Src);
26469 unsigned BOpcode = BinOp.getOpcode();
26470 if (!TLI.isBinOp(Opcode: BOpcode) || BinOp->getNumValues() != 1)
26471 return SDValue();
26472
26473 // Exclude the fake form of fneg (fsub -0.0, x) because that is likely to be
26474 // reduced to the unary fneg when it is visited, and we probably want to deal
26475 // with fneg in a target-specific way.
26476 if (BOpcode == ISD::FSUB) {
26477 auto *C = isConstOrConstSplatFP(N: BinOp.getOperand(i: 0), /*AllowUndefs*/ true);
26478 if (C && C->getValueAPF().isNegZero())
26479 return SDValue();
26480 }
26481
26482 // The binop must be a vector type, so we can extract some fraction of it.
26483 EVT WideBVT = BinOp.getValueType();
26484 // The optimisations below currently assume we are dealing with fixed length
26485 // vectors. It is possible to add support for scalable vectors, but at the
26486 // moment we've done no analysis to prove whether they are profitable or not.
26487 if (!WideBVT.isFixedLengthVector())
26488 return SDValue();
26489
26490 assert((Index % VT.getVectorNumElements()) == 0 &&
26491 "Extract index is not a multiple of the vector length.");
26492
26493 // Bail out if this is not a proper multiple width extraction.
26494 unsigned WideWidth = WideBVT.getSizeInBits();
26495 unsigned NarrowWidth = VT.getSizeInBits();
26496 if (WideWidth % NarrowWidth != 0)
26497 return SDValue();
26498
26499 // Bail out if we are extracting a fraction of a single operation. This can
26500 // occur because we potentially looked through a bitcast of the binop.
26501 unsigned NarrowingRatio = WideWidth / NarrowWidth;
26502 unsigned WideNumElts = WideBVT.getVectorNumElements();
26503 if (WideNumElts % NarrowingRatio != 0)
26504 return SDValue();
26505
26506 // Bail out if the target does not support a narrower version of the binop.
26507 EVT NarrowBVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: WideBVT.getScalarType(),
26508 NumElements: WideNumElts / NarrowingRatio);
26509 if (!TLI.isOperationLegalOrCustomOrPromote(Op: BOpcode, VT: NarrowBVT,
26510 LegalOnly: LegalOperations))
26511 return SDValue();
26512
26513 // If extraction is cheap, we don't need to look at the binop operands
26514 // for concat ops. The narrow binop alone makes this transform profitable.
26515 // We can't just reuse the original extract index operand because we may have
26516 // bitcasted.
26517 unsigned ConcatOpNum = Index / VT.getVectorNumElements();
26518 unsigned ExtBOIdx = ConcatOpNum * NarrowBVT.getVectorNumElements();
26519 if (TLI.isExtractSubvectorCheap(ResVT: NarrowBVT, SrcVT: WideBVT, Index: ExtBOIdx) &&
26520 BinOp.hasOneUse() && Src->hasOneUse()) {
26521 // extract (binop B0, B1), N --> binop (extract B0, N), (extract B1, N)
26522 SDValue NewExtIndex = DAG.getVectorIdxConstant(Val: ExtBOIdx, DL);
26523 SDValue X = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NarrowBVT,
26524 N1: BinOp.getOperand(i: 0), N2: NewExtIndex);
26525 SDValue Y = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NarrowBVT,
26526 N1: BinOp.getOperand(i: 1), N2: NewExtIndex);
26527 SDValue NarrowBinOp =
26528 DAG.getNode(Opcode: BOpcode, DL, VT: NarrowBVT, N1: X, N2: Y, Flags: BinOp->getFlags());
26529 return DAG.getBitcast(VT, V: NarrowBinOp);
26530 }
26531
26532 // Only handle the case where we are doubling and then halving. A larger ratio
26533 // may require more than two narrow binops to replace the wide binop.
26534 if (NarrowingRatio != 2)
26535 return SDValue();
26536
26537 // TODO: The motivating case for this transform is an x86 AVX1 target. That
26538 // target has temptingly almost legal versions of bitwise logic ops in 256-bit
26539 // flavors, but no other 256-bit integer support. This could be extended to
26540 // handle any binop, but that may require fixing/adding other folds to avoid
26541 // codegen regressions.
26542 if (BOpcode != ISD::AND && BOpcode != ISD::OR && BOpcode != ISD::XOR)
26543 return SDValue();
26544
26545 // We need at least one concatenation operation of a binop operand to make
26546 // this transform worthwhile. The concat must double the input vector sizes.
26547 auto GetSubVector = [ConcatOpNum](SDValue V) -> SDValue {
26548 if (V.getOpcode() == ISD::CONCAT_VECTORS && V.getNumOperands() == 2)
26549 return V.getOperand(i: ConcatOpNum);
26550 return SDValue();
26551 };
26552 SDValue SubVecL = GetSubVector(peekThroughBitcasts(V: BinOp.getOperand(i: 0)));
26553 SDValue SubVecR = GetSubVector(peekThroughBitcasts(V: BinOp.getOperand(i: 1)));
26554
26555 if (SubVecL || SubVecR) {
26556 // If a binop operand was not the result of a concat, we must extract a
26557 // half-sized operand for our new narrow binop:
26558 // extract (binop (concat X1, X2), (concat Y1, Y2)), N --> binop XN, YN
26559 // extract (binop (concat X1, X2), Y), N --> binop XN, (extract Y, IndexC)
26560 // extract (binop X, (concat Y1, Y2)), N --> binop (extract X, IndexC), YN
26561 SDValue IndexC = DAG.getVectorIdxConstant(Val: ExtBOIdx, DL);
26562 SDValue X = SubVecL ? DAG.getBitcast(VT: NarrowBVT, V: SubVecL)
26563 : DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NarrowBVT,
26564 N1: BinOp.getOperand(i: 0), N2: IndexC);
26565
26566 SDValue Y = SubVecR ? DAG.getBitcast(VT: NarrowBVT, V: SubVecR)
26567 : DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NarrowBVT,
26568 N1: BinOp.getOperand(i: 1), N2: IndexC);
26569
26570 SDValue NarrowBinOp = DAG.getNode(Opcode: BOpcode, DL, VT: NarrowBVT, N1: X, N2: Y);
26571 return DAG.getBitcast(VT, V: NarrowBinOp);
26572 }
26573
26574 return SDValue();
26575}
26576
26577/// If we are extracting a subvector from a wide vector load, convert to a
26578/// narrow load to eliminate the extraction:
26579/// (extract_subvector (load wide vector)) --> (load narrow vector)
26580static SDValue narrowExtractedVectorLoad(EVT VT, SDValue Src, unsigned Index,
26581 const SDLoc &DL, SelectionDAG &DAG) {
26582 // TODO: Add support for big-endian. The offset calculation must be adjusted.
26583 if (DAG.getDataLayout().isBigEndian())
26584 return SDValue();
26585
26586 auto *Ld = dyn_cast<LoadSDNode>(Val&: Src);
26587 if (!Ld || !ISD::isNormalLoad(N: Ld) || !Ld->isSimple())
26588 return SDValue();
26589
26590 // We can only create byte sized loads.
26591 if (!VT.isByteSized())
26592 return SDValue();
26593
26594 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
26595 if (!TLI.isOperationLegalOrCustomOrPromote(Op: ISD::LOAD, VT))
26596 return SDValue();
26597
26598 unsigned NumElts = VT.getVectorMinNumElements();
26599 // A fixed length vector being extracted from a scalable vector
26600 // may not be any *smaller* than the scalable one.
26601 if (Index == 0 && NumElts >= Ld->getValueType(ResNo: 0).getVectorMinNumElements())
26602 return SDValue();
26603
26604 // The definition of EXTRACT_SUBVECTOR states that the index must be a
26605 // multiple of the minimum number of elements in the result type.
26606 assert(Index % NumElts == 0 && "The extract subvector index is not a "
26607 "multiple of the result's element count");
26608
26609 // It's fine to use TypeSize here as we know the offset will not be negative.
26610 TypeSize Offset = VT.getStoreSize() * (Index / NumElts);
26611 std::optional<unsigned> ByteOffset;
26612 if (Offset.isFixed())
26613 ByteOffset = Offset.getFixedValue();
26614
26615 if (!TLI.shouldReduceLoadWidth(Load: Ld, ExtTy: Ld->getExtensionType(), NewVT: VT, ByteOffset))
26616 return SDValue();
26617
26618 // The narrow load will be offset from the base address of the old load if
26619 // we are extracting from something besides index 0 (little-endian).
26620 // TODO: Use "BaseIndexOffset" to make this more effective.
26621 SDValue NewAddr = DAG.getMemBasePlusOffset(Base: Ld->getBasePtr(), Offset, DL);
26622
26623 MachineFunction &MF = DAG.getMachineFunction();
26624 MachineMemOperand *MMO;
26625 if (Offset.isScalable()) {
26626 MachinePointerInfo MPI =
26627 MachinePointerInfo(Ld->getPointerInfo().getAddrSpace());
26628 MMO = MF.getMachineMemOperand(MMO: Ld->getMemOperand(), PtrInfo: MPI, Size: VT.getStoreSize());
26629 } else
26630 MMO = MF.getMachineMemOperand(MMO: Ld->getMemOperand(), Offset: Offset.getFixedValue(),
26631 Size: VT.getStoreSize());
26632
26633 SDValue NewLd = DAG.getLoad(VT, dl: DL, Chain: Ld->getChain(), Ptr: NewAddr, MMO);
26634 DAG.makeEquivalentMemoryOrdering(OldLoad: Ld, NewMemOp: NewLd);
26635 return NewLd;
26636}
26637
26638/// Given EXTRACT_SUBVECTOR(VECTOR_SHUFFLE(Op0, Op1, Mask)),
26639/// try to produce VECTOR_SHUFFLE(EXTRACT_SUBVECTOR(Op?, ?),
26640/// EXTRACT_SUBVECTOR(Op?, ?),
26641/// Mask'))
26642/// iff it is legal and profitable to do so. Notably, the trimmed mask
26643/// (containing only the elements that are extracted)
26644/// must reference at most two subvectors.
26645static SDValue foldExtractSubvectorFromShuffleVector(EVT NarrowVT, SDValue Src,
26646 unsigned Index,
26647 const SDLoc &DL,
26648 SelectionDAG &DAG,
26649 bool LegalOperations) {
26650 // Only deal with non-scalable vectors.
26651 EVT WideVT = Src.getValueType();
26652 if (!NarrowVT.isFixedLengthVector() || !WideVT.isFixedLengthVector())
26653 return SDValue();
26654
26655 // The operand must be a shufflevector.
26656 auto *WideShuffleVector = dyn_cast<ShuffleVectorSDNode>(Val&: Src);
26657 if (!WideShuffleVector)
26658 return SDValue();
26659
26660 // The old shuffleneeds to go away.
26661 if (!WideShuffleVector->hasOneUse())
26662 return SDValue();
26663
26664 // And the narrow shufflevector that we'll form must be legal.
26665 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
26666 if (LegalOperations &&
26667 !TLI.isOperationLegalOrCustom(Op: ISD::VECTOR_SHUFFLE, VT: NarrowVT))
26668 return SDValue();
26669
26670 int NumEltsExtracted = NarrowVT.getVectorNumElements();
26671 assert((Index % NumEltsExtracted) == 0 &&
26672 "Extract index is not a multiple of the output vector length.");
26673
26674 int WideNumElts = WideVT.getVectorNumElements();
26675
26676 SmallVector<int, 16> NewMask;
26677 NewMask.reserve(N: NumEltsExtracted);
26678 SmallSetVector<std::pair<SDValue /*Op*/, int /*SubvectorIndex*/>, 2>
26679 DemandedSubvectors;
26680
26681 // Try to decode the wide mask into narrow mask from at most two subvectors.
26682 for (int M : WideShuffleVector->getMask().slice(N: Index, M: NumEltsExtracted)) {
26683 assert((M >= -1) && (M < (2 * WideNumElts)) &&
26684 "Out-of-bounds shuffle mask?");
26685
26686 if (M < 0) {
26687 // Does not depend on operands, does not require adjustment.
26688 NewMask.emplace_back(Args&: M);
26689 continue;
26690 }
26691
26692 // From which operand of the shuffle does this shuffle mask element pick?
26693 int WideShufOpIdx = M / WideNumElts;
26694 // Which element of that operand is picked?
26695 int OpEltIdx = M % WideNumElts;
26696
26697 assert((OpEltIdx + WideShufOpIdx * WideNumElts) == M &&
26698 "Shuffle mask vector decomposition failure.");
26699
26700 // And which NumEltsExtracted-sized subvector of that operand is that?
26701 int OpSubvecIdx = OpEltIdx / NumEltsExtracted;
26702 // And which element within that subvector of that operand is that?
26703 int OpEltIdxInSubvec = OpEltIdx % NumEltsExtracted;
26704
26705 assert((OpEltIdxInSubvec + OpSubvecIdx * NumEltsExtracted) == OpEltIdx &&
26706 "Shuffle mask subvector decomposition failure.");
26707
26708 assert((OpEltIdxInSubvec + OpSubvecIdx * NumEltsExtracted +
26709 WideShufOpIdx * WideNumElts) == M &&
26710 "Shuffle mask full decomposition failure.");
26711
26712 SDValue Op = WideShuffleVector->getOperand(Num: WideShufOpIdx);
26713
26714 if (Op.isUndef()) {
26715 // Picking from an undef operand. Let's adjust mask instead.
26716 NewMask.emplace_back(Args: -1);
26717 continue;
26718 }
26719
26720 const std::pair<SDValue, int> DemandedSubvector =
26721 std::make_pair(x&: Op, y&: OpSubvecIdx);
26722
26723 if (DemandedSubvectors.insert(X: DemandedSubvector)) {
26724 if (DemandedSubvectors.size() > 2)
26725 return SDValue(); // We can't handle more than two subvectors.
26726 // How many elements into the WideVT does this subvector start?
26727 int Index = NumEltsExtracted * OpSubvecIdx;
26728 // Bail out if the extraction isn't going to be cheap.
26729 if (!TLI.isExtractSubvectorCheap(ResVT: NarrowVT, SrcVT: WideVT, Index))
26730 return SDValue();
26731 }
26732
26733 // Ok, but from which operand of the new shuffle will this element pick?
26734 int NewOpIdx =
26735 getFirstIndexOf(Range: DemandedSubvectors.getArrayRef(), Val: DemandedSubvector);
26736 assert((NewOpIdx == 0 || NewOpIdx == 1) && "Unexpected operand index.");
26737
26738 int AdjM = OpEltIdxInSubvec + NewOpIdx * NumEltsExtracted;
26739 NewMask.emplace_back(Args&: AdjM);
26740 }
26741 assert(NewMask.size() == (unsigned)NumEltsExtracted && "Produced bad mask.");
26742 assert(DemandedSubvectors.size() <= 2 &&
26743 "Should have ended up demanding at most two subvectors.");
26744
26745 // Did we discover that the shuffle does not actually depend on operands?
26746 if (DemandedSubvectors.empty())
26747 return DAG.getPOISON(VT: NarrowVT);
26748
26749 // Profitability check: only deal with extractions from the first subvector
26750 // unless the mask becomes an identity mask.
26751 if (!ShuffleVectorInst::isIdentityMask(Mask: NewMask, NumSrcElts: NewMask.size()) ||
26752 any_of(Range&: NewMask, P: [](int M) { return M < 0; }))
26753 for (auto &DemandedSubvector : DemandedSubvectors)
26754 if (DemandedSubvector.second != 0)
26755 return SDValue();
26756
26757 // We still perform the exact same EXTRACT_SUBVECTOR, just on different
26758 // operand[s]/index[es], so there is no point in checking for it's legality.
26759
26760 // Do not turn a legal shuffle into an illegal one.
26761 if (TLI.isShuffleMaskLegal(WideShuffleVector->getMask(), WideVT) &&
26762 !TLI.isShuffleMaskLegal(NewMask, NarrowVT))
26763 return SDValue();
26764
26765 SmallVector<SDValue, 2> NewOps;
26766 for (const std::pair<SDValue /*Op*/, int /*SubvectorIndex*/>
26767 &DemandedSubvector : DemandedSubvectors) {
26768 // How many elements into the WideVT does this subvector start?
26769 int Index = NumEltsExtracted * DemandedSubvector.second;
26770 SDValue IndexC = DAG.getVectorIdxConstant(Val: Index, DL);
26771 NewOps.emplace_back(Args: DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NarrowVT,
26772 N1: DemandedSubvector.first, N2: IndexC));
26773 }
26774 assert((NewOps.size() == 1 || NewOps.size() == 2) &&
26775 "Should end up with either one or two ops");
26776
26777 // If we ended up with only one operand, pad with poison.
26778 if (NewOps.size() == 1)
26779 NewOps.emplace_back(Args: DAG.getPOISON(VT: NarrowVT));
26780
26781 return DAG.getVectorShuffle(VT: NarrowVT, dl: DL, N1: NewOps[0], N2: NewOps[1], Mask: NewMask);
26782}
26783
26784SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) {
26785 EVT NVT = N->getValueType(ResNo: 0);
26786 SDValue V = N->getOperand(Num: 0);
26787 uint64_t ExtIdx = N->getConstantOperandVal(Num: 1);
26788 SDLoc DL(N);
26789
26790 // Extract from UNDEF is UNDEF.
26791 if (V.isUndef())
26792 return DAG.getUNDEF(VT: NVT);
26793
26794 if (SDValue NarrowLoad = narrowExtractedVectorLoad(VT: NVT, Src: V, Index: ExtIdx, DL, DAG))
26795 return NarrowLoad;
26796
26797 // Combine an extract of an extract into a single extract_subvector.
26798 // ext (ext X, C), 0 --> ext X, C
26799 if (ExtIdx == 0 && V.getOpcode() == ISD::EXTRACT_SUBVECTOR && V.hasOneUse()) {
26800 // The index has to be a multiple of the new result type's known minimum
26801 // vector length.
26802 if (V.getConstantOperandVal(i: 1) % NVT.getVectorMinNumElements() == 0 &&
26803 TLI.isExtractSubvectorCheap(ResVT: NVT, SrcVT: V.getOperand(i: 0).getValueType(),
26804 Index: V.getConstantOperandVal(i: 1)) &&
26805 TLI.isOperationLegalOrCustom(Op: ISD::EXTRACT_SUBVECTOR, VT: NVT)) {
26806 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NVT, N1: V.getOperand(i: 0),
26807 N2: V.getOperand(i: 1));
26808 }
26809 }
26810
26811 // ty1 extract_vector(ty2 splat(V))) -> ty1 splat(V)
26812 if (V.getOpcode() == ISD::SPLAT_VECTOR)
26813 if (DAG.isConstantValueOfAnyType(N: V.getOperand(i: 0)) || V.hasOneUse())
26814 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::SPLAT_VECTOR, VT: NVT))
26815 return DAG.getSplatVector(VT: NVT, DL, Op: V.getOperand(i: 0));
26816
26817 // extract_subvector(insert_subvector(x,y,c1),c2)
26818 // --> extract_subvector(y,c2-c1)
26819 // iff we're just extracting from the inserted subvector.
26820 if (V.getOpcode() == ISD::INSERT_SUBVECTOR) {
26821 SDValue InsSub = V.getOperand(i: 1);
26822 EVT InsSubVT = InsSub.getValueType();
26823 unsigned NumInsElts = InsSubVT.getVectorMinNumElements();
26824 unsigned InsIdx = V.getConstantOperandVal(i: 2);
26825 unsigned NumSubElts = NVT.getVectorMinNumElements();
26826 if (InsIdx <= ExtIdx && (ExtIdx + NumSubElts) <= (InsIdx + NumInsElts) &&
26827 TLI.isExtractSubvectorCheap(ResVT: NVT, SrcVT: InsSubVT, Index: ExtIdx - InsIdx) &&
26828 InsSubVT.isFixedLengthVector() && NVT.isFixedLengthVector() &&
26829 V.getValueType().isFixedLengthVector())
26830 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NVT, N1: InsSub,
26831 N2: DAG.getVectorIdxConstant(Val: ExtIdx - InsIdx, DL));
26832 }
26833
26834 // Try to move vector bitcast after extract_subv by scaling extraction index:
26835 // extract_subv (bitcast X), Index --> bitcast (extract_subv X, Index')
26836 if (V.getOpcode() == ISD::BITCAST &&
26837 V.getOperand(i: 0).getValueType().isVector() &&
26838 (!LegalOperations || TLI.isOperationLegal(Op: ISD::BITCAST, VT: NVT))) {
26839 SDValue SrcOp = V.getOperand(i: 0);
26840 EVT SrcVT = SrcOp.getValueType();
26841 unsigned SrcNumElts = SrcVT.getVectorMinNumElements();
26842 unsigned DestNumElts = V.getValueType().getVectorMinNumElements();
26843 if ((SrcNumElts % DestNumElts) == 0) {
26844 unsigned SrcDestRatio = SrcNumElts / DestNumElts;
26845 ElementCount NewExtEC = NVT.getVectorElementCount() * SrcDestRatio;
26846 EVT NewExtVT =
26847 EVT::getVectorVT(Context&: *DAG.getContext(), VT: SrcVT.getScalarType(), EC: NewExtEC);
26848 if (TLI.isOperationLegalOrCustom(Op: ISD::EXTRACT_SUBVECTOR, VT: NewExtVT)) {
26849 SDValue NewIndex = DAG.getVectorIdxConstant(Val: ExtIdx * SrcDestRatio, DL);
26850 SDValue NewExtract = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NewExtVT,
26851 N1: V.getOperand(i: 0), N2: NewIndex);
26852 return DAG.getBitcast(VT: NVT, V: NewExtract);
26853 }
26854 }
26855 if ((DestNumElts % SrcNumElts) == 0) {
26856 unsigned DestSrcRatio = DestNumElts / SrcNumElts;
26857 if (NVT.getVectorElementCount().isKnownMultipleOf(RHS: DestSrcRatio)) {
26858 ElementCount NewExtEC =
26859 NVT.getVectorElementCount().divideCoefficientBy(RHS: DestSrcRatio);
26860 EVT ScalarVT = SrcVT.getScalarType();
26861 if ((ExtIdx % DestSrcRatio) == 0) {
26862 unsigned IndexValScaled = ExtIdx / DestSrcRatio;
26863 EVT NewExtVT =
26864 EVT::getVectorVT(Context&: *DAG.getContext(), VT: ScalarVT, EC: NewExtEC);
26865 if (TLI.isOperationLegalOrCustom(Op: ISD::EXTRACT_SUBVECTOR, VT: NewExtVT)) {
26866 SDValue NewIndex = DAG.getVectorIdxConstant(Val: IndexValScaled, DL);
26867 SDValue NewExtract =
26868 DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NewExtVT,
26869 N1: V.getOperand(i: 0), N2: NewIndex);
26870 return DAG.getBitcast(VT: NVT, V: NewExtract);
26871 }
26872 if (NewExtEC.isScalar() &&
26873 TLI.isOperationLegalOrCustom(Op: ISD::EXTRACT_VECTOR_ELT, VT: ScalarVT)) {
26874 SDValue NewIndex = DAG.getVectorIdxConstant(Val: IndexValScaled, DL);
26875 SDValue NewExtract =
26876 DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: ScalarVT,
26877 N1: V.getOperand(i: 0), N2: NewIndex);
26878 return DAG.getBitcast(VT: NVT, V: NewExtract);
26879 }
26880 }
26881 }
26882 }
26883 }
26884
26885 if (V.getOpcode() == ISD::CONCAT_VECTORS) {
26886 unsigned ExtNumElts = NVT.getVectorMinNumElements();
26887 EVT ConcatSrcVT = V.getOperand(i: 0).getValueType();
26888 assert(ConcatSrcVT.getVectorElementType() == NVT.getVectorElementType() &&
26889 "Concat and extract subvector do not change element type");
26890
26891 unsigned ConcatSrcNumElts = ConcatSrcVT.getVectorMinNumElements();
26892 unsigned ConcatOpIdx = ExtIdx / ConcatSrcNumElts;
26893
26894 // If the concatenated source types match this extract, it's a direct
26895 // simplification:
26896 // extract_subvec (concat V1, V2, ...), i --> Vi
26897 if (NVT.getVectorElementCount() == ConcatSrcVT.getVectorElementCount())
26898 return V.getOperand(i: ConcatOpIdx);
26899
26900 // If the concatenated source vectors are a multiple length of this extract,
26901 // then extract a fraction of one of those source vectors directly from a
26902 // concat operand. Example:
26903 // v2i8 extract_subvec (v16i8 concat (v8i8 X), (v8i8 Y), 14 -->
26904 // v2i8 extract_subvec v8i8 Y, 6
26905 if (NVT.isFixedLengthVector() && ConcatSrcVT.isFixedLengthVector() &&
26906 ConcatSrcNumElts % ExtNumElts == 0) {
26907 unsigned NewExtIdx = ExtIdx - ConcatOpIdx * ConcatSrcNumElts;
26908 assert(NewExtIdx + ExtNumElts <= ConcatSrcNumElts &&
26909 "Trying to extract from >1 concat operand?");
26910 assert(NewExtIdx % ExtNumElts == 0 &&
26911 "Extract index is not a multiple of the input vector length.");
26912 SDValue NewIndexC = DAG.getVectorIdxConstant(Val: NewExtIdx, DL);
26913 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NVT,
26914 N1: V.getOperand(i: ConcatOpIdx), N2: NewIndexC);
26915 }
26916 }
26917
26918 if (SDValue Shuffle = foldExtractSubvectorFromShuffleVector(
26919 NarrowVT: NVT, Src: V, Index: ExtIdx, DL, DAG, LegalOperations))
26920 return Shuffle;
26921
26922 if (SDValue NarrowBOp =
26923 narrowExtractedVectorBinOp(VT: NVT, Src: V, Index: ExtIdx, DL, DAG, LegalOperations))
26924 return NarrowBOp;
26925
26926 V = peekThroughBitcasts(V);
26927
26928 // If the input is a build vector. Try to make a smaller build vector.
26929 if (V.getOpcode() == ISD::BUILD_VECTOR) {
26930 EVT InVT = V.getValueType();
26931 unsigned ExtractSize = NVT.getSizeInBits();
26932 unsigned EltSize = InVT.getScalarSizeInBits();
26933 // Only do this if we won't split any elements.
26934 if (ExtractSize % EltSize == 0) {
26935 unsigned NumElems = ExtractSize / EltSize;
26936 EVT EltVT = InVT.getVectorElementType();
26937 EVT ExtractVT =
26938 NumElems == 1 ? EltVT
26939 : EVT::getVectorVT(Context&: *DAG.getContext(), VT: EltVT, NumElements: NumElems);
26940 if ((Level < AfterLegalizeDAG ||
26941 (NumElems == 1 ||
26942 TLI.isOperationLegal(Op: ISD::BUILD_VECTOR, VT: ExtractVT))) &&
26943 (!LegalTypes || TLI.isTypeLegal(VT: ExtractVT))) {
26944 unsigned IdxVal = (ExtIdx * NVT.getScalarSizeInBits()) / EltSize;
26945
26946 if (NumElems == 1) {
26947 SDValue Src = V->getOperand(Num: IdxVal);
26948 if (EltVT != Src.getValueType())
26949 Src = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: EltVT, Operand: Src);
26950 return DAG.getBitcast(VT: NVT, V: Src);
26951 }
26952
26953 // Extract the pieces from the original build_vector.
26954 SDValue BuildVec =
26955 DAG.getBuildVector(VT: ExtractVT, DL, Ops: V->ops().slice(N: IdxVal, M: NumElems));
26956 return DAG.getBitcast(VT: NVT, V: BuildVec);
26957 }
26958 }
26959 }
26960
26961 if (V.getOpcode() == ISD::INSERT_SUBVECTOR) {
26962 // Handle only simple case where vector being inserted and vector
26963 // being extracted are of same size.
26964 EVT SmallVT = V.getOperand(i: 1).getValueType();
26965 if (NVT.bitsEq(VT: SmallVT)) {
26966 // Combine:
26967 // (extract_subvec (insert_subvec V1, V2, InsIdx), ExtIdx)
26968 // Into:
26969 // indices are equal or bit offsets are equal => V1
26970 // otherwise => (extract_subvec V1, ExtIdx)
26971 uint64_t InsIdx = V.getConstantOperandVal(i: 2);
26972 if (InsIdx * SmallVT.getScalarSizeInBits() ==
26973 ExtIdx * NVT.getScalarSizeInBits()) {
26974 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::BITCAST, VT: NVT))
26975 return DAG.getBitcast(VT: NVT, V: V.getOperand(i: 1));
26976 } else {
26977 return DAG.getNode(
26978 Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NVT,
26979 N1: DAG.getBitcast(VT: N->getOperand(Num: 0).getValueType(), V: V.getOperand(i: 0)),
26980 N2: N->getOperand(Num: 1));
26981 }
26982 }
26983 }
26984
26985 // If only EXTRACT_SUBVECTOR nodes use the source vector we can
26986 // simplify it based on the (valid) extractions.
26987 if (!V.getValueType().isScalableVector() &&
26988 llvm::all_of(Range: V->users(), P: [&](SDNode *Use) {
26989 return Use->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
26990 Use->getOperand(Num: 0) == V;
26991 })) {
26992 unsigned NumElts = V.getValueType().getVectorNumElements();
26993 APInt DemandedElts = APInt::getZero(numBits: NumElts);
26994 for (SDNode *User : V->users()) {
26995 unsigned ExtIdx = User->getConstantOperandVal(Num: 1);
26996 unsigned NumSubElts = User->getValueType(ResNo: 0).getVectorNumElements();
26997 DemandedElts.setBits(loBit: ExtIdx, hiBit: ExtIdx + NumSubElts);
26998 }
26999 if (SimplifyDemandedVectorElts(Op: V, DemandedElts, /*AssumeSingleUse=*/true)) {
27000 // We simplified the vector operand of this extract subvector. If this
27001 // extract is not dead, visit it again so it is folded properly.
27002 if (N->getOpcode() != ISD::DELETED_NODE)
27003 AddToWorklist(N);
27004 return SDValue(N, 0);
27005 }
27006 } else {
27007 if (SimplifyDemandedVectorElts(Op: SDValue(N, 0)))
27008 return SDValue(N, 0);
27009 }
27010
27011 return SDValue();
27012}
27013
27014/// Try to convert a wide shuffle of concatenated vectors into 2 narrow shuffles
27015/// followed by concatenation. Narrow vector ops may have better performance
27016/// than wide ops, and this can unlock further narrowing of other vector ops.
27017/// Targets can invert this transform later if it is not profitable.
27018static SDValue foldShuffleOfConcatUndefs(ShuffleVectorSDNode *Shuf,
27019 SelectionDAG &DAG) {
27020 SDValue N0 = Shuf->getOperand(Num: 0), N1 = Shuf->getOperand(Num: 1);
27021 if (N0.getOpcode() != ISD::CONCAT_VECTORS || N0.getNumOperands() != 2 ||
27022 N1.getOpcode() != ISD::CONCAT_VECTORS || N1.getNumOperands() != 2 ||
27023 !N0.getOperand(i: 1).isUndef() || !N1.getOperand(i: 1).isUndef())
27024 return SDValue();
27025
27026 // Split the wide shuffle mask into halves. Any mask element that is accessing
27027 // operand 1 is offset down to account for narrowing of the vectors.
27028 ArrayRef<int> Mask = Shuf->getMask();
27029 EVT VT = Shuf->getValueType(ResNo: 0);
27030 unsigned NumElts = VT.getVectorNumElements();
27031 unsigned HalfNumElts = NumElts / 2;
27032 SmallVector<int, 16> Mask0(HalfNumElts, -1);
27033 SmallVector<int, 16> Mask1(HalfNumElts, -1);
27034 for (unsigned i = 0; i != NumElts; ++i) {
27035 if (Mask[i] == -1)
27036 continue;
27037 // If we reference the upper (undef) subvector then the element is undef.
27038 if ((Mask[i] % NumElts) >= HalfNumElts)
27039 continue;
27040 int M = Mask[i] < (int)NumElts ? Mask[i] : Mask[i] - (int)HalfNumElts;
27041 if (i < HalfNumElts)
27042 Mask0[i] = M;
27043 else
27044 Mask1[i - HalfNumElts] = M;
27045 }
27046
27047 // Ask the target if this is a valid transform.
27048 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
27049 EVT HalfVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: VT.getScalarType(),
27050 NumElements: HalfNumElts);
27051 if (!TLI.isShuffleMaskLegal(Mask0, HalfVT) ||
27052 !TLI.isShuffleMaskLegal(Mask1, HalfVT))
27053 return SDValue();
27054
27055 // shuffle (concat X, undef), (concat Y, undef), Mask -->
27056 // concat (shuffle X, Y, Mask0), (shuffle X, Y, Mask1)
27057 SDValue X = N0.getOperand(i: 0), Y = N1.getOperand(i: 0);
27058 SDLoc DL(Shuf);
27059 SDValue Shuf0 = DAG.getVectorShuffle(VT: HalfVT, dl: DL, N1: X, N2: Y, Mask: Mask0);
27060 SDValue Shuf1 = DAG.getVectorShuffle(VT: HalfVT, dl: DL, N1: X, N2: Y, Mask: Mask1);
27061 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT, N1: Shuf0, N2: Shuf1);
27062}
27063
27064// Tries to turn a shuffle of two CONCAT_VECTORS into a single concat,
27065// or turn a shuffle of a single concat into simpler shuffle then concat.
27066static SDValue partitionShuffleOfConcats(SDNode *N, SelectionDAG &DAG) {
27067 EVT VT = N->getValueType(ResNo: 0);
27068 unsigned NumElts = VT.getVectorNumElements();
27069
27070 SDValue N0 = N->getOperand(Num: 0);
27071 SDValue N1 = N->getOperand(Num: 1);
27072 ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Val: N);
27073 ArrayRef<int> Mask = SVN->getMask();
27074
27075 SmallVector<SDValue, 4> Ops;
27076 EVT ConcatVT = N0.getOperand(i: 0).getValueType();
27077 unsigned NumElemsPerConcat = ConcatVT.getVectorNumElements();
27078 unsigned NumConcats = NumElts / NumElemsPerConcat;
27079
27080 auto IsUndefMaskElt = [](int i) { return i == -1; };
27081
27082 // Special case: shuffle(concat(A,B)) can be more efficiently represented
27083 // as concat(shuffle(A,B),UNDEF) if the shuffle doesn't set any of the high
27084 // half vector elements.
27085 if (NumElemsPerConcat * 2 == NumElts && N1.isUndef() &&
27086 llvm::all_of(Range: Mask.slice(N: NumElemsPerConcat, M: NumElemsPerConcat),
27087 P: IsUndefMaskElt)) {
27088 N0 = DAG.getVectorShuffle(VT: ConcatVT, dl: SDLoc(N), N1: N0.getOperand(i: 0),
27089 N2: N0.getOperand(i: 1),
27090 Mask: Mask.slice(N: 0, M: NumElemsPerConcat));
27091 N1 = DAG.getPOISON(VT: ConcatVT);
27092 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT, N1: N0, N2: N1);
27093 }
27094
27095 // Look at every vector that's inserted. We're looking for exact
27096 // subvector-sized copies from a concatenated vector
27097 for (unsigned I = 0; I != NumConcats; ++I) {
27098 unsigned Begin = I * NumElemsPerConcat;
27099 ArrayRef<int> SubMask = Mask.slice(N: Begin, M: NumElemsPerConcat);
27100
27101 // Make sure we're dealing with a copy.
27102 if (llvm::all_of(Range&: SubMask, P: IsUndefMaskElt)) {
27103 Ops.push_back(Elt: DAG.getUNDEF(VT: ConcatVT));
27104 continue;
27105 }
27106
27107 int OpIdx = -1;
27108 for (int i = 0; i != (int)NumElemsPerConcat; ++i) {
27109 if (IsUndefMaskElt(SubMask[i]))
27110 continue;
27111 if ((SubMask[i] % (int)NumElemsPerConcat) != i)
27112 return SDValue();
27113 int EltOpIdx = SubMask[i] / NumElemsPerConcat;
27114 if (0 <= OpIdx && EltOpIdx != OpIdx)
27115 return SDValue();
27116 OpIdx = EltOpIdx;
27117 }
27118 assert(0 <= OpIdx && "Unknown concat_vectors op");
27119
27120 if (OpIdx < (int)N0.getNumOperands())
27121 Ops.push_back(Elt: N0.getOperand(i: OpIdx));
27122 else
27123 Ops.push_back(Elt: N1.getOperand(i: OpIdx - N0.getNumOperands()));
27124 }
27125
27126 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT, Ops);
27127}
27128
27129// Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
27130// BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
27131//
27132// SHUFFLE(BUILD_VECTOR(), BUILD_VECTOR()) -> BUILD_VECTOR() is always
27133// a simplification in some sense, but it isn't appropriate in general: some
27134// BUILD_VECTORs are substantially cheaper than others. The general case
27135// of a BUILD_VECTOR requires inserting each element individually (or
27136// performing the equivalent in a temporary stack variable). A BUILD_VECTOR of
27137// all constants is a single constant pool load. A BUILD_VECTOR where each
27138// element is identical is a splat. A BUILD_VECTOR where most of the operands
27139// are undef lowers to a small number of element insertions.
27140//
27141// To deal with this, we currently use a bunch of mostly arbitrary heuristics.
27142// We don't fold shuffles where one side is a non-zero constant, and we don't
27143// fold shuffles if the resulting (non-splat) BUILD_VECTOR would have duplicate
27144// non-constant operands. This seems to work out reasonably well in practice.
27145static SDValue combineShuffleOfScalars(ShuffleVectorSDNode *SVN,
27146 SelectionDAG &DAG,
27147 const TargetLowering &TLI) {
27148 EVT VT = SVN->getValueType(ResNo: 0);
27149 unsigned NumElts = VT.getVectorNumElements();
27150 SDValue N0 = SVN->getOperand(Num: 0);
27151 SDValue N1 = SVN->getOperand(Num: 1);
27152
27153 if (!N0->hasOneUse())
27154 return SDValue();
27155
27156 // If only one of N1,N2 is constant, bail out if it is not ALL_ZEROS as
27157 // discussed above.
27158 if (!N1.isUndef()) {
27159 if (!N1->hasOneUse())
27160 return SDValue();
27161
27162 bool N0AnyConst = isAnyConstantBuildVector(V: N0);
27163 bool N1AnyConst = isAnyConstantBuildVector(V: N1);
27164 if (N0AnyConst && !N1AnyConst && !ISD::isBuildVectorAllZeros(N: N0.getNode()))
27165 return SDValue();
27166 if (!N0AnyConst && N1AnyConst && !ISD::isBuildVectorAllZeros(N: N1.getNode()))
27167 return SDValue();
27168 }
27169
27170 // If both inputs are splats of the same value then we can safely merge this
27171 // to a single BUILD_VECTOR with undef elements based on the shuffle mask.
27172 bool IsSplat = false;
27173 auto *BV0 = dyn_cast<BuildVectorSDNode>(Val&: N0);
27174 auto *BV1 = dyn_cast<BuildVectorSDNode>(Val&: N1);
27175 if (BV0 && BV1)
27176 if (SDValue Splat0 = BV0->getSplatValue())
27177 IsSplat = (Splat0 == BV1->getSplatValue());
27178
27179 SmallVector<SDValue, 8> Ops;
27180 SmallSet<SDValue, 16> DuplicateOps;
27181 for (int M : SVN->getMask()) {
27182 SDValue Op = DAG.getPOISON(VT: VT.getScalarType());
27183 if (M >= 0) {
27184 int Idx = M < (int)NumElts ? M : M - NumElts;
27185 SDValue &S = (M < (int)NumElts ? N0 : N1);
27186 if (S.getOpcode() == ISD::BUILD_VECTOR) {
27187 Op = S.getOperand(i: Idx);
27188 } else if (S.getOpcode() == ISD::SCALAR_TO_VECTOR) {
27189 SDValue Op0 = S.getOperand(i: 0);
27190 Op = Idx == 0 ? Op0 : DAG.getPOISON(VT: Op0.getValueType());
27191 } else {
27192 // Operand can't be combined - bail out.
27193 return SDValue();
27194 }
27195 }
27196
27197 // Don't duplicate a non-constant BUILD_VECTOR operand unless we're
27198 // generating a splat; semantically, this is fine, but it's likely to
27199 // generate low-quality code if the target can't reconstruct an appropriate
27200 // shuffle.
27201 if (!Op.isUndef() && !isIntOrFPConstant(V: Op))
27202 if (!IsSplat && !DuplicateOps.insert(V: Op).second)
27203 return SDValue();
27204
27205 Ops.push_back(Elt: Op);
27206 }
27207
27208 // BUILD_VECTOR requires all inputs to be of the same type, find the
27209 // maximum type and extend them all.
27210 EVT SVT = VT.getScalarType();
27211 if (SVT.isInteger())
27212 for (SDValue &Op : Ops)
27213 SVT = (SVT.bitsLT(VT: Op.getValueType()) ? Op.getValueType() : SVT);
27214 if (SVT != VT.getScalarType())
27215 for (SDValue &Op : Ops)
27216 Op = Op.isUndef() ? DAG.getUNDEF(VT: SVT)
27217 : (TLI.isZExtFree(FromTy: Op.getValueType(), ToTy: SVT)
27218 ? DAG.getZExtOrTrunc(Op, DL: SDLoc(SVN), VT: SVT)
27219 : DAG.getSExtOrTrunc(Op, DL: SDLoc(SVN), VT: SVT));
27220 return DAG.getBuildVector(VT, DL: SDLoc(SVN), Ops);
27221}
27222
27223// Match shuffles that can be converted to *_vector_extend_in_reg.
27224// This is often generated during legalization.
27225// e.g. v4i32 <0,u,1,u> -> (v2i64 any_vector_extend_in_reg(v4i32 src)),
27226// and returns the EVT to which the extension should be performed.
27227// NOTE: this assumes that the src is the first operand of the shuffle.
27228static std::optional<EVT> canCombineShuffleToExtendVectorInreg(
27229 unsigned Opcode, EVT VT, std::function<bool(unsigned)> Match,
27230 SelectionDAG &DAG, const TargetLowering &TLI, bool LegalTypes,
27231 bool LegalOperations) {
27232 bool IsBigEndian = DAG.getDataLayout().isBigEndian();
27233
27234 // TODO Add support for big-endian when we have a test case.
27235 if (!VT.isInteger() || IsBigEndian)
27236 return std::nullopt;
27237
27238 unsigned NumElts = VT.getVectorNumElements();
27239 unsigned EltSizeInBits = VT.getScalarSizeInBits();
27240
27241 // Attempt to match a '*_extend_vector_inreg' shuffle, we just search for
27242 // power-of-2 extensions as they are the most likely.
27243 // FIXME: should try Scale == NumElts case too,
27244 for (unsigned Scale = 2; Scale < NumElts; Scale *= 2) {
27245 // The vector width must be a multiple of Scale.
27246 if (NumElts % Scale != 0)
27247 continue;
27248
27249 EVT OutSVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: EltSizeInBits * Scale);
27250 EVT OutVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: OutSVT, NumElements: NumElts / Scale);
27251
27252 if ((LegalTypes && !TLI.isTypeLegal(VT: OutVT)) ||
27253 (LegalOperations && !TLI.isOperationLegalOrCustom(Op: Opcode, VT: OutVT)))
27254 continue;
27255
27256 if (Match(Scale))
27257 return OutVT;
27258 }
27259
27260 return std::nullopt;
27261}
27262
27263// Match shuffles that can be converted to any_vector_extend_in_reg.
27264// This is often generated during legalization.
27265// e.g. v4i32 <0,u,1,u> -> (v2i64 any_vector_extend_in_reg(v4i32 src))
27266static SDValue combineShuffleToAnyExtendVectorInreg(ShuffleVectorSDNode *SVN,
27267 SelectionDAG &DAG,
27268 const TargetLowering &TLI,
27269 bool LegalOperations) {
27270 EVT VT = SVN->getValueType(ResNo: 0);
27271 bool IsBigEndian = DAG.getDataLayout().isBigEndian();
27272
27273 // TODO Add support for big-endian when we have a test case.
27274 if (!VT.isInteger() || IsBigEndian)
27275 return SDValue();
27276
27277 // shuffle<0,-1,1,-1> == (v2i64 anyextend_vector_inreg(v4i32))
27278 auto isAnyExtend = [NumElts = VT.getVectorNumElements(),
27279 Mask = SVN->getMask()](unsigned Scale) {
27280 for (unsigned i = 0; i != NumElts; ++i) {
27281 if (Mask[i] < 0)
27282 continue;
27283 if ((i % Scale) == 0 && Mask[i] == (int)(i / Scale))
27284 continue;
27285 return false;
27286 }
27287 return true;
27288 };
27289
27290 unsigned Opcode = ISD::ANY_EXTEND_VECTOR_INREG;
27291 SDValue N0 = SVN->getOperand(Num: 0);
27292 // Never create an illegal type. Only create unsupported operations if we
27293 // are pre-legalization.
27294 std::optional<EVT> OutVT = canCombineShuffleToExtendVectorInreg(
27295 Opcode, VT, Match: isAnyExtend, DAG, TLI, /*LegalTypes=*/true, LegalOperations);
27296 if (!OutVT)
27297 return SDValue();
27298 return DAG.getBitcast(VT, V: DAG.getNode(Opcode, DL: SDLoc(SVN), VT: *OutVT, Operand: N0));
27299}
27300
27301// Match shuffles that can be converted to zero_extend_vector_inreg.
27302// This is often generated during legalization.
27303// e.g. v4i32 <0,z,1,u> -> (v2i64 zero_extend_vector_inreg(v4i32 src))
27304static SDValue combineShuffleToZeroExtendVectorInReg(ShuffleVectorSDNode *SVN,
27305 SelectionDAG &DAG,
27306 const TargetLowering &TLI,
27307 bool LegalOperations) {
27308 bool LegalTypes = true;
27309 EVT VT = SVN->getValueType(ResNo: 0);
27310 assert(!VT.isScalableVector() && "Encountered scalable shuffle?");
27311 unsigned NumElts = VT.getVectorNumElements();
27312 unsigned EltSizeInBits = VT.getScalarSizeInBits();
27313
27314 // TODO: add support for big-endian when we have a test case.
27315 bool IsBigEndian = DAG.getDataLayout().isBigEndian();
27316 if (!VT.isInteger() || IsBigEndian)
27317 return SDValue();
27318
27319 SmallVector<int, 16> Mask(SVN->getMask());
27320 auto ForEachDecomposedIndice = [NumElts, &Mask](auto Fn) {
27321 for (int &Indice : Mask) {
27322 if (Indice < 0)
27323 continue;
27324 int OpIdx = (unsigned)Indice < NumElts ? 0 : 1;
27325 int OpEltIdx = (unsigned)Indice < NumElts ? Indice : Indice - NumElts;
27326 Fn(Indice, OpIdx, OpEltIdx);
27327 }
27328 };
27329
27330 // Which elements of which operand does this shuffle demand?
27331 std::array<APInt, 2> OpsDemandedElts;
27332 for (APInt &OpDemandedElts : OpsDemandedElts)
27333 OpDemandedElts = APInt::getZero(numBits: NumElts);
27334 ForEachDecomposedIndice(
27335 [&OpsDemandedElts](int &Indice, int OpIdx, int OpEltIdx) {
27336 OpsDemandedElts[OpIdx].setBit(OpEltIdx);
27337 });
27338
27339 // Element-wise(!), which of these demanded elements are know to be zero?
27340 std::array<APInt, 2> OpsKnownZeroElts;
27341 for (auto I : zip(t: SVN->ops(), u&: OpsDemandedElts, args&: OpsKnownZeroElts))
27342 std::get<2>(t&: I) =
27343 DAG.computeVectorKnownZeroElements(Op: std::get<0>(t&: I), DemandedElts: std::get<1>(t&: I));
27344
27345 // Manifest zeroable element knowledge in the shuffle mask.
27346 // NOTE: we don't have 'zeroable' sentinel value in generic DAG,
27347 // this is a local invention, but it won't leak into DAG.
27348 // FIXME: should we not manifest them, but just check when matching?
27349 bool HadZeroableElts = false;
27350 ForEachDecomposedIndice([&OpsKnownZeroElts, &HadZeroableElts](
27351 int &Indice, int OpIdx, int OpEltIdx) {
27352 if (OpsKnownZeroElts[OpIdx][OpEltIdx]) {
27353 Indice = -2; // Zeroable element.
27354 HadZeroableElts = true;
27355 }
27356 });
27357
27358 // Don't proceed unless we've refined at least one zeroable mask indice.
27359 // If we didn't, then we are still trying to match the same shuffle mask
27360 // we previously tried to match as ISD::ANY_EXTEND_VECTOR_INREG,
27361 // and evidently failed. Proceeding will lead to endless combine loops.
27362 if (!HadZeroableElts)
27363 return SDValue();
27364
27365 // The shuffle may be more fine-grained than we want. Widen elements first.
27366 // FIXME: should we do this before manifesting zeroable shuffle mask indices?
27367 SmallVector<int, 16> ScaledMask;
27368 getShuffleMaskWithWidestElts(Mask, ScaledMask);
27369 assert(Mask.size() >= ScaledMask.size() &&
27370 Mask.size() % ScaledMask.size() == 0 && "Unexpected mask widening.");
27371 int Prescale = Mask.size() / ScaledMask.size();
27372
27373 NumElts = ScaledMask.size();
27374 EltSizeInBits *= Prescale;
27375
27376 EVT PrescaledVT = EVT::getVectorVT(
27377 Context&: *DAG.getContext(), VT: EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: EltSizeInBits),
27378 NumElements: NumElts);
27379
27380 if (LegalTypes && !TLI.isTypeLegal(VT: PrescaledVT) && TLI.isTypeLegal(VT))
27381 return SDValue();
27382
27383 // For example,
27384 // shuffle<0,z,1,-1> == (v2i64 zero_extend_vector_inreg(v4i32))
27385 // But not shuffle<z,z,1,-1> and not shuffle<0,z,z,-1> ! (for same types)
27386 auto isZeroExtend = [NumElts, &ScaledMask](unsigned Scale) {
27387 assert(Scale >= 2 && Scale <= NumElts && NumElts % Scale == 0 &&
27388 "Unexpected mask scaling factor.");
27389 ArrayRef<int> Mask = ScaledMask;
27390 for (unsigned SrcElt = 0, NumSrcElts = NumElts / Scale;
27391 SrcElt != NumSrcElts; ++SrcElt) {
27392 // Analyze the shuffle mask in Scale-sized chunks.
27393 ArrayRef<int> MaskChunk = Mask.take_front(N: Scale);
27394 assert(MaskChunk.size() == Scale && "Unexpected mask size.");
27395 Mask = Mask.drop_front(N: MaskChunk.size());
27396 // The first indice in this chunk must be SrcElt, but not zero!
27397 // FIXME: undef should be fine, but that results in more-defined result.
27398 if (int FirstIndice = MaskChunk[0]; (unsigned)FirstIndice != SrcElt)
27399 return false;
27400 // The rest of the indices in this chunk must be zeros.
27401 // FIXME: undef should be fine, but that results in more-defined result.
27402 if (!all_of(Range: MaskChunk.drop_front(N: 1),
27403 P: [](int Indice) { return Indice == -2; }))
27404 return false;
27405 }
27406 assert(Mask.empty() && "Did not process the whole mask?");
27407 return true;
27408 };
27409
27410 unsigned Opcode = ISD::ZERO_EXTEND_VECTOR_INREG;
27411 for (bool Commuted : {false, true}) {
27412 SDValue Op = SVN->getOperand(Num: !Commuted ? 0 : 1);
27413 if (Commuted)
27414 ShuffleVectorSDNode::commuteMask(Mask: ScaledMask);
27415 std::optional<EVT> OutVT = canCombineShuffleToExtendVectorInreg(
27416 Opcode, VT: PrescaledVT, Match: isZeroExtend, DAG, TLI, LegalTypes,
27417 LegalOperations);
27418 if (OutVT)
27419 return DAG.getBitcast(VT, V: DAG.getNode(Opcode, DL: SDLoc(SVN), VT: *OutVT,
27420 Operand: DAG.getBitcast(VT: PrescaledVT, V: Op)));
27421 }
27422 return SDValue();
27423}
27424
27425// Detect 'truncate_vector_inreg' style shuffles that pack the lower parts of
27426// each source element of a large type into the lowest elements of a smaller
27427// destination type. This is often generated during legalization.
27428// If the source node itself was a '*_extend_vector_inreg' node then we should
27429// then be able to remove it.
27430static SDValue combineTruncationShuffle(ShuffleVectorSDNode *SVN,
27431 SelectionDAG &DAG) {
27432 EVT VT = SVN->getValueType(ResNo: 0);
27433 bool IsBigEndian = DAG.getDataLayout().isBigEndian();
27434
27435 // TODO Add support for big-endian when we have a test case.
27436 if (!VT.isInteger() || IsBigEndian)
27437 return SDValue();
27438
27439 SDValue N0 = peekThroughBitcasts(V: SVN->getOperand(Num: 0));
27440
27441 unsigned Opcode = N0.getOpcode();
27442 if (!ISD::isExtVecInRegOpcode(Opcode))
27443 return SDValue();
27444
27445 SDValue N00 = N0.getOperand(i: 0);
27446 ArrayRef<int> Mask = SVN->getMask();
27447 unsigned NumElts = VT.getVectorNumElements();
27448 unsigned EltSizeInBits = VT.getScalarSizeInBits();
27449 unsigned ExtSrcSizeInBits = N00.getScalarValueSizeInBits();
27450 unsigned ExtDstSizeInBits = N0.getScalarValueSizeInBits();
27451
27452 if (ExtDstSizeInBits % ExtSrcSizeInBits != 0)
27453 return SDValue();
27454 unsigned ExtScale = ExtDstSizeInBits / ExtSrcSizeInBits;
27455
27456 // (v4i32 truncate_vector_inreg(v2i64)) == shuffle<0,2-1,-1>
27457 // (v8i16 truncate_vector_inreg(v4i32)) == shuffle<0,2,4,6,-1,-1,-1,-1>
27458 // (v8i16 truncate_vector_inreg(v2i64)) == shuffle<0,4,-1,-1,-1,-1,-1,-1>
27459 auto isTruncate = [&Mask, &NumElts](unsigned Scale) {
27460 for (unsigned i = 0; i != NumElts; ++i) {
27461 if (Mask[i] < 0)
27462 continue;
27463 if ((i * Scale) < NumElts && Mask[i] == (int)(i * Scale))
27464 continue;
27465 return false;
27466 }
27467 return true;
27468 };
27469
27470 // At the moment we just handle the case where we've truncated back to the
27471 // same size as before the extension.
27472 // TODO: handle more extension/truncation cases as cases arise.
27473 if (EltSizeInBits != ExtSrcSizeInBits)
27474 return SDValue();
27475 if (VT.getSizeInBits() != N00.getValueSizeInBits())
27476 return SDValue();
27477
27478 // We can remove *extend_vector_inreg only if the truncation happens at
27479 // the same scale as the extension.
27480 if (isTruncate(ExtScale))
27481 return DAG.getBitcast(VT, V: N00);
27482
27483 return SDValue();
27484}
27485
27486// Combine shuffles of splat-shuffles of the form:
27487// shuffle (shuffle V, undef, splat-mask), undef, M
27488// If splat-mask contains undef elements, we need to be careful about
27489// introducing undef's in the folded mask which are not the result of composing
27490// the masks of the shuffles.
27491static SDValue combineShuffleOfSplatVal(ShuffleVectorSDNode *Shuf,
27492 SelectionDAG &DAG) {
27493 EVT VT = Shuf->getValueType(ResNo: 0);
27494 unsigned NumElts = VT.getVectorNumElements();
27495
27496 if (!Shuf->getOperand(Num: 1).isUndef())
27497 return SDValue();
27498
27499 // See if this unary non-splat shuffle actually *is* a splat shuffle,
27500 // in disguise, with all demanded elements being identical.
27501 // FIXME: this can be done per-operand.
27502 if (!Shuf->isSplat()) {
27503 APInt DemandedElts(NumElts, 0);
27504 for (int Idx : Shuf->getMask()) {
27505 if (Idx < 0)
27506 continue; // Ignore sentinel indices.
27507 assert((unsigned)Idx < NumElts && "Out-of-bounds shuffle indice?");
27508 DemandedElts.setBit(Idx);
27509 }
27510 assert(DemandedElts.popcount() > 1 && "Is a splat shuffle already?");
27511 APInt UndefElts;
27512 if (DAG.isSplatValue(V: Shuf->getOperand(Num: 0), DemandedElts, UndefElts)) {
27513 // Even if all demanded elements are splat, some of them could be undef.
27514 // Which lowest demanded element is *not* known-undef?
27515 std::optional<unsigned> MinNonUndefIdx;
27516 for (int Idx : Shuf->getMask()) {
27517 if (Idx < 0 || UndefElts[Idx])
27518 continue; // Ignore sentinel indices, and undef elements.
27519 MinNonUndefIdx = std::min<unsigned>(a: Idx, b: MinNonUndefIdx.value_or(u: ~0U));
27520 }
27521 if (!MinNonUndefIdx)
27522 return DAG.getUNDEF(VT); // All undef - result is undef.
27523 assert(*MinNonUndefIdx < NumElts && "Expected valid element index.");
27524 SmallVector<int, 8> SplatMask(Shuf->getMask());
27525 for (int &Idx : SplatMask) {
27526 if (Idx < 0)
27527 continue; // Passthrough sentinel indices.
27528 // Otherwise, just pick the lowest demanded non-undef element.
27529 // Or sentinel undef, if we know we'd pick a known-undef element.
27530 Idx = UndefElts[Idx] ? -1 : *MinNonUndefIdx;
27531 }
27532 assert(SplatMask != Shuf->getMask() && "Expected mask to change!");
27533 return DAG.getVectorShuffle(VT, dl: SDLoc(Shuf), N1: Shuf->getOperand(Num: 0),
27534 N2: Shuf->getOperand(Num: 1), Mask: SplatMask);
27535 }
27536 }
27537
27538 // If the inner operand is a known splat with no undefs, just return that directly.
27539 // TODO: Create DemandedElts mask from Shuf's mask.
27540 // TODO: Allow undef elements and merge with the shuffle code below.
27541 if (DAG.isSplatValue(V: Shuf->getOperand(Num: 0), /*AllowUndefs*/ false))
27542 return Shuf->getOperand(Num: 0);
27543
27544 auto *Splat = dyn_cast<ShuffleVectorSDNode>(Val: Shuf->getOperand(Num: 0));
27545 if (!Splat || !Splat->isSplat())
27546 return SDValue();
27547
27548 ArrayRef<int> ShufMask = Shuf->getMask();
27549 ArrayRef<int> SplatMask = Splat->getMask();
27550 assert(ShufMask.size() == SplatMask.size() && "Mask length mismatch");
27551
27552 // Prefer simplifying to the splat-shuffle, if possible. This is legal if
27553 // every undef mask element in the splat-shuffle has a corresponding undef
27554 // element in the user-shuffle's mask or if the composition of mask elements
27555 // would result in undef.
27556 // Examples for (shuffle (shuffle v, undef, SplatMask), undef, UserMask):
27557 // * UserMask=[0,2,u,u], SplatMask=[2,u,2,u] -> [2,2,u,u]
27558 // In this case it is not legal to simplify to the splat-shuffle because we
27559 // may be exposing the users of the shuffle an undef element at index 1
27560 // which was not there before the combine.
27561 // * UserMask=[0,u,2,u], SplatMask=[2,u,2,u] -> [2,u,2,u]
27562 // In this case the composition of masks yields SplatMask, so it's ok to
27563 // simplify to the splat-shuffle.
27564 // * UserMask=[3,u,2,u], SplatMask=[2,u,2,u] -> [u,u,2,u]
27565 // In this case the composed mask includes all undef elements of SplatMask
27566 // and in addition sets element zero to undef. It is safe to simplify to
27567 // the splat-shuffle.
27568 auto CanSimplifyToExistingSplat = [](ArrayRef<int> UserMask,
27569 ArrayRef<int> SplatMask) {
27570 for (unsigned i = 0, e = UserMask.size(); i != e; ++i)
27571 if (UserMask[i] != -1 && SplatMask[i] == -1 &&
27572 SplatMask[UserMask[i]] != -1)
27573 return false;
27574 return true;
27575 };
27576 if (CanSimplifyToExistingSplat(ShufMask, SplatMask))
27577 return Shuf->getOperand(Num: 0);
27578
27579 // Create a new shuffle with a mask that is composed of the two shuffles'
27580 // masks.
27581 SmallVector<int, 32> NewMask;
27582 for (int Idx : ShufMask)
27583 NewMask.push_back(Elt: Idx == -1 ? -1 : SplatMask[Idx]);
27584
27585 return DAG.getVectorShuffle(VT: Splat->getValueType(ResNo: 0), dl: SDLoc(Splat),
27586 N1: Splat->getOperand(Num: 0), N2: Splat->getOperand(Num: 1),
27587 Mask: NewMask);
27588}
27589
27590// Combine shuffles of bitcasts into a shuffle of the bitcast type, providing
27591// the mask can be treated as a larger type.
27592static SDValue combineShuffleOfBitcast(ShuffleVectorSDNode *SVN,
27593 SelectionDAG &DAG,
27594 const TargetLowering &TLI,
27595 bool LegalOperations) {
27596 SDValue Op0 = SVN->getOperand(Num: 0);
27597 SDValue Op1 = SVN->getOperand(Num: 1);
27598 EVT VT = SVN->getValueType(ResNo: 0);
27599 if (Op0.getOpcode() != ISD::BITCAST)
27600 return SDValue();
27601 EVT InVT = Op0.getOperand(i: 0).getValueType();
27602 if (!InVT.isVector() ||
27603 (!Op1.isUndef() && (Op1.getOpcode() != ISD::BITCAST ||
27604 Op1.getOperand(i: 0).getValueType() != InVT)))
27605 return SDValue();
27606 if (isAnyConstantBuildVector(V: Op0.getOperand(i: 0)) &&
27607 (Op1.isUndef() || isAnyConstantBuildVector(V: Op1.getOperand(i: 0))))
27608 return SDValue();
27609
27610 int VTLanes = VT.getVectorNumElements();
27611 int InLanes = InVT.getVectorNumElements();
27612 if (VTLanes <= InLanes || VTLanes % InLanes != 0 ||
27613 (LegalOperations &&
27614 !TLI.isOperationLegalOrCustom(Op: ISD::VECTOR_SHUFFLE, VT: InVT)))
27615 return SDValue();
27616 int Factor = VTLanes / InLanes;
27617
27618 // Check that each group of lanes in the mask are either undef or make a valid
27619 // mask for the wider lane type.
27620 ArrayRef<int> Mask = SVN->getMask();
27621 SmallVector<int> NewMask;
27622 if (!widenShuffleMaskElts(Scale: Factor, Mask, ScaledMask&: NewMask))
27623 return SDValue();
27624
27625 if (!TLI.isShuffleMaskLegal(NewMask, InVT))
27626 return SDValue();
27627
27628 // Create the new shuffle with the new mask and bitcast it back to the
27629 // original type.
27630 SDLoc DL(SVN);
27631 Op0 = Op0.getOperand(i: 0);
27632 Op1 = Op1.isUndef() ? DAG.getUNDEF(VT: InVT) : Op1.getOperand(i: 0);
27633 SDValue NewShuf = DAG.getVectorShuffle(VT: InVT, dl: DL, N1: Op0, N2: Op1, Mask: NewMask);
27634 return DAG.getBitcast(VT, V: NewShuf);
27635}
27636
27637/// Combine shuffle of shuffle of the form:
27638/// shuf (shuf X, undef, InnerMask), undef, OuterMask --> splat X
27639static SDValue formSplatFromShuffles(ShuffleVectorSDNode *OuterShuf,
27640 SelectionDAG &DAG) {
27641 if (!OuterShuf->getOperand(Num: 1).isUndef())
27642 return SDValue();
27643 auto *InnerShuf = dyn_cast<ShuffleVectorSDNode>(Val: OuterShuf->getOperand(Num: 0));
27644 if (!InnerShuf || !InnerShuf->getOperand(Num: 1).isUndef())
27645 return SDValue();
27646
27647 ArrayRef<int> OuterMask = OuterShuf->getMask();
27648 ArrayRef<int> InnerMask = InnerShuf->getMask();
27649 unsigned NumElts = OuterMask.size();
27650 assert(NumElts == InnerMask.size() && "Mask length mismatch");
27651 SmallVector<int, 32> CombinedMask(NumElts, -1);
27652 int SplatIndex = -1;
27653 for (unsigned i = 0; i != NumElts; ++i) {
27654 // Undef lanes remain undef.
27655 int OuterMaskElt = OuterMask[i];
27656 if (OuterMaskElt == -1)
27657 continue;
27658
27659 // Peek through the shuffle masks to get the underlying source element.
27660 int InnerMaskElt = InnerMask[OuterMaskElt];
27661 if (InnerMaskElt == -1)
27662 continue;
27663
27664 // Initialize the splatted element.
27665 if (SplatIndex == -1)
27666 SplatIndex = InnerMaskElt;
27667
27668 // Non-matching index - this is not a splat.
27669 if (SplatIndex != InnerMaskElt)
27670 return SDValue();
27671
27672 CombinedMask[i] = InnerMaskElt;
27673 }
27674 assert((all_of(CombinedMask, equal_to(-1)) ||
27675 getSplatIndex(CombinedMask) != -1) &&
27676 "Expected a splat mask");
27677
27678 // TODO: The transform may be a win even if the mask is not legal.
27679 EVT VT = OuterShuf->getValueType(ResNo: 0);
27680 assert(VT == InnerShuf->getValueType(0) && "Expected matching shuffle types");
27681 if (!DAG.getTargetLoweringInfo().isShuffleMaskLegal(CombinedMask, VT))
27682 return SDValue();
27683
27684 return DAG.getVectorShuffle(VT, dl: SDLoc(OuterShuf), N1: InnerShuf->getOperand(Num: 0),
27685 N2: InnerShuf->getOperand(Num: 1), Mask: CombinedMask);
27686}
27687
27688/// If the shuffle mask is taking exactly one element from the first vector
27689/// operand and passing through all other elements from the second vector
27690/// operand, return the index of the mask element that is choosing an element
27691/// from the first operand. Otherwise, return -1.
27692static int getShuffleMaskIndexOfOneElementFromOp0IntoOp1(ArrayRef<int> Mask) {
27693 int MaskSize = Mask.size();
27694 int EltFromOp0 = -1;
27695 // TODO: This does not match if there are undef elements in the shuffle mask.
27696 // Should we ignore undefs in the shuffle mask instead? The trade-off is
27697 // removing an instruction (a shuffle), but losing the knowledge that some
27698 // vector lanes are not needed.
27699 for (int i = 0; i != MaskSize; ++i) {
27700 if (Mask[i] >= 0 && Mask[i] < MaskSize) {
27701 // We're looking for a shuffle of exactly one element from operand 0.
27702 if (EltFromOp0 != -1)
27703 return -1;
27704 EltFromOp0 = i;
27705 } else if (Mask[i] != i + MaskSize) {
27706 // Nothing from operand 1 can change lanes.
27707 return -1;
27708 }
27709 }
27710 return EltFromOp0;
27711}
27712
27713/// If a shuffle inserts exactly one element from a source vector operand into
27714/// another vector operand and we can access the specified element as a scalar,
27715/// then we can eliminate the shuffle.
27716SDValue DAGCombiner::replaceShuffleOfInsert(ShuffleVectorSDNode *Shuf) {
27717 // First, check if we are taking one element of a vector and shuffling that
27718 // element into another vector.
27719 ArrayRef<int> Mask = Shuf->getMask();
27720 SmallVector<int, 16> CommutedMask(Mask);
27721 SDValue Op0 = Shuf->getOperand(Num: 0);
27722 SDValue Op1 = Shuf->getOperand(Num: 1);
27723 int ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(Mask);
27724 if (ShufOp0Index == -1) {
27725 // Commute mask and check again.
27726 ShuffleVectorSDNode::commuteMask(Mask: CommutedMask);
27727 ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(Mask: CommutedMask);
27728 if (ShufOp0Index == -1)
27729 return SDValue();
27730 // Commute operands to match the commuted shuffle mask.
27731 std::swap(a&: Op0, b&: Op1);
27732 Mask = CommutedMask;
27733 }
27734
27735 // The shuffle inserts exactly one element from operand 0 into operand 1.
27736 // Now see if we can access that element as a scalar via a real insert element
27737 // instruction.
27738 // TODO: We can try harder to locate the element as a scalar. Examples: it
27739 // could be an operand of BUILD_VECTOR, or a constant.
27740 assert(Mask[ShufOp0Index] >= 0 && Mask[ShufOp0Index] < (int)Mask.size() &&
27741 "Shuffle mask value must be from operand 0");
27742
27743 SDValue Elt;
27744 if (sd_match(N: Op0, P: m_InsertElt(Vec: m_Value(), Val: m_Value(N&: Elt),
27745 Idx: m_SpecificInt(V: Mask[ShufOp0Index])))) {
27746 // There's an existing insertelement with constant insertion index, so we
27747 // don't need to check the legality/profitability of a replacement operation
27748 // that differs at most in the constant value. The target should be able to
27749 // lower any of those in a similar way. If not, legalization will expand
27750 // this to a scalar-to-vector plus shuffle.
27751 //
27752 // Note that the shuffle may move the scalar from the position that the
27753 // insert element used. Therefore, our new insert element occurs at the
27754 // shuffle's mask index value, not the insert's index value.
27755 //
27756 // shuffle (insertelt v1, x, C), v2, mask --> insertelt v2, x, C'
27757 SDValue NewInsIndex = DAG.getVectorIdxConstant(Val: ShufOp0Index, DL: SDLoc(Shuf));
27758 return DAG.getNode(Opcode: ISD::INSERT_VECTOR_ELT, DL: SDLoc(Shuf), VT: Op0.getValueType(),
27759 N1: Op1, N2: Elt, N3: NewInsIndex);
27760 }
27761
27762 if (!hasOperation(Opcode: ISD::INSERT_VECTOR_ELT, VT: Op0.getValueType()))
27763 return SDValue();
27764
27765 if (sd_match(N: Op0, P: m_UnaryOp(Opc: ISD::SCALAR_TO_VECTOR, Op: m_Value(N&: Elt))) &&
27766 Mask[ShufOp0Index] == 0) {
27767 SDValue NewInsIndex = DAG.getVectorIdxConstant(Val: ShufOp0Index, DL: SDLoc(Shuf));
27768 return DAG.getNode(Opcode: ISD::INSERT_VECTOR_ELT, DL: SDLoc(Shuf), VT: Op0.getValueType(),
27769 N1: Op1, N2: Elt, N3: NewInsIndex);
27770 }
27771
27772 return SDValue();
27773}
27774
27775/// If we have a unary shuffle of a shuffle, see if it can be folded away
27776/// completely. This has the potential to lose undef knowledge because the first
27777/// shuffle may not have an undef mask element where the second one does. So
27778/// only call this after doing simplifications based on demanded elements.
27779static SDValue simplifyShuffleOfShuffle(ShuffleVectorSDNode *Shuf) {
27780 // shuf (shuf0 X, Y, Mask0), undef, Mask
27781 auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(Val: Shuf->getOperand(Num: 0));
27782 if (!Shuf0 || !Shuf->getOperand(Num: 1).isUndef())
27783 return SDValue();
27784
27785 ArrayRef<int> Mask = Shuf->getMask();
27786 ArrayRef<int> Mask0 = Shuf0->getMask();
27787 for (int i = 0, e = (int)Mask.size(); i != e; ++i) {
27788 // Ignore undef elements.
27789 if (Mask[i] == -1)
27790 continue;
27791 assert(Mask[i] >= 0 && Mask[i] < e && "Unexpected shuffle mask value");
27792
27793 // Is the element of the shuffle operand chosen by this shuffle the same as
27794 // the element chosen by the shuffle operand itself?
27795 if (Mask0[Mask[i]] != Mask0[i])
27796 return SDValue();
27797 }
27798 // Every element of this shuffle is identical to the result of the previous
27799 // shuffle, so we can replace this value.
27800 return Shuf->getOperand(Num: 0);
27801}
27802
27803SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
27804 EVT VT = N->getValueType(ResNo: 0);
27805 unsigned NumElts = VT.getVectorNumElements();
27806
27807 SDValue N0 = N->getOperand(Num: 0);
27808 SDValue N1 = N->getOperand(Num: 1);
27809
27810 assert(N0.getValueType() == VT && "Vector shuffle must be normalized in DAG");
27811
27812 // Canonicalize shuffle undef, undef -> undef
27813 if (N0.isUndef() && N1.isUndef())
27814 return DAG.getUNDEF(VT);
27815
27816 ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Val: N);
27817
27818 // Canonicalize shuffle v, v -> v, poison
27819 if (N0 == N1)
27820 return DAG.getVectorShuffle(VT, dl: SDLoc(N), N1: N0, N2: DAG.getPOISON(VT),
27821 Mask: createUnaryMask(Mask: SVN->getMask(), NumElts));
27822
27823 // Canonicalize shuffle undef, v -> v, undef. Commute the shuffle mask.
27824 if (N0.isUndef())
27825 return DAG.getCommutedVectorShuffle(SV: *SVN);
27826
27827 // Remove references to rhs if it is undef
27828 if (N1.isUndef()) {
27829 bool Changed = false;
27830 SmallVector<int, 8> NewMask;
27831 for (unsigned i = 0; i != NumElts; ++i) {
27832 int Idx = SVN->getMaskElt(Idx: i);
27833 if (Idx >= (int)NumElts) {
27834 Idx = -1;
27835 Changed = true;
27836 }
27837 NewMask.push_back(Elt: Idx);
27838 }
27839 if (Changed)
27840 return DAG.getVectorShuffle(VT, dl: SDLoc(N), N1: N0, N2: N1, Mask: NewMask);
27841 }
27842
27843 if (SDValue InsElt = replaceShuffleOfInsert(Shuf: SVN))
27844 return InsElt;
27845
27846 // A shuffle of a single vector that is a splatted value can always be folded.
27847 if (SDValue V = combineShuffleOfSplatVal(Shuf: SVN, DAG))
27848 return V;
27849
27850 if (SDValue V = formSplatFromShuffles(OuterShuf: SVN, DAG))
27851 return V;
27852
27853 // If it is a splat, check if the argument vector is another splat or a
27854 // build_vector.
27855 if (SVN->isSplat() && SVN->getSplatIndex() < (int)NumElts) {
27856 int SplatIndex = SVN->getSplatIndex();
27857 if (N0.hasOneUse() && TLI.isExtractVecEltCheap(VT, Index: SplatIndex) &&
27858 TLI.isBinOp(Opcode: N0.getOpcode()) && N0->getNumValues() == 1) {
27859 // splat (vector_bo L, R), Index -->
27860 // splat (scalar_bo (extelt L, Index), (extelt R, Index))
27861 SDValue L = N0.getOperand(i: 0), R = N0.getOperand(i: 1);
27862 SDLoc DL(N);
27863 EVT EltVT = VT.getScalarType();
27864 SDValue Index = DAG.getVectorIdxConstant(Val: SplatIndex, DL);
27865 SDValue ExtL = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: L, N2: Index);
27866 SDValue ExtR = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: R, N2: Index);
27867 SDValue NewBO =
27868 DAG.getNode(Opcode: N0.getOpcode(), DL, VT: EltVT, N1: ExtL, N2: ExtR, Flags: N0->getFlags());
27869 SDValue Insert = DAG.getNode(Opcode: ISD::SCALAR_TO_VECTOR, DL, VT, Operand: NewBO);
27870 SmallVector<int, 16> ZeroMask(VT.getVectorNumElements(), 0);
27871 return DAG.getVectorShuffle(VT, dl: DL, N1: Insert, N2: DAG.getPOISON(VT), Mask: ZeroMask);
27872 }
27873
27874 // splat(scalar_to_vector(x), 0) -> build_vector(x,...,x)
27875 // splat(insert_vector_elt(v, x, c), c) -> build_vector(x,...,x)
27876 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::BUILD_VECTOR, VT)) &&
27877 N0.hasOneUse()) {
27878 if (N0.getOpcode() == ISD::SCALAR_TO_VECTOR && SplatIndex == 0)
27879 return DAG.getSplatBuildVector(VT, DL: SDLoc(N), Op: N0.getOperand(i: 0));
27880
27881 if (N0.getOpcode() == ISD::INSERT_VECTOR_ELT)
27882 if (auto *Idx = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 2)))
27883 if (Idx->getAPIntValue() == SplatIndex)
27884 return DAG.getSplatBuildVector(VT, DL: SDLoc(N), Op: N0.getOperand(i: 1));
27885
27886 // Look through a bitcast if LE and splatting lane 0, through to a
27887 // scalar_to_vector or a build_vector.
27888 if (N0.getOpcode() == ISD::BITCAST && N0.getOperand(i: 0).hasOneUse() &&
27889 SplatIndex == 0 && DAG.getDataLayout().isLittleEndian() &&
27890 (N0.getOperand(i: 0).getOpcode() == ISD::SCALAR_TO_VECTOR ||
27891 N0.getOperand(i: 0).getOpcode() == ISD::BUILD_VECTOR)) {
27892 EVT N00VT = N0.getOperand(i: 0).getValueType();
27893 if (VT.getScalarSizeInBits() <= N00VT.getScalarSizeInBits() &&
27894 VT.isInteger() && N00VT.isInteger()) {
27895 EVT InVT =
27896 TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT: VT.getScalarType());
27897 SDValue Op = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 0).getOperand(i: 0),
27898 DL: SDLoc(N), VT: InVT);
27899 return DAG.getSplatBuildVector(VT, DL: SDLoc(N), Op);
27900 }
27901 }
27902 }
27903
27904 // If this is a bit convert that changes the element type of the vector but
27905 // not the number of vector elements, look through it. Be careful not to
27906 // look though conversions that change things like v4f32 to v2f64.
27907 SDNode *V = N0.getNode();
27908 if (V->getOpcode() == ISD::BITCAST) {
27909 SDValue ConvInput = V->getOperand(Num: 0);
27910 if (ConvInput.getValueType().isVector() &&
27911 ConvInput.getValueType().getVectorNumElements() == NumElts)
27912 V = ConvInput.getNode();
27913 }
27914
27915 if (V->getOpcode() == ISD::BUILD_VECTOR) {
27916 assert(V->getNumOperands() == NumElts &&
27917 "BUILD_VECTOR has wrong number of operands");
27918 SDValue Base;
27919 bool AllSame = true;
27920 for (unsigned i = 0; i != NumElts; ++i) {
27921 if (!V->getOperand(Num: i).isUndef()) {
27922 Base = V->getOperand(Num: i);
27923 break;
27924 }
27925 }
27926 // Splat of <u, u, u, u>, return <u, u, u, u>
27927 if (!Base.getNode())
27928 return N0;
27929 for (unsigned i = 0; i != NumElts; ++i) {
27930 if (V->getOperand(Num: i) != Base) {
27931 AllSame = false;
27932 break;
27933 }
27934 }
27935 // Splat of <x, x, x, x>, return <x, x, x, x>
27936 if (AllSame)
27937 return N0;
27938
27939 // Canonicalize any other splat as a build_vector, but avoid defining any
27940 // undefined elements in the mask.
27941 SDValue Splatted = V->getOperand(Num: SplatIndex);
27942 SmallVector<SDValue, 8> Ops(NumElts, Splatted);
27943 EVT EltVT = Splatted.getValueType();
27944
27945 for (unsigned i = 0; i != NumElts; ++i) {
27946 if (SVN->getMaskElt(Idx: i) < 0)
27947 Ops[i] = DAG.getPOISON(VT: EltVT);
27948 }
27949
27950 SDValue NewBV = DAG.getBuildVector(VT: V->getValueType(ResNo: 0), DL: SDLoc(N), Ops);
27951
27952 // We may have jumped through bitcasts, so the type of the
27953 // BUILD_VECTOR may not match the type of the shuffle.
27954 if (V->getValueType(ResNo: 0) != VT)
27955 NewBV = DAG.getBitcast(VT, V: NewBV);
27956 return NewBV;
27957 }
27958 }
27959
27960 // Simplify source operands based on shuffle mask.
27961 if (SimplifyDemandedVectorElts(Op: SDValue(N, 0)))
27962 return SDValue(N, 0);
27963
27964 // This is intentionally placed after demanded elements simplification because
27965 // it could eliminate knowledge of undef elements created by this shuffle.
27966 if (SDValue ShufOp = simplifyShuffleOfShuffle(Shuf: SVN))
27967 return ShufOp;
27968
27969 // Match shuffles that can be converted to any_vector_extend_in_reg.
27970 if (SDValue V =
27971 combineShuffleToAnyExtendVectorInreg(SVN, DAG, TLI, LegalOperations))
27972 return V;
27973
27974 // Combine "truncate_vector_in_reg" style shuffles.
27975 if (SDValue V = combineTruncationShuffle(SVN, DAG))
27976 return V;
27977
27978 if (N0.getOpcode() == ISD::CONCAT_VECTORS &&
27979 Level < AfterLegalizeVectorOps &&
27980 (N1.isUndef() ||
27981 (N1.getOpcode() == ISD::CONCAT_VECTORS &&
27982 N0.getOperand(i: 0).getValueType() == N1.getOperand(i: 0).getValueType()))) {
27983 if (SDValue V = partitionShuffleOfConcats(N, DAG))
27984 return V;
27985 }
27986
27987 // A shuffle of a concat of the same narrow vector can be reduced to use
27988 // only low-half elements of a concat with undef:
27989 // shuf (concat X, X), undef, Mask --> shuf (concat X, undef), undef, Mask'
27990 if (N0.getOpcode() == ISD::CONCAT_VECTORS && N1.isUndef() &&
27991 N0.getNumOperands() == 2 &&
27992 N0.getOperand(i: 0) == N0.getOperand(i: 1)) {
27993 int HalfNumElts = (int)NumElts / 2;
27994 SmallVector<int, 8> NewMask;
27995 for (unsigned i = 0; i != NumElts; ++i) {
27996 int Idx = SVN->getMaskElt(Idx: i);
27997 if (Idx >= HalfNumElts) {
27998 assert(Idx < (int)NumElts && "Shuffle mask chooses undef op");
27999 Idx -= HalfNumElts;
28000 }
28001 NewMask.push_back(Elt: Idx);
28002 }
28003 if (TLI.isShuffleMaskLegal(NewMask, VT)) {
28004 SDValue UndefVec = DAG.getPOISON(VT: N0.getOperand(i: 0).getValueType());
28005 SDValue NewCat = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT,
28006 N1: N0.getOperand(i: 0), N2: UndefVec);
28007 return DAG.getVectorShuffle(VT, dl: SDLoc(N), N1: NewCat, N2: N1, Mask: NewMask);
28008 }
28009 }
28010
28011 // See if we can replace a shuffle with an insert_subvector.
28012 // e.g. v2i32 into v8i32:
28013 // shuffle(lhs,concat(rhs0,rhs1,rhs2,rhs3),0,1,2,3,10,11,6,7).
28014 // --> insert_subvector(lhs,rhs1,4).
28015 if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT) &&
28016 TLI.isOperationLegalOrCustom(Op: ISD::INSERT_SUBVECTOR, VT)) {
28017 auto ShuffleToInsert = [&](SDValue LHS, SDValue RHS, ArrayRef<int> Mask) {
28018 // Ensure RHS subvectors are legal.
28019 assert(RHS.getOpcode() == ISD::CONCAT_VECTORS && "Can't find subvectors");
28020 EVT SubVT = RHS.getOperand(i: 0).getValueType();
28021 int NumSubVecs = RHS.getNumOperands();
28022 int NumSubElts = SubVT.getVectorNumElements();
28023 assert((NumElts % NumSubElts) == 0 && "Subvector mismatch");
28024 if (!TLI.isTypeLegal(VT: SubVT))
28025 return SDValue();
28026
28027 // Don't bother if we have an unary shuffle (matches undef + LHS elts).
28028 if (all_of(Range&: Mask, P: [NumElts](int M) { return M < (int)NumElts; }))
28029 return SDValue();
28030
28031 // Search [NumSubElts] spans for RHS sequence.
28032 // TODO: Can we avoid nested loops to increase performance?
28033 SmallVector<int> InsertionMask(NumElts);
28034 for (int SubVec = 0; SubVec != NumSubVecs; ++SubVec) {
28035 for (int SubIdx = 0; SubIdx != (int)NumElts; SubIdx += NumSubElts) {
28036 // Reset mask to identity.
28037 std::iota(first: InsertionMask.begin(), last: InsertionMask.end(), value: 0);
28038
28039 // Add subvector insertion.
28040 std::iota(first: InsertionMask.begin() + SubIdx,
28041 last: InsertionMask.begin() + SubIdx + NumSubElts,
28042 value: NumElts + (SubVec * NumSubElts));
28043
28044 // See if the shuffle mask matches the reference insertion mask.
28045 bool MatchingShuffle = true;
28046 for (int i = 0; i != (int)NumElts; ++i) {
28047 int ExpectIdx = InsertionMask[i];
28048 int ActualIdx = Mask[i];
28049 if (0 <= ActualIdx && ExpectIdx != ActualIdx) {
28050 MatchingShuffle = false;
28051 break;
28052 }
28053 }
28054
28055 if (MatchingShuffle)
28056 return DAG.getInsertSubvector(DL: SDLoc(N), Vec: LHS, SubVec: RHS.getOperand(i: SubVec),
28057 Idx: SubIdx);
28058 }
28059 }
28060 return SDValue();
28061 };
28062 ArrayRef<int> Mask = SVN->getMask();
28063 if (N1.getOpcode() == ISD::CONCAT_VECTORS)
28064 if (SDValue InsertN1 = ShuffleToInsert(N0, N1, Mask))
28065 return InsertN1;
28066 if (N0.getOpcode() == ISD::CONCAT_VECTORS) {
28067 SmallVector<int> CommuteMask(Mask);
28068 ShuffleVectorSDNode::commuteMask(Mask: CommuteMask);
28069 if (SDValue InsertN0 = ShuffleToInsert(N1, N0, CommuteMask))
28070 return InsertN0;
28071 }
28072 }
28073
28074 // If we're not performing a select/blend shuffle, see if we can convert the
28075 // shuffle into a AND node, with all the out-of-lane elements are known zero.
28076 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
28077 bool IsInLaneMask = true;
28078 ArrayRef<int> Mask = SVN->getMask();
28079 SmallVector<int, 16> ClearMask(NumElts, -1);
28080 APInt DemandedLHS = APInt::getZero(numBits: NumElts);
28081 APInt DemandedRHS = APInt::getZero(numBits: NumElts);
28082 for (int I = 0; I != (int)NumElts; ++I) {
28083 int M = Mask[I];
28084 if (M < 0)
28085 continue;
28086 ClearMask[I] = M == I ? I : (I + NumElts);
28087 IsInLaneMask &= (M == I) || (M == (int)(I + NumElts));
28088 if (M != I) {
28089 APInt &Demanded = M < (int)NumElts ? DemandedLHS : DemandedRHS;
28090 Demanded.setBit(M % NumElts);
28091 }
28092 }
28093 // TODO: Should we try to mask with N1 as well?
28094 if (!IsInLaneMask && (!DemandedLHS.isZero() || !DemandedRHS.isZero()) &&
28095 (DemandedLHS.isZero() || DAG.MaskedVectorIsZero(Op: N0, DemandedElts: DemandedLHS)) &&
28096 (DemandedRHS.isZero() || DAG.MaskedVectorIsZero(Op: N1, DemandedElts: DemandedRHS))) {
28097 SDLoc DL(N);
28098 EVT IntVT = VT.changeVectorElementTypeToInteger();
28099 EVT IntSVT = VT.getVectorElementType().changeTypeToInteger();
28100 // Transform the type to a legal type so that the buildvector constant
28101 // elements are not illegal. Make sure that the result is larger than the
28102 // original type, incase the value is split into two (eg i64->i32).
28103 if (!TLI.isTypeLegal(VT: IntSVT) && LegalTypes)
28104 IntSVT = TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT: IntSVT);
28105 if (IntSVT.getSizeInBits() >= IntVT.getScalarSizeInBits()) {
28106 SDValue ZeroElt = DAG.getConstant(Val: 0, DL, VT: IntSVT);
28107 SDValue AllOnesElt = DAG.getAllOnesConstant(DL, VT: IntSVT);
28108 SmallVector<SDValue, 16> AndMask(NumElts, DAG.getPOISON(VT: IntSVT));
28109 for (int I = 0; I != (int)NumElts; ++I)
28110 if (0 <= Mask[I])
28111 AndMask[I] = Mask[I] == I ? AllOnesElt : ZeroElt;
28112
28113 // See if a clear mask is legal instead of going via
28114 // XformToShuffleWithZero which loses UNDEF mask elements.
28115 if (TLI.isVectorClearMaskLegal(ClearMask, IntVT))
28116 return DAG.getBitcast(
28117 VT, V: DAG.getVectorShuffle(VT: IntVT, dl: DL, N1: DAG.getBitcast(VT: IntVT, V: N0),
28118 N2: DAG.getConstant(Val: 0, DL, VT: IntVT), Mask: ClearMask));
28119
28120 if (TLI.isOperationLegalOrCustom(Op: ISD::AND, VT: IntVT))
28121 return DAG.getBitcast(
28122 VT, V: DAG.getNode(Opcode: ISD::AND, DL, VT: IntVT, N1: DAG.getBitcast(VT: IntVT, V: N0),
28123 N2: DAG.getBuildVector(VT: IntVT, DL, Ops: AndMask)));
28124 }
28125 }
28126 }
28127
28128 // Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
28129 // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
28130 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT))
28131 if (SDValue Res = combineShuffleOfScalars(SVN, DAG, TLI))
28132 return Res;
28133
28134 // If this shuffle only has a single input that is a bitcasted shuffle,
28135 // attempt to merge the 2 shuffles and suitably bitcast the inputs/output
28136 // back to their original types.
28137 if (N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
28138 N1.isUndef() && Level < AfterLegalizeVectorOps &&
28139 TLI.isTypeLegal(VT)) {
28140
28141 SDValue BC0 = peekThroughOneUseBitcasts(V: N0);
28142 if (BC0.getOpcode() == ISD::VECTOR_SHUFFLE && BC0.hasOneUse()) {
28143 EVT SVT = VT.getScalarType();
28144 EVT InnerVT = BC0->getValueType(ResNo: 0);
28145 EVT InnerSVT = InnerVT.getScalarType();
28146
28147 // Determine which shuffle works with the smaller scalar type.
28148 EVT ScaleVT = SVT.bitsLT(VT: InnerSVT) ? VT : InnerVT;
28149 EVT ScaleSVT = ScaleVT.getScalarType();
28150
28151 if (TLI.isTypeLegal(VT: ScaleVT) &&
28152 0 == (InnerSVT.getSizeInBits() % ScaleSVT.getSizeInBits()) &&
28153 0 == (SVT.getSizeInBits() % ScaleSVT.getSizeInBits())) {
28154 int InnerScale = InnerSVT.getSizeInBits() / ScaleSVT.getSizeInBits();
28155 int OuterScale = SVT.getSizeInBits() / ScaleSVT.getSizeInBits();
28156
28157 // Scale the shuffle masks to the smaller scalar type.
28158 ShuffleVectorSDNode *InnerSVN = cast<ShuffleVectorSDNode>(Val&: BC0);
28159 SmallVector<int, 8> InnerMask;
28160 SmallVector<int, 8> OuterMask;
28161 narrowShuffleMaskElts(Scale: InnerScale, Mask: InnerSVN->getMask(), ScaledMask&: InnerMask);
28162 narrowShuffleMaskElts(Scale: OuterScale, Mask: SVN->getMask(), ScaledMask&: OuterMask);
28163
28164 // Merge the shuffle masks.
28165 SmallVector<int, 8> NewMask;
28166 for (int M : OuterMask)
28167 NewMask.push_back(Elt: M < 0 ? -1 : InnerMask[M]);
28168
28169 // Test for shuffle mask legality over both commutations.
28170 SDValue SV0 = BC0->getOperand(Num: 0);
28171 SDValue SV1 = BC0->getOperand(Num: 1);
28172 bool LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
28173 if (!LegalMask) {
28174 std::swap(a&: SV0, b&: SV1);
28175 ShuffleVectorSDNode::commuteMask(Mask: NewMask);
28176 LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
28177 }
28178
28179 if (LegalMask) {
28180 SV0 = DAG.getBitcast(VT: ScaleVT, V: SV0);
28181 SV1 = DAG.getBitcast(VT: ScaleVT, V: SV1);
28182 return DAG.getBitcast(
28183 VT, V: DAG.getVectorShuffle(VT: ScaleVT, dl: SDLoc(N), N1: SV0, N2: SV1, Mask: NewMask));
28184 }
28185 }
28186 }
28187 }
28188
28189 // Match shuffles of bitcasts, so long as the mask can be treated as the
28190 // larger type.
28191 if (SDValue V = combineShuffleOfBitcast(SVN, DAG, TLI, LegalOperations))
28192 return V;
28193
28194 // Compute the combined shuffle mask for a shuffle with SV0 as the first
28195 // operand, and SV1 as the second operand.
28196 // i.e. Merge SVN(OtherSVN, N1) -> shuffle(SV0, SV1, Mask) iff Commute = false
28197 // Merge SVN(N1, OtherSVN) -> shuffle(SV0, SV1, Mask') iff Commute = true
28198 auto MergeInnerShuffle =
28199 [NumElts, &VT](bool Commute, ShuffleVectorSDNode *SVN,
28200 ShuffleVectorSDNode *OtherSVN, SDValue N1,
28201 const TargetLowering &TLI, SDValue &SV0, SDValue &SV1,
28202 SmallVectorImpl<int> &Mask) -> bool {
28203 // Don't try to fold splats; they're likely to simplify somehow, or they
28204 // might be free.
28205 if (OtherSVN->isSplat())
28206 return false;
28207
28208 SV0 = SV1 = SDValue();
28209 Mask.clear();
28210
28211 for (unsigned i = 0; i != NumElts; ++i) {
28212 int Idx = SVN->getMaskElt(Idx: i);
28213 if (Idx < 0) {
28214 // Propagate Undef.
28215 Mask.push_back(Elt: Idx);
28216 continue;
28217 }
28218
28219 if (Commute)
28220 Idx = (Idx < (int)NumElts) ? (Idx + NumElts) : (Idx - NumElts);
28221
28222 SDValue CurrentVec;
28223 if (Idx < (int)NumElts) {
28224 // This shuffle index refers to the inner shuffle N0. Lookup the inner
28225 // shuffle mask to identify which vector is actually referenced.
28226 Idx = OtherSVN->getMaskElt(Idx);
28227 if (Idx < 0) {
28228 // Propagate Undef.
28229 Mask.push_back(Elt: Idx);
28230 continue;
28231 }
28232 CurrentVec = (Idx < (int)NumElts) ? OtherSVN->getOperand(Num: 0)
28233 : OtherSVN->getOperand(Num: 1);
28234 } else {
28235 // This shuffle index references an element within N1.
28236 CurrentVec = N1;
28237 }
28238
28239 // Simple case where 'CurrentVec' is UNDEF.
28240 if (CurrentVec.isUndef()) {
28241 Mask.push_back(Elt: -1);
28242 continue;
28243 }
28244
28245 // Canonicalize the shuffle index. We don't know yet if CurrentVec
28246 // will be the first or second operand of the combined shuffle.
28247 Idx = Idx % NumElts;
28248 if (!SV0.getNode() || SV0 == CurrentVec) {
28249 // Ok. CurrentVec is the left hand side.
28250 // Update the mask accordingly.
28251 SV0 = CurrentVec;
28252 Mask.push_back(Elt: Idx);
28253 continue;
28254 }
28255 if (!SV1.getNode() || SV1 == CurrentVec) {
28256 // Ok. CurrentVec is the right hand side.
28257 // Update the mask accordingly.
28258 SV1 = CurrentVec;
28259 Mask.push_back(Elt: Idx + NumElts);
28260 continue;
28261 }
28262
28263 // Last chance - see if the vector is another shuffle and if it
28264 // uses one of the existing candidate shuffle ops.
28265 if (auto *CurrentSVN = dyn_cast<ShuffleVectorSDNode>(Val&: CurrentVec)) {
28266 int InnerIdx = CurrentSVN->getMaskElt(Idx);
28267 if (InnerIdx < 0) {
28268 Mask.push_back(Elt: -1);
28269 continue;
28270 }
28271 SDValue InnerVec = (InnerIdx < (int)NumElts)
28272 ? CurrentSVN->getOperand(Num: 0)
28273 : CurrentSVN->getOperand(Num: 1);
28274 if (InnerVec.isUndef()) {
28275 Mask.push_back(Elt: -1);
28276 continue;
28277 }
28278 InnerIdx %= NumElts;
28279 if (InnerVec == SV0) {
28280 Mask.push_back(Elt: InnerIdx);
28281 continue;
28282 }
28283 if (InnerVec == SV1) {
28284 Mask.push_back(Elt: InnerIdx + NumElts);
28285 continue;
28286 }
28287 }
28288
28289 // Bail out if we cannot convert the shuffle pair into a single shuffle.
28290 return false;
28291 }
28292
28293 if (llvm::all_of(Range&: Mask, P: [](int M) { return M < 0; }))
28294 return true;
28295
28296 // Avoid introducing shuffles with illegal mask.
28297 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
28298 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
28299 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
28300 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, A, M2)
28301 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, A, M2)
28302 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, B, M2)
28303 if (TLI.isShuffleMaskLegal(Mask, VT))
28304 return true;
28305
28306 std::swap(a&: SV0, b&: SV1);
28307 ShuffleVectorSDNode::commuteMask(Mask);
28308 return TLI.isShuffleMaskLegal(Mask, VT);
28309 };
28310
28311 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
28312 // Canonicalize shuffles according to rules:
28313 // shuffle(A, shuffle(A, B)) -> shuffle(shuffle(A,B), A)
28314 // shuffle(B, shuffle(A, B)) -> shuffle(shuffle(A,B), B)
28315 // shuffle(B, shuffle(A, Undef)) -> shuffle(shuffle(A, Undef), B)
28316 if (N1.getOpcode() == ISD::VECTOR_SHUFFLE &&
28317 N0.getOpcode() != ISD::VECTOR_SHUFFLE) {
28318 // The incoming shuffle must be of the same type as the result of the
28319 // current shuffle.
28320 assert(N1->getOperand(0).getValueType() == VT &&
28321 "Shuffle types don't match");
28322
28323 SDValue SV0 = N1->getOperand(Num: 0);
28324 SDValue SV1 = N1->getOperand(Num: 1);
28325 bool HasSameOp0 = N0 == SV0;
28326 bool IsSV1Undef = SV1.isUndef();
28327 if (HasSameOp0 || IsSV1Undef || N0 == SV1)
28328 // Commute the operands of this shuffle so merging below will trigger.
28329 return DAG.getCommutedVectorShuffle(SV: *SVN);
28330 }
28331
28332 // Canonicalize splat shuffles to the RHS to improve merging below.
28333 // shuffle(splat(A,u), shuffle(C,D)) -> shuffle'(shuffle(C,D), splat(A,u))
28334 if (N0.getOpcode() == ISD::VECTOR_SHUFFLE &&
28335 N1.getOpcode() == ISD::VECTOR_SHUFFLE &&
28336 cast<ShuffleVectorSDNode>(Val&: N0)->isSplat() &&
28337 !cast<ShuffleVectorSDNode>(Val&: N1)->isSplat()) {
28338 return DAG.getCommutedVectorShuffle(SV: *SVN);
28339 }
28340
28341 // Try to fold according to rules:
28342 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
28343 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
28344 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
28345 // Don't try to fold shuffles with illegal type.
28346 // Only fold if this shuffle is the only user of the other shuffle.
28347 // Try matching shuffle(C,shuffle(A,B)) commutted patterns as well.
28348 for (int i = 0; i != 2; ++i) {
28349 if (N->getOperand(Num: i).getOpcode() == ISD::VECTOR_SHUFFLE &&
28350 N->isOnlyUserOf(N: N->getOperand(Num: i).getNode())) {
28351 // The incoming shuffle must be of the same type as the result of the
28352 // current shuffle.
28353 auto *OtherSV = cast<ShuffleVectorSDNode>(Val: N->getOperand(Num: i));
28354 assert(OtherSV->getOperand(0).getValueType() == VT &&
28355 "Shuffle types don't match");
28356
28357 SDValue SV0, SV1;
28358 SmallVector<int, 4> Mask;
28359 if (MergeInnerShuffle(i != 0, SVN, OtherSV, N->getOperand(Num: 1 - i), TLI,
28360 SV0, SV1, Mask)) {
28361 // Check if all indices in Mask are poison. In case, propagate poison.
28362 if (llvm::all_of(Range&: Mask, P: [](int M) { return M < 0; }))
28363 return DAG.getPOISON(VT);
28364
28365 return DAG.getVectorShuffle(VT, dl: SDLoc(N),
28366 N1: SV0 ? SV0 : DAG.getPOISON(VT),
28367 N2: SV1 ? SV1 : DAG.getPOISON(VT), Mask);
28368 }
28369 }
28370 }
28371
28372 // Merge shuffles through binops if we are able to merge it with at least
28373 // one other shuffles.
28374 // shuffle(bop(shuffle(x,y),shuffle(z,w)),undef)
28375 // shuffle(bop(shuffle(x,y),shuffle(z,w)),bop(shuffle(a,b),shuffle(c,d)))
28376 unsigned SrcOpcode = N0.getOpcode();
28377 if (TLI.isBinOp(Opcode: SrcOpcode) && N->isOnlyUserOf(N: N0.getNode()) &&
28378 (N1.isUndef() ||
28379 (SrcOpcode == N1.getOpcode() && N->isOnlyUserOf(N: N1.getNode()) &&
28380 N0.getResNo() == N1.getResNo()))) {
28381 // Get binop source ops, or just pass on the undef.
28382 SDValue Op00 = N0.getOperand(i: 0);
28383 SDValue Op01 = N0.getOperand(i: 1);
28384 SDValue Op10 = N1.isUndef() ? N1 : N1.getOperand(i: 0);
28385 SDValue Op11 = N1.isUndef() ? N1 : N1.getOperand(i: 1);
28386 // TODO: We might be able to relax the VT check but we don't currently
28387 // have any isBinOp() that has different result/ops VTs so play safe until
28388 // we have test coverage.
28389 if (Op00.getValueType() == VT && Op10.getValueType() == VT &&
28390 Op01.getValueType() == VT && Op11.getValueType() == VT &&
28391 (Op00.getOpcode() == ISD::VECTOR_SHUFFLE ||
28392 Op10.getOpcode() == ISD::VECTOR_SHUFFLE ||
28393 Op01.getOpcode() == ISD::VECTOR_SHUFFLE ||
28394 Op11.getOpcode() == ISD::VECTOR_SHUFFLE)) {
28395 auto CanMergeInnerShuffle = [&](SDValue &SV0, SDValue &SV1,
28396 SmallVectorImpl<int> &Mask, bool LeftOp,
28397 bool Commute) {
28398 SDValue InnerN = Commute ? N1 : N0;
28399 SDValue Op0 = LeftOp ? Op00 : Op01;
28400 SDValue Op1 = LeftOp ? Op10 : Op11;
28401 if (Commute)
28402 std::swap(a&: Op0, b&: Op1);
28403 // Only accept the merged shuffle if we don't introduce undef elements,
28404 // or the inner shuffle already contained undef elements.
28405 auto *SVN0 = dyn_cast<ShuffleVectorSDNode>(Val&: Op0);
28406 return SVN0 && InnerN->isOnlyUserOf(N: SVN0) &&
28407 MergeInnerShuffle(Commute, SVN, SVN0, Op1, TLI, SV0, SV1,
28408 Mask) &&
28409 (llvm::any_of(Range: SVN0->getMask(), P: [](int M) { return M < 0; }) ||
28410 llvm::none_of(Range&: Mask, P: [](int M) { return M < 0; }));
28411 };
28412
28413 // Ensure we don't increase the number of shuffles - we must merge a
28414 // shuffle from at least one of the LHS and RHS ops.
28415 bool MergedLeft = false;
28416 SDValue LeftSV0, LeftSV1;
28417 SmallVector<int, 4> LeftMask;
28418 if (CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, false) ||
28419 CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, true)) {
28420 MergedLeft = true;
28421 } else {
28422 LeftMask.assign(in_start: SVN->getMask().begin(), in_end: SVN->getMask().end());
28423 LeftSV0 = Op00, LeftSV1 = Op10;
28424 }
28425
28426 bool MergedRight = false;
28427 SDValue RightSV0, RightSV1;
28428 SmallVector<int, 4> RightMask;
28429 if (CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, false) ||
28430 CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, true)) {
28431 MergedRight = true;
28432 } else {
28433 RightMask.assign(in_start: SVN->getMask().begin(), in_end: SVN->getMask().end());
28434 RightSV0 = Op01, RightSV1 = Op11;
28435 }
28436
28437 if (MergedLeft || MergedRight) {
28438 SDLoc DL(N);
28439 SDValue LHS = DAG.getVectorShuffle(
28440 VT, dl: DL, N1: LeftSV0 ? LeftSV0 : DAG.getPOISON(VT),
28441 N2: LeftSV1 ? LeftSV1 : DAG.getPOISON(VT), Mask: LeftMask);
28442 SDValue RHS = DAG.getVectorShuffle(
28443 VT, dl: DL, N1: RightSV0 ? RightSV0 : DAG.getPOISON(VT),
28444 N2: RightSV1 ? RightSV1 : DAG.getPOISON(VT), Mask: RightMask);
28445 return DAG.getNode(Opcode: SrcOpcode, DL, VTList: N0->getVTList(), N1: LHS, N2: RHS)
28446 .getValue(R: N0.getResNo());
28447 }
28448 }
28449 }
28450 }
28451
28452 if (SDValue V = foldShuffleOfConcatUndefs(Shuf: SVN, DAG))
28453 return V;
28454
28455 // Match shuffles that can be converted to ISD::ZERO_EXTEND_VECTOR_INREG.
28456 // Perform this really late, because it could eliminate knowledge
28457 // of undef elements created by this shuffle.
28458 if (Level < AfterLegalizeTypes)
28459 if (SDValue V = combineShuffleToZeroExtendVectorInReg(SVN, DAG, TLI,
28460 LegalOperations))
28461 return V;
28462
28463 return SDValue();
28464}
28465
28466SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) {
28467 EVT VT = N->getValueType(ResNo: 0);
28468 if (!VT.isFixedLengthVector())
28469 return SDValue();
28470
28471 // Try to convert a scalar binop with an extracted vector element to a vector
28472 // binop. This is intended to reduce potentially expensive register moves.
28473 // TODO: Check if both operands are extracted.
28474 // TODO: How to prefer scalar/vector ops with multiple uses of the extact?
28475 // TODO: Generalize this, so it can be called from visitINSERT_VECTOR_ELT().
28476 SDValue Scalar = N->getOperand(Num: 0);
28477 unsigned Opcode = Scalar.getOpcode();
28478 EVT VecEltVT = VT.getScalarType();
28479 if (Scalar.hasOneUse() && Scalar->getNumValues() == 1 &&
28480 TLI.isBinOp(Opcode) && Scalar.getValueType() == VecEltVT &&
28481 Scalar.getOperand(i: 0).getValueType() == VecEltVT &&
28482 Scalar.getOperand(i: 1).getValueType() == VecEltVT &&
28483 Scalar->isOnlyUserOf(N: Scalar.getOperand(i: 0).getNode()) &&
28484 Scalar->isOnlyUserOf(N: Scalar.getOperand(i: 1).getNode()) &&
28485 DAG.isSafeToSpeculativelyExecute(Opcode) && hasOperation(Opcode, VT)) {
28486 // Match an extract element and get a shuffle mask equivalent.
28487 SmallVector<int, 8> ShufMask(VT.getVectorNumElements(), -1);
28488
28489 for (int i : {0, 1}) {
28490 // s2v (bo (extelt V, Idx), C) --> shuffle (bo V, C'), {Idx, -1, -1...}
28491 // s2v (bo C, (extelt V, Idx)) --> shuffle (bo C', V), {Idx, -1, -1...}
28492 SDValue EE = Scalar.getOperand(i);
28493 auto *C = dyn_cast<ConstantSDNode>(Val: Scalar.getOperand(i: i ? 0 : 1));
28494 if (C && EE.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
28495 EE.getOperand(i: 0).getValueType() == VT &&
28496 isa<ConstantSDNode>(Val: EE.getOperand(i: 1))) {
28497 // Mask = {ExtractIndex, undef, undef....}
28498 ShufMask[0] = EE.getConstantOperandVal(i: 1);
28499 // Make sure the shuffle is legal if we are crossing lanes.
28500 if (TLI.isShuffleMaskLegal(ShufMask, VT)) {
28501 SDLoc DL(N);
28502 SDValue V[] = {EE.getOperand(i: 0),
28503 DAG.getConstant(Val: C->getAPIntValue(), DL, VT)};
28504 SDValue VecBO = DAG.getNode(Opcode, DL, VT, N1: V[i], N2: V[1 - i]);
28505 return DAG.getVectorShuffle(VT, dl: DL, N1: VecBO, N2: DAG.getPOISON(VT),
28506 Mask: ShufMask);
28507 }
28508 }
28509 }
28510 }
28511
28512 // Replace a SCALAR_TO_VECTOR(EXTRACT_VECTOR_ELT(V,C0)) pattern
28513 // with a VECTOR_SHUFFLE and possible truncate.
28514 if (Opcode != ISD::EXTRACT_VECTOR_ELT ||
28515 !Scalar.getOperand(i: 0).getValueType().isFixedLengthVector())
28516 return SDValue();
28517
28518 // If we have an implicit truncate, truncate here if it is legal.
28519 if (VecEltVT != Scalar.getValueType() &&
28520 Scalar.getValueType().isScalarInteger() && isTypeLegal(VT: VecEltVT)) {
28521 SDValue Val = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(Scalar), VT: VecEltVT, Operand: Scalar);
28522 return DAG.getNode(Opcode: ISD::SCALAR_TO_VECTOR, DL: SDLoc(N), VT, Operand: Val);
28523 }
28524
28525 auto *ExtIndexC = dyn_cast<ConstantSDNode>(Val: Scalar.getOperand(i: 1));
28526 if (!ExtIndexC)
28527 return SDValue();
28528
28529 SDValue SrcVec = Scalar.getOperand(i: 0);
28530 EVT SrcVT = SrcVec.getValueType();
28531 unsigned SrcNumElts = SrcVT.getVectorNumElements();
28532 unsigned VTNumElts = VT.getVectorNumElements();
28533 if (VecEltVT == SrcVT.getScalarType() && VTNumElts <= SrcNumElts) {
28534 // Create a shuffle equivalent for scalar-to-vector: {ExtIndex, -1, -1, ...}
28535 SmallVector<int, 8> Mask(SrcNumElts, -1);
28536 Mask[0] = ExtIndexC->getZExtValue();
28537 SDValue LegalShuffle = TLI.buildLegalVectorShuffle(
28538 VT: SrcVT, DL: SDLoc(N), N0: SrcVec, N1: DAG.getPOISON(VT: SrcVT), Mask, DAG);
28539 if (!LegalShuffle)
28540 return SDValue();
28541
28542 // If the initial vector is the same size, the shuffle is the result.
28543 if (VT == SrcVT)
28544 return LegalShuffle;
28545
28546 // If not, shorten the shuffled vector.
28547 if (VTNumElts != SrcNumElts) {
28548 SDValue ZeroIdx = DAG.getVectorIdxConstant(Val: 0, DL: SDLoc(N));
28549 EVT SubVT = EVT::getVectorVT(Context&: *DAG.getContext(),
28550 VT: SrcVT.getVectorElementType(), NumElements: VTNumElts);
28551 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL: SDLoc(N), VT: SubVT, N1: LegalShuffle,
28552 N2: ZeroIdx);
28553 }
28554 }
28555
28556 return SDValue();
28557}
28558
28559SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) {
28560 EVT VT = N->getValueType(ResNo: 0);
28561 SDValue N0 = N->getOperand(Num: 0);
28562 SDValue N1 = N->getOperand(Num: 1);
28563 SDValue N2 = N->getOperand(Num: 2);
28564 uint64_t InsIdx = N->getConstantOperandVal(Num: 2);
28565
28566 // Remove insert of UNDEF/POISON.
28567 if (N1.isUndef()) {
28568 if (N1.getOpcode() == ISD::POISON || N0.getOpcode() == ISD::UNDEF)
28569 return N0;
28570 return DAG.getFreeze(V: N0);
28571 }
28572
28573 // If this is an insert of an extracted vector into an undef/poison vector, we
28574 // can just use the input to the extract if the types match, and can simplify
28575 // in some cases even if they don't.
28576 if (N0.isUndef() && N1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
28577 N1.getOperand(i: 1) == N2) {
28578 EVT N1VT = N1.getValueType();
28579 EVT SrcVT = N1.getOperand(i: 0).getValueType();
28580 if (SrcVT == VT) {
28581 // Need to ensure that result isn't more poisonous if skipping both the
28582 // extract+insert.
28583 if (N0.getOpcode() == ISD::POISON)
28584 return N1.getOperand(i: 0);
28585 if (VT.isFixedLengthVector() && N1VT.isFixedLengthVector()) {
28586 unsigned SubVecNumElts = N1VT.getVectorNumElements();
28587 APInt EltMask = APInt::getBitsSet(numBits: VT.getVectorNumElements(), loBit: InsIdx,
28588 hiBit: InsIdx + SubVecNumElts);
28589 if (DAG.isGuaranteedNotToBePoison(Op: N1.getOperand(i: 0), DemandedElts: ~EltMask))
28590 return N1.getOperand(i: 0);
28591 } else if (DAG.isGuaranteedNotToBePoison(Op: N1.getOperand(i: 0)))
28592 return N1.getOperand(i: 0);
28593 }
28594 // TODO: To remove the zero check, need to adjust the offset to
28595 // a multiple of the new src type.
28596 if (isNullConstant(V: N2)) {
28597 if (VT.knownBitsGE(VT: SrcVT) &&
28598 !(VT.isFixedLengthVector() && SrcVT.isScalableVector()))
28599 return DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N),
28600 VT, N1: N0, N2: N1.getOperand(i: 0), N3: N2);
28601 else if (VT.knownBitsLE(VT: SrcVT) &&
28602 !(VT.isScalableVector() && SrcVT.isFixedLengthVector()))
28603 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL: SDLoc(N),
28604 VT, N1: N1.getOperand(i: 0), N2);
28605 }
28606 }
28607
28608 // Handle case where we've ended up inserting back into the source vector
28609 // we extracted the subvector from.
28610 // insert_subvector(N0, extract_subvector(N0, N2), N2) --> N0
28611 if (N1.getOpcode() == ISD::EXTRACT_SUBVECTOR && N1.getOperand(i: 0) == N0 &&
28612 N1.getOperand(i: 1) == N2)
28613 return N0;
28614
28615 // Simplify scalar inserts into an undef vector:
28616 // insert_subvector undef, (splat X), N2 -> splat X
28617 if (N0.isUndef() && N1.getOpcode() == ISD::SPLAT_VECTOR)
28618 if (DAG.isConstantValueOfAnyType(N: N1.getOperand(i: 0)) || N1.hasOneUse())
28619 return DAG.getNode(Opcode: ISD::SPLAT_VECTOR, DL: SDLoc(N), VT, Operand: N1.getOperand(i: 0));
28620
28621 // insert_subvector (splat X), (splat X), N2 -> splat X
28622 if (N0.getOpcode() == ISD::SPLAT_VECTOR && N0.getOpcode() == N1.getOpcode() &&
28623 N0.getOperand(i: 0) == N1.getOperand(i: 0))
28624 return N0;
28625
28626 // If we are inserting a bitcast value into an undef, with the same
28627 // number of elements, just use the bitcast input of the extract.
28628 // i.e. INSERT_SUBVECTOR UNDEF (BITCAST N1) N2 ->
28629 // BITCAST (INSERT_SUBVECTOR UNDEF N1 N2)
28630 if (N0.isUndef() && N1.getOpcode() == ISD::BITCAST &&
28631 N1.getOperand(i: 0).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
28632 N1.getOperand(i: 0).getOperand(i: 1) == N2 &&
28633 N1.getOperand(i: 0).getOperand(i: 0).getValueType().getVectorElementCount() ==
28634 VT.getVectorElementCount() &&
28635 N1.getOperand(i: 0).getOperand(i: 0).getValueType().getSizeInBits() ==
28636 VT.getSizeInBits()) {
28637 return DAG.getBitcast(VT, V: N1.getOperand(i: 0).getOperand(i: 0));
28638 }
28639
28640 // If both N1 and N2 are bitcast values on which insert_subvector
28641 // would makes sense, pull the bitcast through.
28642 // i.e. INSERT_SUBVECTOR (BITCAST N0) (BITCAST N1) N2 ->
28643 // BITCAST (INSERT_SUBVECTOR N0 N1 N2)
28644 if (N0.getOpcode() == ISD::BITCAST && N1.getOpcode() == ISD::BITCAST) {
28645 SDValue CN0 = N0.getOperand(i: 0);
28646 SDValue CN1 = N1.getOperand(i: 0);
28647 EVT CN0VT = CN0.getValueType();
28648 EVT CN1VT = CN1.getValueType();
28649 if (CN0VT.isVector() && CN1VT.isVector() &&
28650 CN0VT.getVectorElementType() == CN1VT.getVectorElementType() &&
28651 CN0VT.getVectorElementCount() == VT.getVectorElementCount()) {
28652 SDValue NewINSERT = DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N),
28653 VT: CN0.getValueType(), N1: CN0, N2: CN1, N3: N2);
28654 return DAG.getBitcast(VT, V: NewINSERT);
28655 }
28656 }
28657
28658 // Combine INSERT_SUBVECTORs where we are inserting to the same index.
28659 // INSERT_SUBVECTOR( INSERT_SUBVECTOR( Vec, SubOld, Idx ), SubNew, Idx )
28660 // --> INSERT_SUBVECTOR( Vec, SubNew, Idx )
28661 if (N0.getOpcode() == ISD::INSERT_SUBVECTOR &&
28662 N0.getOperand(i: 1).getValueType() == N1.getValueType() &&
28663 N0.getOperand(i: 2) == N2)
28664 return DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N), VT, N1: N0.getOperand(i: 0),
28665 N2: N1, N3: N2);
28666
28667 // Eliminate an intermediate insert into an undef vector:
28668 // insert_subvector undef, (insert_subvector undef, X, 0), 0 -->
28669 // insert_subvector undef, X, 0
28670 if (N0.isUndef() && N1.getOpcode() == ISD::INSERT_SUBVECTOR &&
28671 N1.getOperand(i: 0).isUndef() && isNullConstant(V: N1.getOperand(i: 2)) &&
28672 isNullConstant(V: N2))
28673 return DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N), VT, N1: N0,
28674 N2: N1.getOperand(i: 1), N3: N2);
28675
28676 // Push subvector bitcasts to the output, adjusting the index as we go.
28677 // insert_subvector(bitcast(v), bitcast(s), c1)
28678 // -> bitcast(insert_subvector(v, s, c2))
28679 if ((N0.isUndef() || N0.getOpcode() == ISD::BITCAST) &&
28680 N1.getOpcode() == ISD::BITCAST) {
28681 SDValue N0Src = peekThroughBitcasts(V: N0);
28682 SDValue N1Src = peekThroughBitcasts(V: N1);
28683 EVT N0SrcSVT = N0Src.getValueType().getScalarType();
28684 EVT N1SrcSVT = N1Src.getValueType().getScalarType();
28685 if ((N0.isUndef() || N0SrcSVT == N1SrcSVT) &&
28686 N0Src.getValueType().isVector() && N1Src.getValueType().isVector()) {
28687 EVT NewVT;
28688 SDLoc DL(N);
28689 SDValue NewIdx;
28690 LLVMContext &Ctx = *DAG.getContext();
28691 ElementCount NumElts = VT.getVectorElementCount();
28692 unsigned EltSizeInBits = VT.getScalarSizeInBits();
28693 if ((EltSizeInBits % N1SrcSVT.getSizeInBits()) == 0) {
28694 unsigned Scale = EltSizeInBits / N1SrcSVT.getSizeInBits();
28695 NewVT = EVT::getVectorVT(Context&: Ctx, VT: N1SrcSVT, EC: NumElts * Scale);
28696 NewIdx = DAG.getVectorIdxConstant(Val: InsIdx * Scale, DL);
28697 } else if ((N1SrcSVT.getSizeInBits() % EltSizeInBits) == 0) {
28698 unsigned Scale = N1SrcSVT.getSizeInBits() / EltSizeInBits;
28699 if (NumElts.isKnownMultipleOf(RHS: Scale) && (InsIdx % Scale) == 0) {
28700 NewVT = EVT::getVectorVT(Context&: Ctx, VT: N1SrcSVT,
28701 EC: NumElts.divideCoefficientBy(RHS: Scale));
28702 NewIdx = DAG.getVectorIdxConstant(Val: InsIdx / Scale, DL);
28703 }
28704 }
28705 if (NewIdx && hasOperation(Opcode: ISD::INSERT_SUBVECTOR, VT: NewVT)) {
28706 SDValue Res = DAG.getBitcast(VT: NewVT, V: N0Src);
28707 Res = DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL, VT: NewVT, N1: Res, N2: N1Src, N3: NewIdx);
28708 return DAG.getBitcast(VT, V: Res);
28709 }
28710 }
28711 }
28712
28713 // Canonicalize insert_subvector dag nodes.
28714 // Example:
28715 // (insert_subvector (insert_subvector A, Idx0), Idx1)
28716 // -> (insert_subvector (insert_subvector A, Idx1), Idx0)
28717 if (N0.getOpcode() == ISD::INSERT_SUBVECTOR && N0.hasOneUse() &&
28718 N1.getValueType() == N0.getOperand(i: 1).getValueType()) {
28719 unsigned OtherIdx = N0.getConstantOperandVal(i: 2);
28720 if (InsIdx < OtherIdx) {
28721 // Swap nodes.
28722 SDValue NewOp = DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N), VT,
28723 N1: N0.getOperand(i: 0), N2: N1, N3: N2);
28724 AddToWorklist(N: NewOp.getNode());
28725 return DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N0.getNode()),
28726 VT, N1: NewOp, N2: N0.getOperand(i: 1), N3: N0.getOperand(i: 2));
28727 }
28728 }
28729
28730 // If the input vector is a concatenation, and the insert replaces
28731 // one of the pieces, we can optimize into a single concat_vectors.
28732 if (N0.getOpcode() == ISD::CONCAT_VECTORS && N0.hasOneUse() &&
28733 N0.getOperand(i: 0).getValueType() == N1.getValueType() &&
28734 N0.getOperand(i: 0).getValueType().isScalableVector() ==
28735 N1.getValueType().isScalableVector()) {
28736 unsigned Factor = N1.getValueType().getVectorMinNumElements();
28737 SmallVector<SDValue, 8> Ops(N0->ops());
28738 Ops[InsIdx / Factor] = N1;
28739 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT, Ops);
28740 }
28741
28742 // Simplify source operands based on insertion.
28743 if (SimplifyDemandedVectorElts(Op: SDValue(N, 0)))
28744 return SDValue(N, 0);
28745
28746 return SDValue();
28747}
28748
28749SDValue DAGCombiner::visitFP_TO_FP16(SDNode *N) {
28750 SDValue N0 = N->getOperand(Num: 0);
28751
28752 // fold (fp_to_fp16 (fp16_to_fp op)) -> op
28753 if (N0->getOpcode() == ISD::FP16_TO_FP)
28754 return N0->getOperand(Num: 0);
28755
28756 return SDValue();
28757}
28758
28759SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
28760 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
28761 auto Op = N->getOpcode();
28762 assert((Op == ISD::FP16_TO_FP || Op == ISD::BF16_TO_FP) &&
28763 "opcode should be FP16_TO_FP or BF16_TO_FP.");
28764 SDValue N0 = N->getOperand(Num: 0);
28765
28766 // fold fp16_to_fp(op & 0xffff) -> fp16_to_fp(op) or
28767 // fold bf16_to_fp(op & 0xffff) -> bf16_to_fp(op)
28768 if (!TLI.shouldKeepZExtForFP16Conv() && N0->getOpcode() == ISD::AND) {
28769 ConstantSDNode *AndConst = getAsNonOpaqueConstant(N: N0.getOperand(i: 1));
28770 if (AndConst && AndConst->getAPIntValue() == 0xffff) {
28771 return DAG.getNode(Opcode: Op, DL: SDLoc(N), VT: N->getValueType(ResNo: 0), Operand: N0.getOperand(i: 0));
28772 }
28773 }
28774
28775 if (SDValue CastEliminated = eliminateFPCastPair(N))
28776 return CastEliminated;
28777
28778 // Sometimes constants manage to survive very late in the pipeline, e.g.,
28779 // because they are wrapped inside the <1 x f16> type. Try one last time to
28780 // get rid of them.
28781 SDValue Folded = DAG.FoldConstantArithmetic(Opcode: N->getOpcode(), DL: SDLoc(N),
28782 VT: N->getValueType(ResNo: 0), Ops: {N0});
28783 return Folded;
28784}
28785
28786SDValue DAGCombiner::visitFP_TO_BF16(SDNode *N) {
28787 SDValue N0 = N->getOperand(Num: 0);
28788
28789 // fold (fp_to_bf16 (bf16_to_fp op)) -> op
28790 if (N0->getOpcode() == ISD::BF16_TO_FP)
28791 return N0->getOperand(Num: 0);
28792
28793 return SDValue();
28794}
28795
28796SDValue DAGCombiner::visitBF16_TO_FP(SDNode *N) {
28797 // fold bf16_to_fp(op & 0xffff) -> bf16_to_fp(op)
28798 return visitFP16_TO_FP(N);
28799}
28800
28801SDValue DAGCombiner::visitVECREDUCE(SDNode *N) {
28802 SDValue N0 = N->getOperand(Num: 0);
28803 EVT VT = N0.getValueType();
28804 unsigned Opcode = N->getOpcode();
28805
28806 // VECREDUCE over 1-element vector is just an extract.
28807 if (VT.getVectorElementCount().isScalar()) {
28808 SDLoc dl(N);
28809 SDValue Res =
28810 DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: dl, VT: VT.getVectorElementType(), N1: N0,
28811 N2: DAG.getVectorIdxConstant(Val: 0, DL: dl));
28812 if (Res.getValueType() != N->getValueType(ResNo: 0))
28813 Res = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL: dl, VT: N->getValueType(ResNo: 0), Operand: Res);
28814 return Res;
28815 }
28816
28817 // On an boolean vector an and/or reduction is the same as a umin/umax
28818 // reduction. Convert them if the latter is legal while the former isn't.
28819 if (Opcode == ISD::VECREDUCE_AND || Opcode == ISD::VECREDUCE_OR) {
28820 unsigned NewOpcode = Opcode == ISD::VECREDUCE_AND
28821 ? ISD::VECREDUCE_UMIN : ISD::VECREDUCE_UMAX;
28822 if (!TLI.isOperationLegalOrCustom(Op: Opcode, VT) &&
28823 TLI.isOperationLegalOrCustom(Op: NewOpcode, VT) &&
28824 DAG.ComputeNumSignBits(Op: N0) == VT.getScalarSizeInBits())
28825 return DAG.getNode(Opcode: NewOpcode, DL: SDLoc(N), VT: N->getValueType(ResNo: 0), Operand: N0);
28826 }
28827
28828 // vecreduce_or(insert_subvector(zero or undef, val)) -> vecreduce_or(val)
28829 // vecreduce_and(insert_subvector(ones or undef, val)) -> vecreduce_and(val)
28830 if (N0.getOpcode() == ISD::INSERT_SUBVECTOR &&
28831 TLI.isTypeLegal(VT: N0.getOperand(i: 1).getValueType())) {
28832 SDValue Vec = N0.getOperand(i: 0);
28833 SDValue Subvec = N0.getOperand(i: 1);
28834 if ((Opcode == ISD::VECREDUCE_OR &&
28835 (N0.getOperand(i: 0).isUndef() || isNullOrNullSplat(V: Vec))) ||
28836 (Opcode == ISD::VECREDUCE_AND &&
28837 (N0.getOperand(i: 0).isUndef() || isAllOnesOrAllOnesSplat(V: Vec))))
28838 return DAG.getNode(Opcode, DL: SDLoc(N), VT: N->getValueType(ResNo: 0), Operand: Subvec);
28839 }
28840
28841 // vecreduce_or(sext(x)) -> sext(vecreduce_or(x))
28842 // Same for zext and anyext, and for and/or/xor reductions.
28843 if ((Opcode == ISD::VECREDUCE_OR || Opcode == ISD::VECREDUCE_AND ||
28844 Opcode == ISD::VECREDUCE_XOR) &&
28845 (N0.getOpcode() == ISD::SIGN_EXTEND ||
28846 N0.getOpcode() == ISD::ZERO_EXTEND ||
28847 N0.getOpcode() == ISD::ANY_EXTEND) &&
28848 TLI.isOperationLegalOrCustom(Op: Opcode, VT: N0.getOperand(i: 0).getValueType())) {
28849 SDValue Red = DAG.getNode(Opcode, DL: SDLoc(N),
28850 VT: N0.getOperand(i: 0).getValueType().getScalarType(),
28851 Operand: N0.getOperand(i: 0));
28852 return DAG.getNode(Opcode: N0.getOpcode(), DL: SDLoc(N), VT: N->getValueType(ResNo: 0), Operand: Red);
28853 }
28854 return SDValue();
28855}
28856
28857SDValue DAGCombiner::visitVP_FSUB(SDNode *N) {
28858 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
28859
28860 // FSUB -> FMA combines:
28861 if (SDValue Fused = visitFSUBForFMACombine<VPMatchContext>(N)) {
28862 AddToWorklist(N: Fused.getNode());
28863 return Fused;
28864 }
28865 return SDValue();
28866}
28867
28868SDValue DAGCombiner::visitVPOp(SDNode *N) {
28869
28870 if (N->getOpcode() == ISD::VP_GATHER)
28871 if (SDValue SD = visitVPGATHER(N))
28872 return SD;
28873
28874 if (N->getOpcode() == ISD::VP_SCATTER)
28875 if (SDValue SD = visitVPSCATTER(N))
28876 return SD;
28877
28878 if (N->getOpcode() == ISD::EXPERIMENTAL_VP_STRIDED_LOAD)
28879 if (SDValue SD = visitVP_STRIDED_LOAD(N))
28880 return SD;
28881
28882 if (N->getOpcode() == ISD::EXPERIMENTAL_VP_STRIDED_STORE)
28883 if (SDValue SD = visitVP_STRIDED_STORE(N))
28884 return SD;
28885
28886 // VP operations in which all vector elements are disabled - either by
28887 // determining that the mask is all false or that the EVL is 0 - can be
28888 // eliminated.
28889 bool AreAllEltsDisabled = false;
28890 if (auto EVLIdx = ISD::getVPExplicitVectorLengthIdx(Opcode: N->getOpcode()))
28891 AreAllEltsDisabled |= isNullConstant(V: N->getOperand(Num: *EVLIdx));
28892 if (auto MaskIdx = ISD::getVPMaskIdx(Opcode: N->getOpcode()))
28893 AreAllEltsDisabled |=
28894 ISD::isConstantSplatVectorAllZeros(N: N->getOperand(Num: *MaskIdx).getNode());
28895
28896 // This is the only generic VP combine we support for now.
28897 if (!AreAllEltsDisabled) {
28898 switch (N->getOpcode()) {
28899 case ISD::VP_FADD:
28900 return visitVP_FADD(N);
28901 case ISD::VP_FSUB:
28902 return visitVP_FSUB(N);
28903 case ISD::VP_FMA:
28904 return visitFMA<VPMatchContext>(N);
28905 case ISD::VP_SELECT:
28906 return visitVP_SELECT(N);
28907 case ISD::VP_MUL:
28908 return visitMUL<VPMatchContext>(N);
28909 case ISD::VP_SUB:
28910 return foldSubCtlzNot<VPMatchContext>(N, DAG);
28911 default:
28912 break;
28913 }
28914 return SDValue();
28915 }
28916
28917 // Binary operations can be replaced by UNDEF.
28918 if (ISD::isVPBinaryOp(Opcode: N->getOpcode()))
28919 return DAG.getUNDEF(VT: N->getValueType(ResNo: 0));
28920
28921 // VP Memory operations can be replaced by either the chain (stores) or the
28922 // chain + undef (loads).
28923 if (const auto *MemSD = dyn_cast<MemSDNode>(Val: N)) {
28924 if (MemSD->writeMem())
28925 return MemSD->getChain();
28926 return CombineTo(N, Res0: DAG.getUNDEF(VT: N->getValueType(ResNo: 0)), Res1: MemSD->getChain());
28927 }
28928
28929 // Reduction operations return the start operand when no elements are active.
28930 if (ISD::isVPReduction(Opcode: N->getOpcode()))
28931 return N->getOperand(Num: 0);
28932
28933 return SDValue();
28934}
28935
28936SDValue DAGCombiner::visitGET_FPENV_MEM(SDNode *N) {
28937 SDValue Chain = N->getOperand(Num: 0);
28938 SDValue Ptr = N->getOperand(Num: 1);
28939 EVT MemVT = cast<FPStateAccessSDNode>(Val: N)->getMemoryVT();
28940
28941 // Check if the memory, where FP state is written to, is used only in a single
28942 // load operation.
28943 LoadSDNode *LdNode = nullptr;
28944 for (auto *U : Ptr->users()) {
28945 if (U == N)
28946 continue;
28947 if (auto *Ld = dyn_cast<LoadSDNode>(Val: U)) {
28948 if (LdNode && LdNode != Ld)
28949 return SDValue();
28950 LdNode = Ld;
28951 continue;
28952 }
28953 return SDValue();
28954 }
28955 if (!LdNode || !LdNode->isSimple() || LdNode->isIndexed() ||
28956 !LdNode->getOffset().isUndef() || LdNode->getMemoryVT() != MemVT ||
28957 !LdNode->getChain().reachesChainWithoutSideEffects(Dest: SDValue(N, 0)))
28958 return SDValue();
28959
28960 // Check if the loaded value is used only in a store operation.
28961 StoreSDNode *StNode = nullptr;
28962 for (SDUse &U : LdNode->uses()) {
28963 if (U.getResNo() == 0) {
28964 if (auto *St = dyn_cast<StoreSDNode>(Val: U.getUser())) {
28965 if (StNode)
28966 return SDValue();
28967 StNode = St;
28968 } else {
28969 return SDValue();
28970 }
28971 }
28972 }
28973 if (!StNode || !StNode->isSimple() || StNode->isIndexed() ||
28974 !StNode->getOffset().isUndef() || StNode->getMemoryVT() != MemVT ||
28975 !StNode->getChain().reachesChainWithoutSideEffects(Dest: SDValue(LdNode, 1)))
28976 return SDValue();
28977
28978 // Create new node GET_FPENV_MEM, which uses the store address to write FP
28979 // environment.
28980 SDValue Res = DAG.getGetFPEnv(Chain, dl: SDLoc(N), Ptr: StNode->getBasePtr(), MemVT,
28981 MMO: StNode->getMemOperand());
28982 CombineTo(N: StNode, Res, AddTo: false);
28983 return Res;
28984}
28985
28986SDValue DAGCombiner::visitSET_FPENV_MEM(SDNode *N) {
28987 SDValue Chain = N->getOperand(Num: 0);
28988 SDValue Ptr = N->getOperand(Num: 1);
28989 EVT MemVT = cast<FPStateAccessSDNode>(Val: N)->getMemoryVT();
28990
28991 // Check if the address of FP state is used also in a store operation only.
28992 StoreSDNode *StNode = nullptr;
28993 for (auto *U : Ptr->users()) {
28994 if (U == N)
28995 continue;
28996 if (auto *St = dyn_cast<StoreSDNode>(Val: U)) {
28997 if (StNode && StNode != St)
28998 return SDValue();
28999 StNode = St;
29000 continue;
29001 }
29002 return SDValue();
29003 }
29004 if (!StNode || !StNode->isSimple() || StNode->isIndexed() ||
29005 !StNode->getOffset().isUndef() || StNode->getMemoryVT() != MemVT ||
29006 !Chain.reachesChainWithoutSideEffects(Dest: SDValue(StNode, 0)))
29007 return SDValue();
29008
29009 // Check if the stored value is loaded from some location and the loaded
29010 // value is used only in the store operation.
29011 SDValue StValue = StNode->getValue();
29012 auto *LdNode = dyn_cast<LoadSDNode>(Val&: StValue);
29013 if (!LdNode || !LdNode->isSimple() || LdNode->isIndexed() ||
29014 !LdNode->getOffset().isUndef() || LdNode->getMemoryVT() != MemVT ||
29015 !StNode->getChain().reachesChainWithoutSideEffects(Dest: SDValue(LdNode, 1)))
29016 return SDValue();
29017
29018 // Create new node SET_FPENV_MEM, which uses the load address to read FP
29019 // environment.
29020 SDValue Res =
29021 DAG.getSetFPEnv(Chain: LdNode->getChain(), dl: SDLoc(N), Ptr: LdNode->getBasePtr(), MemVT,
29022 MMO: LdNode->getMemOperand());
29023 return Res;
29024}
29025
29026/// Returns a vector_shuffle if it able to transform an AND to a vector_shuffle
29027/// with the destination vector and a zero vector.
29028/// e.g. AND V, <0xffffffff, 0, 0xffffffff, 0>. ==>
29029/// vector_shuffle V, Zero, <0, 4, 2, 4>
29030SDValue DAGCombiner::XformToShuffleWithZero(SDNode *N) {
29031 assert(N->getOpcode() == ISD::AND && "Unexpected opcode!");
29032
29033 EVT VT = N->getValueType(ResNo: 0);
29034 SDValue LHS = N->getOperand(Num: 0);
29035 SDValue RHS = peekThroughBitcasts(V: N->getOperand(Num: 1));
29036 SDLoc DL(N);
29037
29038 // Make sure we're not running after operation legalization where it
29039 // may have custom lowered the vector shuffles.
29040 if (LegalOperations)
29041 return SDValue();
29042
29043 if (RHS.getOpcode() != ISD::BUILD_VECTOR)
29044 return SDValue();
29045
29046 EVT RVT = RHS.getValueType();
29047 unsigned NumElts = RHS.getNumOperands();
29048
29049 // Attempt to create a valid clear mask, splitting the mask into
29050 // sub elements and checking to see if each is
29051 // all zeros or all ones - suitable for shuffle masking.
29052 auto BuildClearMask = [&](int Split) {
29053 int NumSubElts = NumElts * Split;
29054 int NumSubBits = RVT.getScalarSizeInBits() / Split;
29055
29056 SmallVector<int, 8> Indices;
29057 for (int i = 0; i != NumSubElts; ++i) {
29058 int EltIdx = i / Split;
29059 int SubIdx = i % Split;
29060 SDValue Elt = RHS.getOperand(i: EltIdx);
29061 // X & undef --> 0 (not undef). So this lane must be converted to choose
29062 // from the zero constant vector (same as if the element had all 0-bits).
29063 if (Elt.isUndef()) {
29064 Indices.push_back(Elt: i + NumSubElts);
29065 continue;
29066 }
29067
29068 std::optional<APInt> Bits = Elt->bitcastToAPInt();
29069 if (!Bits)
29070 return SDValue();
29071
29072 // Extract the sub element from the constant bit mask.
29073 if (DAG.getDataLayout().isBigEndian())
29074 *Bits =
29075 Bits->extractBits(numBits: NumSubBits, bitPosition: (Split - SubIdx - 1) * NumSubBits);
29076 else
29077 *Bits = Bits->extractBits(numBits: NumSubBits, bitPosition: SubIdx * NumSubBits);
29078
29079 if (Bits->isAllOnes())
29080 Indices.push_back(Elt: i);
29081 else if (*Bits == 0)
29082 Indices.push_back(Elt: i + NumSubElts);
29083 else
29084 return SDValue();
29085 }
29086
29087 // Let's see if the target supports this vector_shuffle.
29088 EVT ClearSVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: NumSubBits);
29089 EVT ClearVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: ClearSVT, NumElements: NumSubElts);
29090 if (!TLI.isVectorClearMaskLegal(Indices, ClearVT))
29091 return SDValue();
29092
29093 SDValue Zero = DAG.getConstant(Val: 0, DL, VT: ClearVT);
29094 return DAG.getBitcast(VT, V: DAG.getVectorShuffle(VT: ClearVT, dl: DL,
29095 N1: DAG.getBitcast(VT: ClearVT, V: LHS),
29096 N2: Zero, Mask: Indices));
29097 };
29098
29099 // Determine maximum split level (byte level masking).
29100 int MaxSplit = 1;
29101 if (RVT.getScalarSizeInBits() % 8 == 0)
29102 MaxSplit = RVT.getScalarSizeInBits() / 8;
29103
29104 for (int Split = 1; Split <= MaxSplit; ++Split)
29105 if (RVT.getScalarSizeInBits() % Split == 0)
29106 if (SDValue S = BuildClearMask(Split))
29107 return S;
29108
29109 return SDValue();
29110}
29111
29112/// If a vector binop is performed on splat values, it may be profitable to
29113/// extract, scalarize, and insert/splat.
29114static SDValue scalarizeBinOpOfSplats(SDNode *N, SelectionDAG &DAG,
29115 const SDLoc &DL, bool LegalTypes) {
29116 SDValue N0 = N->getOperand(Num: 0);
29117 SDValue N1 = N->getOperand(Num: 1);
29118 unsigned Opcode = N->getOpcode();
29119 EVT VT = N->getValueType(ResNo: 0);
29120 EVT EltVT = VT.getVectorElementType();
29121 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
29122
29123 // TODO: Remove/replace the extract cost check? If the elements are available
29124 // as scalars, then there may be no extract cost. Should we ask if
29125 // inserting a scalar back into a vector is cheap instead?
29126 int Index0, Index1;
29127 SDValue Src0 = DAG.getSplatSourceVector(V: N0, SplatIndex&: Index0);
29128 SDValue Src1 = DAG.getSplatSourceVector(V: N1, SplatIndex&: Index1);
29129 // Extract element from splat_vector should be free.
29130 // TODO: use DAG.isSplatValue instead?
29131 bool IsBothSplatVector = N0.getOpcode() == ISD::SPLAT_VECTOR &&
29132 N1.getOpcode() == ISD::SPLAT_VECTOR;
29133 if (!Src0 || !Src1 || Index0 != Index1 ||
29134 Src0.getValueType().getVectorElementType() != EltVT ||
29135 Src1.getValueType().getVectorElementType() != EltVT ||
29136 !(IsBothSplatVector || TLI.isExtractVecEltCheap(VT, Index: Index0)) ||
29137 // If before type legalization, allow scalar types that will eventually be
29138 // made legal.
29139 !TLI.isOperationLegalOrCustom(
29140 Op: Opcode, VT: LegalTypes
29141 ? EltVT
29142 : TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT: EltVT)))
29143 return SDValue();
29144
29145 // FIXME: Type legalization can't handle illegal MULHS/MULHU.
29146 if ((Opcode == ISD::MULHS || Opcode == ISD::MULHU) && !TLI.isTypeLegal(VT: EltVT))
29147 return SDValue();
29148
29149 if (N0.getOpcode() == ISD::BUILD_VECTOR && N0.getOpcode() == N1.getOpcode()) {
29150 // All but one element should have an undef input, which will fold to a
29151 // constant or undef. Avoid splatting which would over-define potentially
29152 // undefined elements.
29153
29154 // bo (build_vec ..undef, X, undef...), (build_vec ..undef, Y, undef...) -->
29155 // build_vec ..undef, (bo X, Y), undef...
29156 SmallVector<SDValue, 16> EltsX, EltsY, EltsResult;
29157 DAG.ExtractVectorElements(Op: Src0, Args&: EltsX);
29158 DAG.ExtractVectorElements(Op: Src1, Args&: EltsY);
29159
29160 for (auto [X, Y] : zip(t&: EltsX, u&: EltsY))
29161 EltsResult.push_back(Elt: DAG.getNode(Opcode, DL, VT: EltVT, N1: X, N2: Y, Flags: N->getFlags()));
29162 return DAG.getBuildVector(VT, DL, Ops: EltsResult);
29163 }
29164
29165 SDValue IndexC = DAG.getVectorIdxConstant(Val: Index0, DL);
29166 SDValue X = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: Src0, N2: IndexC);
29167 SDValue Y = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: Src1, N2: IndexC);
29168 SDValue ScalarBO = DAG.getNode(Opcode, DL, VT: EltVT, N1: X, N2: Y, Flags: N->getFlags());
29169
29170 // bo (splat X, Index), (splat Y, Index) --> splat (bo X, Y), Index
29171 return DAG.getSplat(VT, DL, Op: ScalarBO);
29172}
29173
29174/// Visit a vector cast operation, like FP_EXTEND.
29175SDValue DAGCombiner::SimplifyVCastOp(SDNode *N, const SDLoc &DL) {
29176 EVT VT = N->getValueType(ResNo: 0);
29177 assert(VT.isVector() && "SimplifyVCastOp only works on vectors!");
29178 EVT EltVT = VT.getVectorElementType();
29179 unsigned Opcode = N->getOpcode();
29180
29181 SDValue N0 = N->getOperand(Num: 0);
29182 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
29183
29184 // TODO: promote operation might be also good here?
29185 int Index0;
29186 SDValue Src0 = DAG.getSplatSourceVector(V: N0, SplatIndex&: Index0);
29187 if (Src0 &&
29188 (N0.getOpcode() == ISD::SPLAT_VECTOR ||
29189 TLI.isExtractVecEltCheap(VT, Index: Index0)) &&
29190 TLI.isOperationLegalOrCustom(Op: Opcode, VT: EltVT) &&
29191 TLI.preferScalarizeSplat(N)) {
29192 EVT SrcVT = N0.getValueType();
29193 EVT SrcEltVT = SrcVT.getVectorElementType();
29194 if (!LegalTypes || TLI.isTypeLegal(VT: SrcEltVT)) {
29195 SDValue IndexC = DAG.getVectorIdxConstant(Val: Index0, DL);
29196 SDValue Elt =
29197 DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: SrcEltVT, N1: Src0, N2: IndexC);
29198 SDValue ScalarBO = DAG.getNode(Opcode, DL, VT: EltVT, Operand: Elt, Flags: N->getFlags());
29199 if (VT.isScalableVector())
29200 return DAG.getSplatVector(VT, DL, Op: ScalarBO);
29201 SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), ScalarBO);
29202 return DAG.getBuildVector(VT, DL, Ops);
29203 }
29204 }
29205
29206 return SDValue();
29207}
29208
29209/// Visit a binary vector operation, like ADD.
29210SDValue DAGCombiner::SimplifyVBinOp(SDNode *N, const SDLoc &DL) {
29211 EVT VT = N->getValueType(ResNo: 0);
29212 assert(VT.isVector() && "SimplifyVBinOp only works on vectors!");
29213
29214 SDValue LHS = N->getOperand(Num: 0);
29215 SDValue RHS = N->getOperand(Num: 1);
29216 unsigned Opcode = N->getOpcode();
29217 SDNodeFlags Flags = N->getFlags();
29218
29219 // Move unary shuffles with identical masks after a vector binop:
29220 // VBinOp (shuffle A, Undef, Mask), (shuffle B, Undef, Mask))
29221 // --> shuffle (VBinOp A, B), Undef, Mask
29222 // This does not require type legality checks because we are creating the
29223 // same types of operations that are in the original sequence. We do have to
29224 // restrict ops like integer div that have immediate UB (eg, div-by-zero)
29225 // though. This code is adapted from the identical transform in instcombine.
29226 if (DAG.isSafeToSpeculativelyExecute(Opcode)) {
29227 auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(Val&: LHS);
29228 auto *Shuf1 = dyn_cast<ShuffleVectorSDNode>(Val&: RHS);
29229 if (Shuf0 && Shuf1 && Shuf0->getMask().equals(RHS: Shuf1->getMask()) &&
29230 LHS.getOperand(i: 1).isUndef() && RHS.getOperand(i: 1).isUndef() &&
29231 (LHS.hasOneUse() || RHS.hasOneUse() || LHS == RHS)) {
29232 SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, N1: LHS.getOperand(i: 0),
29233 N2: RHS.getOperand(i: 0), Flags);
29234 SDValue UndefV = LHS.getOperand(i: 1);
29235 return DAG.getVectorShuffle(VT, dl: DL, N1: NewBinOp, N2: UndefV, Mask: Shuf0->getMask());
29236 }
29237
29238 // Try to sink a splat shuffle after a binop with a uniform constant.
29239 // This is limited to cases where neither the shuffle nor the constant have
29240 // undefined elements because that could be poison-unsafe or inhibit
29241 // demanded elements analysis. It is further limited to not change a splat
29242 // of an inserted scalar because that may be optimized better by
29243 // load-folding or other target-specific behaviors.
29244 if (isConstOrConstSplat(N: RHS) && Shuf0 && all_equal(Range: Shuf0->getMask()) &&
29245 Shuf0->hasOneUse() && Shuf0->getOperand(Num: 1).isUndef() &&
29246 Shuf0->getOperand(Num: 0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
29247 // binop (splat X), (splat C) --> splat (binop X, C)
29248 SDValue X = Shuf0->getOperand(Num: 0);
29249 SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, N1: X, N2: RHS, Flags);
29250 return DAG.getVectorShuffle(VT, dl: DL, N1: NewBinOp, N2: DAG.getPOISON(VT),
29251 Mask: Shuf0->getMask());
29252 }
29253 if (isConstOrConstSplat(N: LHS) && Shuf1 && all_equal(Range: Shuf1->getMask()) &&
29254 Shuf1->hasOneUse() && Shuf1->getOperand(Num: 1).isUndef() &&
29255 Shuf1->getOperand(Num: 0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
29256 // binop (splat C), (splat X) --> splat (binop C, X)
29257 SDValue X = Shuf1->getOperand(Num: 0);
29258 SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, N1: LHS, N2: X, Flags);
29259 return DAG.getVectorShuffle(VT, dl: DL, N1: NewBinOp, N2: DAG.getPOISON(VT),
29260 Mask: Shuf1->getMask());
29261 }
29262 }
29263
29264 // The following pattern is likely to emerge with vector reduction ops. Moving
29265 // the binary operation ahead of insertion may allow using a narrower vector
29266 // instruction that has better performance than the wide version of the op:
29267 // VBinOp (ins undef, X, Z), (ins undef, Y, Z) --> ins VecC, (VBinOp X, Y), Z
29268 if (LHS.getOpcode() == ISD::INSERT_SUBVECTOR && LHS.getOperand(i: 0).isUndef() &&
29269 RHS.getOpcode() == ISD::INSERT_SUBVECTOR && RHS.getOperand(i: 0).isUndef() &&
29270 LHS.getOperand(i: 2) == RHS.getOperand(i: 2) &&
29271 (LHS.hasOneUse() || RHS.hasOneUse())) {
29272 SDValue X = LHS.getOperand(i: 1);
29273 SDValue Y = RHS.getOperand(i: 1);
29274 SDValue Z = LHS.getOperand(i: 2);
29275 EVT NarrowVT = X.getValueType();
29276 if (NarrowVT == Y.getValueType() &&
29277 TLI.isOperationLegalOrCustomOrPromote(Op: Opcode, VT: NarrowVT,
29278 LegalOnly: LegalOperations)) {
29279 // (binop undef, undef) may not return undef, so compute that result.
29280 SDValue VecC =
29281 DAG.getNode(Opcode, DL, VT, N1: DAG.getUNDEF(VT), N2: DAG.getUNDEF(VT));
29282 SDValue NarrowBO = DAG.getNode(Opcode, DL, VT: NarrowVT, N1: X, N2: Y);
29283 return DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL, VT, N1: VecC, N2: NarrowBO, N3: Z);
29284 }
29285 }
29286
29287 // Make sure all but the first op are undef or constant.
29288 auto ConcatWithConstantOrUndef = [](SDValue Concat) {
29289 return Concat.getOpcode() == ISD::CONCAT_VECTORS &&
29290 all_of(Range: drop_begin(RangeOrContainer: Concat->ops()), P: [](const SDValue &Op) {
29291 return Op.isUndef() ||
29292 ISD::isBuildVectorOfConstantSDNodes(N: Op.getNode());
29293 });
29294 };
29295
29296 // The following pattern is likely to emerge with vector reduction ops. Moving
29297 // the binary operation ahead of the concat may allow using a narrower vector
29298 // instruction that has better performance than the wide version of the op:
29299 // VBinOp (concat X, undef/constant), (concat Y, undef/constant) -->
29300 // concat (VBinOp X, Y), VecC
29301 if (ConcatWithConstantOrUndef(LHS) && ConcatWithConstantOrUndef(RHS) &&
29302 (LHS.hasOneUse() || RHS.hasOneUse())) {
29303 EVT NarrowVT = LHS.getOperand(i: 0).getValueType();
29304 if (NarrowVT == RHS.getOperand(i: 0).getValueType() &&
29305 TLI.isOperationLegalOrCustomOrPromote(Op: Opcode, VT: NarrowVT)) {
29306 unsigned NumOperands = LHS.getNumOperands();
29307 SmallVector<SDValue, 4> ConcatOps;
29308 for (unsigned i = 0; i != NumOperands; ++i) {
29309 // This constant fold for operands 1 and up.
29310 ConcatOps.push_back(Elt: DAG.getNode(Opcode, DL, VT: NarrowVT, N1: LHS.getOperand(i),
29311 N2: RHS.getOperand(i)));
29312 }
29313
29314 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT, Ops: ConcatOps);
29315 }
29316 }
29317
29318 if (SDValue V = scalarizeBinOpOfSplats(N, DAG, DL, LegalTypes))
29319 return V;
29320
29321 return SDValue();
29322}
29323
29324SDValue DAGCombiner::SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1,
29325 SDValue N2) {
29326 assert(N0.getOpcode() == ISD::SETCC &&
29327 "First argument must be a SetCC node!");
29328
29329 SDValue SCC = SimplifySelectCC(DL, N0: N0.getOperand(i: 0), N1: N0.getOperand(i: 1), N2: N1, N3: N2,
29330 CC: cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get());
29331
29332 // If we got a simplified select_cc node back from SimplifySelectCC, then
29333 // break it down into a new SETCC node, and a new SELECT node, and then return
29334 // the SELECT node, since we were called with a SELECT node.
29335 if (SCC.getNode()) {
29336 // Check to see if we got a select_cc back (to turn into setcc/select).
29337 // Otherwise, just return whatever node we got back, like fabs.
29338 if (SCC.getOpcode() == ISD::SELECT_CC) {
29339 const SDNodeFlags Flags = N0->getFlags();
29340 SDValue SETCC = DAG.getNode(Opcode: ISD::SETCC, DL: SDLoc(N0),
29341 VT: N0.getValueType(),
29342 N1: SCC.getOperand(i: 0), N2: SCC.getOperand(i: 1),
29343 N3: SCC.getOperand(i: 4), Flags);
29344 AddToWorklist(N: SETCC.getNode());
29345 return DAG.getSelect(DL: SDLoc(SCC), VT: SCC.getValueType(), Cond: SETCC,
29346 LHS: SCC.getOperand(i: 2), RHS: SCC.getOperand(i: 3), Flags);
29347 }
29348
29349 return SCC;
29350 }
29351 return SDValue();
29352}
29353
29354/// Given a SELECT or a SELECT_CC node, where LHS and RHS are the two values
29355/// being selected between, see if we can simplify the select. Callers of this
29356/// should assume that TheSelect is deleted if this returns true. As such, they
29357/// should return the appropriate thing (e.g. the node) back to the top-level of
29358/// the DAG combiner loop to avoid it being looked at.
29359bool DAGCombiner::SimplifySelectOps(SDNode *TheSelect, SDValue LHS,
29360 SDValue RHS) {
29361 // fold (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
29362 // The select + setcc is redundant, because fsqrt returns NaN for X < 0.
29363 if (const ConstantFPSDNode *NaN = isConstOrConstSplatFP(N: LHS)) {
29364 if (NaN->isNaN() && RHS.getOpcode() == ISD::FSQRT) {
29365 // We have: (select (setcc ?, ?, ?), NaN, (fsqrt ?))
29366 SDValue Sqrt = RHS;
29367 ISD::CondCode CC;
29368 SDValue CmpLHS;
29369 const ConstantFPSDNode *Zero = nullptr;
29370
29371 if (TheSelect->getOpcode() == ISD::SELECT_CC) {
29372 CC = cast<CondCodeSDNode>(Val: TheSelect->getOperand(Num: 4))->get();
29373 CmpLHS = TheSelect->getOperand(Num: 0);
29374 Zero = isConstOrConstSplatFP(N: TheSelect->getOperand(Num: 1));
29375 } else {
29376 // SELECT or VSELECT
29377 SDValue Cmp = TheSelect->getOperand(Num: 0);
29378 if (Cmp.getOpcode() == ISD::SETCC) {
29379 CC = cast<CondCodeSDNode>(Val: Cmp.getOperand(i: 2))->get();
29380 CmpLHS = Cmp.getOperand(i: 0);
29381 Zero = isConstOrConstSplatFP(N: Cmp.getOperand(i: 1));
29382 }
29383 }
29384 if (Zero && Zero->isZero() &&
29385 Sqrt.getOperand(i: 0) == CmpLHS && (CC == ISD::SETOLT ||
29386 CC == ISD::SETULT || CC == ISD::SETLT)) {
29387 // We have: (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
29388 CombineTo(N: TheSelect, Res: Sqrt);
29389 return true;
29390 }
29391 }
29392 }
29393 // Cannot simplify select with vector condition
29394 if (TheSelect->getOperand(Num: 0).getValueType().isVector()) return false;
29395
29396 // If this is a select from two identical things, try to pull the operation
29397 // through the select.
29398 if (LHS.getOpcode() != RHS.getOpcode() ||
29399 !LHS.hasOneUse() || !RHS.hasOneUse())
29400 return false;
29401
29402 // If this is a load and the token chain is identical, replace the select
29403 // of two loads with a load through a select of the address to load from.
29404 // This triggers in things like "select bool X, 10.0, 123.0" after the FP
29405 // constants have been dropped into the constant pool.
29406 if (LHS.getOpcode() == ISD::LOAD) {
29407 LoadSDNode *LLD = cast<LoadSDNode>(Val&: LHS);
29408 LoadSDNode *RLD = cast<LoadSDNode>(Val&: RHS);
29409
29410 // Token chains must be identical.
29411 if (LHS.getOperand(i: 0) != RHS.getOperand(i: 0) ||
29412 // Do not let this transformation reduce the number of volatile loads.
29413 // Be conservative for atomics for the moment
29414 // TODO: This does appear to be legal for unordered atomics (see D66309)
29415 !LLD->isSimple() || !RLD->isSimple() ||
29416 // FIXME: If either is a pre/post inc/dec load,
29417 // we'd need to split out the address adjustment.
29418 LLD->isIndexed() || RLD->isIndexed() ||
29419 // If this is an EXTLOAD, the VT's must match.
29420 LLD->getMemoryVT() != RLD->getMemoryVT() ||
29421 // If this is an EXTLOAD, the kind of extension must match.
29422 (LLD->getExtensionType() != RLD->getExtensionType() &&
29423 // The only exception is if one of the extensions is anyext.
29424 LLD->getExtensionType() != ISD::EXTLOAD &&
29425 RLD->getExtensionType() != ISD::EXTLOAD) ||
29426 // FIXME: this discards src value information. This is
29427 // over-conservative. It would be beneficial to be able to remember
29428 // both potential memory locations. Since we are discarding
29429 // src value info, don't do the transformation if the memory
29430 // locations are not in the same address space.
29431 LLD->getPointerInfo().getAddrSpace() !=
29432 RLD->getPointerInfo().getAddrSpace() ||
29433 // We can't produce a CMOV of a TargetFrameIndex since we won't
29434 // generate the address generation required.
29435 LLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
29436 RLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
29437 !TLI.isOperationLegalOrCustom(Op: TheSelect->getOpcode(),
29438 VT: LLD->getBasePtr().getValueType()))
29439 return false;
29440
29441 // The loads must not depend on one another.
29442 if (LLD->isPredecessorOf(N: RLD) || RLD->isPredecessorOf(N: LLD))
29443 return false;
29444
29445 // Check that the select condition doesn't reach either load. If so,
29446 // folding this will induce a cycle into the DAG. If not, this is safe to
29447 // xform, so create a select of the addresses.
29448
29449 SmallPtrSet<const SDNode *, 32> Visited;
29450 SmallVector<const SDNode *, 16> Worklist;
29451
29452 // Always fail if LLD and RLD are not independent. TheSelect is a
29453 // predecessor to all Nodes in question so we need not search past it.
29454
29455 Visited.insert(Ptr: TheSelect);
29456 Worklist.push_back(Elt: LLD);
29457 Worklist.push_back(Elt: RLD);
29458
29459 if (SDNode::hasPredecessorHelper(N: LLD, Visited, Worklist) ||
29460 SDNode::hasPredecessorHelper(N: RLD, Visited, Worklist))
29461 return false;
29462
29463 SDValue Addr;
29464 if (TheSelect->getOpcode() == ISD::SELECT) {
29465 // We cannot do this optimization if any pair of {RLD, LLD} is a
29466 // predecessor to {RLD, LLD, CondNode}. As we've already compared the
29467 // Loads, we only need to check if CondNode is a successor to one of the
29468 // loads. We can further avoid this if there's no use of their chain
29469 // value.
29470 SDNode *CondNode = TheSelect->getOperand(Num: 0).getNode();
29471 Worklist.push_back(Elt: CondNode);
29472
29473 if ((LLD->hasAnyUseOfValue(Value: 1) &&
29474 SDNode::hasPredecessorHelper(N: LLD, Visited, Worklist)) ||
29475 (RLD->hasAnyUseOfValue(Value: 1) &&
29476 SDNode::hasPredecessorHelper(N: RLD, Visited, Worklist)))
29477 return false;
29478
29479 Addr = DAG.getSelect(DL: SDLoc(TheSelect),
29480 VT: LLD->getBasePtr().getValueType(),
29481 Cond: TheSelect->getOperand(Num: 0), LHS: LLD->getBasePtr(),
29482 RHS: RLD->getBasePtr());
29483 } else { // Otherwise SELECT_CC
29484 // We cannot do this optimization if any pair of {RLD, LLD} is a
29485 // predecessor to {RLD, LLD, CondLHS, CondRHS}. As we've already compared
29486 // the Loads, we only need to check if CondLHS/CondRHS is a successor to
29487 // one of the loads. We can further avoid this if there's no use of their
29488 // chain value.
29489
29490 SDNode *CondLHS = TheSelect->getOperand(Num: 0).getNode();
29491 SDNode *CondRHS = TheSelect->getOperand(Num: 1).getNode();
29492 Worklist.push_back(Elt: CondLHS);
29493 Worklist.push_back(Elt: CondRHS);
29494
29495 if ((LLD->hasAnyUseOfValue(Value: 1) &&
29496 SDNode::hasPredecessorHelper(N: LLD, Visited, Worklist)) ||
29497 (RLD->hasAnyUseOfValue(Value: 1) &&
29498 SDNode::hasPredecessorHelper(N: RLD, Visited, Worklist)))
29499 return false;
29500
29501 Addr = DAG.getNode(Opcode: ISD::SELECT_CC, DL: SDLoc(TheSelect),
29502 VT: LLD->getBasePtr().getValueType(),
29503 N1: TheSelect->getOperand(Num: 0),
29504 N2: TheSelect->getOperand(Num: 1),
29505 N3: LLD->getBasePtr(), N4: RLD->getBasePtr(),
29506 N5: TheSelect->getOperand(Num: 4));
29507 }
29508
29509 SDValue Load;
29510 // It is safe to replace the two loads if they have different alignments,
29511 // but the new load must be the minimum (most restrictive) alignment of the
29512 // inputs.
29513 Align Alignment = std::min(a: LLD->getAlign(), b: RLD->getAlign());
29514 unsigned AddrSpace = LLD->getAddressSpace();
29515 assert(AddrSpace == RLD->getAddressSpace());
29516
29517 MachineMemOperand::Flags MMOFlags = LLD->getMemOperand()->getFlags();
29518 if (!RLD->isInvariant())
29519 MMOFlags &= ~MachineMemOperand::MOInvariant;
29520 if (!RLD->isDereferenceable())
29521 MMOFlags &= ~MachineMemOperand::MODereferenceable;
29522 if (LLD->getExtensionType() == ISD::NON_EXTLOAD) {
29523 // FIXME: Discards pointer and AA info.
29524 Load = DAG.getLoad(VT: TheSelect->getValueType(ResNo: 0), dl: SDLoc(TheSelect),
29525 Chain: LLD->getChain(), Ptr: Addr, PtrInfo: MachinePointerInfo(AddrSpace),
29526 Alignment, MMOFlags);
29527 } else {
29528 // FIXME: Discards pointer and AA info.
29529 Load = DAG.getExtLoad(
29530 ExtType: LLD->getExtensionType() == ISD::EXTLOAD ? RLD->getExtensionType()
29531 : LLD->getExtensionType(),
29532 dl: SDLoc(TheSelect), VT: TheSelect->getValueType(ResNo: 0), Chain: LLD->getChain(), Ptr: Addr,
29533 PtrInfo: MachinePointerInfo(AddrSpace), MemVT: LLD->getMemoryVT(), Alignment,
29534 MMOFlags);
29535 }
29536
29537 // Users of the select now use the result of the load.
29538 CombineTo(N: TheSelect, Res: Load);
29539
29540 // Users of the old loads now use the new load's chain. We know the
29541 // old-load value is dead now.
29542 CombineTo(N: LHS.getNode(), Res0: Load.getValue(R: 0), Res1: Load.getValue(R: 1));
29543 CombineTo(N: RHS.getNode(), Res0: Load.getValue(R: 0), Res1: Load.getValue(R: 1));
29544 return true;
29545 }
29546
29547 return false;
29548}
29549
29550/// Try to fold an expression of the form (N0 cond N1) ? N2 : N3 to a shift and
29551/// bitwise 'and'.
29552SDValue DAGCombiner::foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0,
29553 SDValue N1, SDValue N2, SDValue N3,
29554 ISD::CondCode CC) {
29555 // If this is a select where the false operand is zero and the compare is a
29556 // check of the sign bit, see if we can perform the "gzip trick":
29557 // select_cc setlt X, 0, A, 0 -> and (sra X, size(X)-1), A
29558 // select_cc setgt X, 0, A, 0 -> and (not (sra X, size(X)-1)), A
29559 EVT XType = N0.getValueType();
29560 EVT AType = N2.getValueType();
29561 if (!isNullConstant(V: N3) || !XType.bitsGE(VT: AType))
29562 return SDValue();
29563
29564 // If the comparison is testing for a positive value, we have to invert
29565 // the sign bit mask, so only do that transform if the target has a bitwise
29566 // 'and not' instruction (the invert is free).
29567 if (CC == ISD::SETGT && TLI.hasAndNot(X: N2)) {
29568 // (X > -1) ? A : 0
29569 // (X > 0) ? X : 0 <-- This is canonical signed max.
29570 if (!(isAllOnesConstant(V: N1) || (isNullConstant(V: N1) && N0 == N2)))
29571 return SDValue();
29572 } else if (CC == ISD::SETLT) {
29573 // (X < 0) ? A : 0
29574 // (X < 1) ? X : 0 <-- This is un-canonicalized signed min.
29575 if (!(isNullConstant(V: N1) || (isOneConstant(V: N1) && N0 == N2)))
29576 return SDValue();
29577 } else {
29578 return SDValue();
29579 }
29580
29581 // and (sra X, size(X)-1), A -> "and (srl X, C2), A" iff A is a single-bit
29582 // constant.
29583 auto *N2C = dyn_cast<ConstantSDNode>(Val: N2.getNode());
29584 if (N2C && ((N2C->getAPIntValue() & (N2C->getAPIntValue() - 1)) == 0)) {
29585 unsigned ShCt = XType.getSizeInBits() - N2C->getAPIntValue().logBase2() - 1;
29586 if (!TLI.shouldAvoidTransformToShift(VT: XType, Amount: ShCt)) {
29587 SDValue ShiftAmt = DAG.getShiftAmountConstant(Val: ShCt, VT: XType, DL);
29588 SDValue Shift = DAG.getNode(Opcode: ISD::SRL, DL, VT: XType, N1: N0, N2: ShiftAmt);
29589 AddToWorklist(N: Shift.getNode());
29590
29591 if (XType.bitsGT(VT: AType)) {
29592 Shift = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: AType, Operand: Shift);
29593 AddToWorklist(N: Shift.getNode());
29594 }
29595
29596 if (CC == ISD::SETGT)
29597 Shift = DAG.getNOT(DL, Val: Shift, VT: AType);
29598
29599 return DAG.getNode(Opcode: ISD::AND, DL, VT: AType, N1: Shift, N2);
29600 }
29601 }
29602
29603 unsigned ShCt = XType.getSizeInBits() - 1;
29604 if (TLI.shouldAvoidTransformToShift(VT: XType, Amount: ShCt))
29605 return SDValue();
29606
29607 SDValue ShiftAmt = DAG.getShiftAmountConstant(Val: ShCt, VT: XType, DL);
29608 SDValue Shift = DAG.getNode(Opcode: ISD::SRA, DL, VT: XType, N1: N0, N2: ShiftAmt);
29609 AddToWorklist(N: Shift.getNode());
29610
29611 if (XType.bitsGT(VT: AType)) {
29612 Shift = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: AType, Operand: Shift);
29613 AddToWorklist(N: Shift.getNode());
29614 }
29615
29616 if (CC == ISD::SETGT)
29617 Shift = DAG.getNOT(DL, Val: Shift, VT: AType);
29618
29619 return DAG.getNode(Opcode: ISD::AND, DL, VT: AType, N1: Shift, N2);
29620}
29621
29622// Fold select(cc, binop(), binop()) -> binop(select(), select()) etc.
29623SDValue DAGCombiner::foldSelectOfBinops(SDNode *N) {
29624 SDValue N0 = N->getOperand(Num: 0);
29625 SDValue N1 = N->getOperand(Num: 1);
29626 SDValue N2 = N->getOperand(Num: 2);
29627 SDLoc DL(N);
29628
29629 unsigned BinOpc = N1.getOpcode();
29630 if (!TLI.isBinOp(Opcode: BinOpc) || (N2.getOpcode() != BinOpc) ||
29631 (N1.getResNo() != N2.getResNo()))
29632 return SDValue();
29633
29634 // The use checks are intentionally on SDNode because we may be dealing
29635 // with opcodes that produce more than one SDValue.
29636 // TODO: Do we really need to check N0 (the condition operand of the select)?
29637 // But removing that clause could cause an infinite loop...
29638 if (!N0->hasOneUse() || !N1->hasOneUse() || !N2->hasOneUse())
29639 return SDValue();
29640
29641 // Binops may include opcodes that return multiple values, so all values
29642 // must be created/propagated from the newly created binops below.
29643 SDVTList OpVTs = N1->getVTList();
29644
29645 // Fold select(cond, binop(x, y), binop(z, y))
29646 // --> binop(select(cond, x, z), y)
29647 if (N1.getOperand(i: 1) == N2.getOperand(i: 1)) {
29648 SDValue N10 = N1.getOperand(i: 0);
29649 SDValue N20 = N2.getOperand(i: 0);
29650 SDValue NewSel = DAG.getSelect(DL, VT: N10.getValueType(), Cond: N0, LHS: N10, RHS: N20);
29651 SDNodeFlags Flags = N1->getFlags() & N2->getFlags();
29652 SDValue NewBinOp =
29653 DAG.getNode(Opcode: BinOpc, DL, VTList: OpVTs, Ops: {NewSel, N1.getOperand(i: 1)}, Flags);
29654 return SDValue(NewBinOp.getNode(), N1.getResNo());
29655 }
29656
29657 // Fold select(cond, binop(x, y), binop(x, z))
29658 // --> binop(x, select(cond, y, z))
29659 if (N1.getOperand(i: 0) == N2.getOperand(i: 0)) {
29660 SDValue N11 = N1.getOperand(i: 1);
29661 SDValue N21 = N2.getOperand(i: 1);
29662 // Second op VT might be different (e.g. shift amount type)
29663 if (N11.getValueType() == N21.getValueType()) {
29664 SDValue NewSel = DAG.getSelect(DL, VT: N11.getValueType(), Cond: N0, LHS: N11, RHS: N21);
29665 SDNodeFlags Flags = N1->getFlags() & N2->getFlags();
29666 SDValue NewBinOp =
29667 DAG.getNode(Opcode: BinOpc, DL, VTList: OpVTs, Ops: {N1.getOperand(i: 0), NewSel}, Flags);
29668 return SDValue(NewBinOp.getNode(), N1.getResNo());
29669 }
29670 }
29671
29672 // TODO: Handle isCommutativeBinOp patterns as well?
29673 return SDValue();
29674}
29675
29676// Transform (fneg/fabs (bitconvert x)) to avoid loading constant pool values.
29677SDValue DAGCombiner::foldSignChangeInBitcast(SDNode *N) {
29678 SDValue N0 = N->getOperand(Num: 0);
29679 EVT VT = N->getValueType(ResNo: 0);
29680 bool IsFabs = N->getOpcode() == ISD::FABS;
29681 bool IsFree = IsFabs ? TLI.isFAbsFree(VT) : TLI.isFNegFree(VT);
29682
29683 if (IsFree || N0.getOpcode() != ISD::BITCAST || !N0.hasOneUse())
29684 return SDValue();
29685
29686 SDValue Int = N0.getOperand(i: 0);
29687 EVT IntVT = Int.getValueType();
29688
29689 // The operand to cast should be integer.
29690 if (!IntVT.isInteger() || IntVT.isVector())
29691 return SDValue();
29692
29693 // (fneg (bitconvert x)) -> (bitconvert (xor x sign))
29694 // (fabs (bitconvert x)) -> (bitconvert (and x ~sign))
29695 APInt SignMask;
29696 if (N0.getValueType().isVector()) {
29697 // For vector, create a sign mask (0x80...) or its inverse (for fabs,
29698 // 0x7f...) per element and splat it.
29699 SignMask = APInt::getSignMask(BitWidth: N0.getScalarValueSizeInBits());
29700 if (IsFabs)
29701 SignMask = ~SignMask;
29702 SignMask = APInt::getSplat(NewLen: IntVT.getSizeInBits(), V: SignMask);
29703 } else {
29704 // For scalar, just use the sign mask (0x80... or the inverse, 0x7f...)
29705 SignMask = APInt::getSignMask(BitWidth: IntVT.getSizeInBits());
29706 if (IsFabs)
29707 SignMask = ~SignMask;
29708 }
29709 SDLoc DL(N0);
29710 Int = DAG.getNode(Opcode: IsFabs ? ISD::AND : ISD::XOR, DL, VT: IntVT, N1: Int,
29711 N2: DAG.getConstant(Val: SignMask, DL, VT: IntVT));
29712 AddToWorklist(N: Int.getNode());
29713 return DAG.getBitcast(VT, V: Int);
29714}
29715
29716/// Turn "(a cond b) ? 1.0f : 2.0f" into "load (tmp + ((a cond b) ? 0 : 4)"
29717/// where "tmp" is a constant pool entry containing an array with 1.0 and 2.0
29718/// in it. This may be a win when the constant is not otherwise available
29719/// because it replaces two constant pool loads with one.
29720SDValue DAGCombiner::convertSelectOfFPConstantsToLoadOffset(
29721 const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
29722 ISD::CondCode CC) {
29723 if (!TLI.reduceSelectOfFPConstantLoads(CmpOpVT: N0.getValueType()))
29724 return SDValue();
29725
29726 // If we are before legalize types, we want the other legalization to happen
29727 // first (for example, to avoid messing with soft float).
29728 auto *TV = dyn_cast<ConstantFPSDNode>(Val&: N2);
29729 auto *FV = dyn_cast<ConstantFPSDNode>(Val&: N3);
29730 EVT VT = N2.getValueType();
29731 if (!TV || !FV || !TLI.isTypeLegal(VT))
29732 return SDValue();
29733
29734 // If a constant can be materialized without loads, this does not make sense.
29735 if (TLI.getOperationAction(Op: ISD::ConstantFP, VT) == TargetLowering::Legal ||
29736 TLI.isFPImmLegal(TV->getValueAPF(), TV->getValueType(ResNo: 0), ForCodeSize) ||
29737 TLI.isFPImmLegal(FV->getValueAPF(), FV->getValueType(ResNo: 0), ForCodeSize))
29738 return SDValue();
29739
29740 // If both constants have multiple uses, then we won't need to do an extra
29741 // load. The values are likely around in registers for other users.
29742 if (!TV->hasOneUse() && !FV->hasOneUse())
29743 return SDValue();
29744
29745 Constant *Elts[] = { const_cast<ConstantFP*>(FV->getConstantFPValue()),
29746 const_cast<ConstantFP*>(TV->getConstantFPValue()) };
29747 Type *FPTy = Elts[0]->getType();
29748 const DataLayout &TD = DAG.getDataLayout();
29749
29750 // Create a ConstantArray of the two constants.
29751 Constant *CA = ConstantArray::get(T: ArrayType::get(ElementType: FPTy, NumElements: 2), V: Elts);
29752 SDValue CPIdx = DAG.getConstantPool(C: CA, VT: TLI.getPointerTy(DL: DAG.getDataLayout()),
29753 Align: TD.getPrefTypeAlign(Ty: FPTy));
29754 Align Alignment = cast<ConstantPoolSDNode>(Val&: CPIdx)->getAlign();
29755
29756 // Get offsets to the 0 and 1 elements of the array, so we can select between
29757 // them.
29758 SDValue Zero = DAG.getIntPtrConstant(Val: 0, DL);
29759 unsigned EltSize = (unsigned)TD.getTypeAllocSize(Ty: Elts[0]->getType());
29760 SDValue One = DAG.getIntPtrConstant(Val: EltSize, DL: SDLoc(FV));
29761 SDValue Cond =
29762 DAG.getSetCC(DL, VT: getSetCCResultType(VT: N0.getValueType()), LHS: N0, RHS: N1, Cond: CC);
29763 AddToWorklist(N: Cond.getNode());
29764 SDValue CstOffset = DAG.getSelect(DL, VT: Zero.getValueType(), Cond, LHS: One, RHS: Zero);
29765 AddToWorklist(N: CstOffset.getNode());
29766 CPIdx = DAG.getNode(Opcode: ISD::ADD, DL, VT: CPIdx.getValueType(), N1: CPIdx, N2: CstOffset);
29767 AddToWorklist(N: CPIdx.getNode());
29768 return DAG.getLoad(VT: TV->getValueType(ResNo: 0), dl: DL, Chain: DAG.getEntryNode(), Ptr: CPIdx,
29769 PtrInfo: MachinePointerInfo::getConstantPool(
29770 MF&: DAG.getMachineFunction()), Alignment);
29771}
29772
29773/// Simplify an expression of the form (N0 cond N1) ? N2 : N3
29774/// where 'cond' is the comparison specified by CC.
29775SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
29776 SDValue N2, SDValue N3, ISD::CondCode CC,
29777 bool NotExtCompare) {
29778 // (x ? y : y) -> y.
29779 if (N2 == N3) return N2;
29780
29781 EVT CmpOpVT = N0.getValueType();
29782 EVT CmpResVT = getSetCCResultType(VT: CmpOpVT);
29783 EVT VT = N2.getValueType();
29784 auto *N1C = dyn_cast<ConstantSDNode>(Val: N1.getNode());
29785 auto *N2C = dyn_cast<ConstantSDNode>(Val: N2.getNode());
29786 auto *N3C = dyn_cast<ConstantSDNode>(Val: N3.getNode());
29787
29788 // Determine if the condition we're dealing with is constant.
29789 if (SDValue SCC = DAG.FoldSetCC(VT: CmpResVT, N1: N0, N2: N1, Cond: CC, dl: DL)) {
29790 AddToWorklist(N: SCC.getNode());
29791 if (auto *SCCC = dyn_cast<ConstantSDNode>(Val&: SCC)) {
29792 // fold select_cc true, x, y -> x
29793 // fold select_cc false, x, y -> y
29794 return !(SCCC->isZero()) ? N2 : N3;
29795 }
29796 }
29797
29798 if (SDValue V =
29799 convertSelectOfFPConstantsToLoadOffset(DL, N0, N1, N2, N3, CC))
29800 return V;
29801
29802 if (SDValue V = foldSelectCCToShiftAnd(DL, N0, N1, N2, N3, CC))
29803 return V;
29804
29805 // fold (select_cc seteq (and x, y), 0, 0, A) -> (and (sra (shl x)) A)
29806 // where y is has a single bit set.
29807 // A plaintext description would be, we can turn the SELECT_CC into an AND
29808 // when the condition can be materialized as an all-ones register. Any
29809 // single bit-test can be materialized as an all-ones register with
29810 // shift-left and shift-right-arith.
29811 if (CC == ISD::SETEQ && N0->getOpcode() == ISD::AND &&
29812 N0->getValueType(ResNo: 0) == VT && isNullConstant(V: N1) && isNullConstant(V: N2)) {
29813 SDValue AndLHS = N0->getOperand(Num: 0);
29814 auto *ConstAndRHS = dyn_cast<ConstantSDNode>(Val: N0->getOperand(Num: 1));
29815 if (ConstAndRHS && ConstAndRHS->getAPIntValue().popcount() == 1) {
29816 // Shift the tested bit over the sign bit.
29817 const APInt &AndMask = ConstAndRHS->getAPIntValue();
29818 if (TLI.shouldFoldSelectWithSingleBitTest(VT, AndMask)) {
29819 unsigned ShCt = AndMask.getBitWidth() - 1;
29820 SDValue ShlAmt = DAG.getShiftAmountConstant(Val: AndMask.countl_zero(), VT,
29821 DL: SDLoc(AndLHS));
29822 SDValue Shl = DAG.getNode(Opcode: ISD::SHL, DL: SDLoc(N0), VT, N1: AndLHS, N2: ShlAmt);
29823
29824 // Now arithmetic right shift it all the way over, so the result is
29825 // either all-ones, or zero.
29826 SDValue ShrAmt = DAG.getShiftAmountConstant(Val: ShCt, VT, DL: SDLoc(Shl));
29827 SDValue Shr = DAG.getNode(Opcode: ISD::SRA, DL: SDLoc(N0), VT, N1: Shl, N2: ShrAmt);
29828
29829 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Shr, N2: N3);
29830 }
29831 }
29832 }
29833
29834 // fold select C, 16, 0 -> shl C, 4
29835 bool Fold = N2C && isNullConstant(V: N3) && N2C->getAPIntValue().isPowerOf2();
29836 bool Swap = N3C && isNullConstant(V: N2) && N3C->getAPIntValue().isPowerOf2();
29837
29838 if ((Fold || Swap) &&
29839 TLI.getBooleanContents(Type: CmpOpVT) ==
29840 TargetLowering::ZeroOrOneBooleanContent &&
29841 (!LegalOperations || TLI.isOperationLegal(Op: ISD::SETCC, VT: CmpOpVT)) &&
29842 TLI.convertSelectOfConstantsToMath(VT)) {
29843
29844 if (Swap) {
29845 CC = ISD::getSetCCInverse(Operation: CC, Type: CmpOpVT);
29846 std::swap(a&: N2C, b&: N3C);
29847 }
29848
29849 // If the caller doesn't want us to simplify this into a zext of a compare,
29850 // don't do it.
29851 if (NotExtCompare && N2C->isOne())
29852 return SDValue();
29853
29854 SDValue Temp, SCC;
29855 // zext (setcc n0, n1)
29856 if (LegalTypes) {
29857 SCC = DAG.getSetCC(DL, VT: CmpResVT, LHS: N0, RHS: N1, Cond: CC);
29858 Temp = DAG.getZExtOrTrunc(Op: SCC, DL: SDLoc(N2), VT);
29859 } else {
29860 SCC = DAG.getSetCC(DL: SDLoc(N0), VT: MVT::i1, LHS: N0, RHS: N1, Cond: CC);
29861 Temp = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL: SDLoc(N2), VT, Operand: SCC);
29862 }
29863
29864 AddToWorklist(N: SCC.getNode());
29865 AddToWorklist(N: Temp.getNode());
29866
29867 if (N2C->isOne())
29868 return Temp;
29869
29870 unsigned ShCt = N2C->getAPIntValue().logBase2();
29871 if (TLI.shouldAvoidTransformToShift(VT, Amount: ShCt))
29872 return SDValue();
29873
29874 // shl setcc result by log2 n2c
29875 return DAG.getNode(
29876 Opcode: ISD::SHL, DL, VT: N2.getValueType(), N1: Temp,
29877 N2: DAG.getShiftAmountConstant(Val: ShCt, VT: N2.getValueType(), DL: SDLoc(Temp)));
29878 }
29879
29880 // select_cc seteq X, 0, sizeof(X), ctlz(X) -> ctlz(X)
29881 // select_cc seteq X, 0, sizeof(X), ctlz_zero_undef(X) -> ctlz(X)
29882 // select_cc seteq X, 0, sizeof(X), cttz(X) -> cttz(X)
29883 // select_cc seteq X, 0, sizeof(X), cttz_zero_undef(X) -> cttz(X)
29884 // select_cc setne X, 0, ctlz(X), sizeof(X) -> ctlz(X)
29885 // select_cc setne X, 0, ctlz_zero_undef(X), sizeof(X) -> ctlz(X)
29886 // select_cc setne X, 0, cttz(X), sizeof(X) -> cttz(X)
29887 // select_cc setne X, 0, cttz_zero_undef(X), sizeof(X) -> cttz(X)
29888 if (N1C && N1C->isZero() && (CC == ISD::SETEQ || CC == ISD::SETNE)) {
29889 SDValue ValueOnZero = N2;
29890 SDValue Count = N3;
29891 // If the condition is NE instead of E, swap the operands.
29892 if (CC == ISD::SETNE)
29893 std::swap(a&: ValueOnZero, b&: Count);
29894 // Check if the value on zero is a constant equal to the bits in the type.
29895 if (auto *ValueOnZeroC = dyn_cast<ConstantSDNode>(Val&: ValueOnZero)) {
29896 if (ValueOnZeroC->getAPIntValue() == VT.getSizeInBits()) {
29897 // If the other operand is cttz/cttz_zero_undef of N0, and cttz is
29898 // legal, combine to just cttz.
29899 if ((Count.getOpcode() == ISD::CTTZ ||
29900 Count.getOpcode() == ISD::CTTZ_ZERO_UNDEF) &&
29901 N0 == Count.getOperand(i: 0) &&
29902 (!LegalOperations || TLI.isOperationLegal(Op: ISD::CTTZ, VT)))
29903 return DAG.getNode(Opcode: ISD::CTTZ, DL, VT, Operand: N0);
29904 // If the other operand is ctlz/ctlz_zero_undef of N0, and ctlz is
29905 // legal, combine to just ctlz.
29906 if ((Count.getOpcode() == ISD::CTLZ ||
29907 Count.getOpcode() == ISD::CTLZ_ZERO_UNDEF) &&
29908 N0 == Count.getOperand(i: 0) &&
29909 (!LegalOperations || TLI.isOperationLegal(Op: ISD::CTLZ, VT)))
29910 return DAG.getNode(Opcode: ISD::CTLZ, DL, VT, Operand: N0);
29911 }
29912 }
29913 }
29914
29915 // Fold select_cc setgt X, -1, C, ~C -> xor (ashr X, BW-1), C
29916 // Fold select_cc setlt X, 0, C, ~C -> xor (ashr X, BW-1), ~C
29917 if (!NotExtCompare && N1C && N2C && N3C &&
29918 N2C->getAPIntValue() == ~N3C->getAPIntValue() &&
29919 ((N1C->isAllOnes() && CC == ISD::SETGT) ||
29920 (N1C->isZero() && CC == ISD::SETLT)) &&
29921 !TLI.shouldAvoidTransformToShift(VT, Amount: CmpOpVT.getScalarSizeInBits() - 1)) {
29922 SDValue ASHR =
29923 DAG.getNode(Opcode: ISD::SRA, DL, VT: CmpOpVT, N1: N0,
29924 N2: DAG.getShiftAmountConstant(
29925 Val: CmpOpVT.getScalarSizeInBits() - 1, VT: CmpOpVT, DL));
29926 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1: DAG.getSExtOrTrunc(Op: ASHR, DL, VT),
29927 N2: DAG.getSExtOrTrunc(Op: CC == ISD::SETLT ? N3 : N2, DL, VT));
29928 }
29929
29930 // Fold sign pattern select_cc setgt X, -1, 1, -1 -> or (ashr X, BW-1), 1
29931 if (CC == ISD::SETGT && N1C && N2C && N3C && N1C->isAllOnes() &&
29932 N2C->isOne() && N3C->isAllOnes() &&
29933 !TLI.shouldAvoidTransformToShift(VT: CmpOpVT,
29934 Amount: CmpOpVT.getScalarSizeInBits() - 1)) {
29935 SDValue ASHR =
29936 DAG.getNode(Opcode: ISD::SRA, DL, VT: CmpOpVT, N1: N0,
29937 N2: DAG.getShiftAmountConstant(
29938 Val: CmpOpVT.getScalarSizeInBits() - 1, VT: CmpOpVT, DL));
29939 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: DAG.getSExtOrTrunc(Op: ASHR, DL, VT),
29940 N2: DAG.getConstant(Val: 1, DL, VT));
29941 }
29942
29943 if (SDValue S = PerformMinMaxFpToSatCombine(N0, N1, N2, N3, CC, DAG))
29944 return S;
29945 if (SDValue S = PerformUMinFpToSatCombine(N0, N1, N2, N3, CC, DAG))
29946 return S;
29947 if (SDValue ABD = foldSelectToABD(LHS: N0, RHS: N1, True: N2, False: N3, CC, DL))
29948 return ABD;
29949
29950 return SDValue();
29951}
29952
29953static SDValue matchMergedBFX(SDValue Root, SelectionDAG &DAG,
29954 const TargetLowering &TLI) {
29955 // Match a pattern such as:
29956 // (X | (X >> C0) | (X >> C1) | ...) & Mask
29957 // This extracts contiguous parts of X and ORs them together before comparing.
29958 // We can optimize this so that we directly check (X & SomeMask) instead,
29959 // eliminating the shifts.
29960
29961 EVT VT = Root.getValueType();
29962
29963 // TODO: Support vectors?
29964 if (!VT.isScalarInteger() || Root.getOpcode() != ISD::AND)
29965 return SDValue();
29966
29967 SDValue N0 = Root.getOperand(i: 0);
29968 SDValue N1 = Root.getOperand(i: 1);
29969
29970 if (N0.getOpcode() != ISD::OR || !isa<ConstantSDNode>(Val: N1))
29971 return SDValue();
29972
29973 APInt RootMask = cast<ConstantSDNode>(Val&: N1)->getAsAPIntVal();
29974
29975 SDValue Src;
29976 const auto IsSrc = [&](SDValue V) {
29977 if (!Src) {
29978 Src = V;
29979 return true;
29980 }
29981
29982 return Src == V;
29983 };
29984
29985 SmallVector<SDValue> Worklist = {N0};
29986 APInt PartsMask(VT.getSizeInBits(), 0);
29987 while (!Worklist.empty()) {
29988 SDValue V = Worklist.pop_back_val();
29989 if (!V.hasOneUse() && (Src && Src != V))
29990 return SDValue();
29991
29992 if (V.getOpcode() == ISD::OR) {
29993 Worklist.push_back(Elt: V.getOperand(i: 0));
29994 Worklist.push_back(Elt: V.getOperand(i: 1));
29995 continue;
29996 }
29997
29998 if (V.getOpcode() == ISD::SRL) {
29999 SDValue ShiftSrc = V.getOperand(i: 0);
30000 SDValue ShiftAmt = V.getOperand(i: 1);
30001
30002 if (!IsSrc(ShiftSrc) || !isa<ConstantSDNode>(Val: ShiftAmt))
30003 return SDValue();
30004
30005 auto ShiftAmtVal = cast<ConstantSDNode>(Val&: ShiftAmt)->getAsZExtVal();
30006 if (ShiftAmtVal > RootMask.getBitWidth())
30007 return SDValue();
30008
30009 PartsMask |= (RootMask << ShiftAmtVal);
30010 continue;
30011 }
30012
30013 if (IsSrc(V)) {
30014 PartsMask |= RootMask;
30015 continue;
30016 }
30017
30018 return SDValue();
30019 }
30020
30021 if (!Src)
30022 return SDValue();
30023
30024 SDLoc DL(Root);
30025 return DAG.getNode(Opcode: ISD::AND, DL, VT,
30026 Ops: {Src, DAG.getConstant(Val: PartsMask, DL, VT)});
30027}
30028
30029/// This is a stub for TargetLowering::SimplifySetCC.
30030SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
30031 ISD::CondCode Cond, const SDLoc &DL,
30032 bool foldBooleans) {
30033 TargetLowering::DAGCombinerInfo
30034 DagCombineInfo(DAG, Level, false, this);
30035 if (SDValue C =
30036 TLI.SimplifySetCC(VT, N0, N1, Cond, foldBooleans, DCI&: DagCombineInfo, dl: DL))
30037 return C;
30038
30039 if (ISD::isIntEqualitySetCC(Code: Cond) && N0.getOpcode() == ISD::AND &&
30040 isNullConstant(V: N1)) {
30041
30042 if (SDValue Res = matchMergedBFX(Root: N0, DAG, TLI))
30043 return DAG.getSetCC(DL, VT, LHS: Res, RHS: N1, Cond);
30044 }
30045
30046 return SDValue();
30047}
30048
30049/// Given an ISD::SDIV node expressing a divide by constant, return
30050/// a DAG expression to select that will generate the same value by multiplying
30051/// by a magic number.
30052/// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
30053SDValue DAGCombiner::BuildSDIV(SDNode *N) {
30054 // when optimising for minimum size, we don't want to expand a div to a mul
30055 // and a shift.
30056 if (DAG.getMachineFunction().getFunction().hasMinSize())
30057 return SDValue();
30058
30059 SmallVector<SDNode *, 8> Built;
30060 if (SDValue S = TLI.BuildSDIV(N, DAG, IsAfterLegalization: LegalOperations, IsAfterLegalTypes: LegalTypes, Created&: Built)) {
30061 for (SDNode *N : Built)
30062 AddToWorklist(N);
30063 return S;
30064 }
30065
30066 return SDValue();
30067}
30068
30069/// Given an ISD::SDIV node expressing a divide by constant power of 2, return a
30070/// DAG expression that will generate the same value by right shifting.
30071SDValue DAGCombiner::BuildSDIVPow2(SDNode *N) {
30072 ConstantSDNode *C = isConstOrConstSplat(N: N->getOperand(Num: 1));
30073 if (!C)
30074 return SDValue();
30075
30076 // Avoid division by zero.
30077 if (C->isZero())
30078 return SDValue();
30079
30080 SmallVector<SDNode *, 8> Built;
30081 if (SDValue S = TLI.BuildSDIVPow2(N, Divisor: C->getAPIntValue(), DAG, Created&: Built)) {
30082 for (SDNode *N : Built)
30083 AddToWorklist(N);
30084 return S;
30085 }
30086
30087 return SDValue();
30088}
30089
30090/// Given an ISD::UDIV node expressing a divide by constant, return a DAG
30091/// expression that will generate the same value by multiplying by a magic
30092/// number.
30093/// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
30094SDValue DAGCombiner::BuildUDIV(SDNode *N) {
30095 // when optimising for minimum size, we don't want to expand a div to a mul
30096 // and a shift.
30097 if (DAG.getMachineFunction().getFunction().hasMinSize())
30098 return SDValue();
30099
30100 SmallVector<SDNode *, 8> Built;
30101 if (SDValue S = TLI.BuildUDIV(N, DAG, IsAfterLegalization: LegalOperations, IsAfterLegalTypes: LegalTypes, Created&: Built)) {
30102 for (SDNode *N : Built)
30103 AddToWorklist(N);
30104 return S;
30105 }
30106
30107 return SDValue();
30108}
30109
30110/// Given an ISD::SREM node expressing a remainder by constant power of 2,
30111/// return a DAG expression that will generate the same value.
30112SDValue DAGCombiner::BuildSREMPow2(SDNode *N) {
30113 ConstantSDNode *C = isConstOrConstSplat(N: N->getOperand(Num: 1));
30114 if (!C)
30115 return SDValue();
30116
30117 // Avoid division by zero.
30118 if (C->isZero())
30119 return SDValue();
30120
30121 SmallVector<SDNode *, 8> Built;
30122 if (SDValue S = TLI.BuildSREMPow2(N, Divisor: C->getAPIntValue(), DAG, Created&: Built)) {
30123 for (SDNode *N : Built)
30124 AddToWorklist(N);
30125 return S;
30126 }
30127
30128 return SDValue();
30129}
30130
30131// This is basically just a port of takeLog2 from InstCombineMulDivRem.cpp
30132//
30133// Returns the node that represents `Log2(Op)`. This may create a new node. If
30134// we are unable to compute `Log2(Op)` its return `SDValue()`.
30135//
30136// All nodes will be created at `DL` and the output will be of type `VT`.
30137//
30138// This will only return `Log2(Op)` if we can prove `Op` is non-zero. Set
30139// `AssumeNonZero` if this function should simply assume (not require proving
30140// `Op` is non-zero).
30141static SDValue takeInexpensiveLog2(SelectionDAG &DAG, const SDLoc &DL, EVT VT,
30142 SDValue Op, unsigned Depth,
30143 bool AssumeNonZero) {
30144 assert(VT.isInteger() && "Only integer types are supported!");
30145
30146 auto PeekThroughCastsAndTrunc = [](SDValue V) {
30147 while (true) {
30148 switch (V.getOpcode()) {
30149 case ISD::TRUNCATE:
30150 case ISD::ZERO_EXTEND:
30151 V = V.getOperand(i: 0);
30152 break;
30153 default:
30154 return V;
30155 }
30156 }
30157 };
30158
30159 if (VT.isScalableVector())
30160 return SDValue();
30161
30162 Op = PeekThroughCastsAndTrunc(Op);
30163
30164 // Helper for determining whether a value is a power-2 constant scalar or a
30165 // vector of such elements.
30166 SmallVector<APInt> Pow2Constants;
30167 auto IsPowerOfTwo = [&Pow2Constants](ConstantSDNode *C) {
30168 if (C->isZero() || C->isOpaque())
30169 return false;
30170 // TODO: We may also be able to support negative powers of 2 here.
30171 if (C->getAPIntValue().isPowerOf2()) {
30172 Pow2Constants.emplace_back(Args: C->getAPIntValue());
30173 return true;
30174 }
30175 return false;
30176 };
30177
30178 if (ISD::matchUnaryPredicate(Op, Match: IsPowerOfTwo, /*AllowUndefs=*/false,
30179 /*AllowTruncation=*/true)) {
30180 if (!VT.isVector())
30181 return DAG.getConstant(Val: Pow2Constants.back().logBase2(), DL, VT);
30182 // We need to create a build vector
30183 if (Op.getOpcode() == ISD::SPLAT_VECTOR)
30184 return DAG.getSplat(VT, DL,
30185 Op: DAG.getConstant(Val: Pow2Constants.back().logBase2(), DL,
30186 VT: VT.getScalarType()));
30187 SmallVector<SDValue> Log2Ops;
30188 for (const APInt &Pow2 : Pow2Constants)
30189 Log2Ops.emplace_back(
30190 Args: DAG.getConstant(Val: Pow2.logBase2(), DL, VT: VT.getScalarType()));
30191 return DAG.getBuildVector(VT, DL, Ops: Log2Ops);
30192 }
30193
30194 if (Depth >= DAG.MaxRecursionDepth)
30195 return SDValue();
30196
30197 auto CastToVT = [&](EVT NewVT, SDValue ToCast) {
30198 // Peek through zero extend. We can't peek through truncates since this
30199 // function is called on a shift amount. We must ensure that all of the bits
30200 // above the original shift amount are zeroed by this function.
30201 while (ToCast.getOpcode() == ISD::ZERO_EXTEND)
30202 ToCast = ToCast.getOperand(i: 0);
30203 EVT CurVT = ToCast.getValueType();
30204 if (NewVT == CurVT)
30205 return ToCast;
30206
30207 if (NewVT.getSizeInBits() == CurVT.getSizeInBits())
30208 return DAG.getBitcast(VT: NewVT, V: ToCast);
30209
30210 return DAG.getZExtOrTrunc(Op: ToCast, DL, VT: NewVT);
30211 };
30212
30213 // log2(X << Y) -> log2(X) + Y
30214 if (Op.getOpcode() == ISD::SHL) {
30215 // 1 << Y and X nuw/nsw << Y are all non-zero.
30216 if (AssumeNonZero || Op->getFlags().hasNoUnsignedWrap() ||
30217 Op->getFlags().hasNoSignedWrap() || isOneConstant(V: Op.getOperand(i: 0)))
30218 if (SDValue LogX = takeInexpensiveLog2(DAG, DL, VT, Op: Op.getOperand(i: 0),
30219 Depth: Depth + 1, AssumeNonZero))
30220 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: LogX,
30221 N2: CastToVT(VT, Op.getOperand(i: 1)));
30222 }
30223
30224 // c ? X : Y -> c ? Log2(X) : Log2(Y)
30225 SDValue Cond, TVal, FVal;
30226 if (sd_match(N: Op, P: m_OneUse(P: m_SelectLike(Cond: m_Value(N&: Cond), T: m_Value(N&: TVal),
30227 F: m_Value(N&: FVal))))) {
30228 if (SDValue LogX =
30229 takeInexpensiveLog2(DAG, DL, VT, Op: TVal, Depth: Depth + 1, AssumeNonZero))
30230 if (SDValue LogY =
30231 takeInexpensiveLog2(DAG, DL, VT, Op: FVal, Depth: Depth + 1, AssumeNonZero))
30232 return DAG.getSelect(DL, VT, Cond, LHS: LogX, RHS: LogY);
30233 }
30234
30235 // log2(umin(X, Y)) -> umin(log2(X), log2(Y))
30236 // log2(umax(X, Y)) -> umax(log2(X), log2(Y))
30237 if ((Op.getOpcode() == ISD::UMIN || Op.getOpcode() == ISD::UMAX) &&
30238 Op.hasOneUse()) {
30239 // Use AssumeNonZero as false here. Otherwise we can hit case where
30240 // log2(umax(X, Y)) != umax(log2(X), log2(Y)) (because overflow).
30241 if (SDValue LogX =
30242 takeInexpensiveLog2(DAG, DL, VT, Op: Op.getOperand(i: 0), Depth: Depth + 1,
30243 /*AssumeNonZero*/ false))
30244 if (SDValue LogY =
30245 takeInexpensiveLog2(DAG, DL, VT, Op: Op.getOperand(i: 1), Depth: Depth + 1,
30246 /*AssumeNonZero*/ false))
30247 return DAG.getNode(Opcode: Op.getOpcode(), DL, VT, N1: LogX, N2: LogY);
30248 }
30249
30250 return SDValue();
30251}
30252
30253/// Determines the LogBase2 value for a non-null input value using the
30254/// transform: LogBase2(V) = (EltBits - 1) - ctlz(V).
30255SDValue DAGCombiner::BuildLogBase2(SDValue V, const SDLoc &DL,
30256 bool KnownNonZero, bool InexpensiveOnly,
30257 std::optional<EVT> OutVT) {
30258 EVT VT = OutVT ? *OutVT : V.getValueType();
30259 SDValue InexpensiveLogBase2 =
30260 takeInexpensiveLog2(DAG, DL, VT, Op: V, /*Depth*/ 0, AssumeNonZero: KnownNonZero);
30261 if (InexpensiveLogBase2 || InexpensiveOnly || !DAG.isKnownToBeAPowerOfTwo(Val: V))
30262 return InexpensiveLogBase2;
30263
30264 SDValue Ctlz = DAG.getNode(Opcode: ISD::CTLZ, DL, VT, Operand: V);
30265 SDValue Base = DAG.getConstant(Val: VT.getScalarSizeInBits() - 1, DL, VT);
30266 SDValue LogBase2 = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Base, N2: Ctlz);
30267 return LogBase2;
30268}
30269
30270/// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
30271/// For the reciprocal, we need to find the zero of the function:
30272/// F(X) = 1/X - A [which has a zero at X = 1/A]
30273/// =>
30274/// X_{i+1} = X_i (2 - A X_i) = X_i + X_i (1 - A X_i) [this second form
30275/// does not require additional intermediate precision]
30276/// For the last iteration, put numerator N into it to gain more precision:
30277/// Result = N X_i + X_i (N - N A X_i)
30278SDValue DAGCombiner::BuildDivEstimate(SDValue N, SDValue Op,
30279 SDNodeFlags Flags) {
30280 if (LegalDAG)
30281 return SDValue();
30282
30283 // TODO: Handle extended types?
30284 EVT VT = Op.getValueType();
30285 if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
30286 VT.getScalarType() != MVT::f64)
30287 return SDValue();
30288
30289 // If estimates are explicitly disabled for this function, we're done.
30290 MachineFunction &MF = DAG.getMachineFunction();
30291 int Enabled = TLI.getRecipEstimateDivEnabled(VT, MF);
30292 if (Enabled == TLI.ReciprocalEstimate::Disabled)
30293 return SDValue();
30294
30295 // Estimates may be explicitly enabled for this type with a custom number of
30296 // refinement steps.
30297 int Iterations = TLI.getDivRefinementSteps(VT, MF);
30298 if (SDValue Est = TLI.getRecipEstimate(Operand: Op, DAG, Enabled, RefinementSteps&: Iterations)) {
30299 AddToWorklist(N: Est.getNode());
30300
30301 SDLoc DL(Op);
30302 if (Iterations) {
30303 SDValue FPOne = DAG.getConstantFP(Val: 1.0, DL, VT);
30304
30305 // Newton iterations: Est = Est + Est (N - Arg * Est)
30306 // If this is the last iteration, also multiply by the numerator.
30307 for (int i = 0; i < Iterations; ++i) {
30308 SDValue MulEst = Est;
30309
30310 if (i == Iterations - 1) {
30311 MulEst = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N, N2: Est, Flags);
30312 AddToWorklist(N: MulEst.getNode());
30313 }
30314
30315 SDValue NewEst = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Op, N2: MulEst, Flags);
30316 AddToWorklist(N: NewEst.getNode());
30317
30318 NewEst = DAG.getNode(Opcode: ISD::FSUB, DL, VT,
30319 N1: (i == Iterations - 1 ? N : FPOne), N2: NewEst, Flags);
30320 AddToWorklist(N: NewEst.getNode());
30321
30322 NewEst = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Est, N2: NewEst, Flags);
30323 AddToWorklist(N: NewEst.getNode());
30324
30325 Est = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: MulEst, N2: NewEst, Flags);
30326 AddToWorklist(N: Est.getNode());
30327 }
30328 } else {
30329 // If no iterations are available, multiply with N.
30330 Est = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Est, N2: N, Flags);
30331 AddToWorklist(N: Est.getNode());
30332 }
30333
30334 return Est;
30335 }
30336
30337 return SDValue();
30338}
30339
30340/// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
30341/// For the reciprocal sqrt, we need to find the zero of the function:
30342/// F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
30343/// =>
30344/// X_{i+1} = X_i (1.5 - A X_i^2 / 2)
30345/// As a result, we precompute A/2 prior to the iteration loop.
30346SDValue DAGCombiner::buildSqrtNROneConst(SDValue Arg, SDValue Est,
30347 unsigned Iterations, bool Reciprocal) {
30348 EVT VT = Arg.getValueType();
30349 SDLoc DL(Arg);
30350 SDValue ThreeHalves = DAG.getConstantFP(Val: 1.5, DL, VT);
30351
30352 // We now need 0.5 * Arg which we can write as (1.5 * Arg - Arg) so that
30353 // this entire sequence requires only one FP constant.
30354 SDValue HalfArg = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: ThreeHalves, N2: Arg);
30355 HalfArg = DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1: HalfArg, N2: Arg);
30356
30357 // Newton iterations: Est = Est * (1.5 - HalfArg * Est * Est)
30358 for (unsigned i = 0; i < Iterations; ++i) {
30359 SDValue NewEst = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Est, N2: Est);
30360 NewEst = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: HalfArg, N2: NewEst);
30361 NewEst = DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1: ThreeHalves, N2: NewEst);
30362 Est = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Est, N2: NewEst);
30363 }
30364
30365 // If non-reciprocal square root is requested, multiply the result by Arg.
30366 if (!Reciprocal)
30367 Est = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Est, N2: Arg);
30368
30369 return Est;
30370}
30371
30372/// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
30373/// For the reciprocal sqrt, we need to find the zero of the function:
30374/// F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
30375/// =>
30376/// X_{i+1} = (-0.5 * X_i) * (A * X_i * X_i + (-3.0))
30377SDValue DAGCombiner::buildSqrtNRTwoConst(SDValue Arg, SDValue Est,
30378 unsigned Iterations, bool Reciprocal) {
30379 EVT VT = Arg.getValueType();
30380 SDLoc DL(Arg);
30381 SDValue MinusThree = DAG.getConstantFP(Val: -3.0, DL, VT);
30382 SDValue MinusHalf = DAG.getConstantFP(Val: -0.5, DL, VT);
30383
30384 // This routine must enter the loop below to work correctly
30385 // when (Reciprocal == false).
30386 assert(Iterations > 0);
30387
30388 // Newton iterations for reciprocal square root:
30389 // E = (E * -0.5) * ((A * E) * E + -3.0)
30390 for (unsigned i = 0; i < Iterations; ++i) {
30391 SDValue AE = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Arg, N2: Est);
30392 SDValue AEE = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: AE, N2: Est);
30393 SDValue RHS = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: AEE, N2: MinusThree);
30394
30395 // When calculating a square root at the last iteration build:
30396 // S = ((A * E) * -0.5) * ((A * E) * E + -3.0)
30397 // (notice a common subexpression)
30398 SDValue LHS;
30399 if (Reciprocal || (i + 1) < Iterations) {
30400 // RSQRT: LHS = (E * -0.5)
30401 LHS = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Est, N2: MinusHalf);
30402 } else {
30403 // SQRT: LHS = (A * E) * -0.5
30404 LHS = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: AE, N2: MinusHalf);
30405 }
30406
30407 Est = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: LHS, N2: RHS);
30408 }
30409
30410 return Est;
30411}
30412
30413/// Build code to calculate either rsqrt(Op) or sqrt(Op). In the latter case
30414/// Op*rsqrt(Op) is actually computed, so additional postprocessing is needed if
30415/// Op can be zero.
30416SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, bool Reciprocal) {
30417 if (LegalDAG)
30418 return SDValue();
30419
30420 // TODO: Handle extended types?
30421 EVT VT = Op.getValueType();
30422 if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
30423 VT.getScalarType() != MVT::f64)
30424 return SDValue();
30425
30426 // If estimates are explicitly disabled for this function, we're done.
30427 MachineFunction &MF = DAG.getMachineFunction();
30428 int Enabled = TLI.getRecipEstimateSqrtEnabled(VT, MF);
30429 if (Enabled == TLI.ReciprocalEstimate::Disabled)
30430 return SDValue();
30431
30432 // Estimates may be explicitly enabled for this type with a custom number of
30433 // refinement steps.
30434 int Iterations = TLI.getSqrtRefinementSteps(VT, MF);
30435
30436 bool UseOneConstNR = false;
30437 if (SDValue Est =
30438 TLI.getSqrtEstimate(Operand: Op, DAG, Enabled, RefinementSteps&: Iterations, UseOneConstNR,
30439 Reciprocal)) {
30440 AddToWorklist(N: Est.getNode());
30441
30442 if (Iterations > 0)
30443 Est = UseOneConstNR
30444 ? buildSqrtNROneConst(Arg: Op, Est, Iterations, Reciprocal)
30445 : buildSqrtNRTwoConst(Arg: Op, Est, Iterations, Reciprocal);
30446 if (!Reciprocal) {
30447 SDLoc DL(Op);
30448 // Try the target specific test first.
30449 SDValue Test = TLI.getSqrtInputTest(Operand: Op, DAG, Mode: DAG.getDenormalMode(VT));
30450
30451 // The estimate is now completely wrong if the input was exactly 0.0 or
30452 // possibly a denormal. Force the answer to 0.0 or value provided by
30453 // target for those cases.
30454 Est = DAG.getSelect(DL, VT, Cond: Test,
30455 LHS: TLI.getSqrtResultForDenormInput(Operand: Op, DAG), RHS: Est);
30456 }
30457 return Est;
30458 }
30459
30460 return SDValue();
30461}
30462
30463SDValue DAGCombiner::buildRsqrtEstimate(SDValue Op) {
30464 return buildSqrtEstimateImpl(Op, Reciprocal: true);
30465}
30466
30467SDValue DAGCombiner::buildSqrtEstimate(SDValue Op) {
30468 return buildSqrtEstimateImpl(Op, Reciprocal: false);
30469}
30470
30471/// Return true if there is any possibility that the two addresses overlap.
30472bool DAGCombiner::mayAlias(SDNode *Op0, SDNode *Op1) const {
30473
30474 struct MemUseCharacteristics {
30475 bool IsVolatile;
30476 bool IsAtomic;
30477 SDValue BasePtr;
30478 int64_t Offset;
30479 LocationSize NumBytes;
30480 MachineMemOperand *MMO;
30481 };
30482
30483 auto getCharacteristics = [this](SDNode *N) -> MemUseCharacteristics {
30484 if (const auto *LSN = dyn_cast<LSBaseSDNode>(Val: N)) {
30485 int64_t Offset = 0;
30486 if (auto *C = dyn_cast<ConstantSDNode>(Val: LSN->getOffset()))
30487 Offset = (LSN->getAddressingMode() == ISD::PRE_INC) ? C->getSExtValue()
30488 : (LSN->getAddressingMode() == ISD::PRE_DEC)
30489 ? -1 * C->getSExtValue()
30490 : 0;
30491 TypeSize Size = LSN->getMemoryVT().getStoreSize();
30492 return {.IsVolatile: LSN->isVolatile(), .IsAtomic: LSN->isAtomic(),
30493 .BasePtr: LSN->getBasePtr(), .Offset: Offset /*base offset*/,
30494 .NumBytes: LocationSize::precise(Value: Size), .MMO: LSN->getMemOperand()};
30495 }
30496 if (const auto *LN = cast<LifetimeSDNode>(Val: N)) {
30497 MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo();
30498 return {.IsVolatile: false /*isVolatile*/,
30499 /*isAtomic*/ .IsAtomic: false,
30500 .BasePtr: LN->getOperand(Num: 1),
30501 .Offset: 0,
30502 .NumBytes: LocationSize::precise(Value: MFI.getObjectSize(ObjectIdx: LN->getFrameIndex())),
30503 .MMO: (MachineMemOperand *)nullptr};
30504 }
30505 // Default.
30506 return {.IsVolatile: false /*isvolatile*/,
30507 /*isAtomic*/ .IsAtomic: false,
30508 .BasePtr: SDValue(),
30509 .Offset: (int64_t)0 /*offset*/,
30510 .NumBytes: LocationSize::beforeOrAfterPointer() /*size*/,
30511 .MMO: (MachineMemOperand *)nullptr};
30512 };
30513
30514 MemUseCharacteristics MUC0 = getCharacteristics(Op0),
30515 MUC1 = getCharacteristics(Op1);
30516
30517 // If they are to the same address, then they must be aliases.
30518 if (MUC0.BasePtr.getNode() && MUC0.BasePtr == MUC1.BasePtr &&
30519 MUC0.Offset == MUC1.Offset)
30520 return true;
30521
30522 // If they are both volatile then they cannot be reordered.
30523 if (MUC0.IsVolatile && MUC1.IsVolatile)
30524 return true;
30525
30526 // Be conservative about atomics for the moment
30527 // TODO: This is way overconservative for unordered atomics (see D66309)
30528 if (MUC0.IsAtomic && MUC1.IsAtomic)
30529 return true;
30530
30531 if (MUC0.MMO && MUC1.MMO) {
30532 if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
30533 (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
30534 return false;
30535 }
30536
30537 // If NumBytes is scalable and offset is not 0, conservatively return may
30538 // alias
30539 if ((MUC0.NumBytes.hasValue() && MUC0.NumBytes.isScalable() &&
30540 MUC0.Offset != 0) ||
30541 (MUC1.NumBytes.hasValue() && MUC1.NumBytes.isScalable() &&
30542 MUC1.Offset != 0))
30543 return true;
30544 // Try to prove that there is aliasing, or that there is no aliasing. Either
30545 // way, we can return now. If nothing can be proved, proceed with more tests.
30546 bool IsAlias;
30547 if (BaseIndexOffset::computeAliasing(Op0, NumBytes0: MUC0.NumBytes, Op1, NumBytes1: MUC1.NumBytes,
30548 DAG, IsAlias))
30549 return IsAlias;
30550
30551 // The following all rely on MMO0 and MMO1 being valid. Fail conservatively if
30552 // either are not known.
30553 if (!MUC0.MMO || !MUC1.MMO)
30554 return true;
30555
30556 // If one operation reads from invariant memory, and the other may store, they
30557 // cannot alias. These should really be checking the equivalent of mayWrite,
30558 // but it only matters for memory nodes other than load /store.
30559 if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
30560 (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
30561 return false;
30562
30563 // If we know required SrcValue1 and SrcValue2 have relatively large
30564 // alignment compared to the size and offset of the access, we may be able
30565 // to prove they do not alias. This check is conservative for now to catch
30566 // cases created by splitting vector types, it only works when the offsets are
30567 // multiples of the size of the data.
30568 int64_t SrcValOffset0 = MUC0.MMO->getOffset();
30569 int64_t SrcValOffset1 = MUC1.MMO->getOffset();
30570 Align OrigAlignment0 = MUC0.MMO->getBaseAlign();
30571 Align OrigAlignment1 = MUC1.MMO->getBaseAlign();
30572 LocationSize Size0 = MUC0.NumBytes;
30573 LocationSize Size1 = MUC1.NumBytes;
30574
30575 if (OrigAlignment0 == OrigAlignment1 && SrcValOffset0 != SrcValOffset1 &&
30576 Size0.hasValue() && Size1.hasValue() && !Size0.isScalable() &&
30577 !Size1.isScalable() && Size0 == Size1 &&
30578 OrigAlignment0 > Size0.getValue().getKnownMinValue() &&
30579 SrcValOffset0 % Size0.getValue().getKnownMinValue() == 0 &&
30580 SrcValOffset1 % Size1.getValue().getKnownMinValue() == 0) {
30581 int64_t OffAlign0 = SrcValOffset0 % OrigAlignment0.value();
30582 int64_t OffAlign1 = SrcValOffset1 % OrigAlignment1.value();
30583
30584 // There is no overlap between these relatively aligned accesses of
30585 // similar size. Return no alias.
30586 if ((OffAlign0 + static_cast<int64_t>(
30587 Size0.getValue().getKnownMinValue())) <= OffAlign1 ||
30588 (OffAlign1 + static_cast<int64_t>(
30589 Size1.getValue().getKnownMinValue())) <= OffAlign0)
30590 return false;
30591 }
30592
30593 bool UseAA = CombinerGlobalAA.getNumOccurrences() > 0
30594 ? CombinerGlobalAA
30595 : DAG.getSubtarget().useAA();
30596#ifndef NDEBUG
30597 if (CombinerAAOnlyFunc.getNumOccurrences() &&
30598 CombinerAAOnlyFunc != DAG.getMachineFunction().getName())
30599 UseAA = false;
30600#endif
30601
30602 if (UseAA && BatchAA && MUC0.MMO->getValue() && MUC1.MMO->getValue() &&
30603 Size0.hasValue() && Size1.hasValue() &&
30604 // Can't represent a scalable size + fixed offset in LocationSize
30605 (!Size0.isScalable() || SrcValOffset0 == 0) &&
30606 (!Size1.isScalable() || SrcValOffset1 == 0)) {
30607 // Use alias analysis information.
30608 int64_t MinOffset = std::min(a: SrcValOffset0, b: SrcValOffset1);
30609 int64_t Overlap0 =
30610 Size0.getValue().getKnownMinValue() + SrcValOffset0 - MinOffset;
30611 int64_t Overlap1 =
30612 Size1.getValue().getKnownMinValue() + SrcValOffset1 - MinOffset;
30613 LocationSize Loc0 =
30614 Size0.isScalable() ? Size0 : LocationSize::precise(Value: Overlap0);
30615 LocationSize Loc1 =
30616 Size1.isScalable() ? Size1 : LocationSize::precise(Value: Overlap1);
30617 if (BatchAA->isNoAlias(
30618 LocA: MemoryLocation(MUC0.MMO->getValue(), Loc0,
30619 UseTBAA ? MUC0.MMO->getAAInfo() : AAMDNodes()),
30620 LocB: MemoryLocation(MUC1.MMO->getValue(), Loc1,
30621 UseTBAA ? MUC1.MMO->getAAInfo() : AAMDNodes())))
30622 return false;
30623 }
30624
30625 // Otherwise we have to assume they alias.
30626 return true;
30627}
30628
30629/// Walk up chain skipping non-aliasing memory nodes,
30630/// looking for aliasing nodes and adding them to the Aliases vector.
30631void DAGCombiner::GatherAllAliases(SDNode *N, SDValue OriginalChain,
30632 SmallVectorImpl<SDValue> &Aliases) {
30633 SmallVector<SDValue, 8> Chains; // List of chains to visit.
30634 SmallPtrSet<SDNode *, 16> Visited; // Visited node set.
30635
30636 // Get alias information for node.
30637 // TODO: relax aliasing for unordered atomics (see D66309)
30638 const bool IsLoad = isa<LoadSDNode>(Val: N) && cast<LoadSDNode>(Val: N)->isSimple();
30639
30640 // Starting off.
30641 Chains.push_back(Elt: OriginalChain);
30642 unsigned Depth = 0;
30643
30644 // Attempt to improve chain by a single step
30645 auto ImproveChain = [&](SDValue &C) -> bool {
30646 switch (C.getOpcode()) {
30647 case ISD::EntryToken:
30648 // No need to mark EntryToken.
30649 C = SDValue();
30650 return true;
30651 case ISD::LOAD:
30652 case ISD::STORE: {
30653 // Get alias information for C.
30654 // TODO: Relax aliasing for unordered atomics (see D66309)
30655 bool IsOpLoad = isa<LoadSDNode>(Val: C.getNode()) &&
30656 cast<LSBaseSDNode>(Val: C.getNode())->isSimple();
30657 if ((IsLoad && IsOpLoad) || !mayAlias(Op0: N, Op1: C.getNode())) {
30658 // Look further up the chain.
30659 C = C.getOperand(i: 0);
30660 return true;
30661 }
30662 // Alias, so stop here.
30663 return false;
30664 }
30665
30666 case ISD::CopyFromReg:
30667 // Always forward past CopyFromReg.
30668 C = C.getOperand(i: 0);
30669 return true;
30670
30671 case ISD::LIFETIME_START:
30672 case ISD::LIFETIME_END: {
30673 // We can forward past any lifetime start/end that can be proven not to
30674 // alias the memory access.
30675 if (!mayAlias(Op0: N, Op1: C.getNode())) {
30676 // Look further up the chain.
30677 C = C.getOperand(i: 0);
30678 return true;
30679 }
30680 return false;
30681 }
30682 default:
30683 return false;
30684 }
30685 };
30686
30687 // Look at each chain and determine if it is an alias. If so, add it to the
30688 // aliases list. If not, then continue up the chain looking for the next
30689 // candidate.
30690 while (!Chains.empty()) {
30691 SDValue Chain = Chains.pop_back_val();
30692
30693 // Don't bother if we've seen Chain before.
30694 if (!Visited.insert(Ptr: Chain.getNode()).second)
30695 continue;
30696
30697 // For TokenFactor nodes, look at each operand and only continue up the
30698 // chain until we reach the depth limit.
30699 //
30700 // FIXME: The depth check could be made to return the last non-aliasing
30701 // chain we found before we hit a tokenfactor rather than the original
30702 // chain.
30703 if (Depth > TLI.getGatherAllAliasesMaxDepth()) {
30704 Aliases.clear();
30705 Aliases.push_back(Elt: OriginalChain);
30706 return;
30707 }
30708
30709 if (Chain.getOpcode() == ISD::TokenFactor) {
30710 // We have to check each of the operands of the token factor for "small"
30711 // token factors, so we queue them up. Adding the operands to the queue
30712 // (stack) in reverse order maintains the original order and increases the
30713 // likelihood that getNode will find a matching token factor (CSE.)
30714 if (Chain.getNumOperands() > 16) {
30715 Aliases.push_back(Elt: Chain);
30716 continue;
30717 }
30718 for (unsigned n = Chain.getNumOperands(); n;)
30719 Chains.push_back(Elt: Chain.getOperand(i: --n));
30720 ++Depth;
30721 continue;
30722 }
30723 // Everything else
30724 if (ImproveChain(Chain)) {
30725 // Updated Chain Found, Consider new chain if one exists.
30726 if (Chain.getNode())
30727 Chains.push_back(Elt: Chain);
30728 ++Depth;
30729 continue;
30730 }
30731 // No Improved Chain Possible, treat as Alias.
30732 Aliases.push_back(Elt: Chain);
30733 }
30734}
30735
30736/// Walk up chain skipping non-aliasing memory nodes, looking for a better chain
30737/// (aliasing node.)
30738SDValue DAGCombiner::FindBetterChain(SDNode *N, SDValue OldChain) {
30739 if (OptLevel == CodeGenOptLevel::None)
30740 return OldChain;
30741
30742 // Ops for replacing token factor.
30743 SmallVector<SDValue, 8> Aliases;
30744
30745 // Accumulate all the aliases to this node.
30746 GatherAllAliases(N, OriginalChain: OldChain, Aliases);
30747
30748 // If no operands then chain to entry token.
30749 if (Aliases.empty())
30750 return DAG.getEntryNode();
30751
30752 // If a single operand then chain to it. We don't need to revisit it.
30753 if (Aliases.size() == 1)
30754 return Aliases[0];
30755
30756 // Construct a custom tailored token factor.
30757 return DAG.getTokenFactor(DL: SDLoc(N), Vals&: Aliases);
30758}
30759
30760// This function tries to collect a bunch of potentially interesting
30761// nodes to improve the chains of, all at once. This might seem
30762// redundant, as this function gets called when visiting every store
30763// node, so why not let the work be done on each store as it's visited?
30764//
30765// I believe this is mainly important because mergeConsecutiveStores
30766// is unable to deal with merging stores of different sizes, so unless
30767// we improve the chains of all the potential candidates up-front
30768// before running mergeConsecutiveStores, it might only see some of
30769// the nodes that will eventually be candidates, and then not be able
30770// to go from a partially-merged state to the desired final
30771// fully-merged state.
30772
30773bool DAGCombiner::parallelizeChainedStores(StoreSDNode *St) {
30774 SmallVector<StoreSDNode *, 8> ChainedStores;
30775 StoreSDNode *STChain = St;
30776 // Intervals records which offsets from BaseIndex have been covered. In
30777 // the common case, every store writes to the immediately previous address
30778 // space and thus merged with the previous interval at insertion time.
30779
30780 using IMap = llvm::IntervalMap<int64_t, std::monostate, 8,
30781 IntervalMapHalfOpenInfo<int64_t>>;
30782 IMap::Allocator A;
30783 IMap Intervals(A);
30784
30785 // This holds the base pointer, index, and the offset in bytes from the base
30786 // pointer.
30787 const BaseIndexOffset BasePtr = BaseIndexOffset::match(N: St, DAG);
30788
30789 // We must have a base and an offset.
30790 if (!BasePtr.getBase().getNode())
30791 return false;
30792
30793 // Do not handle stores to undef base pointers.
30794 if (BasePtr.getBase().isUndef())
30795 return false;
30796
30797 // Do not handle stores to opaque types
30798 if (St->getMemoryVT().isZeroSized())
30799 return false;
30800
30801 // BaseIndexOffset assumes that offsets are fixed-size, which
30802 // is not valid for scalable vectors where the offsets are
30803 // scaled by `vscale`, so bail out early.
30804 if (St->getMemoryVT().isScalableVT())
30805 return false;
30806
30807 // Add ST's interval.
30808 Intervals.insert(a: 0, b: (St->getMemoryVT().getSizeInBits() + 7) / 8,
30809 y: std::monostate{});
30810
30811 while (StoreSDNode *Chain = dyn_cast<StoreSDNode>(Val: STChain->getChain())) {
30812 if (Chain->getMemoryVT().isScalableVector())
30813 return false;
30814
30815 // If the chain has more than one use, then we can't reorder the mem ops.
30816 if (!SDValue(Chain, 0)->hasOneUse())
30817 break;
30818 // TODO: Relax for unordered atomics (see D66309)
30819 if (!Chain->isSimple() || Chain->isIndexed())
30820 break;
30821
30822 // Find the base pointer and offset for this memory node.
30823 const BaseIndexOffset Ptr = BaseIndexOffset::match(N: Chain, DAG);
30824 // Check that the base pointer is the same as the original one.
30825 int64_t Offset;
30826 if (!BasePtr.equalBaseIndex(Other: Ptr, DAG, Off&: Offset))
30827 break;
30828 int64_t Length = (Chain->getMemoryVT().getSizeInBits() + 7) / 8;
30829 // Make sure we don't overlap with other intervals by checking the ones to
30830 // the left or right before inserting.
30831 auto I = Intervals.find(x: Offset);
30832 // If there's a next interval, we should end before it.
30833 if (I != Intervals.end() && I.start() < (Offset + Length))
30834 break;
30835 // If there's a previous interval, we should start after it.
30836 if (I != Intervals.begin() && (--I).stop() <= Offset)
30837 break;
30838 Intervals.insert(a: Offset, b: Offset + Length, y: std::monostate{});
30839
30840 ChainedStores.push_back(Elt: Chain);
30841 STChain = Chain;
30842 }
30843
30844 // If we didn't find a chained store, exit.
30845 if (ChainedStores.empty())
30846 return false;
30847
30848 // Improve all chained stores (St and ChainedStores members) starting from
30849 // where the store chain ended and return single TokenFactor.
30850 SDValue NewChain = STChain->getChain();
30851 SmallVector<SDValue, 8> TFOps;
30852 for (unsigned I = ChainedStores.size(); I;) {
30853 StoreSDNode *S = ChainedStores[--I];
30854 SDValue BetterChain = FindBetterChain(N: S, OldChain: NewChain);
30855 S = cast<StoreSDNode>(Val: DAG.UpdateNodeOperands(
30856 N: S, Op1: BetterChain, Op2: S->getOperand(Num: 1), Op3: S->getOperand(Num: 2), Op4: S->getOperand(Num: 3)));
30857 TFOps.push_back(Elt: SDValue(S, 0));
30858 ChainedStores[I] = S;
30859 }
30860
30861 // Improve St's chain. Use a new node to avoid creating a loop from CombineTo.
30862 SDValue BetterChain = FindBetterChain(N: St, OldChain: NewChain);
30863 SDValue NewST;
30864 if (St->isTruncatingStore())
30865 NewST = DAG.getTruncStore(Chain: BetterChain, dl: SDLoc(St), Val: St->getValue(),
30866 Ptr: St->getBasePtr(), SVT: St->getMemoryVT(),
30867 MMO: St->getMemOperand());
30868 else
30869 NewST = DAG.getStore(Chain: BetterChain, dl: SDLoc(St), Val: St->getValue(),
30870 Ptr: St->getBasePtr(), MMO: St->getMemOperand());
30871
30872 TFOps.push_back(Elt: NewST);
30873
30874 // If we improved every element of TFOps, then we've lost the dependence on
30875 // NewChain to successors of St and we need to add it back to TFOps. Do so at
30876 // the beginning to keep relative order consistent with FindBetterChains.
30877 auto hasImprovedChain = [&](SDValue ST) -> bool {
30878 return ST->getOperand(Num: 0) != NewChain;
30879 };
30880 bool AddNewChain = llvm::all_of(Range&: TFOps, P: hasImprovedChain);
30881 if (AddNewChain)
30882 TFOps.insert(I: TFOps.begin(), Elt: NewChain);
30883
30884 SDValue TF = DAG.getTokenFactor(DL: SDLoc(STChain), Vals&: TFOps);
30885 CombineTo(N: St, Res: TF);
30886
30887 // Add TF and its operands to the worklist.
30888 AddToWorklist(N: TF.getNode());
30889 for (const SDValue &Op : TF->ops())
30890 AddToWorklist(N: Op.getNode());
30891 AddToWorklist(N: STChain);
30892 return true;
30893}
30894
30895bool DAGCombiner::findBetterNeighborChains(StoreSDNode *St) {
30896 if (OptLevel == CodeGenOptLevel::None)
30897 return false;
30898
30899 const BaseIndexOffset BasePtr = BaseIndexOffset::match(N: St, DAG);
30900
30901 // We must have a base and an offset.
30902 if (!BasePtr.getBase().getNode())
30903 return false;
30904
30905 // Do not handle stores to undef base pointers.
30906 if (BasePtr.getBase().isUndef())
30907 return false;
30908
30909 // Directly improve a chain of disjoint stores starting at St.
30910 if (parallelizeChainedStores(St))
30911 return true;
30912
30913 // Improve St's Chain..
30914 SDValue BetterChain = FindBetterChain(N: St, OldChain: St->getChain());
30915 if (St->getChain() != BetterChain) {
30916 replaceStoreChain(ST: St, BetterChain);
30917 return true;
30918 }
30919 return false;
30920}
30921
30922/// This is the entry point for the file.
30923void SelectionDAG::Combine(CombineLevel Level, BatchAAResults *BatchAA,
30924 CodeGenOptLevel OptLevel) {
30925 /// This is the main entry point to this class.
30926 DAGCombiner(*this, BatchAA, OptLevel).Run(AtLevel: Level);
30927}
30928